mindspore 2.4.10__cp311-cp311-win_amd64.whl → 2.6.0__cp311-cp311-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/__init__.py +13 -6
- mindspore/_c_dataengine.cp311-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp311-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp311-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +3 -0
- mindspore/_checkparam.py +3 -38
- mindspore/_deprecated/__init__.py +17 -0
- mindspore/_deprecated/jit.py +198 -0
- mindspore/_extends/builtin_operations.py +1 -1
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
- mindspore/_extends/parse/__init__.py +6 -7
- mindspore/_extends/parse/compile_config.py +83 -0
- mindspore/_extends/parse/deprecated/__init__.py +0 -0
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +394 -0
- mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
- mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
- mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
- mindspore/_extends/parse/parser.py +47 -198
- mindspore/_extends/parse/resources.py +1 -5
- mindspore/_extends/parse/standard_method.py +229 -99
- mindspore/_extends/pijit/__init__.py +2 -2
- mindspore/_extends/pijit/pijit_func_white_list.py +17 -12
- mindspore/_extends/pijit/tensor_func_list.py +27 -0
- mindspore/_extends/utils.py +1 -1
- mindspore/amp.py +11 -5
- mindspore/atlprov.dll +0 -0
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/__init__.py +2 -2
- mindspore/boost/base.py +3 -7
- mindspore/boost/boost_cell_wrapper.py +138 -43
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/common/__init__.py +6 -3
- mindspore/common/_grad_function.py +56 -0
- mindspore/common/_pijit_context.py +14 -5
- mindspore/common/_register_for_tensor.py +1 -2
- mindspore/common/_stub_tensor.py +30 -14
- mindspore/common/_tensor_cpp_method.py +17 -0
- mindspore/common/_tensor_docs.py +4760 -0
- mindspore/common/api.py +480 -372
- mindspore/common/auto_dynamic_shape.py +41 -44
- mindspore/common/dtype.py +39 -36
- mindspore/common/dump.py +9 -6
- mindspore/common/file_system.py +9 -1
- mindspore/common/generator.py +5 -0
- mindspore/common/hook_handle.py +6 -2
- mindspore/common/initializer.py +13 -10
- mindspore/common/jit_begin_end.py +94 -0
- mindspore/common/jit_config.py +6 -1
- mindspore/common/jit_context.py +76 -0
- mindspore/common/jit_trace.py +378 -0
- mindspore/common/lazy_inline.py +9 -3
- mindspore/common/mindir_util.py +10 -2
- mindspore/common/mutable.py +5 -4
- mindspore/common/parameter.py +135 -52
- mindspore/common/seed.py +2 -2
- mindspore/common/sparse_tensor.py +23 -17
- mindspore/common/tensor.py +975 -1981
- mindspore/communication/__init__.py +7 -5
- mindspore/communication/_comm_helper.py +52 -2
- mindspore/communication/comm_func.py +240 -181
- mindspore/communication/management.py +95 -26
- mindspore/context.py +324 -573
- mindspore/dataset/__init__.py +65 -37
- mindspore/dataset/audio/__init__.py +2 -8
- mindspore/dataset/audio/transforms.py +3 -17
- mindspore/dataset/callback/ds_callback.py +2 -1
- mindspore/dataset/core/config.py +87 -6
- mindspore/dataset/engine/cache_admin.py +3 -3
- mindspore/dataset/engine/cache_client.py +6 -5
- mindspore/dataset/engine/datasets.py +292 -267
- mindspore/dataset/engine/datasets_audio.py +22 -8
- mindspore/dataset/engine/datasets_standard_format.py +46 -27
- mindspore/dataset/engine/datasets_text.py +78 -48
- mindspore/dataset/engine/datasets_user_defined.py +183 -117
- mindspore/dataset/engine/datasets_vision.py +120 -44
- mindspore/dataset/engine/iterators.py +283 -63
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
- mindspore/dataset/engine/obs/util.py +8 -0
- mindspore/dataset/engine/queue.py +40 -0
- mindspore/dataset/engine/samplers.py +289 -43
- mindspore/dataset/engine/serializer_deserializer.py +3 -2
- mindspore/dataset/engine/validators.py +53 -11
- mindspore/dataset/text/__init__.py +7 -6
- mindspore/dataset/text/transforms.py +6 -5
- mindspore/dataset/text/utils.py +3 -3
- mindspore/dataset/transforms/__init__.py +0 -9
- mindspore/dataset/transforms/py_transforms_util.py +17 -0
- mindspore/dataset/transforms/transforms.py +31 -14
- mindspore/dataset/utils/browse_dataset.py +1 -1
- mindspore/dataset/vision/__init__.py +2 -9
- mindspore/dataset/vision/transforms.py +202 -158
- mindspore/dataset/vision/utils.py +7 -5
- mindspore/dataset/vision/validators.py +1 -2
- mindspore/device_context/__init__.py +21 -0
- mindspore/device_context/ascend/__init__.py +25 -0
- mindspore/device_context/ascend/device.py +72 -0
- mindspore/device_context/ascend/op_debug.py +153 -0
- mindspore/device_context/ascend/op_precision.py +193 -0
- mindspore/device_context/ascend/op_tuning.py +123 -0
- mindspore/{ops_generate/gen_constants.py → device_context/cpu/__init__.py} +6 -17
- mindspore/device_context/cpu/device.py +62 -0
- mindspore/device_context/cpu/op_tuning.py +43 -0
- mindspore/device_context/gpu/__init__.py +21 -0
- mindspore/device_context/gpu/device.py +70 -0
- mindspore/device_context/gpu/op_precision.py +67 -0
- mindspore/device_context/gpu/op_tuning.py +175 -0
- mindspore/device_manager.py +170 -0
- mindspore/dnnl.dll +0 -0
- mindspore/dpcmi.dll +0 -0
- mindspore/experimental/es/embedding_service.py +35 -27
- mindspore/experimental/llm_boost/__init__.py +1 -0
- mindspore/experimental/llm_boost/ascend_native/__init__.py +22 -0
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +209 -0
- mindspore/experimental/llm_boost/ascend_native/llm_boost.py +52 -0
- mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
- mindspore/experimental/llm_boost/atb/llama_boost.py +6 -1
- mindspore/experimental/llm_boost/register.py +1 -0
- mindspore/experimental/map_parameter.py +4 -4
- mindspore/experimental/optim/adadelta.py +6 -6
- mindspore/experimental/optim/adagrad.py +4 -4
- mindspore/experimental/optim/adam.py +7 -0
- mindspore/experimental/optim/adamax.py +4 -4
- mindspore/experimental/optim/adamw.py +4 -0
- mindspore/experimental/optim/asgd.py +1 -1
- mindspore/experimental/optim/lr_scheduler.py +73 -46
- mindspore/experimental/optim/radam.py +34 -31
- mindspore/experimental/optim/rprop.py +1 -1
- mindspore/experimental/optim/sgd.py +1 -1
- mindspore/hal/contiguous_tensors_handle.py +6 -10
- mindspore/hal/device.py +55 -53
- mindspore/hal/event.py +52 -52
- mindspore/hal/memory.py +179 -120
- mindspore/hal/stream.py +150 -109
- mindspore/include/api/context.h +0 -1
- mindspore/include/dataset/constants.h +7 -4
- mindspore/include/dataset/execute.h +2 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +50 -0
- mindspore/mindrecord/__init__.py +21 -8
- mindspore/mindrecord/config.py +17 -316
- mindspore/mindrecord/filereader.py +1 -9
- mindspore/mindrecord/filewriter.py +5 -15
- mindspore/mindrecord/mindpage.py +1 -9
- mindspore/mindspore_backend_common.dll +0 -0
- mindspore/mindspore_backend_manager.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_dump.dll +0 -0
- mindspore/mindspore_frontend.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_memory_pool.dll +0 -0
- mindspore/mindspore_ms_backend.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
- mindspore/mindspore_ops_kernel_common.dll +0 -0
- mindspore/mindspore_profiler.dll +0 -0
- mindspore/mindspore_pyboost.dll +0 -0
- mindspore/mindspore_pynative.dll +0 -0
- mindspore/mindspore_res_manager.dll +0 -0
- mindspore/mindspore_runtime_pipeline.dll +0 -0
- mindspore/mint/__init__.py +798 -761
- mindspore/mint/distributed/__init__.py +70 -4
- mindspore/mint/distributed/distributed.py +2679 -44
- mindspore/mint/linalg/__init__.py +8 -0
- mindspore/mint/nn/__init__.py +743 -22
- mindspore/mint/nn/functional.py +716 -23
- mindspore/mint/nn/layer/__init__.py +21 -4
- mindspore/mint/nn/layer/_functions.py +334 -0
- mindspore/mint/nn/layer/activation.py +276 -1
- mindspore/mint/nn/layer/basic.py +123 -0
- mindspore/mint/nn/layer/conv.py +933 -0
- mindspore/mint/nn/layer/normalization.py +223 -28
- mindspore/mint/nn/layer/padding.py +797 -0
- mindspore/mint/nn/layer/pooling.py +235 -0
- mindspore/mint/optim/__init__.py +3 -1
- mindspore/mint/optim/adam.py +223 -0
- mindspore/mint/optim/adamw.py +26 -19
- mindspore/mint/optim/sgd.py +171 -0
- mindspore/mint/special/__init__.py +2 -1
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/multiprocessing/__init__.py +5 -0
- mindspore/nn/__init__.py +4 -1
- mindspore/nn/cell.py +1373 -192
- mindspore/nn/dynamic_lr.py +2 -1
- mindspore/nn/layer/activation.py +29 -27
- mindspore/nn/layer/basic.py +51 -35
- mindspore/nn/layer/channel_shuffle.py +3 -3
- mindspore/nn/layer/container.py +1 -1
- mindspore/nn/layer/conv.py +53 -42
- mindspore/nn/layer/embedding.py +12 -11
- mindspore/nn/layer/normalization.py +56 -49
- mindspore/nn/layer/padding.py +4 -3
- mindspore/nn/layer/pooling.py +120 -42
- mindspore/nn/layer/rnn_cells.py +1 -1
- mindspore/nn/layer/rnns.py +2 -1
- mindspore/nn/layer/timedistributed.py +5 -5
- mindspore/nn/layer/transformer.py +59 -36
- mindspore/nn/learning_rate_schedule.py +8 -4
- mindspore/nn/loss/loss.py +58 -55
- mindspore/nn/optim/ada_grad.py +7 -5
- mindspore/nn/optim/adadelta.py +11 -9
- mindspore/nn/optim/adafactor.py +1 -1
- mindspore/nn/optim/adam.py +19 -15
- mindspore/nn/optim/adamax.py +8 -7
- mindspore/nn/optim/adasum.py +5 -5
- mindspore/nn/optim/asgd.py +3 -1
- mindspore/nn/optim/ftrl.py +11 -9
- mindspore/nn/optim/lamb.py +1 -1
- mindspore/nn/optim/lars.py +1 -4
- mindspore/nn/optim/lazyadam.py +12 -10
- mindspore/nn/optim/momentum.py +7 -6
- mindspore/nn/optim/optimizer.py +3 -3
- mindspore/nn/optim/proximal_ada_grad.py +12 -10
- mindspore/nn/optim/rmsprop.py +13 -12
- mindspore/nn/optim/rprop.py +11 -9
- mindspore/nn/optim/sgd.py +9 -6
- mindspore/nn/optim/tft_wrapper.py +5 -2
- mindspore/nn/optim/thor.py +2 -1
- mindspore/nn/probability/bijector/bijector.py +17 -11
- mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
- mindspore/nn/probability/bijector/invert.py +2 -2
- mindspore/nn/probability/bijector/scalar_affine.py +3 -3
- mindspore/nn/probability/bijector/softplus.py +3 -2
- mindspore/nn/probability/distribution/beta.py +3 -3
- mindspore/nn/probability/distribution/categorical.py +1 -1
- mindspore/nn/probability/distribution/cauchy.py +4 -2
- mindspore/nn/probability/distribution/exponential.py +6 -7
- mindspore/nn/probability/distribution/gamma.py +2 -2
- mindspore/nn/probability/distribution/gumbel.py +2 -2
- mindspore/nn/probability/distribution/half_normal.py +5 -3
- mindspore/nn/probability/distribution/logistic.py +5 -3
- mindspore/nn/probability/distribution/poisson.py +1 -1
- mindspore/nn/probability/distribution/uniform.py +5 -3
- mindspore/nn/reinforcement/_tensors_queue.py +1 -1
- mindspore/nn/reinforcement/tensor_array.py +1 -1
- mindspore/nn/utils/init.py +13 -11
- mindspore/nn/wrap/__init__.py +6 -6
- mindspore/nn/wrap/cell_wrapper.py +181 -122
- mindspore/nn/wrap/grad_reducer.py +45 -36
- mindspore/nn/wrap/loss_scale.py +6 -7
- mindspore/numpy/array_creations.py +63 -65
- mindspore/numpy/array_ops.py +149 -144
- mindspore/numpy/logic_ops.py +41 -42
- mindspore/numpy/math_ops.py +361 -359
- mindspore/numpy/utils.py +17 -18
- mindspore/numpy/utils_const.py +5 -6
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +5 -3
- mindspore/ops/_grad_experimental/grad_comm_ops.py +112 -16
- mindspore/ops/_grad_experimental/grad_debug_ops.py +14 -2
- mindspore/ops/_grad_experimental/grad_inner_ops.py +9 -0
- mindspore/ops/_grad_experimental/grad_math_ops.py +2 -1
- mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
- mindspore/ops/_op_impl/cpu/__init__.py +1 -0
- mindspore/ops/_op_impl/cpu/raise_op.py +28 -0
- mindspore/ops/_register_for_op.py +0 -11
- mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
- mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -65
- mindspore/ops/_vmap/vmap_array_ops.py +52 -25
- mindspore/ops/_vmap/vmap_base.py +0 -2
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +21 -14
- mindspore/ops/_vmap/vmap_math_ops.py +15 -16
- mindspore/ops/_vmap/vmap_nn_ops.py +29 -42
- mindspore/ops/auto_generate/__init__.py +4 -3
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +258 -46
- mindspore/ops/auto_generate/gen_extend_func.py +757 -185
- mindspore/ops/auto_generate/gen_ops_def.py +4197 -2243
- mindspore/ops/auto_generate/gen_ops_prim.py +16976 -6055
- mindspore/ops/auto_generate/pyboost_inner_prim.py +221 -87
- mindspore/ops/composite/__init__.py +2 -1
- mindspore/ops/composite/base.py +20 -25
- mindspore/ops/composite/math_ops.py +6 -16
- mindspore/ops/composite/multitype_ops/__init__.py +5 -2
- mindspore/ops/composite/multitype_ops/_compile_utils.py +228 -30
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
- mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
- mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
- mindspore/ops/composite/multitype_ops/pow_impl.py +2 -30
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
- mindspore/ops/function/__init__.py +40 -2
- mindspore/ops/function/_add_attr_func.py +58 -0
- mindspore/ops/function/array_func.py +2089 -2403
- mindspore/ops/function/clip_func.py +80 -23
- mindspore/ops/function/debug_func.py +57 -57
- mindspore/ops/function/grad/__init__.py +1 -0
- mindspore/ops/function/grad/grad_func.py +104 -71
- mindspore/ops/function/image_func.py +2 -2
- mindspore/ops/function/linalg_func.py +47 -78
- mindspore/ops/function/math_func.py +4351 -3813
- mindspore/ops/function/nn_func.py +1712 -637
- mindspore/ops/function/other_func.py +159 -1
- mindspore/ops/function/parameter_func.py +18 -84
- mindspore/ops/function/random_func.py +452 -387
- mindspore/ops/function/reshard_func.py +4 -70
- mindspore/ops/function/sparse_func.py +3 -3
- mindspore/ops/function/sparse_unary_func.py +6 -6
- mindspore/ops/function/spectral_func.py +25 -58
- mindspore/ops/function/vmap_func.py +26 -18
- mindspore/ops/functional.py +23 -7
- mindspore/ops/functional_overload.py +1548 -0
- mindspore/ops/op_info_register.py +32 -244
- mindspore/ops/operations/__init__.py +23 -15
- mindspore/ops/operations/_custom_ops_utils.py +235 -0
- mindspore/ops/operations/_embedding_cache_ops.py +4 -4
- mindspore/ops/operations/_grad_ops.py +2 -43
- mindspore/ops/operations/_infer_ops.py +2 -1
- mindspore/ops/operations/_inner_ops.py +43 -84
- mindspore/ops/operations/_ms_kernel.py +4 -10
- mindspore/ops/operations/_rl_inner_ops.py +1 -1
- mindspore/ops/operations/_scalar_ops.py +3 -2
- mindspore/ops/operations/_sequence_ops.py +1 -1
- mindspore/ops/operations/_tensor_array.py +1 -1
- mindspore/ops/operations/array_ops.py +81 -324
- mindspore/ops/operations/comm_ops.py +154 -108
- mindspore/ops/operations/custom_ops.py +298 -87
- mindspore/ops/operations/debug_ops.py +157 -59
- mindspore/ops/operations/inner_ops.py +7 -5
- mindspore/ops/operations/linalg_ops.py +1 -57
- mindspore/ops/operations/manually_defined/_inner.py +1 -1
- mindspore/ops/operations/manually_defined/ops_def.py +928 -180
- mindspore/ops/operations/math_ops.py +32 -234
- mindspore/ops/operations/nn_ops.py +212 -531
- mindspore/ops/operations/other_ops.py +62 -9
- mindspore/ops/operations/random_ops.py +13 -7
- mindspore/ops/operations/reshard_ops.py +1 -1
- mindspore/ops/operations/sparse_ops.py +2 -2
- mindspore/ops/primitive.py +66 -53
- mindspore/ops/tensor_method.py +1895 -0
- mindspore/ops_generate/__init__.py +0 -5
- mindspore/ops_generate/aclnn/__init__.py +0 -0
- mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +135 -0
- mindspore/ops_generate/aclnn/gen_aclnn_implement.py +257 -0
- mindspore/ops_generate/api/__init__.py +0 -0
- mindspore/ops_generate/api/add_tensor_docs_generator.py +56 -0
- mindspore/ops_generate/api/cpp_create_prim_instance_helper_generator.py +105 -0
- mindspore/ops_generate/api/functional_map_cpp_generator.py +504 -0
- mindspore/ops_generate/api/functional_overload_py_generator.py +112 -0
- mindspore/ops_generate/api/functions_cc_generator.py +237 -0
- mindspore/ops_generate/api/gen_api.py +103 -0
- mindspore/ops_generate/api/op_api_proto.py +235 -0
- mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +461 -0
- mindspore/ops_generate/common/__init__.py +0 -0
- mindspore/ops_generate/common/base_generator.py +11 -0
- mindspore/ops_generate/common/gen_constants.py +91 -0
- mindspore/ops_generate/common/gen_utils.py +348 -0
- mindspore/ops_generate/common/op_proto.py +473 -0
- mindspore/ops_generate/common/template.py +523 -0
- mindspore/ops_generate/gen_ops.py +22 -1069
- mindspore/ops_generate/op_def/__init__.py +0 -0
- mindspore/ops_generate/op_def/gen_op_def.py +90 -0
- mindspore/ops_generate/op_def/lite_ops_cpp_generator.py +191 -0
- mindspore/ops_generate/op_def/ops_def_cc_generator.py +296 -0
- mindspore/ops_generate/op_def/ops_def_h_generator.py +74 -0
- mindspore/ops_generate/op_def/ops_name_h_generator.py +83 -0
- mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
- mindspore/ops_generate/op_def_py/__init__.py +0 -0
- mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
- mindspore/ops_generate/op_def_py/op_def_py_generator.py +132 -0
- mindspore/ops_generate/op_def_py/op_prim_py_generator.py +489 -0
- mindspore/ops_generate/pyboost/__init__.py +0 -0
- mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +139 -0
- mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +93 -0
- mindspore/ops_generate/pyboost/gen_pyboost_func.py +175 -0
- mindspore/ops_generate/pyboost/op_template_parser.py +517 -0
- mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +407 -0
- mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +100 -0
- mindspore/ops_generate/pyboost/pyboost_functions_py_generator.py +148 -0
- mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +155 -0
- mindspore/ops_generate/pyboost/pyboost_inner_prim_generator.py +132 -0
- mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +272 -0
- mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +938 -0
- mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +357 -0
- mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +179 -36
- mindspore/ops_generate/resources/__init__.py +0 -0
- mindspore/ops_generate/resources/resource_list.py +30 -0
- mindspore/ops_generate/resources/resource_loader.py +36 -0
- mindspore/ops_generate/resources/resource_manager.py +64 -0
- mindspore/ops_generate/resources/yaml_loader.py +88 -0
- mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
- mindspore/parallel/__init__.py +7 -3
- mindspore/parallel/_auto_parallel_context.py +159 -40
- mindspore/parallel/_cell_wrapper.py +132 -15
- mindspore/parallel/_parallel_serialization.py +107 -5
- mindspore/parallel/_ps_context.py +1 -1
- mindspore/parallel/_recovery_context.py +7 -2
- mindspore/parallel/_tensor.py +142 -18
- mindspore/parallel/_utils.py +199 -23
- mindspore/parallel/algo_parameter_config.py +4 -4
- mindspore/parallel/auto_parallel.py +732 -0
- mindspore/parallel/checkpoint_convert.py +159 -0
- mindspore/parallel/checkpoint_transform.py +700 -35
- mindspore/parallel/cluster/process_entity/_api.py +276 -50
- mindspore/parallel/cluster/process_entity/_utils.py +41 -6
- mindspore/parallel/cluster/run.py +21 -4
- mindspore/parallel/function/__init__.py +24 -0
- mindspore/parallel/function/reshard_func.py +258 -0
- mindspore/parallel/nn/__init__.py +25 -0
- mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
- mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
- mindspore/parallel/parameter_broadcast.py +25 -14
- mindspore/parallel/shard.py +137 -59
- mindspore/parallel/transform_safetensors.py +364 -305
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/__init__.py +22 -5
- mindspore/profiler/analysis/__init__.py +0 -0
- mindspore/profiler/analysis/parser/__init__.py +0 -0
- mindspore/profiler/analysis/parser/ascend_cann_parser.py +170 -0
- mindspore/profiler/analysis/parser/base_parser.py +158 -0
- mindspore/profiler/analysis/parser/framework_cann_relation_parser.py +45 -0
- mindspore/profiler/analysis/parser/ms_framework_parser.py +142 -0
- mindspore/profiler/analysis/parser/ms_minddata_parser.py +145 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +264 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +40 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +109 -0
- mindspore/profiler/analysis/parser/timeline_creator/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_creator/base_timeline_creator.py +44 -0
- mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +90 -0
- mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +76 -0
- mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +103 -0
- mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +134 -0
- mindspore/profiler/analysis/parser/timeline_event/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_event/base_event.py +233 -0
- mindspore/profiler/analysis/parser/timeline_event/cpu_op_event.py +47 -0
- mindspore/profiler/analysis/parser/timeline_event/flow_event.py +36 -0
- mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +415 -0
- mindspore/profiler/analysis/parser/timeline_event/msprof_event.py +73 -0
- mindspore/profiler/analysis/parser/timeline_event/scope_layer_event.py +53 -0
- mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +146 -0
- mindspore/profiler/analysis/task_manager.py +131 -0
- mindspore/profiler/analysis/time_converter.py +84 -0
- mindspore/profiler/analysis/viewer/__init__.py +0 -0
- mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +372 -0
- mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +87 -0
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +250 -0
- mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +320 -0
- mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +327 -0
- mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +376 -0
- mindspore/profiler/analysis/viewer/ascend_timeline_viewer.py +58 -0
- mindspore/profiler/analysis/viewer/base_viewer.py +26 -0
- mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +96 -0
- mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +581 -0
- mindspore/profiler/analysis/work_flow.py +73 -0
- mindspore/profiler/common/ascend_msprof_exporter.py +139 -0
- mindspore/profiler/common/command_executor.py +90 -0
- mindspore/profiler/common/constant.py +186 -3
- mindspore/profiler/common/file_manager.py +208 -0
- mindspore/profiler/common/log.py +130 -0
- mindspore/profiler/common/msprof_cmd_tool.py +221 -0
- mindspore/profiler/common/path_manager.py +395 -0
- mindspore/profiler/common/process_bar.py +168 -0
- mindspore/profiler/common/process_pool.py +9 -3
- mindspore/profiler/common/profiler_context.py +500 -0
- mindspore/profiler/common/profiler_info.py +304 -0
- mindspore/profiler/common/profiler_meta_data.py +74 -0
- mindspore/profiler/common/profiler_output_path.py +284 -0
- mindspore/profiler/common/profiler_parameters.py +251 -0
- mindspore/profiler/common/profiler_path_manager.py +179 -0
- mindspore/profiler/common/record_function.py +76 -0
- mindspore/profiler/common/tlv_decoder.py +76 -0
- mindspore/profiler/common/util.py +75 -2
- mindspore/profiler/dynamic_profiler.py +341 -75
- mindspore/profiler/envprofiler.py +163 -0
- mindspore/profiler/experimental_config.py +197 -0
- mindspore/profiler/mstx.py +242 -0
- mindspore/profiler/platform/__init__.py +21 -0
- mindspore/profiler/platform/base_profiler.py +40 -0
- mindspore/profiler/platform/cpu_profiler.py +124 -0
- mindspore/profiler/platform/gpu_profiler.py +74 -0
- mindspore/profiler/platform/npu_profiler.py +335 -0
- mindspore/profiler/profiler.py +1073 -90
- mindspore/profiler/profiler_action_controller.py +187 -0
- mindspore/profiler/profiler_interface.py +118 -0
- mindspore/profiler/schedule.py +243 -0
- mindspore/rewrite/api/node.py +15 -13
- mindspore/rewrite/api/symbol_tree.py +2 -3
- mindspore/run_check/_check_version.py +27 -20
- mindspore/run_check/run_check.py +1 -1
- mindspore/runtime/__init__.py +37 -0
- mindspore/runtime/device.py +27 -0
- mindspore/runtime/event.py +209 -0
- mindspore/runtime/executor.py +177 -0
- mindspore/runtime/memory.py +416 -0
- mindspore/runtime/stream.py +460 -0
- mindspore/runtime/thread_bind_core.py +401 -0
- mindspore/safeguard/rewrite_obfuscation.py +12 -9
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +8 -8
- mindspore/train/_utils.py +96 -27
- mindspore/train/amp.py +9 -5
- mindspore/train/callback/__init__.py +2 -2
- mindspore/train/callback/_callback.py +2 -16
- mindspore/train/callback/_checkpoint.py +53 -55
- mindspore/train/callback/_cluster_monitor.py +14 -18
- mindspore/train/callback/_early_stop.py +1 -1
- mindspore/train/callback/_flops_collector.py +103 -68
- mindspore/train/callback/_history.py +8 -5
- mindspore/train/callback/_lambda_callback.py +2 -2
- mindspore/train/callback/_landscape.py +0 -3
- mindspore/train/callback/_loss_monitor.py +2 -1
- mindspore/train/callback/_on_request_exit.py +6 -5
- mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
- mindspore/train/callback/_summary_collector.py +52 -19
- mindspore/train/callback/_time_monitor.py +2 -1
- mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +228 -108
- mindspore/train/data_sink.py +25 -2
- mindspore/train/dataset_helper.py +15 -16
- mindspore/train/loss_scale_manager.py +8 -7
- mindspore/train/metrics/accuracy.py +3 -3
- mindspore/train/metrics/confusion_matrix.py +9 -9
- mindspore/train/metrics/error.py +3 -3
- mindspore/train/metrics/hausdorff_distance.py +4 -4
- mindspore/train/metrics/mean_surface_distance.py +3 -3
- mindspore/train/metrics/metric.py +0 -12
- mindspore/train/metrics/occlusion_sensitivity.py +4 -2
- mindspore/train/metrics/precision.py +11 -10
- mindspore/train/metrics/recall.py +9 -9
- mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
- mindspore/train/mind_ir_pb2.py +174 -46
- mindspore/train/model.py +269 -136
- mindspore/train/serialization.py +622 -978
- mindspore/train/summary/_summary_adapter.py +2 -2
- mindspore/train/summary/summary_record.py +2 -3
- mindspore/train/train_thor/model_thor.py +1 -1
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +6 -3
- mindspore/utils/dryrun.py +140 -0
- mindspore/utils/hooks.py +81 -0
- mindspore/utils/runtime_execution_order_check.py +552 -0
- mindspore/utils/utils.py +138 -4
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/METADATA +3 -3
- {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/RECORD +587 -418
- {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/entry_points.txt +1 -1
- mindspore/_install_custom.py +0 -43
- mindspore/common/_register_for_adapter.py +0 -74
- mindspore/common/_tensor_overload.py +0 -139
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
- mindspore/ops/auto_generate/gen_arg_handler.py +0 -197
- mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
- mindspore/ops_generate/gen_aclnn_implement.py +0 -263
- mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
- mindspore/ops_generate/gen_pyboost_func.py +0 -1052
- mindspore/ops_generate/gen_utils.py +0 -209
- mindspore/ops_generate/op_proto.py +0 -145
- mindspore/ops_generate/template.py +0 -261
- mindspore/profiler/envprofiling.py +0 -254
- mindspore/profiler/profiling.py +0 -1926
- {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/WHEEL +0 -0
- {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/top_level.txt +0 -0
|
@@ -29,7 +29,7 @@ from mindspore.common.api import jit
|
|
|
29
29
|
from mindspore.common.tensor import Tensor
|
|
30
30
|
from mindspore.common._register_for_tensor import Registry
|
|
31
31
|
from mindspore._c_expression import MetaFuncGraph_, function_id
|
|
32
|
-
from mindspore._c_expression import
|
|
32
|
+
from mindspore._c_expression import TensorPy as Tensor_
|
|
33
33
|
from mindspore._extends.parse.resources import convert_object_map
|
|
34
34
|
from mindspore import _checkparam as validator
|
|
35
35
|
from mindspore import Parameter, ParameterTuple
|
|
@@ -49,11 +49,12 @@ from mindspore.train.data_sink import _init_sink_dataset
|
|
|
49
49
|
from mindspore.train.summary import SummaryRecord
|
|
50
50
|
from mindspore.train._utils import _exec_datagraph
|
|
51
51
|
from mindspore.train.summary.writer import BaseWriter
|
|
52
|
-
from mindspore.train.serialization import _exec_save, load, export_split_mindir,
|
|
52
|
+
from mindspore.train.serialization import _exec_save, load, export_split_mindir, _parse_ckpt_proto, \
|
|
53
53
|
_generate_front_info_for_param_data_file, _get_data_file, _encrypt_data, _split_save, _save_mindir_together, \
|
|
54
54
|
_load_into_param_dict
|
|
55
55
|
from mindspore.parallel import _cost_model_context
|
|
56
56
|
from mindspore.parallel._offload_context import offload_context
|
|
57
|
+
from mindspore.parallel._utils import _is_in_data_parallel_mode
|
|
57
58
|
from mindspore.run_check._check_version import check_version_and_env_config
|
|
58
59
|
from mindspore.dataset.callback.ds_callback import DSCallback, WaitedDSCallback
|
|
59
60
|
from mindspore.dataset.transforms.c_transforms import TensorOperation as CTensorOperation, OneHot as COneHot, \
|
|
@@ -127,7 +128,7 @@ from mindspore.dataset.vision.transforms import AdjustBrightness, AdjustContrast
|
|
|
127
128
|
RandomVerticalFlipWithBBox as VRandomVerticalFlipWithBBox, Rescale as VRescale, Resize as VResize, ResizedCrop, \
|
|
128
129
|
ResizeWithBBox as VResizeWithBBox, Rotate as VRotate, SlicePatches as VSlicePatches, Solarize, ToTensor,\
|
|
129
130
|
TrivialAugmentWide, UniformAugment as VUniformAugment, VerticalFlip as VVerticalFlip
|
|
130
|
-
from mindspore.profiler.
|
|
131
|
+
from mindspore.profiler.profiler import Profiler
|
|
131
132
|
from mindspore.communication._hccl_management import get_rank_size, get_rank_id
|
|
132
133
|
from mindspore.communication._comm_helper import _create_group_helper, _destroy_group_helper
|
|
133
134
|
from mindspore.communication.management import _set_rank_from_mpi, init as cinit, release as crelease
|
|
@@ -360,6 +361,7 @@ FUNC_KEY_DICT_ITEMS = 22 # dict.items
|
|
|
360
361
|
FUNC_KEY_PRIMITIVE_ASSIGN = 23 # mindspore.ops.assign, Primitive("Assign")
|
|
361
362
|
FUNC_KEY_TENSOR_SETITEM = 24 # Tensor.__setitem__
|
|
362
363
|
FUNC_KEY_TENSOR_ASSIGN_VALUE = 25 # Tensor.assign_value
|
|
364
|
+
FUNC_KEY_TENSOR_IS_CONTIGUOUS = 26 # Tensor.is_contiguous
|
|
363
365
|
|
|
364
366
|
# Initialized only once. This map will initialize by c++ when start pijit.
|
|
365
367
|
# key is customer if fuzzy match. (Primitive, constexpr, primexpr, MetaFuncGraph)
|
|
@@ -376,19 +378,19 @@ _func_map = {
|
|
|
376
378
|
constexpr_key: FUNC_KEY_CONSTEXPR,
|
|
377
379
|
primexpr_key: FUNC_KEY_PRIMEXPR,
|
|
378
380
|
meta_func_graph_key: FUNC_KEY_META_FUNCG_RAPH,
|
|
379
|
-
|
|
381
|
+
function_id(GraphCell.__call__): FUNC_KEY_GRAPH_CELL,
|
|
380
382
|
id(psjit_code): FUNC_KEY_PSJIT_CODE,
|
|
381
|
-
|
|
382
|
-
|
|
383
|
+
function_id(_get_cache_prim): FUNC_KEY_GET_CACHE_PRIM,
|
|
384
|
+
function_id(Registry.get): FUNC_KEY_REGISTRY_GET,
|
|
383
385
|
|
|
384
386
|
# tensor side-effect
|
|
385
387
|
primitive_assign_key: FUNC_KEY_PRIMITIVE_ASSIGN,
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
388
|
+
function_id(F.assign): FUNC_KEY_PRIMITIVE_ASSIGN,
|
|
389
|
+
function_id(Tensor.assign_value): FUNC_KEY_TENSOR_ASSIGN_VALUE,
|
|
390
|
+
function_id(Tensor.__setitem__): FUNC_KEY_TENSOR_SETITEM,
|
|
389
391
|
|
|
390
392
|
# Tensor method
|
|
391
|
-
|
|
393
|
+
function_id(Tensor.astype): FUNC_KEY_TENSOR_ASTYPE,
|
|
392
394
|
|
|
393
395
|
# types.BuiltinFunctionType
|
|
394
396
|
function_id(isinstance): FUNC_KEY_BUILTIN_FUNC,
|
|
@@ -448,6 +450,7 @@ _func_map = {
|
|
|
448
450
|
function_id(str.isalnum): FUNC_KEY_BUILTIN_FUNC,
|
|
449
451
|
function_id(str.isidentifier): FUNC_KEY_BUILTIN_FUNC,
|
|
450
452
|
function_id(str.isprintable): FUNC_KEY_BUILTIN_FUNC,
|
|
453
|
+
function_id(str.replace): FUNC_KEY_BUILTIN_FUNC,
|
|
451
454
|
function_id(str.format): FUNC_KEY_BUILTIN_FUNC,
|
|
452
455
|
function_id(str.format_map): FUNC_KEY_BUILTIN_FUNC,
|
|
453
456
|
function_id(str.__format__): FUNC_KEY_BUILTIN_FUNC,
|
|
@@ -472,7 +475,7 @@ _func_map = {
|
|
|
472
475
|
function_id(Tensor_.getitem_index_info): FUNC_KEY_BUILTIN_FUNC,
|
|
473
476
|
function_id(Tensor_.get_bytes): FUNC_KEY_BUILTIN_FUNC,
|
|
474
477
|
function_id(Tensor_.is_init): FUNC_KEY_BUILTIN_FUNC,
|
|
475
|
-
function_id(Tensor_.is_contiguous):
|
|
478
|
+
function_id(Tensor_.is_contiguous): FUNC_KEY_TENSOR_IS_CONTIGUOUS,
|
|
476
479
|
function_id(Tensor_.stride): FUNC_KEY_BUILTIN_FUNC,
|
|
477
480
|
# Tensor_.asnumpy need real tensor value
|
|
478
481
|
|
|
@@ -488,6 +491,7 @@ _func_map = {
|
|
|
488
491
|
function_id(validator.check_number_range): FUNC_KEY_PIJIT_CONSTEXPR,
|
|
489
492
|
function_id(validator.check_is_int): FUNC_KEY_PIJIT_CONSTEXPR,
|
|
490
493
|
function_id(validator.check_is_number): FUNC_KEY_PIJIT_CONSTEXPR,
|
|
494
|
+
function_id(validator.check_positive_int_sequence): FUNC_KEY_PIJIT_CONSTEXPR,
|
|
491
495
|
function_id(np_version_valid): FUNC_KEY_PIJIT_CONSTEXPR,
|
|
492
496
|
function_id(_is_initialized): FUNC_KEY_PIJIT_CONSTEXPR,
|
|
493
497
|
function_id(_set_elegant_exit_handle): FUNC_KEY_PIJIT_CONSTEXPR,
|
|
@@ -496,7 +500,9 @@ _func_map = {
|
|
|
496
500
|
function_id(get_rank_size): FUNC_KEY_PIJIT_CONSTEXPR,
|
|
497
501
|
function_id(get_rank_id): FUNC_KEY_PIJIT_CONSTEXPR,
|
|
498
502
|
function_id(offload_context): FUNC_KEY_PIJIT_CONSTEXPR,
|
|
503
|
+
function_id(_is_in_data_parallel_mode): FUNC_KEY_PIJIT_CONSTEXPR,
|
|
499
504
|
function_id(check_version_and_env_config): FUNC_KEY_PIJIT_CONSTEXPR,
|
|
505
|
+
function_id(Tensor.tolist): FUNC_KEY_PIJIT_CONSTEXPR,
|
|
500
506
|
|
|
501
507
|
# inner function
|
|
502
508
|
function_id(type_size_in_bytes): FUNC_KEY_BUILTIN_FUNC,
|
|
@@ -530,7 +536,6 @@ _func_map = {
|
|
|
530
536
|
function_id(_exec_save): FUNC_KEY_PIJIT_FORBIDDEN,
|
|
531
537
|
function_id(load): FUNC_KEY_PIJIT_FORBIDDEN,
|
|
532
538
|
function_id(export_split_mindir): FUNC_KEY_PIJIT_FORBIDDEN,
|
|
533
|
-
function_id(obfuscate_model): FUNC_KEY_PIJIT_FORBIDDEN,
|
|
534
539
|
function_id(_parse_ckpt_proto): FUNC_KEY_PIJIT_FORBIDDEN,
|
|
535
540
|
function_id(_generate_front_info_for_param_data_file): FUNC_KEY_PIJIT_FORBIDDEN,
|
|
536
541
|
function_id(_get_data_file): FUNC_KEY_PIJIT_FORBIDDEN,
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
# Copyright 2025 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
"""Store and get tensor method"""
|
|
16
|
+
from mindspore import Tensor
|
|
17
|
+
from mindspore._c_expression import function_id
|
|
18
|
+
|
|
19
|
+
tensor_method_id_to_name = {}
|
|
20
|
+
for method_name in dir(Tensor):
|
|
21
|
+
method_id = function_id(getattr(Tensor, method_name))
|
|
22
|
+
tensor_method_id_to_name[method_id] = method_name
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def get_tensor_method_name(id):
|
|
26
|
+
"""Get method name by function id"""
|
|
27
|
+
return tensor_method_id_to_name.get(id, None)
|
mindspore/_extends/utils.py
CHANGED
mindspore/amp.py
CHANGED
|
@@ -69,6 +69,12 @@ def _enable_all_finite():
|
|
|
69
69
|
if not checker.check_custom_version():
|
|
70
70
|
logger.debug("Disable AllFinite due to version check failure.")
|
|
71
71
|
return False
|
|
72
|
+
else:
|
|
73
|
+
return False
|
|
74
|
+
|
|
75
|
+
if "RANK_TABLE_FILE" in os.environ:
|
|
76
|
+
return False
|
|
77
|
+
|
|
72
78
|
runtime_conf = os.environ.get('MS_DEV_RUNTIME_CONF')
|
|
73
79
|
global_jit_config = context.get_jit_config()
|
|
74
80
|
if runtime_conf is not None and ("all_finite:True" in runtime_conf or "all_finite:true" in runtime_conf):
|
|
@@ -82,7 +88,7 @@ def _enable_all_finite():
|
|
|
82
88
|
if global_jit_config:
|
|
83
89
|
logger.debug("Current global jit config is: {}".format(global_jit_config["jit_level"]))
|
|
84
90
|
return global_jit_config["jit_level"] == "O0" or global_jit_config["jit_level"] == "O1"
|
|
85
|
-
return
|
|
91
|
+
return True
|
|
86
92
|
|
|
87
93
|
|
|
88
94
|
def _grad_unscale(scale, grad):
|
|
@@ -93,12 +99,12 @@ def _grad_scale(scale, grad):
|
|
|
93
99
|
return grad * scale.astype(grad.dtype)
|
|
94
100
|
|
|
95
101
|
|
|
96
|
-
@jit
|
|
102
|
+
@jit(backend="ms_backend")
|
|
97
103
|
def _grad_scale_map(scale_value, inputs):
|
|
98
104
|
return _hypermap(_partial(_grad_scale, scale_value), inputs)
|
|
99
105
|
|
|
100
106
|
|
|
101
|
-
@jit
|
|
107
|
+
@jit(backend="ms_backend")
|
|
102
108
|
def _grad_unscale_map(scale_value, inputs):
|
|
103
109
|
return _hypermap(_partial(_grad_unscale, scale_value), inputs)
|
|
104
110
|
|
|
@@ -110,7 +116,7 @@ def _overflow(inputs):
|
|
|
110
116
|
return 1 - status.all()
|
|
111
117
|
|
|
112
118
|
|
|
113
|
-
@jit
|
|
119
|
+
@jit(backend="ms_backend")
|
|
114
120
|
def _all_finite(inputs, check_overflow_mode, enable_allfinite):
|
|
115
121
|
"""all finite check"""
|
|
116
122
|
if _ascend_target():
|
|
@@ -319,7 +325,7 @@ class StaticLossScaler(LossScaler):
|
|
|
319
325
|
|
|
320
326
|
class DynamicLossScaler(LossScaler):
|
|
321
327
|
r"""
|
|
322
|
-
|
|
328
|
+
Manager for dynamically adjusting the loss scaling factor.
|
|
323
329
|
|
|
324
330
|
Dynamic loss scaling tries to determine the largest loss scale value that
|
|
325
331
|
will keep gradients finite. It does this by increasing the loss scale every
|
mindspore/atlprov.dll
CHANGED
|
Binary file
|
mindspore/avcodec-59.dll
CHANGED
|
Binary file
|
mindspore/avdevice-59.dll
CHANGED
|
Binary file
|
mindspore/avfilter-8.dll
CHANGED
|
Binary file
|
mindspore/avformat-59.dll
CHANGED
|
Binary file
|
mindspore/avutil-57.dll
CHANGED
|
Binary file
|
mindspore/boost/__init__.py
CHANGED
|
@@ -13,8 +13,8 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ============================================================================
|
|
15
15
|
"""
|
|
16
|
-
Boost
|
|
17
|
-
|
|
16
|
+
Boost is able to automatically optimize network performance, e.g., by reducing BN, gradient freezing,
|
|
17
|
+
and accumulating gradients to achieve network acceleration.
|
|
18
18
|
|
|
19
19
|
Note:
|
|
20
20
|
This feature is a beta feature, and we are still improving its functionality.
|
mindspore/boost/base.py
CHANGED
|
@@ -21,15 +21,12 @@ import math
|
|
|
21
21
|
import copy
|
|
22
22
|
import numpy as np
|
|
23
23
|
from scipy import linalg as la
|
|
24
|
-
from mindspore.context import ParallelMode
|
|
25
24
|
import mindspore.nn as nn
|
|
26
25
|
from mindspore.nn.optim import LARS
|
|
27
26
|
from mindspore import log as logger
|
|
28
27
|
from mindspore.common import Parameter
|
|
29
|
-
from mindspore.communication.management import get_group_size
|
|
28
|
+
from mindspore.communication.management import get_rank, get_group_size
|
|
30
29
|
from mindspore.train.serialization import load_checkpoint
|
|
31
|
-
from mindspore.parallel._utils import _get_global_rank
|
|
32
|
-
from mindspore.parallel._auto_parallel_context import auto_parallel_context
|
|
33
30
|
from mindspore.boost.less_batch_normalization import CommonHeadLastFN
|
|
34
31
|
|
|
35
32
|
|
|
@@ -329,7 +326,7 @@ def _get_local_pca_mat_path(weight_load_dir, pca_mat_path, n_component, device_n
|
|
|
329
326
|
if os.path.exists(save_pca_end_path):
|
|
330
327
|
os.remove(save_pca_end_path)
|
|
331
328
|
|
|
332
|
-
rank =
|
|
329
|
+
rank = get_rank()
|
|
333
330
|
local_pca_mat_path = full_pca_mat_path[:-4] + "_rank_" + str(rank) + ".npy"
|
|
334
331
|
if os.path.exists(local_pca_mat_path):
|
|
335
332
|
os.remove(local_pca_mat_path)
|
|
@@ -498,8 +495,7 @@ def _save_local_pca_mat(pca_mat, full_pca_mat_path, n_component):
|
|
|
498
495
|
full_pca_mat_path (str): the path of full pca mat.
|
|
499
496
|
n_component (int): pca component.
|
|
500
497
|
"""
|
|
501
|
-
|
|
502
|
-
rank_size = 1 if parallel_mode == ParallelMode.STAND_ALONE else get_group_size()
|
|
498
|
+
rank_size = get_group_size()
|
|
503
499
|
local_dim = math.ceil(n_component // rank_size)
|
|
504
500
|
for rank_id in range(rank_size):
|
|
505
501
|
start_index = rank_id * local_dim
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright 2021-
|
|
1
|
+
# Copyright 2021-2025 Huawei Technologies Co., Ltd
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -15,12 +15,13 @@
|
|
|
15
15
|
"""Boost Mode Cell Wrapper."""
|
|
16
16
|
from __future__ import absolute_import
|
|
17
17
|
|
|
18
|
+
import os
|
|
18
19
|
import numpy as np
|
|
19
20
|
from mindspore.nn.wrap import TrainOneStepCell
|
|
20
21
|
import mindspore.context as context
|
|
21
22
|
from mindspore.context import ParallelMode
|
|
22
23
|
from mindspore.parallel._utils import _get_global_rank, _get_device_num, _get_gradients_mean
|
|
23
|
-
from mindspore.communication.management import get_group_size, create_group
|
|
24
|
+
from mindspore.communication.management import get_rank, get_group_size, create_group
|
|
24
25
|
from mindspore.nn.cell import Cell
|
|
25
26
|
from mindspore.nn import SequentialCell
|
|
26
27
|
from mindspore.common import Tensor
|
|
@@ -38,6 +39,10 @@ from mindspore.boost.adasum import AdaSum
|
|
|
38
39
|
from mindspore.boost.dim_reduce import DimReduce
|
|
39
40
|
from mindspore.boost.grad_accumulation import gradient_accumulation_op, gradient_clear_op
|
|
40
41
|
from mindspore.boost.base import _load_local_pca_mat
|
|
42
|
+
from mindspore.ops.operations.nn_ops import AllFinite
|
|
43
|
+
from mindspore._c_expression import MSContext
|
|
44
|
+
from mindspore.run_check._check_version import AscendEnvChecker
|
|
45
|
+
from mindspore import log as logger
|
|
41
46
|
|
|
42
47
|
__all__ = ["BoostTrainOneStepCell", "BoostTrainOneStepWithLossScaleCell"]
|
|
43
48
|
|
|
@@ -90,6 +95,27 @@ def _tensor_grad_overflow(grad):
|
|
|
90
95
|
def _tensor_grad_overflow_row_tensor(grad):
|
|
91
96
|
return grad_overflow(grad.values)
|
|
92
97
|
|
|
98
|
+
_ascend_grad_overflow = C.MultitypeFuncGraph("_ascend_grad_overflow")
|
|
99
|
+
ascend_grad_overflow = P.IsFinite()
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
@_ascend_grad_overflow.register("Tensor")
|
|
103
|
+
def _tensor_ascend_grad_overflow(grad):
|
|
104
|
+
status = ascend_grad_overflow(grad)
|
|
105
|
+
base = Tensor(1.0, dtype=mstype.float32)
|
|
106
|
+
output = base - status.all()
|
|
107
|
+
output = P.Reshape()(output, ((-1,)))
|
|
108
|
+
return output
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
@_ascend_grad_overflow.register("RowTensor")
|
|
112
|
+
def _tensor_ascend_grad_overflow_row_tensor(grad):
|
|
113
|
+
status = ascend_grad_overflow(grad.values)
|
|
114
|
+
base = Tensor(1.0, dtype=mstype.float32)
|
|
115
|
+
output = base - status.all()
|
|
116
|
+
output = P.Reshape()(output, ((1,)))
|
|
117
|
+
return output
|
|
118
|
+
|
|
93
119
|
|
|
94
120
|
class _OutputToFloat16(Cell):
|
|
95
121
|
"Wrap cell for amp. Cast network output back to float16"
|
|
@@ -362,7 +388,7 @@ class BoostTrainOneStepCell(TrainOneStepCell):
|
|
|
362
388
|
gamma = self.auto_boost.gamma
|
|
363
389
|
alpha = self.auto_boost.alpha
|
|
364
390
|
sigma = self.auto_boost.sigma
|
|
365
|
-
_rank =
|
|
391
|
+
_rank = get_rank()
|
|
366
392
|
_rank_size = 1 if self.parallel_mode == ParallelMode.STAND_ALONE else get_group_size()
|
|
367
393
|
n_components = self.auto_boost.n_components
|
|
368
394
|
timeout = self.auto_boost.timeout
|
|
@@ -483,7 +509,11 @@ class BoostTrainOneStepWithLossScaleCell(BoostTrainOneStepCell):
|
|
|
483
509
|
self.allreduce = P.AllReduce()
|
|
484
510
|
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
|
|
485
511
|
self.gpu_target = (context.get_context("device_target") == "GPU")
|
|
512
|
+
self.ascend_910a_target = (MSContext.get_instance().get_ascend_soc_version() == 'ascend910')
|
|
513
|
+
self.ascend_910b_target = (MSContext.get_instance().get_ascend_soc_version() in ['ascend910b', 'ascend910_93'])
|
|
486
514
|
self.loss_scaling_manager = None
|
|
515
|
+
self._ascend_check_overflow_mode = os.environ.get('MS_ASCEND_CHECK_OVERFLOW_MODE')
|
|
516
|
+
|
|
487
517
|
self.base0 = Tensor(0, mstype.int32)
|
|
488
518
|
self.reduce_all = P.ReduceAll(keep_dims=False)
|
|
489
519
|
self.logic_not = P.LogicalNot()
|
|
@@ -512,6 +542,26 @@ class BoostTrainOneStepWithLossScaleCell(BoostTrainOneStepCell):
|
|
|
512
542
|
else:
|
|
513
543
|
raise TypeError("The scale_sense must be Cell or Tensor, but got {}".format(type(scale_sense)))
|
|
514
544
|
|
|
545
|
+
self.enable_allfinite = True
|
|
546
|
+
runtime_conf = os.environ.get('MS_DEV_RUNTIME_CONF')
|
|
547
|
+
global_jit_config = context.get_jit_config()
|
|
548
|
+
if runtime_conf is not None and ("all_finite:True" in runtime_conf or "all_finite:true" in runtime_conf):
|
|
549
|
+
logger.debug("Enable AllFinite through the environment variable MS_DEV_RUNTIME_CONF.")
|
|
550
|
+
self.enable_allfinite = True
|
|
551
|
+
elif runtime_conf is not None and ("all_finite:False" in runtime_conf or "all_finite:false" in runtime_conf):
|
|
552
|
+
logger.debug("Disable AllFinite through the environment variable MS_DEV_RUNTIME_CONF.")
|
|
553
|
+
self.enable_allfinite = False
|
|
554
|
+
elif global_jit_config:
|
|
555
|
+
logger.debug("Current global jit config is: {}".format(global_jit_config["jit_level"]))
|
|
556
|
+
self.enable_allfinite = global_jit_config["jit_level"] == "O0" or global_jit_config["jit_level"] == "O1"
|
|
557
|
+
if "RANK_TABLE_FILE" in os.environ:
|
|
558
|
+
self.enable_allfinite = False
|
|
559
|
+
if self.ascend_910b_target:
|
|
560
|
+
checker = AscendEnvChecker(None)
|
|
561
|
+
if not checker.check_custom_version():
|
|
562
|
+
logger.debug("Disable AllFinite due to version check failure.")
|
|
563
|
+
self.enable_allfinite = False
|
|
564
|
+
|
|
515
565
|
def construct(self, *inputs):
|
|
516
566
|
weights = self.weights
|
|
517
567
|
loss = self.network(*inputs)
|
|
@@ -523,7 +573,7 @@ class BoostTrainOneStepWithLossScaleCell(BoostTrainOneStepCell):
|
|
|
523
573
|
cond, scaling_sens = self._enhanced_amp_process_overflow_status(grads)
|
|
524
574
|
else:
|
|
525
575
|
scaling_sens = self.scale_sense
|
|
526
|
-
status
|
|
576
|
+
status = Tensor([0] * 8, mstype.int32)
|
|
527
577
|
scaling_sens_filled = C.ones_like(loss) * F.cast(scaling_sens, F.dtype(loss))
|
|
528
578
|
|
|
529
579
|
grads = self.grad(self.network, weights)(*inputs, scaling_sens_filled)
|
|
@@ -646,54 +696,99 @@ class BoostTrainOneStepWithLossScaleCell(BoostTrainOneStepCell):
|
|
|
646
696
|
compute_input = F.depend(compute_input, clear_status)
|
|
647
697
|
return status, compute_input
|
|
648
698
|
|
|
699
|
+
def _check_overflow_status_on_infnan_mode(self, grad_overflow_check_func, compute_output):
|
|
700
|
+
"""check overflow status on infnan mode."""
|
|
701
|
+
flag_sum = self.hyper_map(F.partial(grad_overflow_check_func), compute_output)
|
|
702
|
+
flag_sum = P.AddN()(flag_sum)
|
|
703
|
+
# convert flag_sum to scalar
|
|
704
|
+
flag_sum = P.Reshape()(flag_sum, (()))
|
|
705
|
+
return flag_sum
|
|
706
|
+
|
|
707
|
+
def _get_distributed_overflow_status_on_infnan_mode(self, grad_overflow_check_func, compute_output):
|
|
708
|
+
"""converge the distributed overflow status on infnan mode."""
|
|
709
|
+
flag_sum = self._check_overflow_status_on_infnan_mode(grad_overflow_check_func, compute_output)
|
|
710
|
+
|
|
711
|
+
if self.is_distributed:
|
|
712
|
+
# sum overflow flag over devices
|
|
713
|
+
flag_reduce = self.allreduce(flag_sum)
|
|
714
|
+
overflow = self.less_equal(self.base, flag_reduce)
|
|
715
|
+
else:
|
|
716
|
+
overflow = self.less_equal(self.base, flag_sum)
|
|
717
|
+
return overflow
|
|
718
|
+
|
|
719
|
+
def _get_distributed_overflow_status_on_infnan_enable_allfinite(self, compute_output):
|
|
720
|
+
"""check overflow status on infnan kernel mode."""
|
|
721
|
+
overflow = AllFinite()(compute_output)
|
|
722
|
+
|
|
723
|
+
if self.is_distributed:
|
|
724
|
+
overflow = P.Cast()(overflow, mstype.int8)
|
|
725
|
+
overflow = P.Cast()(self.allreduce(overflow), mstype.bool_)
|
|
726
|
+
return overflow
|
|
727
|
+
|
|
728
|
+
def _get_gpu_overflow_status(self, compute_output):
|
|
729
|
+
"""get overflow status of gpu."""
|
|
730
|
+
overflow = self._get_distributed_overflow_status_on_infnan_mode(_grad_overflow, compute_output)
|
|
731
|
+
return overflow
|
|
732
|
+
|
|
733
|
+
def _get_ascend_overflow_status_on_infnan_mode(self, compute_output):
|
|
734
|
+
"""get overflow status of ascend on infnan mode."""
|
|
735
|
+
overflow = False
|
|
736
|
+
if self.enable_allfinite:
|
|
737
|
+
overflow = self._get_distributed_overflow_status_on_infnan_enable_allfinite(compute_output)
|
|
738
|
+
else:
|
|
739
|
+
overflow = self._get_distributed_overflow_status_on_infnan_mode(_ascend_grad_overflow, compute_output)
|
|
740
|
+
return overflow
|
|
741
|
+
|
|
742
|
+
def _get_ascend_overflow_status_on_saturation_mode(self, status, compute_output):
|
|
743
|
+
"""get overflow status of ascend on saturation mode"""
|
|
744
|
+
status = F.depend(status, compute_output)
|
|
745
|
+
get_status = NPUGetFloatStatusV2()(status)
|
|
746
|
+
|
|
747
|
+
if self.is_distributed:
|
|
748
|
+
# sum overflow flag over devices
|
|
749
|
+
flag_reduce = self.allreduce(get_status)
|
|
750
|
+
# get_status not equal to [0]*8 means overflow
|
|
751
|
+
flag = self.equal(self.base0, flag_reduce)
|
|
752
|
+
status = F.depend(status, flag)
|
|
753
|
+
# distributed needs to skip allreduce to avoid its overflow affecting the next step
|
|
754
|
+
clear_status = NPUClearFloatStatusV2()(status)
|
|
755
|
+
flag = F.depend(flag, clear_status)
|
|
756
|
+
overall_finite = self.reduce_all(flag)
|
|
757
|
+
else:
|
|
758
|
+
status = F.depend(status, get_status)
|
|
759
|
+
clear_status = NPUClearFloatStatusV2()(status)
|
|
760
|
+
get_status = F.depend(get_status, clear_status)
|
|
761
|
+
flag = self.equal(self.base0, get_status)
|
|
762
|
+
overall_finite = self.reduce_all(flag)
|
|
763
|
+
overflow = self.logic_not(overall_finite)
|
|
764
|
+
return overflow
|
|
765
|
+
|
|
766
|
+
|
|
649
767
|
def _get_overflow_status(self, status, compute_output):
|
|
650
768
|
"""
|
|
651
769
|
Get floating-point overflow status.
|
|
652
770
|
|
|
653
|
-
Get overflow results after executing the target process for overflow detection.
|
|
771
|
+
Get overflow results after executing the target process for overflow detection. User-defined training network
|
|
772
|
+
based on this class can also call this interface to process the overflow.
|
|
654
773
|
|
|
655
|
-
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
computation.
|
|
774
|
+
Args:
|
|
775
|
+
status (object): To control the execution sequence with start_overflow_check, it should be set as the first
|
|
776
|
+
output of start_overflow_check.
|
|
777
|
+
compute_output: Overflow detection should be performed in a certain computation process. Set
|
|
778
|
+
`compute_output` as the output of the computation process.
|
|
660
779
|
|
|
661
|
-
|
|
780
|
+
Returns:
|
|
662
781
|
bool, whether the overflow occurs or not.
|
|
663
782
|
"""
|
|
664
|
-
if
|
|
665
|
-
|
|
666
|
-
|
|
667
|
-
|
|
668
|
-
|
|
669
|
-
# sum overflow flag over devices
|
|
670
|
-
flag_reduce = self.allreduce(get_status)
|
|
671
|
-
# get_status not equal to [0]*8 means overflow
|
|
672
|
-
flag = self.equal(self.base0, flag_reduce)
|
|
673
|
-
status = F.depend(status, flag)
|
|
674
|
-
# distributed needs to skip allreduce to avoid its overflow affecting the next step
|
|
675
|
-
clear_status = NPUClearFloatStatusV2()(status)
|
|
676
|
-
flag = F.depend(flag, clear_status)
|
|
677
|
-
overall_finite = self.reduce_all(flag)
|
|
678
|
-
else:
|
|
679
|
-
status = F.depend(status, get_status)
|
|
680
|
-
clear_status = NPUClearFloatStatusV2()(status)
|
|
681
|
-
get_status = F.depend(get_status, clear_status)
|
|
682
|
-
flag = self.equal(self.base0, get_status)
|
|
683
|
-
overall_finite = self.reduce_all(flag)
|
|
684
|
-
overflow = self.logic_not(overall_finite)
|
|
685
|
-
else:
|
|
686
|
-
flag_sum = self.hyper_map(F.partial(_grad_overflow), compute_output)
|
|
687
|
-
flag_sum = P.AddN()(flag_sum)
|
|
688
|
-
# convert flag_sum to scalar
|
|
689
|
-
flag_sum = P.Reshape()(flag_sum, (()))
|
|
690
|
-
|
|
691
|
-
if self.is_distributed:
|
|
692
|
-
# sum overflow flag over devices
|
|
693
|
-
flag_reduce = self.allreduce(flag_sum)
|
|
694
|
-
overflow = self.less_equal(self.base, flag_reduce)
|
|
783
|
+
if self.gpu_target:
|
|
784
|
+
overflow = self._get_gpu_overflow_status(compute_output)
|
|
785
|
+
elif self.ascend_910b_target:
|
|
786
|
+
if self._ascend_check_overflow_mode == "SATURATION_MODE":
|
|
787
|
+
overflow = self._get_ascend_overflow_status_on_saturation_mode(status, compute_output)
|
|
695
788
|
else:
|
|
696
|
-
overflow = self.
|
|
789
|
+
overflow = self._get_ascend_overflow_status_on_infnan_mode(compute_output)
|
|
790
|
+
else: # ascend_910a_target
|
|
791
|
+
overflow = self._get_ascend_overflow_status_on_saturation_mode(status, compute_output)
|
|
697
792
|
return overflow
|
|
698
793
|
|
|
699
794
|
def _process_loss_scale(self, overflow):
|
mindspore/c1.dll
CHANGED
|
Binary file
|
mindspore/c1xx.dll
CHANGED
|
Binary file
|
mindspore/c2.dll
CHANGED
|
Binary file
|
mindspore/common/__init__.py
CHANGED
|
@@ -15,7 +15,8 @@
|
|
|
15
15
|
"""Top-level reference to dtype of common module."""
|
|
16
16
|
from __future__ import absolute_import
|
|
17
17
|
from mindspore.common import dtype
|
|
18
|
-
from mindspore.common.api import
|
|
18
|
+
from mindspore.common.api import ms_memory_recycle, jit, jit_class, _no_grad, \
|
|
19
|
+
flops_collection, set_recursion_limit
|
|
19
20
|
from mindspore.common.dtype import Type, int8, byte, int16, short, int32, intc, int64, intp, \
|
|
20
21
|
uint8, ubyte, uint16, ushort, uint32, uintc, uint64, uintp, float16, half, \
|
|
21
22
|
float32, single, float64, bfloat16, double, bool_, float_, list_, tuple_, int_, \
|
|
@@ -38,6 +39,7 @@ from mindspore.common import generator
|
|
|
38
39
|
from mindspore.common.generator import (
|
|
39
40
|
Generator, default_generator, seed, manual_seed, initial_seed, get_rng_state, set_rng_state)
|
|
40
41
|
from mindspore.ops.function.array_func import is_tensor, from_numpy
|
|
42
|
+
from mindspore.common._grad_function import _Function
|
|
41
43
|
|
|
42
44
|
# symbols from dtype
|
|
43
45
|
__all__ = [
|
|
@@ -69,18 +71,19 @@ __all__ = [
|
|
|
69
71
|
|
|
70
72
|
__all__.extend([
|
|
71
73
|
"tensor", "Tensor", "RowTensor", "SparseTensor", "COOTensor", "CSRTensor", # tensor
|
|
72
|
-
|
|
74
|
+
'jit', 'jit_class', '_no_grad', # api
|
|
73
75
|
"Parameter", "ParameterTuple", # parameter
|
|
74
76
|
"dtype",
|
|
75
77
|
"set_seed", "get_seed", "manual_seed", # random seed
|
|
76
78
|
"set_dump",
|
|
77
79
|
"ms_memory_recycle",
|
|
80
|
+
"set_recursion_limit",
|
|
78
81
|
"mutable", "JitConfig",
|
|
79
82
|
"flops_collection",
|
|
80
83
|
"lazy_inline", "load_mindir", "save_mindir",
|
|
81
84
|
"no_inline",
|
|
82
85
|
"Symbol",
|
|
83
86
|
"recompute",
|
|
84
|
-
"is_tensor", "from_numpy",
|
|
87
|
+
"is_tensor", "from_numpy", "_Function"
|
|
85
88
|
])
|
|
86
89
|
__all__.extend(generator.__all__)
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
# Copyright 2025 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
|
|
16
|
+
"""Defines custom autograd function with functional form."""
|
|
17
|
+
|
|
18
|
+
from typing import Any
|
|
19
|
+
from mindspore._c_expression import FunctionBase as FunctionBase_
|
|
20
|
+
from mindspore.common.tensor import Tensor
|
|
21
|
+
|
|
22
|
+
__all__ = ['_Function']
|
|
23
|
+
|
|
24
|
+
class _Function(FunctionBase_):
|
|
25
|
+
"""
|
|
26
|
+
A Class provides the ability to custom autograd function.
|
|
27
|
+
|
|
28
|
+
Note:
|
|
29
|
+
It is only supported in pynative mode.
|
|
30
|
+
|
|
31
|
+
Supported Platforms:
|
|
32
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
33
|
+
"""
|
|
34
|
+
@staticmethod
|
|
35
|
+
def forward(ctx: Any, *args: Any, **kwars: Any) -> Any:
|
|
36
|
+
raise NotImplementedError("forward function should be customized.")
|
|
37
|
+
|
|
38
|
+
@staticmethod
|
|
39
|
+
def backward(ctx: Any, *grad_outputs: Any) -> Any:
|
|
40
|
+
raise NotImplementedError("backward function should be customized.")
|
|
41
|
+
|
|
42
|
+
@classmethod
|
|
43
|
+
def apply(cls, *args, **kwargs):
|
|
44
|
+
return super().apply(cls, *args, **kwargs)
|
|
45
|
+
|
|
46
|
+
def save_for_backward(self, *tensors: Tensor):
|
|
47
|
+
self.saved_tensors = tensors
|
|
48
|
+
|
|
49
|
+
def mark_dirty(self, *args: Tensor):
|
|
50
|
+
self.dirty_tensors = args
|
|
51
|
+
|
|
52
|
+
def mark_non_differentiable(self, *args: Tensor):
|
|
53
|
+
self.non_differentiable = args
|
|
54
|
+
|
|
55
|
+
def set_materialize_grads(self, value: bool):
|
|
56
|
+
self.materialize_grads = value
|