mindspore 2.4.10__cp311-cp311-win_amd64.whl → 2.6.0__cp311-cp311-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/__init__.py +13 -6
- mindspore/_c_dataengine.cp311-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp311-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp311-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +3 -0
- mindspore/_checkparam.py +3 -38
- mindspore/_deprecated/__init__.py +17 -0
- mindspore/_deprecated/jit.py +198 -0
- mindspore/_extends/builtin_operations.py +1 -1
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
- mindspore/_extends/parse/__init__.py +6 -7
- mindspore/_extends/parse/compile_config.py +83 -0
- mindspore/_extends/parse/deprecated/__init__.py +0 -0
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +394 -0
- mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
- mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
- mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
- mindspore/_extends/parse/parser.py +47 -198
- mindspore/_extends/parse/resources.py +1 -5
- mindspore/_extends/parse/standard_method.py +229 -99
- mindspore/_extends/pijit/__init__.py +2 -2
- mindspore/_extends/pijit/pijit_func_white_list.py +17 -12
- mindspore/_extends/pijit/tensor_func_list.py +27 -0
- mindspore/_extends/utils.py +1 -1
- mindspore/amp.py +11 -5
- mindspore/atlprov.dll +0 -0
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/__init__.py +2 -2
- mindspore/boost/base.py +3 -7
- mindspore/boost/boost_cell_wrapper.py +138 -43
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/common/__init__.py +6 -3
- mindspore/common/_grad_function.py +56 -0
- mindspore/common/_pijit_context.py +14 -5
- mindspore/common/_register_for_tensor.py +1 -2
- mindspore/common/_stub_tensor.py +30 -14
- mindspore/common/_tensor_cpp_method.py +17 -0
- mindspore/common/_tensor_docs.py +4760 -0
- mindspore/common/api.py +480 -372
- mindspore/common/auto_dynamic_shape.py +41 -44
- mindspore/common/dtype.py +39 -36
- mindspore/common/dump.py +9 -6
- mindspore/common/file_system.py +9 -1
- mindspore/common/generator.py +5 -0
- mindspore/common/hook_handle.py +6 -2
- mindspore/common/initializer.py +13 -10
- mindspore/common/jit_begin_end.py +94 -0
- mindspore/common/jit_config.py +6 -1
- mindspore/common/jit_context.py +76 -0
- mindspore/common/jit_trace.py +378 -0
- mindspore/common/lazy_inline.py +9 -3
- mindspore/common/mindir_util.py +10 -2
- mindspore/common/mutable.py +5 -4
- mindspore/common/parameter.py +135 -52
- mindspore/common/seed.py +2 -2
- mindspore/common/sparse_tensor.py +23 -17
- mindspore/common/tensor.py +975 -1981
- mindspore/communication/__init__.py +7 -5
- mindspore/communication/_comm_helper.py +52 -2
- mindspore/communication/comm_func.py +240 -181
- mindspore/communication/management.py +95 -26
- mindspore/context.py +324 -573
- mindspore/dataset/__init__.py +65 -37
- mindspore/dataset/audio/__init__.py +2 -8
- mindspore/dataset/audio/transforms.py +3 -17
- mindspore/dataset/callback/ds_callback.py +2 -1
- mindspore/dataset/core/config.py +87 -6
- mindspore/dataset/engine/cache_admin.py +3 -3
- mindspore/dataset/engine/cache_client.py +6 -5
- mindspore/dataset/engine/datasets.py +292 -267
- mindspore/dataset/engine/datasets_audio.py +22 -8
- mindspore/dataset/engine/datasets_standard_format.py +46 -27
- mindspore/dataset/engine/datasets_text.py +78 -48
- mindspore/dataset/engine/datasets_user_defined.py +183 -117
- mindspore/dataset/engine/datasets_vision.py +120 -44
- mindspore/dataset/engine/iterators.py +283 -63
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
- mindspore/dataset/engine/obs/util.py +8 -0
- mindspore/dataset/engine/queue.py +40 -0
- mindspore/dataset/engine/samplers.py +289 -43
- mindspore/dataset/engine/serializer_deserializer.py +3 -2
- mindspore/dataset/engine/validators.py +53 -11
- mindspore/dataset/text/__init__.py +7 -6
- mindspore/dataset/text/transforms.py +6 -5
- mindspore/dataset/text/utils.py +3 -3
- mindspore/dataset/transforms/__init__.py +0 -9
- mindspore/dataset/transforms/py_transforms_util.py +17 -0
- mindspore/dataset/transforms/transforms.py +31 -14
- mindspore/dataset/utils/browse_dataset.py +1 -1
- mindspore/dataset/vision/__init__.py +2 -9
- mindspore/dataset/vision/transforms.py +202 -158
- mindspore/dataset/vision/utils.py +7 -5
- mindspore/dataset/vision/validators.py +1 -2
- mindspore/device_context/__init__.py +21 -0
- mindspore/device_context/ascend/__init__.py +25 -0
- mindspore/device_context/ascend/device.py +72 -0
- mindspore/device_context/ascend/op_debug.py +153 -0
- mindspore/device_context/ascend/op_precision.py +193 -0
- mindspore/device_context/ascend/op_tuning.py +123 -0
- mindspore/{ops_generate/gen_constants.py → device_context/cpu/__init__.py} +6 -17
- mindspore/device_context/cpu/device.py +62 -0
- mindspore/device_context/cpu/op_tuning.py +43 -0
- mindspore/device_context/gpu/__init__.py +21 -0
- mindspore/device_context/gpu/device.py +70 -0
- mindspore/device_context/gpu/op_precision.py +67 -0
- mindspore/device_context/gpu/op_tuning.py +175 -0
- mindspore/device_manager.py +170 -0
- mindspore/dnnl.dll +0 -0
- mindspore/dpcmi.dll +0 -0
- mindspore/experimental/es/embedding_service.py +35 -27
- mindspore/experimental/llm_boost/__init__.py +1 -0
- mindspore/experimental/llm_boost/ascend_native/__init__.py +22 -0
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +209 -0
- mindspore/experimental/llm_boost/ascend_native/llm_boost.py +52 -0
- mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
- mindspore/experimental/llm_boost/atb/llama_boost.py +6 -1
- mindspore/experimental/llm_boost/register.py +1 -0
- mindspore/experimental/map_parameter.py +4 -4
- mindspore/experimental/optim/adadelta.py +6 -6
- mindspore/experimental/optim/adagrad.py +4 -4
- mindspore/experimental/optim/adam.py +7 -0
- mindspore/experimental/optim/adamax.py +4 -4
- mindspore/experimental/optim/adamw.py +4 -0
- mindspore/experimental/optim/asgd.py +1 -1
- mindspore/experimental/optim/lr_scheduler.py +73 -46
- mindspore/experimental/optim/radam.py +34 -31
- mindspore/experimental/optim/rprop.py +1 -1
- mindspore/experimental/optim/sgd.py +1 -1
- mindspore/hal/contiguous_tensors_handle.py +6 -10
- mindspore/hal/device.py +55 -53
- mindspore/hal/event.py +52 -52
- mindspore/hal/memory.py +179 -120
- mindspore/hal/stream.py +150 -109
- mindspore/include/api/context.h +0 -1
- mindspore/include/dataset/constants.h +7 -4
- mindspore/include/dataset/execute.h +2 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +50 -0
- mindspore/mindrecord/__init__.py +21 -8
- mindspore/mindrecord/config.py +17 -316
- mindspore/mindrecord/filereader.py +1 -9
- mindspore/mindrecord/filewriter.py +5 -15
- mindspore/mindrecord/mindpage.py +1 -9
- mindspore/mindspore_backend_common.dll +0 -0
- mindspore/mindspore_backend_manager.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_dump.dll +0 -0
- mindspore/mindspore_frontend.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_memory_pool.dll +0 -0
- mindspore/mindspore_ms_backend.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
- mindspore/mindspore_ops_kernel_common.dll +0 -0
- mindspore/mindspore_profiler.dll +0 -0
- mindspore/mindspore_pyboost.dll +0 -0
- mindspore/mindspore_pynative.dll +0 -0
- mindspore/mindspore_res_manager.dll +0 -0
- mindspore/mindspore_runtime_pipeline.dll +0 -0
- mindspore/mint/__init__.py +798 -761
- mindspore/mint/distributed/__init__.py +70 -4
- mindspore/mint/distributed/distributed.py +2679 -44
- mindspore/mint/linalg/__init__.py +8 -0
- mindspore/mint/nn/__init__.py +743 -22
- mindspore/mint/nn/functional.py +716 -23
- mindspore/mint/nn/layer/__init__.py +21 -4
- mindspore/mint/nn/layer/_functions.py +334 -0
- mindspore/mint/nn/layer/activation.py +276 -1
- mindspore/mint/nn/layer/basic.py +123 -0
- mindspore/mint/nn/layer/conv.py +933 -0
- mindspore/mint/nn/layer/normalization.py +223 -28
- mindspore/mint/nn/layer/padding.py +797 -0
- mindspore/mint/nn/layer/pooling.py +235 -0
- mindspore/mint/optim/__init__.py +3 -1
- mindspore/mint/optim/adam.py +223 -0
- mindspore/mint/optim/adamw.py +26 -19
- mindspore/mint/optim/sgd.py +171 -0
- mindspore/mint/special/__init__.py +2 -1
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/multiprocessing/__init__.py +5 -0
- mindspore/nn/__init__.py +4 -1
- mindspore/nn/cell.py +1373 -192
- mindspore/nn/dynamic_lr.py +2 -1
- mindspore/nn/layer/activation.py +29 -27
- mindspore/nn/layer/basic.py +51 -35
- mindspore/nn/layer/channel_shuffle.py +3 -3
- mindspore/nn/layer/container.py +1 -1
- mindspore/nn/layer/conv.py +53 -42
- mindspore/nn/layer/embedding.py +12 -11
- mindspore/nn/layer/normalization.py +56 -49
- mindspore/nn/layer/padding.py +4 -3
- mindspore/nn/layer/pooling.py +120 -42
- mindspore/nn/layer/rnn_cells.py +1 -1
- mindspore/nn/layer/rnns.py +2 -1
- mindspore/nn/layer/timedistributed.py +5 -5
- mindspore/nn/layer/transformer.py +59 -36
- mindspore/nn/learning_rate_schedule.py +8 -4
- mindspore/nn/loss/loss.py +58 -55
- mindspore/nn/optim/ada_grad.py +7 -5
- mindspore/nn/optim/adadelta.py +11 -9
- mindspore/nn/optim/adafactor.py +1 -1
- mindspore/nn/optim/adam.py +19 -15
- mindspore/nn/optim/adamax.py +8 -7
- mindspore/nn/optim/adasum.py +5 -5
- mindspore/nn/optim/asgd.py +3 -1
- mindspore/nn/optim/ftrl.py +11 -9
- mindspore/nn/optim/lamb.py +1 -1
- mindspore/nn/optim/lars.py +1 -4
- mindspore/nn/optim/lazyadam.py +12 -10
- mindspore/nn/optim/momentum.py +7 -6
- mindspore/nn/optim/optimizer.py +3 -3
- mindspore/nn/optim/proximal_ada_grad.py +12 -10
- mindspore/nn/optim/rmsprop.py +13 -12
- mindspore/nn/optim/rprop.py +11 -9
- mindspore/nn/optim/sgd.py +9 -6
- mindspore/nn/optim/tft_wrapper.py +5 -2
- mindspore/nn/optim/thor.py +2 -1
- mindspore/nn/probability/bijector/bijector.py +17 -11
- mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
- mindspore/nn/probability/bijector/invert.py +2 -2
- mindspore/nn/probability/bijector/scalar_affine.py +3 -3
- mindspore/nn/probability/bijector/softplus.py +3 -2
- mindspore/nn/probability/distribution/beta.py +3 -3
- mindspore/nn/probability/distribution/categorical.py +1 -1
- mindspore/nn/probability/distribution/cauchy.py +4 -2
- mindspore/nn/probability/distribution/exponential.py +6 -7
- mindspore/nn/probability/distribution/gamma.py +2 -2
- mindspore/nn/probability/distribution/gumbel.py +2 -2
- mindspore/nn/probability/distribution/half_normal.py +5 -3
- mindspore/nn/probability/distribution/logistic.py +5 -3
- mindspore/nn/probability/distribution/poisson.py +1 -1
- mindspore/nn/probability/distribution/uniform.py +5 -3
- mindspore/nn/reinforcement/_tensors_queue.py +1 -1
- mindspore/nn/reinforcement/tensor_array.py +1 -1
- mindspore/nn/utils/init.py +13 -11
- mindspore/nn/wrap/__init__.py +6 -6
- mindspore/nn/wrap/cell_wrapper.py +181 -122
- mindspore/nn/wrap/grad_reducer.py +45 -36
- mindspore/nn/wrap/loss_scale.py +6 -7
- mindspore/numpy/array_creations.py +63 -65
- mindspore/numpy/array_ops.py +149 -144
- mindspore/numpy/logic_ops.py +41 -42
- mindspore/numpy/math_ops.py +361 -359
- mindspore/numpy/utils.py +17 -18
- mindspore/numpy/utils_const.py +5 -6
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +5 -3
- mindspore/ops/_grad_experimental/grad_comm_ops.py +112 -16
- mindspore/ops/_grad_experimental/grad_debug_ops.py +14 -2
- mindspore/ops/_grad_experimental/grad_inner_ops.py +9 -0
- mindspore/ops/_grad_experimental/grad_math_ops.py +2 -1
- mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
- mindspore/ops/_op_impl/cpu/__init__.py +1 -0
- mindspore/ops/_op_impl/cpu/raise_op.py +28 -0
- mindspore/ops/_register_for_op.py +0 -11
- mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
- mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -65
- mindspore/ops/_vmap/vmap_array_ops.py +52 -25
- mindspore/ops/_vmap/vmap_base.py +0 -2
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +21 -14
- mindspore/ops/_vmap/vmap_math_ops.py +15 -16
- mindspore/ops/_vmap/vmap_nn_ops.py +29 -42
- mindspore/ops/auto_generate/__init__.py +4 -3
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +258 -46
- mindspore/ops/auto_generate/gen_extend_func.py +757 -185
- mindspore/ops/auto_generate/gen_ops_def.py +4197 -2243
- mindspore/ops/auto_generate/gen_ops_prim.py +16976 -6055
- mindspore/ops/auto_generate/pyboost_inner_prim.py +221 -87
- mindspore/ops/composite/__init__.py +2 -1
- mindspore/ops/composite/base.py +20 -25
- mindspore/ops/composite/math_ops.py +6 -16
- mindspore/ops/composite/multitype_ops/__init__.py +5 -2
- mindspore/ops/composite/multitype_ops/_compile_utils.py +228 -30
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
- mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
- mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
- mindspore/ops/composite/multitype_ops/pow_impl.py +2 -30
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
- mindspore/ops/function/__init__.py +40 -2
- mindspore/ops/function/_add_attr_func.py +58 -0
- mindspore/ops/function/array_func.py +2089 -2403
- mindspore/ops/function/clip_func.py +80 -23
- mindspore/ops/function/debug_func.py +57 -57
- mindspore/ops/function/grad/__init__.py +1 -0
- mindspore/ops/function/grad/grad_func.py +104 -71
- mindspore/ops/function/image_func.py +2 -2
- mindspore/ops/function/linalg_func.py +47 -78
- mindspore/ops/function/math_func.py +4351 -3813
- mindspore/ops/function/nn_func.py +1712 -637
- mindspore/ops/function/other_func.py +159 -1
- mindspore/ops/function/parameter_func.py +18 -84
- mindspore/ops/function/random_func.py +452 -387
- mindspore/ops/function/reshard_func.py +4 -70
- mindspore/ops/function/sparse_func.py +3 -3
- mindspore/ops/function/sparse_unary_func.py +6 -6
- mindspore/ops/function/spectral_func.py +25 -58
- mindspore/ops/function/vmap_func.py +26 -18
- mindspore/ops/functional.py +23 -7
- mindspore/ops/functional_overload.py +1548 -0
- mindspore/ops/op_info_register.py +32 -244
- mindspore/ops/operations/__init__.py +23 -15
- mindspore/ops/operations/_custom_ops_utils.py +235 -0
- mindspore/ops/operations/_embedding_cache_ops.py +4 -4
- mindspore/ops/operations/_grad_ops.py +2 -43
- mindspore/ops/operations/_infer_ops.py +2 -1
- mindspore/ops/operations/_inner_ops.py +43 -84
- mindspore/ops/operations/_ms_kernel.py +4 -10
- mindspore/ops/operations/_rl_inner_ops.py +1 -1
- mindspore/ops/operations/_scalar_ops.py +3 -2
- mindspore/ops/operations/_sequence_ops.py +1 -1
- mindspore/ops/operations/_tensor_array.py +1 -1
- mindspore/ops/operations/array_ops.py +81 -324
- mindspore/ops/operations/comm_ops.py +154 -108
- mindspore/ops/operations/custom_ops.py +298 -87
- mindspore/ops/operations/debug_ops.py +157 -59
- mindspore/ops/operations/inner_ops.py +7 -5
- mindspore/ops/operations/linalg_ops.py +1 -57
- mindspore/ops/operations/manually_defined/_inner.py +1 -1
- mindspore/ops/operations/manually_defined/ops_def.py +928 -180
- mindspore/ops/operations/math_ops.py +32 -234
- mindspore/ops/operations/nn_ops.py +212 -531
- mindspore/ops/operations/other_ops.py +62 -9
- mindspore/ops/operations/random_ops.py +13 -7
- mindspore/ops/operations/reshard_ops.py +1 -1
- mindspore/ops/operations/sparse_ops.py +2 -2
- mindspore/ops/primitive.py +66 -53
- mindspore/ops/tensor_method.py +1895 -0
- mindspore/ops_generate/__init__.py +0 -5
- mindspore/ops_generate/aclnn/__init__.py +0 -0
- mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +135 -0
- mindspore/ops_generate/aclnn/gen_aclnn_implement.py +257 -0
- mindspore/ops_generate/api/__init__.py +0 -0
- mindspore/ops_generate/api/add_tensor_docs_generator.py +56 -0
- mindspore/ops_generate/api/cpp_create_prim_instance_helper_generator.py +105 -0
- mindspore/ops_generate/api/functional_map_cpp_generator.py +504 -0
- mindspore/ops_generate/api/functional_overload_py_generator.py +112 -0
- mindspore/ops_generate/api/functions_cc_generator.py +237 -0
- mindspore/ops_generate/api/gen_api.py +103 -0
- mindspore/ops_generate/api/op_api_proto.py +235 -0
- mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +461 -0
- mindspore/ops_generate/common/__init__.py +0 -0
- mindspore/ops_generate/common/base_generator.py +11 -0
- mindspore/ops_generate/common/gen_constants.py +91 -0
- mindspore/ops_generate/common/gen_utils.py +348 -0
- mindspore/ops_generate/common/op_proto.py +473 -0
- mindspore/ops_generate/common/template.py +523 -0
- mindspore/ops_generate/gen_ops.py +22 -1069
- mindspore/ops_generate/op_def/__init__.py +0 -0
- mindspore/ops_generate/op_def/gen_op_def.py +90 -0
- mindspore/ops_generate/op_def/lite_ops_cpp_generator.py +191 -0
- mindspore/ops_generate/op_def/ops_def_cc_generator.py +296 -0
- mindspore/ops_generate/op_def/ops_def_h_generator.py +74 -0
- mindspore/ops_generate/op_def/ops_name_h_generator.py +83 -0
- mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
- mindspore/ops_generate/op_def_py/__init__.py +0 -0
- mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
- mindspore/ops_generate/op_def_py/op_def_py_generator.py +132 -0
- mindspore/ops_generate/op_def_py/op_prim_py_generator.py +489 -0
- mindspore/ops_generate/pyboost/__init__.py +0 -0
- mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +139 -0
- mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +93 -0
- mindspore/ops_generate/pyboost/gen_pyboost_func.py +175 -0
- mindspore/ops_generate/pyboost/op_template_parser.py +517 -0
- mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +407 -0
- mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +100 -0
- mindspore/ops_generate/pyboost/pyboost_functions_py_generator.py +148 -0
- mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +155 -0
- mindspore/ops_generate/pyboost/pyboost_inner_prim_generator.py +132 -0
- mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +272 -0
- mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +938 -0
- mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +357 -0
- mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +179 -36
- mindspore/ops_generate/resources/__init__.py +0 -0
- mindspore/ops_generate/resources/resource_list.py +30 -0
- mindspore/ops_generate/resources/resource_loader.py +36 -0
- mindspore/ops_generate/resources/resource_manager.py +64 -0
- mindspore/ops_generate/resources/yaml_loader.py +88 -0
- mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
- mindspore/parallel/__init__.py +7 -3
- mindspore/parallel/_auto_parallel_context.py +159 -40
- mindspore/parallel/_cell_wrapper.py +132 -15
- mindspore/parallel/_parallel_serialization.py +107 -5
- mindspore/parallel/_ps_context.py +1 -1
- mindspore/parallel/_recovery_context.py +7 -2
- mindspore/parallel/_tensor.py +142 -18
- mindspore/parallel/_utils.py +199 -23
- mindspore/parallel/algo_parameter_config.py +4 -4
- mindspore/parallel/auto_parallel.py +732 -0
- mindspore/parallel/checkpoint_convert.py +159 -0
- mindspore/parallel/checkpoint_transform.py +700 -35
- mindspore/parallel/cluster/process_entity/_api.py +276 -50
- mindspore/parallel/cluster/process_entity/_utils.py +41 -6
- mindspore/parallel/cluster/run.py +21 -4
- mindspore/parallel/function/__init__.py +24 -0
- mindspore/parallel/function/reshard_func.py +258 -0
- mindspore/parallel/nn/__init__.py +25 -0
- mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
- mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
- mindspore/parallel/parameter_broadcast.py +25 -14
- mindspore/parallel/shard.py +137 -59
- mindspore/parallel/transform_safetensors.py +364 -305
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/__init__.py +22 -5
- mindspore/profiler/analysis/__init__.py +0 -0
- mindspore/profiler/analysis/parser/__init__.py +0 -0
- mindspore/profiler/analysis/parser/ascend_cann_parser.py +170 -0
- mindspore/profiler/analysis/parser/base_parser.py +158 -0
- mindspore/profiler/analysis/parser/framework_cann_relation_parser.py +45 -0
- mindspore/profiler/analysis/parser/ms_framework_parser.py +142 -0
- mindspore/profiler/analysis/parser/ms_minddata_parser.py +145 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +264 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +40 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +109 -0
- mindspore/profiler/analysis/parser/timeline_creator/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_creator/base_timeline_creator.py +44 -0
- mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +90 -0
- mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +76 -0
- mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +103 -0
- mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +134 -0
- mindspore/profiler/analysis/parser/timeline_event/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_event/base_event.py +233 -0
- mindspore/profiler/analysis/parser/timeline_event/cpu_op_event.py +47 -0
- mindspore/profiler/analysis/parser/timeline_event/flow_event.py +36 -0
- mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +415 -0
- mindspore/profiler/analysis/parser/timeline_event/msprof_event.py +73 -0
- mindspore/profiler/analysis/parser/timeline_event/scope_layer_event.py +53 -0
- mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +146 -0
- mindspore/profiler/analysis/task_manager.py +131 -0
- mindspore/profiler/analysis/time_converter.py +84 -0
- mindspore/profiler/analysis/viewer/__init__.py +0 -0
- mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +372 -0
- mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +87 -0
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +250 -0
- mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +320 -0
- mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +327 -0
- mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +376 -0
- mindspore/profiler/analysis/viewer/ascend_timeline_viewer.py +58 -0
- mindspore/profiler/analysis/viewer/base_viewer.py +26 -0
- mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +96 -0
- mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +581 -0
- mindspore/profiler/analysis/work_flow.py +73 -0
- mindspore/profiler/common/ascend_msprof_exporter.py +139 -0
- mindspore/profiler/common/command_executor.py +90 -0
- mindspore/profiler/common/constant.py +186 -3
- mindspore/profiler/common/file_manager.py +208 -0
- mindspore/profiler/common/log.py +130 -0
- mindspore/profiler/common/msprof_cmd_tool.py +221 -0
- mindspore/profiler/common/path_manager.py +395 -0
- mindspore/profiler/common/process_bar.py +168 -0
- mindspore/profiler/common/process_pool.py +9 -3
- mindspore/profiler/common/profiler_context.py +500 -0
- mindspore/profiler/common/profiler_info.py +304 -0
- mindspore/profiler/common/profiler_meta_data.py +74 -0
- mindspore/profiler/common/profiler_output_path.py +284 -0
- mindspore/profiler/common/profiler_parameters.py +251 -0
- mindspore/profiler/common/profiler_path_manager.py +179 -0
- mindspore/profiler/common/record_function.py +76 -0
- mindspore/profiler/common/tlv_decoder.py +76 -0
- mindspore/profiler/common/util.py +75 -2
- mindspore/profiler/dynamic_profiler.py +341 -75
- mindspore/profiler/envprofiler.py +163 -0
- mindspore/profiler/experimental_config.py +197 -0
- mindspore/profiler/mstx.py +242 -0
- mindspore/profiler/platform/__init__.py +21 -0
- mindspore/profiler/platform/base_profiler.py +40 -0
- mindspore/profiler/platform/cpu_profiler.py +124 -0
- mindspore/profiler/platform/gpu_profiler.py +74 -0
- mindspore/profiler/platform/npu_profiler.py +335 -0
- mindspore/profiler/profiler.py +1073 -90
- mindspore/profiler/profiler_action_controller.py +187 -0
- mindspore/profiler/profiler_interface.py +118 -0
- mindspore/profiler/schedule.py +243 -0
- mindspore/rewrite/api/node.py +15 -13
- mindspore/rewrite/api/symbol_tree.py +2 -3
- mindspore/run_check/_check_version.py +27 -20
- mindspore/run_check/run_check.py +1 -1
- mindspore/runtime/__init__.py +37 -0
- mindspore/runtime/device.py +27 -0
- mindspore/runtime/event.py +209 -0
- mindspore/runtime/executor.py +177 -0
- mindspore/runtime/memory.py +416 -0
- mindspore/runtime/stream.py +460 -0
- mindspore/runtime/thread_bind_core.py +401 -0
- mindspore/safeguard/rewrite_obfuscation.py +12 -9
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +8 -8
- mindspore/train/_utils.py +96 -27
- mindspore/train/amp.py +9 -5
- mindspore/train/callback/__init__.py +2 -2
- mindspore/train/callback/_callback.py +2 -16
- mindspore/train/callback/_checkpoint.py +53 -55
- mindspore/train/callback/_cluster_monitor.py +14 -18
- mindspore/train/callback/_early_stop.py +1 -1
- mindspore/train/callback/_flops_collector.py +103 -68
- mindspore/train/callback/_history.py +8 -5
- mindspore/train/callback/_lambda_callback.py +2 -2
- mindspore/train/callback/_landscape.py +0 -3
- mindspore/train/callback/_loss_monitor.py +2 -1
- mindspore/train/callback/_on_request_exit.py +6 -5
- mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
- mindspore/train/callback/_summary_collector.py +52 -19
- mindspore/train/callback/_time_monitor.py +2 -1
- mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +228 -108
- mindspore/train/data_sink.py +25 -2
- mindspore/train/dataset_helper.py +15 -16
- mindspore/train/loss_scale_manager.py +8 -7
- mindspore/train/metrics/accuracy.py +3 -3
- mindspore/train/metrics/confusion_matrix.py +9 -9
- mindspore/train/metrics/error.py +3 -3
- mindspore/train/metrics/hausdorff_distance.py +4 -4
- mindspore/train/metrics/mean_surface_distance.py +3 -3
- mindspore/train/metrics/metric.py +0 -12
- mindspore/train/metrics/occlusion_sensitivity.py +4 -2
- mindspore/train/metrics/precision.py +11 -10
- mindspore/train/metrics/recall.py +9 -9
- mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
- mindspore/train/mind_ir_pb2.py +174 -46
- mindspore/train/model.py +269 -136
- mindspore/train/serialization.py +622 -978
- mindspore/train/summary/_summary_adapter.py +2 -2
- mindspore/train/summary/summary_record.py +2 -3
- mindspore/train/train_thor/model_thor.py +1 -1
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +6 -3
- mindspore/utils/dryrun.py +140 -0
- mindspore/utils/hooks.py +81 -0
- mindspore/utils/runtime_execution_order_check.py +552 -0
- mindspore/utils/utils.py +138 -4
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/METADATA +3 -3
- {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/RECORD +587 -418
- {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/entry_points.txt +1 -1
- mindspore/_install_custom.py +0 -43
- mindspore/common/_register_for_adapter.py +0 -74
- mindspore/common/_tensor_overload.py +0 -139
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
- mindspore/ops/auto_generate/gen_arg_handler.py +0 -197
- mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
- mindspore/ops_generate/gen_aclnn_implement.py +0 -263
- mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
- mindspore/ops_generate/gen_pyboost_func.py +0 -1052
- mindspore/ops_generate/gen_utils.py +0 -209
- mindspore/ops_generate/op_proto.py +0 -145
- mindspore/ops_generate/template.py +0 -261
- mindspore/profiler/envprofiling.py +0 -254
- mindspore/profiler/profiling.py +0 -1926
- {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/WHEEL +0 -0
- {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/top_level.txt +0 -0
|
@@ -7,6 +7,7 @@
|
|
|
7
7
|
#
|
|
8
8
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
9
|
#
|
|
10
|
+
#
|
|
10
11
|
# Unless required by applicable law or agreed to in writing, software
|
|
11
12
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
13
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
@@ -85,8 +86,8 @@ class WithLossCell(Cell):
|
|
|
85
86
|
loss_fn (Cell): The loss function used to compute loss.
|
|
86
87
|
|
|
87
88
|
Inputs:
|
|
88
|
-
- **data** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
|
|
89
|
-
- **label** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
|
|
89
|
+
- **data** (Tensor) - Tensor of shape :math:`(N, \ldots)`. The dtype of `data` must be float16 or float32.
|
|
90
|
+
- **label** (Tensor) - Tensor of shape :math:`(N, \ldots)`. The dtype of `label` must be float16 or float32.
|
|
90
91
|
|
|
91
92
|
Outputs:
|
|
92
93
|
Tensor, a tensor means the loss value, the shape of which is usually :math:`()`.
|
|
@@ -328,9 +329,11 @@ class TrainOneStepCell(Cell):
|
|
|
328
329
|
Args:
|
|
329
330
|
network (Cell): The training network. The network only supports single output.
|
|
330
331
|
optimizer (Union[Cell]): Optimizer for updating the network parameters.
|
|
331
|
-
sens (numbers.Number): The scaling number to be filled as the input of backpropagation.
|
|
332
|
+
sens (numbers.Number, optional): The scaling number to be filled as the input of backpropagation.
|
|
333
|
+
Default value is
|
|
332
334
|
``None`` , which is ``1.0`` .
|
|
333
|
-
return_grad (bool): Whether to return gradient. If ``True``,
|
|
335
|
+
return_grad (bool, optional): Whether to return gradient. If ``True``,
|
|
336
|
+
it will return the gradient in the form of a dict
|
|
334
337
|
while returning loss. The key of the dict is the parameter name corresponding to the gradient, and value
|
|
335
338
|
is the gradient value. Default value is ``False`` .
|
|
336
339
|
|
|
@@ -529,6 +532,20 @@ class _VirtualDatasetCell(Cell):
|
|
|
529
532
|
return self._backbone(*output)
|
|
530
533
|
|
|
531
534
|
|
|
535
|
+
def _pipeline_clear_grad(accu_grad, grad):
|
|
536
|
+
accu_grad = F.depend(accu_grad, grad)
|
|
537
|
+
zeros = F.zeros_like(accu_grad)
|
|
538
|
+
return F.assign(accu_grad, zeros)
|
|
539
|
+
|
|
540
|
+
def grad_scale(scale, grad):
|
|
541
|
+
"""grad_scale"""
|
|
542
|
+
new_grad = scale * grad
|
|
543
|
+
grad = ops.depend(grad, new_grad)
|
|
544
|
+
zeros = F.zeros_like(grad)
|
|
545
|
+
new_grad = ops.depend(new_grad, F.assign(grad, zeros))
|
|
546
|
+
return new_grad
|
|
547
|
+
|
|
548
|
+
|
|
532
549
|
@_primexpr
|
|
533
550
|
def _check_shape_value_on_axis_divided_by_target_value(input_shape, micro_size):
|
|
534
551
|
if F.isconstant(input_shape[0]) is False:
|
|
@@ -571,122 +588,13 @@ class _MicroBatch(Cell):
|
|
|
571
588
|
return micro_inputs
|
|
572
589
|
|
|
573
590
|
|
|
574
|
-
class MicroBatchInterleaved(Cell):
|
|
575
|
-
"""
|
|
576
|
-
This function splits the input at the 0th into interleave_num pieces and then performs
|
|
577
|
-
the computation of the wrapped cell. Application scenario: When there is model parallelism in semi-automatic mode
|
|
578
|
-
and network, if the first slice data is calculating forward, the second slice data will execute the
|
|
579
|
-
communication operators at the same time, to achieve the performance acceleration of communication and computing
|
|
580
|
-
concurrency.
|
|
581
|
-
|
|
582
|
-
Note:
|
|
583
|
-
The output of the input network must be a single tensor.
|
|
584
|
-
|
|
585
|
-
Args:
|
|
586
|
-
network (Cell): The target network to wrap.
|
|
587
|
-
interleave_num (int, optional): split num of batch size. Default: ``2`` .
|
|
588
|
-
|
|
589
|
-
Inputs:
|
|
590
|
-
tuple[Tensor]. It's the same with the input of the `network` .
|
|
591
|
-
|
|
592
|
-
Outputs:
|
|
593
|
-
Tensor. The output of the input `network` .
|
|
594
|
-
|
|
595
|
-
Supported Platforms:
|
|
596
|
-
``Ascend`` ``GPU``
|
|
597
|
-
|
|
598
|
-
Examples:
|
|
599
|
-
>>> import mindspore.nn as nn
|
|
600
|
-
>>> # Define the network structure of LeNet5. Refer to
|
|
601
|
-
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
|
|
602
|
-
>>> net = LeNet5()
|
|
603
|
-
>>> net = nn.MicroBatchInterleaved(net, 2)
|
|
604
|
-
"""
|
|
605
|
-
def __init__(self, network, interleave_num=2):
|
|
606
|
-
super(MicroBatchInterleaved, self).__init__(auto_prefix=False)
|
|
607
|
-
if not isinstance(interleave_num, int):
|
|
608
|
-
raise TypeError("For 'MicroBatchInterleaved', the argument 'interleave_num' must be integer, "
|
|
609
|
-
"but got the type : {}.".format(type(interleave_num)))
|
|
610
|
-
if interleave_num <= 0:
|
|
611
|
-
raise ValueError("For 'MicroBatchInterleaved', the argument 'interleave_num' must be large than 0, "
|
|
612
|
-
"but got {}.".format(interleave_num))
|
|
613
|
-
self.network = network
|
|
614
|
-
self.interleave_num = interleave_num
|
|
615
|
-
self.interleave_inputs = nn.CellList()
|
|
616
|
-
self.add = P.Add().add_prim_attr("micro_interleaved_add_flag", True)
|
|
617
|
-
for _ in range(interleave_num):
|
|
618
|
-
interleave_data = _MicroBatch(interleave_num)
|
|
619
|
-
interleave_data.strided_slice.add_prim_attr("strided_slice_flag", True)
|
|
620
|
-
interleave_data.strided_slice.add_prim_attr("interleave_num", interleave_num)
|
|
621
|
-
self.interleave_inputs.append(interleave_data)
|
|
622
|
-
self._get_attr_from_cell(network)
|
|
623
|
-
|
|
624
|
-
def construct(self, *inputs):
|
|
625
|
-
output = 0.0
|
|
626
|
-
for i in range(self.interleave_num):
|
|
627
|
-
interleave_input = self.interleave_inputs[i](i, *inputs)
|
|
628
|
-
output = self.add(output, self.network(*interleave_input))
|
|
629
|
-
return output
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
class PipelineCell(Cell):
|
|
633
|
-
"""
|
|
634
|
-
Slice MiniBatch into finer-grained MicroBatch for use in pipeline-parallel training.
|
|
635
|
-
|
|
636
|
-
Note:
|
|
637
|
-
micro_size must be greater or equal to pipeline stages.
|
|
638
|
-
|
|
639
|
-
Args:
|
|
640
|
-
network (Cell): The target network to wrap.
|
|
641
|
-
micro_size (int): MicroBatch size.
|
|
642
|
-
|
|
643
|
-
Supported Platforms:
|
|
644
|
-
``Ascend`` ``GPU``
|
|
645
|
-
|
|
646
|
-
Examples:
|
|
647
|
-
>>> import mindspore.nn as nn
|
|
648
|
-
>>> # Define the network structure of LeNet5. Refer to
|
|
649
|
-
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
|
|
650
|
-
>>> net = LeNet5()
|
|
651
|
-
>>> net = nn.PipelineCell(net, 4)
|
|
652
|
-
"""
|
|
653
|
-
def __init__(self, network, micro_size):
|
|
654
|
-
super(PipelineCell, self).__init__(auto_prefix=False)
|
|
655
|
-
self.network = network
|
|
656
|
-
self.micro_inputs = nn.CellList()
|
|
657
|
-
self.micro_size = micro_size
|
|
658
|
-
self.add_list = []
|
|
659
|
-
if not isinstance(network, Cell):
|
|
660
|
-
raise TypeError("For 'PipelineCell', the argument 'network' must cell type, "
|
|
661
|
-
"but got the type : {}.".format(type(network)))
|
|
662
|
-
if not isinstance(micro_size, int):
|
|
663
|
-
raise TypeError("For 'PipelineCell', the argument 'micro_size' must be integer, "
|
|
664
|
-
"but got the type : {}.".format(type(micro_size)))
|
|
665
|
-
if micro_size <= 0:
|
|
666
|
-
raise ValueError("For 'PipelineCell', the argument 'micro_size' must be large than 0, "
|
|
667
|
-
"but got {}.".format(micro_size))
|
|
668
|
-
for i in range(micro_size):
|
|
669
|
-
micro_input = _MicroBatch(micro_size)
|
|
670
|
-
self.micro_inputs.append(micro_input)
|
|
671
|
-
self.add = P.Add().add_prim_attr("pipeline_end", i)
|
|
672
|
-
self.add_list.append(self.add)
|
|
673
|
-
self._get_attr_from_cell(network)
|
|
674
|
-
|
|
675
|
-
def construct(self, *inputs):
|
|
676
|
-
ret = None
|
|
677
|
-
for i in range(self.micro_size):
|
|
678
|
-
micro_input = self.micro_inputs[i](i, *inputs)
|
|
679
|
-
output = self.network(*micro_input)
|
|
680
|
-
if ret is not None:
|
|
681
|
-
ret = self.add_list[i](ret, output)
|
|
682
|
-
else:
|
|
683
|
-
ret = output
|
|
684
|
-
return ret
|
|
685
|
-
|
|
686
591
|
class GradAccumulationCell(Cell):
|
|
687
592
|
"""
|
|
688
593
|
Wrap the network with Micro Batch to enable the grad accumulation in semi_auto_parallel/auto_parallel mode.
|
|
689
594
|
|
|
595
|
+
Note:
|
|
596
|
+
The api will be deprecated, please use the api :class:`mindspore.parallel.nn.GradAccumulation` instead.
|
|
597
|
+
|
|
690
598
|
Args:
|
|
691
599
|
network (Cell): The target network to wrap.
|
|
692
600
|
micro_size (int): MicroBatch size.
|
|
@@ -736,12 +644,6 @@ class GradAccumulationCell(Cell):
|
|
|
736
644
|
return ret
|
|
737
645
|
|
|
738
646
|
|
|
739
|
-
def _pipeline_clear_grad(accu_grad, grad):
|
|
740
|
-
accu_grad = F.depend(accu_grad, grad)
|
|
741
|
-
zeros = F.zeros_like(accu_grad)
|
|
742
|
-
return F.assign(accu_grad, zeros)
|
|
743
|
-
|
|
744
|
-
|
|
745
647
|
class _TrainGradAccuStepCell(TrainOneStepCell):
|
|
746
648
|
"""
|
|
747
649
|
Wraps the network with an optimizer in pipeline mode.
|
|
@@ -753,6 +655,13 @@ class _TrainGradAccuStepCell(TrainOneStepCell):
|
|
|
753
655
|
self.opt_shard = _get_enable_parallel_optimizer()
|
|
754
656
|
self._get_attr_from_cell(network)
|
|
755
657
|
self.enable_tft = False
|
|
658
|
+
if not self.sense_flag:
|
|
659
|
+
micro_size = 1.0
|
|
660
|
+
for _, cell in network.cells_and_names():
|
|
661
|
+
if hasattr(cell, 'micro_size'):
|
|
662
|
+
micro_size = cell.micro_size
|
|
663
|
+
break
|
|
664
|
+
self.sens = 1 / micro_size
|
|
756
665
|
|
|
757
666
|
def construct(self, *inputs):
|
|
758
667
|
if not self.sense_flag:
|
|
@@ -776,8 +685,10 @@ class _TrainGradAccuStepCell(TrainOneStepCell):
|
|
|
776
685
|
grads = self.grad_no_sens(self.network, self.weights)(*inputs)
|
|
777
686
|
accu_grads = ops.depend(self.accu_grads, grads)
|
|
778
687
|
if self.opt_shard:
|
|
688
|
+
grads = self.hyper_map(F.partial(grad_scale, self.sens), grads)
|
|
779
689
|
succ = self.optimizer(grads)
|
|
780
690
|
else:
|
|
691
|
+
accu_grads = self.hyper_map(F.partial(grad_scale, self.sens), accu_grads)
|
|
781
692
|
succ = self.optimizer(accu_grads)
|
|
782
693
|
loss = ops.depend(loss, succ)
|
|
783
694
|
clear = self.hyper_map(_pipeline_clear_grad, accu_grads, grads)
|
|
@@ -966,3 +877,151 @@ class _BroadCastCell(Cell):
|
|
|
966
877
|
params = self.broadcast(params)
|
|
967
878
|
new_params = self.map_(F.partial(_cast_datatype), datatypes, params)
|
|
968
879
|
return new_params
|
|
880
|
+
|
|
881
|
+
|
|
882
|
+
class PipelineCell(Cell):
|
|
883
|
+
"""
|
|
884
|
+
Slice MiniBatch into finer-grained MicroBatch for use in pipeline-parallel training.
|
|
885
|
+
|
|
886
|
+
Note:
|
|
887
|
+
- micro_size must be greater or equal to pipeline stages.
|
|
888
|
+
- The api will be deprecated, please use the api :class:`mindspore.parallel.nn.Pipeline` instead.
|
|
889
|
+
|
|
890
|
+
Args:
|
|
891
|
+
network (Cell): The target network to wrap.
|
|
892
|
+
micro_size (int): MicroBatch size.
|
|
893
|
+
stage_config (dict, optional): The stage configuration for each cell's execution in pipeline parallel.
|
|
894
|
+
Default ``None``.
|
|
895
|
+
|
|
896
|
+
Supported Platforms:
|
|
897
|
+
``Ascend`` ``GPU``
|
|
898
|
+
|
|
899
|
+
Examples:
|
|
900
|
+
>>> import mindspore.nn as nn
|
|
901
|
+
>>> # Define the network structure of LeNet5. Refer to
|
|
902
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
|
|
903
|
+
>>> net = LeNet5()
|
|
904
|
+
>>> net = nn.PipelineCell(net, 4)
|
|
905
|
+
"""
|
|
906
|
+
def __init__(self, network, micro_size, stage_config=None):
|
|
907
|
+
super(PipelineCell, self).__init__(auto_prefix=False)
|
|
908
|
+
self.network = network
|
|
909
|
+
self.micro_inputs = nn.CellList()
|
|
910
|
+
self.micro_size = micro_size
|
|
911
|
+
self.add_list = []
|
|
912
|
+
if not isinstance(network, Cell):
|
|
913
|
+
raise TypeError("For 'PipelineCell', the argument 'network' must cell type, "
|
|
914
|
+
"but got the type : {}.".format(type(network)))
|
|
915
|
+
if not isinstance(micro_size, int):
|
|
916
|
+
raise TypeError("For 'PipelineCell', the argument 'micro_size' must be integer, "
|
|
917
|
+
"but got the type : {}.".format(type(micro_size)))
|
|
918
|
+
if micro_size <= 0:
|
|
919
|
+
raise ValueError("For 'PipelineCell', the argument 'micro_size' must be large than 0, "
|
|
920
|
+
"but got {}.".format(micro_size))
|
|
921
|
+
for i in range(micro_size):
|
|
922
|
+
micro_input = _MicroBatch(micro_size)
|
|
923
|
+
self.micro_inputs.append(micro_input)
|
|
924
|
+
self.add = P.Add().add_prim_attr("pipeline_end", i)
|
|
925
|
+
self.add_list.append(self.add)
|
|
926
|
+
self._get_attr_from_cell(network)
|
|
927
|
+
|
|
928
|
+
# prase stage_config
|
|
929
|
+
config_dict = {}
|
|
930
|
+
if stage_config is not None:
|
|
931
|
+
for cell_name, stage_num in stage_config.items():
|
|
932
|
+
config_cell_name = cell_name
|
|
933
|
+
config_stage_num = stage_num
|
|
934
|
+
config_dict[config_cell_name] = config_stage_num
|
|
935
|
+
|
|
936
|
+
# set cell.stage_config
|
|
937
|
+
for cell_name, cell in self.network.cells_and_names():
|
|
938
|
+
for config_cell_name, config_stage_num in config_dict.copy().items():
|
|
939
|
+
if not cell_name or not config_cell_name:
|
|
940
|
+
continue
|
|
941
|
+
if cell_name == config_cell_name:
|
|
942
|
+
setattr(cell, "pipeline_stage", config_stage_num)
|
|
943
|
+
del config_dict[config_cell_name]
|
|
944
|
+
|
|
945
|
+
for config_cell_name, config_stage_num in config_dict.copy().items():
|
|
946
|
+
if str(network) == config_cell_name:
|
|
947
|
+
setattr(network, "pipeline_stage", config_stage_num)
|
|
948
|
+
del config_dict[config_cell_name]
|
|
949
|
+
|
|
950
|
+
# if there are any config elements left, print them
|
|
951
|
+
if config_dict:
|
|
952
|
+
for config_cell_name, config_stage_num in config_dict.items():
|
|
953
|
+
print("pipeline_cell stage_config set pipeline_stage fail!")
|
|
954
|
+
print("config cell name:" + str(config_cell_name) +
|
|
955
|
+
" config stage num:" + str(config_stage_num))
|
|
956
|
+
print("network:" + str(self.network))
|
|
957
|
+
print("cell name available:")
|
|
958
|
+
for cell_name, cell in self.network.cells_and_names():
|
|
959
|
+
print(cell_name)
|
|
960
|
+
raise KeyError("For 'PipelineCell', the argument 'stage_config' : {} is not "
|
|
961
|
+
"found in 'network' : {}".format(config_dict, network))
|
|
962
|
+
|
|
963
|
+
def construct(self, *inputs):
|
|
964
|
+
ret = None
|
|
965
|
+
for i in range(self.micro_size):
|
|
966
|
+
micro_input = self.micro_inputs[i](i, *inputs)
|
|
967
|
+
output = self.network(*micro_input)
|
|
968
|
+
if ret is not None:
|
|
969
|
+
ret = self.add_list[i](ret, output)
|
|
970
|
+
else:
|
|
971
|
+
ret = output
|
|
972
|
+
return ret
|
|
973
|
+
|
|
974
|
+
|
|
975
|
+
class MicroBatchInterleaved(Cell):
|
|
976
|
+
"""
|
|
977
|
+
This function splits the input at the 0th into interleave_num pieces and then performs
|
|
978
|
+
the computation of the wrapped cell. Application scenario: When there is model parallelism in semi-automatic mode
|
|
979
|
+
and network, if the first slice data is calculating forward, the second slice data will execute the
|
|
980
|
+
communication operators at the same time, to achieve the performance acceleration of communication and computing
|
|
981
|
+
concurrency.
|
|
982
|
+
|
|
983
|
+
Args:
|
|
984
|
+
network (Cell): The target network to wrap.
|
|
985
|
+
interleave_num (int, optional): split num of batch size. Default: ``2`` .
|
|
986
|
+
|
|
987
|
+
Inputs:
|
|
988
|
+
tuple[Tensor]. It's the same with the input of the `network` .
|
|
989
|
+
|
|
990
|
+
Outputs:
|
|
991
|
+
The wrapped input. The output of the input `network` should be a Tensor.
|
|
992
|
+
|
|
993
|
+
Supported Platforms:
|
|
994
|
+
``Ascend`` ``GPU``
|
|
995
|
+
|
|
996
|
+
Examples:
|
|
997
|
+
>>> import mindspore.nn as nn
|
|
998
|
+
>>> # Define the network structure of LeNet5. Refer to
|
|
999
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
|
|
1000
|
+
>>> net = LeNet5()
|
|
1001
|
+
>>> net = nn.MicroBatchInterleaved(net, 2)
|
|
1002
|
+
"""
|
|
1003
|
+
def __init__(self, network, interleave_num=2):
|
|
1004
|
+
super(MicroBatchInterleaved, self).__init__(auto_prefix=False)
|
|
1005
|
+
if not isinstance(interleave_num, int):
|
|
1006
|
+
raise TypeError("For 'MicroBatchInterleaved', the argument 'interleave_num' must be integer, "
|
|
1007
|
+
"but got the type : {}.".format(type(interleave_num)))
|
|
1008
|
+
if interleave_num <= 0:
|
|
1009
|
+
raise ValueError("For 'MicroBatchInterleaved', the argument 'interleave_num' must be large than 0, "
|
|
1010
|
+
"but got {}.".format(interleave_num))
|
|
1011
|
+
self.network = network
|
|
1012
|
+
self.interleave_num = interleave_num
|
|
1013
|
+
self.interleave_inputs = nn.CellList()
|
|
1014
|
+
self.add = P.Add().add_prim_attr("micro_interleaved_add_flag", True)
|
|
1015
|
+
for _ in range(interleave_num):
|
|
1016
|
+
interleave_data = _MicroBatch(interleave_num)
|
|
1017
|
+
interleave_data.strided_slice.add_prim_attr("strided_slice_flag", True)
|
|
1018
|
+
interleave_data.strided_slice.add_prim_attr("interleave_num", interleave_num)
|
|
1019
|
+
self.interleave_inputs.append(interleave_data)
|
|
1020
|
+
self._get_attr_from_cell(network)
|
|
1021
|
+
|
|
1022
|
+
def construct(self, *inputs):
|
|
1023
|
+
output = 0.0
|
|
1024
|
+
for i in range(self.interleave_num):
|
|
1025
|
+
interleave_input = self.interleave_inputs[i](i, *inputs)
|
|
1026
|
+
output = self.add(output, self.network(*interleave_input))
|
|
1027
|
+
return output
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2025 Huawei Technologies Co., Ltd
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -18,7 +18,6 @@ from __future__ import absolute_import
|
|
|
18
18
|
from mindspore import context
|
|
19
19
|
from mindspore import log as logger
|
|
20
20
|
from mindspore.nn.cell import Cell
|
|
21
|
-
from mindspore.nn.layer import Identity
|
|
22
21
|
from mindspore.communication.management import GlobalComm, get_group_size
|
|
23
22
|
from mindspore.common.sparse_tensor import RowTensorInner
|
|
24
23
|
from mindspore.ops import functional as F, composite as C, operations as P
|
|
@@ -28,30 +27,13 @@ import mindspore.common.dtype as mstype
|
|
|
28
27
|
from mindspore.common.sparse_tensor import Tensor
|
|
29
28
|
from mindspore.common.api import jit
|
|
30
29
|
from mindspore.common.parameter import Parameter
|
|
30
|
+
from mindspore.nn.layer import Identity
|
|
31
31
|
from mindspore.parallel._utils import _get_enable_parallel_optimizer
|
|
32
32
|
|
|
33
|
-
|
|
34
|
-
grad_scale = C.MultitypeFuncGraph("grad_scale")
|
|
35
|
-
shard_grad_scale = C.MultitypeFuncGraph("shard_grad_scale")
|
|
36
|
-
reciprocal = P.Reciprocal()
|
|
33
|
+
__all__ = ['DistributedGradReducer']
|
|
37
34
|
|
|
38
35
|
|
|
39
|
-
|
|
40
|
-
def tensor_grad_scale_pipeline(scale, grad, accu_grad):
|
|
41
|
-
accu_grad = F.depend(accu_grad, grad)
|
|
42
|
-
new_grad = accu_grad * reciprocal(scale)
|
|
43
|
-
accu_grad = F.depend(accu_grad, new_grad)
|
|
44
|
-
zeros = F.tensor_mul(accu_grad, 0.0)
|
|
45
|
-
new_grad = F.depend(new_grad, F.assign(accu_grad, zeros))
|
|
46
|
-
return new_grad
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
@shard_grad_scale.register("Tensor", "Tensor", "Tensor")
|
|
50
|
-
def tensor_shard_grad_scale_pipeline(scale, grad, accu_grad):
|
|
51
|
-
new_grad = grad * reciprocal(scale)
|
|
52
|
-
accu_grad = F.depend(accu_grad, new_grad)
|
|
53
|
-
new_grad = F.depend(new_grad, F.assign(accu_grad, F.zeros_like(accu_grad)))
|
|
54
|
-
return new_grad
|
|
36
|
+
reduce_opt = C.MultitypeFuncGraph("reduce_opt")
|
|
55
37
|
|
|
56
38
|
|
|
57
39
|
def _init_allreduce_operators(length, split_indices, group=GlobalComm.WORLD_COMM_GROUP):
|
|
@@ -335,14 +317,14 @@ class DistributedGradReducer(Cell):
|
|
|
335
317
|
|
|
336
318
|
For the Ascend devices, users need to prepare the rank table, set rank_id and device_id.
|
|
337
319
|
Please see the `rank table Startup
|
|
338
|
-
<https://www.mindspore.cn/
|
|
320
|
+
<https://www.mindspore.cn/tutorials/en/master/parallel/rank_table.html>`_
|
|
339
321
|
for more details.
|
|
340
322
|
|
|
341
323
|
For the GPU devices, users need to prepare the host file and mpi, please see the `mpirun Startup
|
|
342
|
-
<https://www.mindspore.cn/
|
|
324
|
+
<https://www.mindspore.cn/tutorials/en/master/parallel/mpirun.html>`_ .
|
|
343
325
|
|
|
344
326
|
For the CPU device, users need to write a dynamic cluster startup script, please see the `Dynamic Cluster
|
|
345
|
-
Startup <https://www.mindspore.cn/
|
|
327
|
+
Startup <https://www.mindspore.cn/tutorials/en/master/parallel/dynamic_cluster.html>`_ .
|
|
346
328
|
|
|
347
329
|
This example should be run with multiple devices.
|
|
348
330
|
|
|
@@ -427,7 +409,8 @@ class DistributedGradReducer(Cell):
|
|
|
427
409
|
self.degree = degree
|
|
428
410
|
self.degree = Tensor(1.0 / self.degree, mstype.float32)
|
|
429
411
|
|
|
430
|
-
self.allreduce_filter = tuple((x.layerwise_parallel is False) and
|
|
412
|
+
self.allreduce_filter = tuple((x.layerwise_parallel is False) and
|
|
413
|
+
(not x.param_info.is_in_pynative_shard) for x in parameters)
|
|
431
414
|
is_parallel_optimizer = context.get_auto_parallel_context("enable_parallel_optimizer")
|
|
432
415
|
split_indices = auto_parallel_context().get_all_reduce_fusion_split_indices()
|
|
433
416
|
if is_parallel_optimizer and split_indices:
|
|
@@ -447,7 +430,7 @@ class DistributedGradReducer(Cell):
|
|
|
447
430
|
self.mode = context.get_context("mode")
|
|
448
431
|
self.enable_tuple_broaden = True
|
|
449
432
|
|
|
450
|
-
@jit
|
|
433
|
+
@jit(backend="ms_backend")
|
|
451
434
|
def construct(self, grads):
|
|
452
435
|
"""
|
|
453
436
|
Under certain circumstances, the data precision of grads could be mixed with float16 and float32. Thus, the
|
|
@@ -488,13 +471,39 @@ class DistributedGradReducer(Cell):
|
|
|
488
471
|
raise RuntimeError("{} can not use DistributedGradReducer in graph mode".format(parallel_mode))
|
|
489
472
|
|
|
490
473
|
|
|
474
|
+
grad_scale = C.MultitypeFuncGraph("grad_scale")
|
|
475
|
+
shard_grad_scale = C.MultitypeFuncGraph("shard_grad_scale")
|
|
476
|
+
reciprocal = P.Reciprocal()
|
|
477
|
+
|
|
478
|
+
|
|
479
|
+
@grad_scale.register("Tensor", "Tensor", "Tensor")
|
|
480
|
+
def tensor_grad_scale_pipeline(scale, grad, accu_grad):
|
|
481
|
+
accu_grad = F.depend(accu_grad, grad)
|
|
482
|
+
new_grad = accu_grad * reciprocal(scale)
|
|
483
|
+
accu_grad = F.depend(accu_grad, new_grad)
|
|
484
|
+
zeros = F.tensor_mul(accu_grad, 0.0)
|
|
485
|
+
new_grad = F.depend(new_grad, F.assign(accu_grad, zeros))
|
|
486
|
+
return new_grad
|
|
487
|
+
|
|
488
|
+
|
|
489
|
+
@shard_grad_scale.register("Tensor", "Tensor", "Tensor")
|
|
490
|
+
def tensor_shard_grad_scale_pipeline(scale, grad, accu_grad):
|
|
491
|
+
new_grad = grad * reciprocal(scale)
|
|
492
|
+
accu_grad = F.depend(accu_grad, new_grad)
|
|
493
|
+
new_grad = F.depend(new_grad, F.assign(accu_grad, F.zeros_like(accu_grad)))
|
|
494
|
+
return new_grad
|
|
495
|
+
|
|
496
|
+
|
|
491
497
|
class PipelineGradReducer(Cell):
|
|
492
498
|
"""
|
|
493
499
|
PipelineGradReducer is a gradient reducer for pipeline parallelism.
|
|
494
500
|
|
|
501
|
+
Note:
|
|
502
|
+
The api will be deprecated, please use the api :class:`mindspore.parallel.nn.PipelineGradReducer` instead.
|
|
503
|
+
|
|
495
504
|
Args:
|
|
496
505
|
parameters (list): the parameters to be updated.
|
|
497
|
-
scale_sense (float): the scale sense of the gradient. Default: 1.0
|
|
506
|
+
scale_sense (float, optional): the scale sense of the gradient. Default: ``1.0``.
|
|
498
507
|
|
|
499
508
|
Raise:
|
|
500
509
|
RuntimeError: If the mode is not graph mode.
|
|
@@ -509,11 +518,11 @@ class PipelineGradReducer(Cell):
|
|
|
509
518
|
|
|
510
519
|
For the Ascend devices, users need to prepare the rank table, set rank_id and device_id.
|
|
511
520
|
Please see the `rank table Startup
|
|
512
|
-
<https://www.mindspore.cn/
|
|
521
|
+
<https://www.mindspore.cn/tutorials/en/master/parallel/rank_table.html>`_
|
|
513
522
|
for more details.
|
|
514
523
|
|
|
515
524
|
For the GPU devices, users need to prepare the host file and mpi, please see the `mpirun Startup
|
|
516
|
-
<https://www.mindspore.cn/
|
|
525
|
+
<https://www.mindspore.cn/tutorials/en/master/parallel/mpirun.html>`_ .
|
|
517
526
|
|
|
518
527
|
This example should be run with multiple devices.
|
|
519
528
|
|
|
@@ -554,7 +563,7 @@ class PipelineGradReducer(Cell):
|
|
|
554
563
|
>>> net.layer3.pipeline_stage = 1
|
|
555
564
|
>>> loss_fn = nn.CrossEntropyLoss()
|
|
556
565
|
>>> optimizer = nn.SGD(net.trainable_params(), 1e-2)
|
|
557
|
-
>>> net_with_loss = nn.
|
|
566
|
+
>>> net_with_loss = nn.Pipeline(nn.WithLossCell(net, loss_fn), 2)
|
|
558
567
|
>>> net_with_loss.set_train()
|
|
559
568
|
>>> def forward_fn(inputs, target):
|
|
560
569
|
... loss = net_with_loss(inputs, target)
|
|
@@ -576,7 +585,7 @@ class PipelineGradReducer(Cell):
|
|
|
576
585
|
>>> print(loss)
|
|
577
586
|
46.36721
|
|
578
587
|
"""
|
|
579
|
-
def __init__(self, parameters, scale_sense=1.0):
|
|
588
|
+
def __init__(self, parameters, scale_sense=1.0, opt_shard=None):
|
|
580
589
|
super(PipelineGradReducer, self).__init__(auto_prefix=False)
|
|
581
590
|
self._check_mode()
|
|
582
591
|
self.accu_grads = parameters.clone(prefix="accu_grads", init="zeros")
|
|
@@ -584,7 +593,10 @@ class PipelineGradReducer(Cell):
|
|
|
584
593
|
self.degree = Tensor(1, mstype.float32)
|
|
585
594
|
self.scale_sense = Parameter(scale_sense, name='scale_sense')
|
|
586
595
|
self.hyper_map = C.HyperMap()
|
|
587
|
-
|
|
596
|
+
if opt_shard is None:
|
|
597
|
+
self.opt_shard = _get_enable_parallel_optimizer()
|
|
598
|
+
else:
|
|
599
|
+
self.opt_shard = opt_shard
|
|
588
600
|
|
|
589
601
|
@jit
|
|
590
602
|
def construct(self, grads):
|
|
@@ -603,6 +615,3 @@ class PipelineGradReducer(Cell):
|
|
|
603
615
|
mode = context.get_context('mode')
|
|
604
616
|
if mode != context.GRAPH_MODE:
|
|
605
617
|
raise RuntimeError(f"PipelineGradReducer only support graph mode, but get {mode}")
|
|
606
|
-
parallel_mode = context.get_auto_parallel_context('parallel_mode')
|
|
607
|
-
if parallel_mode not in (context.ParallelMode.SEMI_AUTO_PARALLEL, context.ParallelMode.AUTO_PARALLEL):
|
|
608
|
-
raise RuntimeError(f"{parallel_mode} can not use PipelineGradReducer in graph mode")
|
mindspore/nn/wrap/loss_scale.py
CHANGED
|
@@ -31,7 +31,6 @@ from mindspore.ops import composite as C
|
|
|
31
31
|
from mindspore.ops import operations as P
|
|
32
32
|
from mindspore.ops.operations.nn_ops import AllFinite
|
|
33
33
|
from mindspore.common import dtype as mstype
|
|
34
|
-
from mindspore.common.api import jit
|
|
35
34
|
from mindspore._c_expression import MSContext
|
|
36
35
|
from mindspore.run_check._check_version import AscendEnvChecker
|
|
37
36
|
from mindspore import log as logger
|
|
@@ -93,8 +92,8 @@ class DynamicLossScaleUpdateCell(Cell):
|
|
|
93
92
|
Dynamic Loss scale update cell.
|
|
94
93
|
|
|
95
94
|
For loss scaling training, the initial loss scaling value will be set to be `loss_scale_value`.
|
|
96
|
-
In each training step, the loss scaling value will be decreased by `
|
|
97
|
-
when there is an overflow. And it will be increased by
|
|
95
|
+
In each training step, the loss scaling value will be decreased by :math:`loss\_scale/scale\_factor`
|
|
96
|
+
when there is an overflow. And it will be increased by :math:`loss\_scale * scale\_factor` if there is no
|
|
98
97
|
overflow for a continuous `scale_window` steps.
|
|
99
98
|
|
|
100
99
|
`get_update_cell` method of :class:`mindspore.amp.DynamicLossScaleManager` will return this class. It will be called
|
|
@@ -377,7 +376,7 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell):
|
|
|
377
376
|
self.loss_scaling_manager = None
|
|
378
377
|
self._ascend_check_overflow_mode = os.environ.get('MS_ASCEND_CHECK_OVERFLOW_MODE')
|
|
379
378
|
|
|
380
|
-
self.enable_allfinite =
|
|
379
|
+
self.enable_allfinite = True
|
|
381
380
|
runtime_conf = os.environ.get('MS_DEV_RUNTIME_CONF')
|
|
382
381
|
global_jit_config = context.get_jit_config()
|
|
383
382
|
if runtime_conf is not None and ("all_finite:True" in runtime_conf or "all_finite:true" in runtime_conf):
|
|
@@ -389,7 +388,8 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell):
|
|
|
389
388
|
elif global_jit_config:
|
|
390
389
|
logger.debug("Current global jit config is: {}".format(global_jit_config["jit_level"]))
|
|
391
390
|
self.enable_allfinite = global_jit_config["jit_level"] == "O0" or global_jit_config["jit_level"] == "O1"
|
|
392
|
-
|
|
391
|
+
if "RANK_TABLE_FILE" in os.environ:
|
|
392
|
+
self.enable_allfinite = False
|
|
393
393
|
if self.ascend_910b_target:
|
|
394
394
|
checker = AscendEnvChecker(None)
|
|
395
395
|
if not checker.check_custom_version():
|
|
@@ -506,7 +506,7 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell):
|
|
|
506
506
|
overflow = AllFinite()(compute_output)
|
|
507
507
|
|
|
508
508
|
if self.is_distributed:
|
|
509
|
-
overflow = P.Cast()(overflow, mstype.
|
|
509
|
+
overflow = P.Cast()(overflow, mstype.float32)
|
|
510
510
|
overflow = P.Cast()(self.allreduce(overflow), mstype.bool_)
|
|
511
511
|
return overflow
|
|
512
512
|
|
|
@@ -548,7 +548,6 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell):
|
|
|
548
548
|
overflow = self.logic_not(overall_finite)
|
|
549
549
|
return overflow
|
|
550
550
|
|
|
551
|
-
@jit
|
|
552
551
|
def get_overflow_status(self, status, compute_output):
|
|
553
552
|
"""
|
|
554
553
|
Get floating-point overflow status.
|