mindspore 2.4.10__cp311-cp311-win_amd64.whl → 2.6.0rc1__cp311-cp311-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/__init__.py +13 -6
- mindspore/_c_dataengine.cp311-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp311-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp311-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +3 -0
- mindspore/_checkparam.py +3 -38
- mindspore/_deprecated/__init__.py +17 -0
- mindspore/_deprecated/jit.py +198 -0
- mindspore/_extends/builtin_operations.py +1 -1
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
- mindspore/_extends/parse/__init__.py +6 -7
- mindspore/_extends/parse/compile_config.py +83 -0
- mindspore/_extends/parse/deprecated/__init__.py +0 -0
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +394 -0
- mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
- mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
- mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
- mindspore/_extends/parse/parser.py +46 -197
- mindspore/_extends/parse/resources.py +1 -5
- mindspore/_extends/parse/standard_method.py +217 -98
- mindspore/_extends/pijit/__init__.py +2 -2
- mindspore/_extends/pijit/pijit_func_white_list.py +17 -12
- mindspore/_extends/pijit/tensor_func_list.py +27 -0
- mindspore/_extends/utils.py +1 -1
- mindspore/amp.py +11 -5
- mindspore/atlprov.dll +0 -0
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/__init__.py +2 -2
- mindspore/boost/base.py +3 -7
- mindspore/boost/boost_cell_wrapper.py +138 -43
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/common/__init__.py +6 -3
- mindspore/common/_grad_function.py +56 -0
- mindspore/common/_pijit_context.py +14 -5
- mindspore/common/_register_for_tensor.py +1 -2
- mindspore/common/_stub_tensor.py +30 -14
- mindspore/common/_tensor_cpp_method.py +17 -0
- mindspore/common/_tensor_docs.py +4760 -0
- mindspore/common/api.py +435 -371
- mindspore/common/auto_dynamic_shape.py +41 -44
- mindspore/common/dtype.py +39 -36
- mindspore/common/dump.py +9 -6
- mindspore/common/file_system.py +9 -1
- mindspore/common/generator.py +2 -0
- mindspore/common/hook_handle.py +6 -2
- mindspore/common/initializer.py +13 -10
- mindspore/common/jit_begin_end.py +94 -0
- mindspore/common/jit_config.py +6 -1
- mindspore/common/jit_context.py +76 -0
- mindspore/common/jit_trace.py +378 -0
- mindspore/common/lazy_inline.py +9 -3
- mindspore/common/mindir_util.py +10 -2
- mindspore/common/mutable.py +5 -4
- mindspore/common/parameter.py +135 -52
- mindspore/common/seed.py +2 -2
- mindspore/common/sparse_tensor.py +23 -17
- mindspore/common/tensor.py +951 -1992
- mindspore/communication/__init__.py +7 -5
- mindspore/communication/_comm_helper.py +52 -2
- mindspore/communication/comm_func.py +240 -181
- mindspore/communication/management.py +95 -26
- mindspore/context.py +314 -566
- mindspore/dataset/__init__.py +65 -37
- mindspore/dataset/audio/__init__.py +2 -8
- mindspore/dataset/audio/transforms.py +3 -17
- mindspore/dataset/callback/ds_callback.py +2 -1
- mindspore/dataset/core/config.py +87 -6
- mindspore/dataset/engine/cache_admin.py +3 -3
- mindspore/dataset/engine/cache_client.py +6 -5
- mindspore/dataset/engine/datasets.py +292 -267
- mindspore/dataset/engine/datasets_audio.py +22 -8
- mindspore/dataset/engine/datasets_standard_format.py +46 -27
- mindspore/dataset/engine/datasets_text.py +78 -48
- mindspore/dataset/engine/datasets_user_defined.py +182 -116
- mindspore/dataset/engine/datasets_vision.py +120 -44
- mindspore/dataset/engine/iterators.py +283 -63
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
- mindspore/dataset/engine/obs/util.py +8 -0
- mindspore/dataset/engine/queue.py +40 -0
- mindspore/dataset/engine/samplers.py +289 -43
- mindspore/dataset/engine/serializer_deserializer.py +3 -2
- mindspore/dataset/engine/validators.py +53 -11
- mindspore/dataset/text/__init__.py +7 -6
- mindspore/dataset/text/transforms.py +6 -5
- mindspore/dataset/text/utils.py +3 -3
- mindspore/dataset/transforms/__init__.py +0 -9
- mindspore/dataset/transforms/py_transforms_util.py +17 -0
- mindspore/dataset/transforms/transforms.py +31 -14
- mindspore/dataset/utils/browse_dataset.py +1 -1
- mindspore/dataset/vision/__init__.py +2 -9
- mindspore/dataset/vision/transforms.py +202 -158
- mindspore/dataset/vision/utils.py +7 -5
- mindspore/dataset/vision/validators.py +1 -2
- mindspore/device_context/__init__.py +21 -0
- mindspore/device_context/ascend/__init__.py +25 -0
- mindspore/device_context/ascend/device.py +72 -0
- mindspore/device_context/ascend/op_debug.py +153 -0
- mindspore/device_context/ascend/op_precision.py +193 -0
- mindspore/device_context/ascend/op_tuning.py +123 -0
- mindspore/{ops_generate/gen_constants.py → device_context/cpu/__init__.py} +6 -17
- mindspore/device_context/cpu/device.py +62 -0
- mindspore/device_context/cpu/op_tuning.py +43 -0
- mindspore/device_context/gpu/__init__.py +21 -0
- mindspore/device_context/gpu/device.py +70 -0
- mindspore/device_context/gpu/op_precision.py +67 -0
- mindspore/device_context/gpu/op_tuning.py +175 -0
- mindspore/device_manager.py +170 -0
- mindspore/dnnl.dll +0 -0
- mindspore/dpcmi.dll +0 -0
- mindspore/experimental/es/embedding_service.py +35 -27
- mindspore/experimental/llm_boost/__init__.py +1 -0
- mindspore/experimental/llm_boost/ascend_native/__init__.py +22 -0
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +211 -0
- mindspore/experimental/llm_boost/ascend_native/llm_boost.py +52 -0
- mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
- mindspore/experimental/llm_boost/atb/llama_boost.py +6 -1
- mindspore/experimental/llm_boost/register.py +1 -0
- mindspore/experimental/map_parameter.py +4 -4
- mindspore/experimental/optim/adadelta.py +6 -6
- mindspore/experimental/optim/adagrad.py +4 -4
- mindspore/experimental/optim/adam.py +7 -0
- mindspore/experimental/optim/adamax.py +4 -4
- mindspore/experimental/optim/adamw.py +4 -0
- mindspore/experimental/optim/asgd.py +1 -1
- mindspore/experimental/optim/lr_scheduler.py +73 -46
- mindspore/experimental/optim/radam.py +34 -31
- mindspore/experimental/optim/rprop.py +1 -1
- mindspore/experimental/optim/sgd.py +1 -1
- mindspore/hal/contiguous_tensors_handle.py +6 -10
- mindspore/hal/device.py +55 -53
- mindspore/hal/event.py +52 -52
- mindspore/hal/memory.py +157 -117
- mindspore/hal/stream.py +150 -109
- mindspore/include/api/context.h +0 -1
- mindspore/include/dataset/constants.h +7 -4
- mindspore/include/dataset/execute.h +2 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +50 -0
- mindspore/mindrecord/__init__.py +21 -8
- mindspore/mindrecord/config.py +17 -316
- mindspore/mindrecord/filereader.py +1 -9
- mindspore/mindrecord/filewriter.py +5 -15
- mindspore/mindrecord/mindpage.py +1 -9
- mindspore/mindspore_backend_common.dll +0 -0
- mindspore/mindspore_backend_manager.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_dump.dll +0 -0
- mindspore/mindspore_frontend.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_memory_pool.dll +0 -0
- mindspore/mindspore_ms_backend.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
- mindspore/mindspore_ops_kernel_common.dll +0 -0
- mindspore/mindspore_profiler.dll +0 -0
- mindspore/mindspore_pyboost.dll +0 -0
- mindspore/mindspore_pynative.dll +0 -0
- mindspore/mindspore_res_manager.dll +0 -0
- mindspore/mindspore_runtime_pipeline.dll +0 -0
- mindspore/mint/__init__.py +796 -759
- mindspore/mint/distributed/__init__.py +70 -4
- mindspore/mint/distributed/distributed.py +2679 -44
- mindspore/mint/linalg/__init__.py +8 -0
- mindspore/mint/nn/__init__.py +743 -22
- mindspore/mint/nn/functional.py +716 -23
- mindspore/mint/nn/layer/__init__.py +21 -4
- mindspore/mint/nn/layer/_functions.py +334 -0
- mindspore/mint/nn/layer/activation.py +276 -1
- mindspore/mint/nn/layer/basic.py +123 -0
- mindspore/mint/nn/layer/conv.py +921 -0
- mindspore/mint/nn/layer/normalization.py +223 -28
- mindspore/mint/nn/layer/padding.py +797 -0
- mindspore/mint/nn/layer/pooling.py +235 -0
- mindspore/mint/optim/__init__.py +3 -1
- mindspore/mint/optim/adam.py +223 -0
- mindspore/mint/optim/adamw.py +26 -19
- mindspore/mint/optim/sgd.py +171 -0
- mindspore/mint/special/__init__.py +2 -1
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/multiprocessing/__init__.py +5 -0
- mindspore/nn/__init__.py +4 -1
- mindspore/nn/cell.py +1370 -189
- mindspore/nn/dynamic_lr.py +2 -1
- mindspore/nn/layer/activation.py +29 -27
- mindspore/nn/layer/basic.py +51 -35
- mindspore/nn/layer/channel_shuffle.py +3 -3
- mindspore/nn/layer/container.py +1 -1
- mindspore/nn/layer/conv.py +22 -17
- mindspore/nn/layer/embedding.py +12 -11
- mindspore/nn/layer/normalization.py +56 -49
- mindspore/nn/layer/padding.py +4 -3
- mindspore/nn/layer/pooling.py +120 -42
- mindspore/nn/layer/rnn_cells.py +1 -1
- mindspore/nn/layer/rnns.py +2 -1
- mindspore/nn/layer/timedistributed.py +5 -5
- mindspore/nn/layer/transformer.py +59 -36
- mindspore/nn/learning_rate_schedule.py +8 -4
- mindspore/nn/loss/loss.py +58 -55
- mindspore/nn/optim/ada_grad.py +7 -5
- mindspore/nn/optim/adadelta.py +11 -9
- mindspore/nn/optim/adafactor.py +1 -1
- mindspore/nn/optim/adam.py +17 -13
- mindspore/nn/optim/adamax.py +8 -7
- mindspore/nn/optim/adasum.py +5 -5
- mindspore/nn/optim/asgd.py +1 -1
- mindspore/nn/optim/ftrl.py +11 -9
- mindspore/nn/optim/lamb.py +1 -1
- mindspore/nn/optim/lars.py +1 -4
- mindspore/nn/optim/lazyadam.py +12 -10
- mindspore/nn/optim/momentum.py +7 -6
- mindspore/nn/optim/optimizer.py +3 -3
- mindspore/nn/optim/proximal_ada_grad.py +12 -10
- mindspore/nn/optim/rmsprop.py +13 -12
- mindspore/nn/optim/rprop.py +11 -9
- mindspore/nn/optim/sgd.py +9 -6
- mindspore/nn/optim/tft_wrapper.py +5 -2
- mindspore/nn/optim/thor.py +2 -1
- mindspore/nn/probability/bijector/bijector.py +17 -11
- mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
- mindspore/nn/probability/bijector/invert.py +2 -2
- mindspore/nn/probability/bijector/scalar_affine.py +3 -3
- mindspore/nn/probability/bijector/softplus.py +3 -2
- mindspore/nn/probability/distribution/beta.py +3 -3
- mindspore/nn/probability/distribution/categorical.py +1 -1
- mindspore/nn/probability/distribution/cauchy.py +4 -2
- mindspore/nn/probability/distribution/exponential.py +6 -7
- mindspore/nn/probability/distribution/gamma.py +2 -2
- mindspore/nn/probability/distribution/gumbel.py +2 -2
- mindspore/nn/probability/distribution/half_normal.py +5 -3
- mindspore/nn/probability/distribution/logistic.py +5 -3
- mindspore/nn/probability/distribution/poisson.py +1 -1
- mindspore/nn/probability/distribution/uniform.py +5 -3
- mindspore/nn/reinforcement/_tensors_queue.py +1 -1
- mindspore/nn/reinforcement/tensor_array.py +1 -1
- mindspore/nn/utils/init.py +13 -11
- mindspore/nn/wrap/__init__.py +6 -6
- mindspore/nn/wrap/cell_wrapper.py +181 -122
- mindspore/nn/wrap/grad_reducer.py +45 -36
- mindspore/nn/wrap/loss_scale.py +6 -7
- mindspore/numpy/array_creations.py +63 -65
- mindspore/numpy/array_ops.py +149 -144
- mindspore/numpy/logic_ops.py +41 -42
- mindspore/numpy/math_ops.py +365 -363
- mindspore/numpy/utils.py +17 -18
- mindspore/numpy/utils_const.py +5 -6
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +5 -3
- mindspore/ops/_grad_experimental/grad_comm_ops.py +112 -16
- mindspore/ops/_grad_experimental/grad_debug_ops.py +14 -2
- mindspore/ops/_grad_experimental/grad_inner_ops.py +9 -0
- mindspore/ops/_grad_experimental/grad_math_ops.py +2 -1
- mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
- mindspore/ops/_op_impl/cpu/__init__.py +1 -0
- mindspore/ops/_op_impl/cpu/raise_op.py +28 -0
- mindspore/ops/_register_for_op.py +0 -11
- mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
- mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -65
- mindspore/ops/_vmap/vmap_array_ops.py +27 -25
- mindspore/ops/_vmap/vmap_base.py +0 -2
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +21 -14
- mindspore/ops/_vmap/vmap_math_ops.py +15 -16
- mindspore/ops/_vmap/vmap_nn_ops.py +29 -42
- mindspore/ops/auto_generate/__init__.py +4 -3
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +236 -46
- mindspore/ops/auto_generate/gen_extend_func.py +764 -124
- mindspore/ops/auto_generate/gen_ops_def.py +4018 -2264
- mindspore/ops/auto_generate/gen_ops_prim.py +15463 -5037
- mindspore/ops/auto_generate/pyboost_inner_prim.py +221 -87
- mindspore/ops/composite/__init__.py +2 -1
- mindspore/ops/composite/base.py +20 -25
- mindspore/ops/composite/math_ops.py +6 -16
- mindspore/ops/composite/multitype_ops/__init__.py +5 -2
- mindspore/ops/composite/multitype_ops/_compile_utils.py +228 -30
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
- mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
- mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
- mindspore/ops/composite/multitype_ops/pow_impl.py +2 -30
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
- mindspore/ops/function/__init__.py +40 -2
- mindspore/ops/function/_add_attr_func.py +58 -0
- mindspore/ops/function/array_func.py +2089 -2403
- mindspore/ops/function/clip_func.py +80 -23
- mindspore/ops/function/debug_func.py +57 -57
- mindspore/ops/function/grad/__init__.py +1 -0
- mindspore/ops/function/grad/grad_func.py +104 -71
- mindspore/ops/function/image_func.py +2 -2
- mindspore/ops/function/linalg_func.py +47 -78
- mindspore/ops/function/math_func.py +4501 -3802
- mindspore/ops/function/nn_func.py +1726 -620
- mindspore/ops/function/other_func.py +159 -1
- mindspore/ops/function/parameter_func.py +18 -84
- mindspore/ops/function/random_func.py +440 -387
- mindspore/ops/function/reshard_func.py +4 -70
- mindspore/ops/function/sparse_func.py +3 -3
- mindspore/ops/function/sparse_unary_func.py +6 -6
- mindspore/ops/function/spectral_func.py +25 -58
- mindspore/ops/function/vmap_func.py +24 -17
- mindspore/ops/functional.py +22 -7
- mindspore/ops/functional_overload.py +1440 -0
- mindspore/ops/op_info_register.py +32 -244
- mindspore/ops/operations/__init__.py +13 -7
- mindspore/ops/operations/_custom_ops_utils.py +247 -0
- mindspore/ops/operations/_embedding_cache_ops.py +4 -4
- mindspore/ops/operations/_grad_ops.py +2 -43
- mindspore/ops/operations/_infer_ops.py +2 -1
- mindspore/ops/operations/_inner_ops.py +43 -84
- mindspore/ops/operations/_ms_kernel.py +4 -10
- mindspore/ops/operations/_rl_inner_ops.py +1 -1
- mindspore/ops/operations/_scalar_ops.py +3 -2
- mindspore/ops/operations/_sequence_ops.py +1 -1
- mindspore/ops/operations/_tensor_array.py +1 -1
- mindspore/ops/operations/array_ops.py +81 -324
- mindspore/ops/operations/comm_ops.py +154 -108
- mindspore/ops/operations/custom_ops.py +232 -78
- mindspore/ops/operations/debug_ops.py +153 -59
- mindspore/ops/operations/inner_ops.py +7 -5
- mindspore/ops/operations/linalg_ops.py +1 -57
- mindspore/ops/operations/manually_defined/_inner.py +1 -1
- mindspore/ops/operations/manually_defined/ops_def.py +928 -180
- mindspore/ops/operations/math_ops.py +32 -234
- mindspore/ops/operations/nn_ops.py +210 -498
- mindspore/ops/operations/other_ops.py +62 -9
- mindspore/ops/operations/random_ops.py +13 -7
- mindspore/ops/operations/reshard_ops.py +1 -1
- mindspore/ops/operations/sparse_ops.py +2 -2
- mindspore/ops/primitive.py +66 -53
- mindspore/ops/tensor_method.py +1888 -0
- mindspore/ops_generate/__init__.py +0 -5
- mindspore/ops_generate/aclnn/__init__.py +0 -0
- mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +135 -0
- mindspore/ops_generate/aclnn/gen_aclnn_implement.py +257 -0
- mindspore/ops_generate/api/__init__.py +0 -0
- mindspore/ops_generate/api/add_tensor_docs_generator.py +56 -0
- mindspore/ops_generate/api/cpp_create_prim_instance_helper_generator.py +105 -0
- mindspore/ops_generate/api/functional_map_cpp_generator.py +504 -0
- mindspore/ops_generate/api/functional_overload_py_generator.py +112 -0
- mindspore/ops_generate/api/functions_cc_generator.py +237 -0
- mindspore/ops_generate/api/gen_api.py +103 -0
- mindspore/ops_generate/api/op_api_proto.py +235 -0
- mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +461 -0
- mindspore/ops_generate/common/__init__.py +0 -0
- mindspore/ops_generate/common/base_generator.py +11 -0
- mindspore/ops_generate/common/gen_constants.py +91 -0
- mindspore/ops_generate/common/gen_utils.py +348 -0
- mindspore/ops_generate/common/op_proto.py +473 -0
- mindspore/ops_generate/common/template.py +523 -0
- mindspore/ops_generate/gen_ops.py +22 -1069
- mindspore/ops_generate/op_def/__init__.py +0 -0
- mindspore/ops_generate/op_def/gen_op_def.py +90 -0
- mindspore/ops_generate/op_def/lite_ops_cpp_generator.py +191 -0
- mindspore/ops_generate/op_def/ops_def_cc_generator.py +299 -0
- mindspore/ops_generate/op_def/ops_def_h_generator.py +74 -0
- mindspore/ops_generate/op_def/ops_name_h_generator.py +83 -0
- mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
- mindspore/ops_generate/op_def_py/__init__.py +0 -0
- mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
- mindspore/ops_generate/op_def_py/op_def_py_generator.py +132 -0
- mindspore/ops_generate/op_def_py/op_prim_py_generator.py +489 -0
- mindspore/ops_generate/pyboost/__init__.py +0 -0
- mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +139 -0
- mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +93 -0
- mindspore/ops_generate/pyboost/gen_pyboost_func.py +175 -0
- mindspore/ops_generate/pyboost/op_template_parser.py +517 -0
- mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +407 -0
- mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +100 -0
- mindspore/ops_generate/pyboost/pyboost_functions_py_generator.py +148 -0
- mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +155 -0
- mindspore/ops_generate/pyboost/pyboost_inner_prim_generator.py +132 -0
- mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +272 -0
- mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +938 -0
- mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +357 -0
- mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +179 -36
- mindspore/ops_generate/resources/__init__.py +0 -0
- mindspore/ops_generate/resources/resource_list.py +30 -0
- mindspore/ops_generate/resources/resource_loader.py +36 -0
- mindspore/ops_generate/resources/resource_manager.py +64 -0
- mindspore/ops_generate/resources/yaml_loader.py +88 -0
- mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
- mindspore/parallel/__init__.py +7 -3
- mindspore/parallel/_auto_parallel_context.py +152 -34
- mindspore/parallel/_cell_wrapper.py +130 -15
- mindspore/parallel/_parallel_serialization.py +107 -5
- mindspore/parallel/_ps_context.py +1 -1
- mindspore/parallel/_recovery_context.py +7 -2
- mindspore/parallel/_tensor.py +142 -18
- mindspore/parallel/_utils.py +199 -23
- mindspore/parallel/algo_parameter_config.py +4 -4
- mindspore/parallel/auto_parallel.py +732 -0
- mindspore/parallel/checkpoint_convert.py +159 -0
- mindspore/parallel/checkpoint_transform.py +698 -35
- mindspore/parallel/cluster/process_entity/_api.py +276 -50
- mindspore/parallel/cluster/process_entity/_utils.py +41 -6
- mindspore/parallel/cluster/run.py +21 -4
- mindspore/parallel/function/__init__.py +24 -0
- mindspore/parallel/function/reshard_func.py +259 -0
- mindspore/parallel/nn/__init__.py +25 -0
- mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
- mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
- mindspore/parallel/parameter_broadcast.py +25 -14
- mindspore/parallel/shard.py +137 -58
- mindspore/parallel/transform_safetensors.py +363 -305
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/__init__.py +22 -5
- mindspore/profiler/analysis/__init__.py +0 -0
- mindspore/profiler/analysis/parser/__init__.py +0 -0
- mindspore/profiler/analysis/parser/ascend_cann_parser.py +170 -0
- mindspore/profiler/analysis/parser/base_parser.py +158 -0
- mindspore/profiler/analysis/parser/framework_cann_relation_parser.py +45 -0
- mindspore/profiler/analysis/parser/ms_framework_parser.py +142 -0
- mindspore/profiler/analysis/parser/ms_minddata_parser.py +145 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +264 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +40 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +106 -0
- mindspore/profiler/analysis/parser/timeline_creator/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_creator/base_timeline_creator.py +44 -0
- mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +90 -0
- mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +76 -0
- mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +103 -0
- mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +134 -0
- mindspore/profiler/analysis/parser/timeline_event/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_event/base_event.py +233 -0
- mindspore/profiler/analysis/parser/timeline_event/cpu_op_event.py +47 -0
- mindspore/profiler/analysis/parser/timeline_event/flow_event.py +36 -0
- mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +415 -0
- mindspore/profiler/analysis/parser/timeline_event/msprof_event.py +73 -0
- mindspore/profiler/analysis/parser/timeline_event/scope_layer_event.py +53 -0
- mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +146 -0
- mindspore/profiler/analysis/task_manager.py +131 -0
- mindspore/profiler/analysis/time_converter.py +84 -0
- mindspore/profiler/analysis/viewer/__init__.py +0 -0
- mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +372 -0
- mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +87 -0
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +250 -0
- mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +320 -0
- mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +327 -0
- mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +376 -0
- mindspore/profiler/analysis/viewer/ascend_timeline_viewer.py +58 -0
- mindspore/profiler/analysis/viewer/base_viewer.py +26 -0
- mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +96 -0
- mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +581 -0
- mindspore/profiler/analysis/work_flow.py +73 -0
- mindspore/profiler/common/ascend_msprof_exporter.py +139 -0
- mindspore/profiler/common/command_executor.py +90 -0
- mindspore/profiler/common/constant.py +186 -3
- mindspore/profiler/common/file_manager.py +208 -0
- mindspore/profiler/common/log.py +130 -0
- mindspore/profiler/common/msprof_cmd_tool.py +221 -0
- mindspore/profiler/common/path_manager.py +395 -0
- mindspore/profiler/common/process_bar.py +168 -0
- mindspore/profiler/common/process_pool.py +9 -3
- mindspore/profiler/common/profiler_context.py +500 -0
- mindspore/profiler/common/profiler_info.py +304 -0
- mindspore/profiler/common/profiler_meta_data.py +74 -0
- mindspore/profiler/common/profiler_output_path.py +284 -0
- mindspore/profiler/common/profiler_parameters.py +251 -0
- mindspore/profiler/common/profiler_path_manager.py +179 -0
- mindspore/profiler/common/record_function.py +76 -0
- mindspore/profiler/common/tlv_decoder.py +76 -0
- mindspore/profiler/common/util.py +75 -2
- mindspore/profiler/dynamic_profiler.py +341 -75
- mindspore/profiler/envprofiler.py +163 -0
- mindspore/profiler/experimental_config.py +197 -0
- mindspore/profiler/mstx.py +242 -0
- mindspore/profiler/platform/__init__.py +21 -0
- mindspore/profiler/platform/base_profiler.py +40 -0
- mindspore/profiler/platform/cpu_profiler.py +124 -0
- mindspore/profiler/platform/gpu_profiler.py +74 -0
- mindspore/profiler/platform/npu_profiler.py +335 -0
- mindspore/profiler/profiler.py +1073 -90
- mindspore/profiler/profiler_action_controller.py +187 -0
- mindspore/profiler/profiler_interface.py +118 -0
- mindspore/profiler/schedule.py +243 -0
- mindspore/rewrite/api/node.py +15 -13
- mindspore/rewrite/api/symbol_tree.py +2 -3
- mindspore/run_check/_check_version.py +27 -20
- mindspore/run_check/run_check.py +1 -1
- mindspore/runtime/__init__.py +37 -0
- mindspore/runtime/device.py +27 -0
- mindspore/runtime/event.py +209 -0
- mindspore/runtime/executor.py +177 -0
- mindspore/runtime/memory.py +409 -0
- mindspore/runtime/stream.py +460 -0
- mindspore/runtime/thread_bind_core.py +401 -0
- mindspore/safeguard/rewrite_obfuscation.py +12 -9
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +8 -8
- mindspore/train/_utils.py +88 -25
- mindspore/train/amp.py +9 -5
- mindspore/train/callback/__init__.py +2 -2
- mindspore/train/callback/_callback.py +2 -16
- mindspore/train/callback/_checkpoint.py +53 -55
- mindspore/train/callback/_cluster_monitor.py +14 -18
- mindspore/train/callback/_early_stop.py +1 -1
- mindspore/train/callback/_flops_collector.py +103 -68
- mindspore/train/callback/_history.py +8 -5
- mindspore/train/callback/_lambda_callback.py +2 -2
- mindspore/train/callback/_landscape.py +0 -3
- mindspore/train/callback/_loss_monitor.py +2 -1
- mindspore/train/callback/_on_request_exit.py +6 -5
- mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
- mindspore/train/callback/_summary_collector.py +52 -19
- mindspore/train/callback/_time_monitor.py +2 -1
- mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +204 -107
- mindspore/train/data_sink.py +25 -2
- mindspore/train/dataset_helper.py +15 -16
- mindspore/train/loss_scale_manager.py +8 -7
- mindspore/train/metrics/accuracy.py +3 -3
- mindspore/train/metrics/confusion_matrix.py +9 -9
- mindspore/train/metrics/error.py +3 -3
- mindspore/train/metrics/hausdorff_distance.py +4 -4
- mindspore/train/metrics/mean_surface_distance.py +3 -3
- mindspore/train/metrics/metric.py +0 -12
- mindspore/train/metrics/occlusion_sensitivity.py +4 -2
- mindspore/train/metrics/precision.py +11 -10
- mindspore/train/metrics/recall.py +9 -9
- mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
- mindspore/train/mind_ir_pb2.py +174 -46
- mindspore/train/model.py +184 -113
- mindspore/train/serialization.py +622 -978
- mindspore/train/summary/_summary_adapter.py +2 -2
- mindspore/train/summary/summary_record.py +2 -3
- mindspore/train/train_thor/model_thor.py +1 -1
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +6 -3
- mindspore/utils/dryrun.py +140 -0
- mindspore/utils/hooks.py +81 -0
- mindspore/utils/runtime_execution_order_check.py +550 -0
- mindspore/utils/utils.py +138 -4
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/METADATA +3 -3
- {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/RECORD +587 -418
- {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/entry_points.txt +1 -1
- mindspore/_install_custom.py +0 -43
- mindspore/common/_register_for_adapter.py +0 -74
- mindspore/common/_tensor_overload.py +0 -139
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
- mindspore/ops/auto_generate/gen_arg_handler.py +0 -197
- mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
- mindspore/ops_generate/gen_aclnn_implement.py +0 -263
- mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
- mindspore/ops_generate/gen_pyboost_func.py +0 -1052
- mindspore/ops_generate/gen_utils.py +0 -209
- mindspore/ops_generate/op_proto.py +0 -145
- mindspore/ops_generate/template.py +0 -261
- mindspore/profiler/envprofiling.py +0 -254
- mindspore/profiler/profiling.py +0 -1926
- {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/WHEEL +0 -0
- {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/top_level.txt +0 -0
mindspore/common/api.py
CHANGED
|
@@ -17,6 +17,7 @@
|
|
|
17
17
|
"""Providing interface methods."""
|
|
18
18
|
from __future__ import absolute_import
|
|
19
19
|
|
|
20
|
+
import gc
|
|
20
21
|
import types
|
|
21
22
|
import sys
|
|
22
23
|
import os
|
|
@@ -24,11 +25,11 @@ import time
|
|
|
24
25
|
import ast
|
|
25
26
|
import inspect
|
|
26
27
|
import importlib
|
|
27
|
-
import hashlib
|
|
28
28
|
import contextlib
|
|
29
|
+
import json
|
|
29
30
|
from collections import OrderedDict, namedtuple
|
|
30
31
|
from functools import wraps
|
|
31
|
-
import
|
|
32
|
+
from typing import Optional, Callable
|
|
32
33
|
import mindspore as ms
|
|
33
34
|
from mindspore import context
|
|
34
35
|
from mindspore import log as logger
|
|
@@ -39,21 +40,23 @@ from mindspore.common.sparse_tensor import CSRTensor as PythonCSRTensor
|
|
|
39
40
|
from mindspore.common.sparse_tensor import COOTensor as PythonCOOTensor
|
|
40
41
|
from mindspore.common.sparse_tensor import RowTensor as PythonRowTensor
|
|
41
42
|
from mindspore._c_expression.amp import get_curr_amp_strategy
|
|
42
|
-
from mindspore._c_expression import GraphExecutor_,
|
|
43
|
+
from mindspore._c_expression import GraphExecutor_, JitExecutor_, CSRTensor, RowTensor, COOTensor, \
|
|
43
44
|
PyNativeExecutor_, verify_inputs_signature, init_exec_dataset, _set_dataset_mode_config, init_pipeline, \
|
|
44
|
-
_ms_memory_recycle, _bind_device_ctx
|
|
45
|
+
_run_jit_pipeline, _ms_memory_recycle, _bind_device_ctx, StubNode, MSContext, TensorPy as Tensor
|
|
45
46
|
from mindspore.parallel._ps_context import _is_role_sched
|
|
46
|
-
from mindspore.parallel._utils import _check_full_batch, _get_parameter_broadcast,
|
|
47
|
-
|
|
47
|
+
from mindspore.parallel._utils import _check_full_batch, _get_parameter_broadcast, _is_in_auto_parallel_mode, \
|
|
48
|
+
_is_parallel_mode
|
|
48
49
|
from mindspore import _checkparam as Validator
|
|
49
50
|
from mindspore._checkparam import is_stub_tensor
|
|
50
51
|
from mindspore.common._utils import is_shape_unknown
|
|
51
|
-
from mindspore.common.mutable import mutable
|
|
52
|
-
from mindspore.common._register_for_adapter import ms_adapter_registry
|
|
52
|
+
from mindspore.common.mutable import mutable, _check_element_type
|
|
53
53
|
from mindspore.common.auto_dynamic_shape import get_auto_dynamic_shape_args, update_auto_dynamic_shape_phase, \
|
|
54
54
|
get_auto_dynamic_shape_args_with_check_input_signature, update_auto_dynamic_shape_phase_with_check_input_signature
|
|
55
55
|
from mindspore.common._pijit_context import PIJitCaptureContext
|
|
56
|
-
from mindspore.common.parameter import Parameter
|
|
56
|
+
from mindspore.common.parameter import Parameter, set_parameter_hook_updated, parameter_hook_updated
|
|
57
|
+
from mindspore.common.jit_context import jit_context
|
|
58
|
+
from mindspore.common.jit_trace import _jit_trace
|
|
59
|
+
from mindspore.parallel._utils import _init_auto_parallel_context, _clear_auto_parallel_context
|
|
57
60
|
|
|
58
61
|
# Store ms_function class compiled pipeline cache.
|
|
59
62
|
ms_compile_cache = set()
|
|
@@ -107,8 +110,7 @@ def _check_recompile(obj, compile_args, kwargs, full_function_name, create_time,
|
|
|
107
110
|
logger.info(f"The {echo_function_name} has been compiled again. "
|
|
108
111
|
f"{tips} ")
|
|
109
112
|
else:
|
|
110
|
-
tips = "Try to
|
|
111
|
-
"or @jit(compile_once=True) to reduce the compile time. " \
|
|
113
|
+
tips = "Try to reuse the function object decorated by @jit to reduce the compile time. " \
|
|
112
114
|
"For more details, get instructions about `jit` at " \
|
|
113
115
|
"https://www.mindspore.cn/search?inputValue=jit."
|
|
114
116
|
logger.warning(f"The {echo_function_name} has been compiled again. "
|
|
@@ -120,14 +122,6 @@ def _check_recompile(obj, compile_args, kwargs, full_function_name, create_time,
|
|
|
120
122
|
function_phases[full_function_name].add(create_time)
|
|
121
123
|
|
|
122
124
|
|
|
123
|
-
def _ms_adapter_tensor_as_parameter_output(data):
|
|
124
|
-
"""Check whether the data is an output from a parameter which is a ms_adapter tensor.
|
|
125
|
-
Pylint: disable=unidiomatic-typecheck.
|
|
126
|
-
"""
|
|
127
|
-
return ms_adapter_registry.is_registered and isinstance(data, ms_adapter_registry.tensor) \
|
|
128
|
-
and hasattr(data, "__ms_parameter_output__") and getattr(data, "__ms_parameter_output__")
|
|
129
|
-
|
|
130
|
-
|
|
131
125
|
def _convert_python_data(data):
|
|
132
126
|
"""
|
|
133
127
|
Convert C++ data to python.
|
|
@@ -138,18 +132,10 @@ def _convert_python_data(data):
|
|
|
138
132
|
Returns:
|
|
139
133
|
data, a data convert C++ to python
|
|
140
134
|
"""
|
|
141
|
-
if isinstance(data,
|
|
142
|
-
return
|
|
143
|
-
if
|
|
144
|
-
return data
|
|
145
|
-
if isinstance(data, Tensor) and not isinstance(data, PythonTensor):
|
|
146
|
-
return PythonTensor(data, internal=True)
|
|
147
|
-
if isinstance(data, CSRTensor) and not isinstance(data, PythonCSRTensor):
|
|
148
|
-
return PythonCSRTensor(csr_tensor=data)
|
|
149
|
-
if isinstance(data, COOTensor) and not isinstance(data, PythonCOOTensor):
|
|
150
|
-
return PythonCOOTensor(coo_tensor=data)
|
|
151
|
-
if isinstance(data, RowTensor) and not isinstance(data, PythonRowTensor):
|
|
152
|
-
return PythonRowTensor(row_tensor=data)
|
|
135
|
+
if isinstance(data, PythonTensor):
|
|
136
|
+
return data
|
|
137
|
+
if isinstance(data, StubNode):
|
|
138
|
+
return ms.common._stub_tensor._convert_stub(data)
|
|
153
139
|
if data.__class__ is tuple:
|
|
154
140
|
# Handle namedtuple since its type is tuple.
|
|
155
141
|
if hasattr(data, "_fields"):
|
|
@@ -158,6 +144,12 @@ def _convert_python_data(data):
|
|
|
158
144
|
fields = data_dict.keys()
|
|
159
145
|
return namedtuple(type_name, fields)(**_convert_python_data(data_dict))
|
|
160
146
|
return tuple(_convert_python_data(x) for x in data)
|
|
147
|
+
if isinstance(data, CSRTensor) and not isinstance(data, PythonCSRTensor):
|
|
148
|
+
return PythonCSRTensor(csr_tensor=data)
|
|
149
|
+
if isinstance(data, COOTensor) and not isinstance(data, PythonCOOTensor):
|
|
150
|
+
return PythonCOOTensor(coo_tensor=data)
|
|
151
|
+
if isinstance(data, RowTensor) and not isinstance(data, PythonRowTensor):
|
|
152
|
+
return PythonRowTensor(row_tensor=data)
|
|
161
153
|
if data.__class__ is list:
|
|
162
154
|
# Keep list object not change for inplace operation.
|
|
163
155
|
for i in range(len(data)):
|
|
@@ -273,7 +265,9 @@ def __get_compile_cache_dep_files(file_path, compile_cache_dep_files, pkg):
|
|
|
273
265
|
else:
|
|
274
266
|
whole_module = module_name
|
|
275
267
|
if n.name is not None:
|
|
276
|
-
|
|
268
|
+
if not whole_module.endswith("."):
|
|
269
|
+
whole_module += "."
|
|
270
|
+
whole_module += n.name
|
|
277
271
|
try:
|
|
278
272
|
module_spec = importlib.util.find_spec(whole_module, pkg)
|
|
279
273
|
except (ModuleNotFoundError, ValueError):
|
|
@@ -305,7 +299,22 @@ def _get_compile_cache_dep_files():
|
|
|
305
299
|
return compile_cache_dep_files
|
|
306
300
|
|
|
307
301
|
|
|
308
|
-
def
|
|
302
|
+
def _contains_auto_grad_tensor(obj):
|
|
303
|
+
"""Check object is or contains auto grad tensor element"""
|
|
304
|
+
if isinstance(obj, PythonTensor):
|
|
305
|
+
return obj._has_auto_grad()
|
|
306
|
+
if isinstance(obj, (tuple, list)):
|
|
307
|
+
for element in obj:
|
|
308
|
+
if _contains_auto_grad_tensor(element):
|
|
309
|
+
return True
|
|
310
|
+
if isinstance(obj, dict):
|
|
311
|
+
for key in obj:
|
|
312
|
+
if _contains_auto_grad_tensor(obj[key]):
|
|
313
|
+
return True
|
|
314
|
+
return False
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
def _add_mutable_attr(args_list, compile_args, is_grad):
|
|
309
318
|
"""Restore the mutable attr for every arg."""
|
|
310
319
|
new_compile_args = ()
|
|
311
320
|
for idx, arg in enumerate(args_list):
|
|
@@ -316,7 +325,12 @@ def _restore_mutable_attr(args_list, compile_args):
|
|
|
316
325
|
else:
|
|
317
326
|
new_compile_args += (mutable(compile_args[idx], False),)
|
|
318
327
|
else:
|
|
319
|
-
|
|
328
|
+
if is_grad and _contains_auto_grad_tensor(arg):
|
|
329
|
+
if not _check_element_type(arg):
|
|
330
|
+
raise RuntimeError("Input \"%s\" contains tensor with gradient but can not mutable." % (str(arg)))
|
|
331
|
+
new_compile_args += (mutable(compile_args[idx], False),)
|
|
332
|
+
else:
|
|
333
|
+
new_compile_args += (compile_args[idx],)
|
|
320
334
|
return new_compile_args
|
|
321
335
|
|
|
322
336
|
|
|
@@ -330,6 +344,7 @@ def _get_parameter_layout():
|
|
|
330
344
|
|
|
331
345
|
def _handle_arg(obj, arg, compile_arg):
|
|
332
346
|
"""Handle arg for runtime .If need handle the arg, return True"""
|
|
347
|
+
from mindspore._extends.parse import compile_config
|
|
333
348
|
if isinstance(arg, PythonTensor):
|
|
334
349
|
if arg.has_init:
|
|
335
350
|
arg.init_data()
|
|
@@ -342,7 +357,8 @@ def _handle_arg(obj, arg, compile_arg):
|
|
|
342
357
|
if isinstance(arg, list) and not arg:
|
|
343
358
|
return None
|
|
344
359
|
return arg
|
|
345
|
-
elif context.get_context("grad_for_scalar")
|
|
360
|
+
elif (context.get_context("grad_for_scalar") or str(compile_config.GRAD_FOR_SCALAR) == '1') and \
|
|
361
|
+
isinstance(arg, (int, float)):
|
|
346
362
|
return arg
|
|
347
363
|
elif hasattr(obj, "enable_tuple_broaden") and obj.enable_tuple_broaden and isinstance(arg, tuple) and \
|
|
348
364
|
_check_all_tensor(arg):
|
|
@@ -528,12 +544,35 @@ def _get_parameter_ids(args, kwargs):
|
|
|
528
544
|
parameter_ids += str(id(value))
|
|
529
545
|
return parameter_ids
|
|
530
546
|
|
|
547
|
+
def _get_tensor_hook_key(tensor):
|
|
548
|
+
"""Get the hook key of Tensor/Parameter"""
|
|
549
|
+
return ".".join(map(str, map(id, tensor.hooks())))
|
|
550
|
+
|
|
551
|
+
def _get_hook_key(*args, **kwargs):
|
|
552
|
+
"""Get the hook key of Tensors/Parameters"""
|
|
553
|
+
hook_key = ""
|
|
554
|
+
for idx, arg in enumerate(args):
|
|
555
|
+
if idx != 0:
|
|
556
|
+
hook_key += "."
|
|
557
|
+
# Only arg of the type Tensor or Parameter is supported now
|
|
558
|
+
if isinstance(arg, (Tensor, Parameter)):
|
|
559
|
+
hook_key += _get_tensor_hook_key(arg)
|
|
560
|
+
|
|
561
|
+
for idx, value in enumerate(kwargs.values()):
|
|
562
|
+
if idx != 0:
|
|
563
|
+
hook_key += "."
|
|
564
|
+
# Only kwarg of the type Tensor or Parameter is supported now
|
|
565
|
+
if isinstance(value, (Tensor, Parameter)):
|
|
566
|
+
hook_key += _get_tensor_hook_key(value)
|
|
567
|
+
|
|
568
|
+
return hook_key
|
|
569
|
+
|
|
531
570
|
|
|
532
|
-
class
|
|
571
|
+
class _JitExecutor:
|
|
533
572
|
"""
|
|
534
573
|
Represents a function compiled by graph compiler.
|
|
535
574
|
|
|
536
|
-
|
|
575
|
+
_JitExecutor will compile the original function for every combination
|
|
537
576
|
of argument types and shapes it is given (as well as their values, optionally).
|
|
538
577
|
|
|
539
578
|
Args:
|
|
@@ -547,7 +586,7 @@ class _MindsporeFunctionExecutor:
|
|
|
547
586
|
The result of pipeline running in graph mode.
|
|
548
587
|
"""
|
|
549
588
|
|
|
550
|
-
def __init__(self, fn, ms_create_time, input_signature=None, obj=None, jit_config=None):
|
|
589
|
+
def __init__(self, fn, ms_create_time, input_signature=None, obj=None, jit_config=None, dynamic=0):
|
|
551
590
|
init_pipeline()
|
|
552
591
|
if not isinstance(fn, (types.FunctionType, types.MethodType)):
|
|
553
592
|
raise RuntimeError('fn {} is not function or method'.format(fn))
|
|
@@ -559,13 +598,19 @@ class _MindsporeFunctionExecutor:
|
|
|
559
598
|
self.obj = obj
|
|
560
599
|
self.shard_parent_obj = obj
|
|
561
600
|
self.enable_tuple_broaden = False
|
|
562
|
-
|
|
601
|
+
if _run_jit_pipeline():
|
|
602
|
+
self._graph_executor = JitExecutor_.get_instance()
|
|
603
|
+
else:
|
|
604
|
+
self._graph_executor = GraphExecutor_.get_instance()
|
|
563
605
|
self._create_time = ms_create_time
|
|
564
606
|
self._compile_args = None
|
|
607
|
+
self._enable_auto_dynamic = dynamic == 1
|
|
565
608
|
self.jit_config_dict = jit_config.jit_config_dict if jit_config else None
|
|
566
609
|
|
|
567
610
|
@_wrap_func
|
|
568
611
|
def __call__(self, *args, **kwargs):
|
|
612
|
+
if jit_context() and jit_context().is_nested():
|
|
613
|
+
return jit_context().run_graph("", None, *())
|
|
569
614
|
args_list = args
|
|
570
615
|
if self.obj is not None:
|
|
571
616
|
args_list = args_list[1:]
|
|
@@ -581,13 +626,18 @@ class _MindsporeFunctionExecutor:
|
|
|
581
626
|
_pynative_executor.clear_res()
|
|
582
627
|
raise err
|
|
583
628
|
|
|
584
|
-
if context.get_context("precompile_only"):
|
|
629
|
+
if context.get_context("precompile_only") or os.getenv('MS_DEV_PRECOMPILE_ONLY') == '1':
|
|
585
630
|
return None
|
|
586
631
|
|
|
587
632
|
new_inputs = self._generate_run_args(args_list, kwargs)
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
|
|
633
|
+
if context.get_context("mode") == context.PYNATIVE_MODE and not jit_context():
|
|
634
|
+
output = _pynative_executor.grad_jit(*new_inputs)
|
|
635
|
+
else:
|
|
636
|
+
output = self._graph_executor(tuple(new_inputs), phase)
|
|
637
|
+
if jit_context():
|
|
638
|
+
if is_stub_tensor(output):
|
|
639
|
+
output = output.stub_sync()
|
|
640
|
+
return jit_context().run_graph(phase, output, *tuple(new_inputs))
|
|
591
641
|
|
|
592
642
|
return output
|
|
593
643
|
|
|
@@ -603,10 +653,13 @@ class _MindsporeFunctionExecutor:
|
|
|
603
653
|
compile_args = self._generate_compile_args(args)
|
|
604
654
|
key_id = self._get_key_id()
|
|
605
655
|
compile_args = get_auto_dynamic_shape_args_with_check_input_signature(compile_args, key_id,
|
|
606
|
-
self.input_signature
|
|
656
|
+
self.input_signature,
|
|
657
|
+
self._enable_auto_dynamic)
|
|
607
658
|
|
|
608
|
-
#
|
|
609
|
-
|
|
659
|
+
# Add mutable for compile_args for two scene:
|
|
660
|
+
# 1) Origin args is mutable.
|
|
661
|
+
# 2) Args contains sequence with gradient tensor.
|
|
662
|
+
compile_args = _add_mutable_attr(args, compile_args, _pynative_executor.requires_grad())
|
|
610
663
|
self._compile_args = compile_args
|
|
611
664
|
generate_name, echo_function_name = self._get_generate_name()
|
|
612
665
|
# The full Function name
|
|
@@ -645,11 +698,14 @@ class _MindsporeFunctionExecutor:
|
|
|
645
698
|
parameter_ids = _get_parameter_ids(args, kwargs)
|
|
646
699
|
if parameter_ids != "":
|
|
647
700
|
key = str(key) + '.' + parameter_ids
|
|
701
|
+
|
|
702
|
+
key = str(key) + "." + _get_hook_key(*args, **kwargs)
|
|
703
|
+
|
|
648
704
|
phase = generate_name + '.' + str(key)
|
|
649
705
|
|
|
650
706
|
update_auto_dynamic_shape_phase_with_check_input_signature(compile_args, key_id, phase, self.input_signature)
|
|
651
707
|
|
|
652
|
-
if phase in ms_compile_cache:
|
|
708
|
+
if phase in ms_compile_cache and self._graph_executor.has_compiled(phase) and not parameter_hook_updated():
|
|
653
709
|
# Release resource should be released when CompileInner won't be executed, such as cur_convert_input_
|
|
654
710
|
# generated in generate_arguments_key.
|
|
655
711
|
self._graph_executor.clear_compile_arguments_resource()
|
|
@@ -671,7 +727,7 @@ class _MindsporeFunctionExecutor:
|
|
|
671
727
|
setattr(self.fn.__func__, "__jit_function__", True)
|
|
672
728
|
else:
|
|
673
729
|
setattr(self.fn, "__jit_function__", True)
|
|
674
|
-
is_compile = self._graph_executor.compile(self.fn, compile_args, kwargs, phase
|
|
730
|
+
is_compile = self._graph_executor.compile(self.fn, compile_args, kwargs, phase)
|
|
675
731
|
if isinstance(self.fn, types.MethodType):
|
|
676
732
|
delattr(self.fn.__func__, "__jit_function__")
|
|
677
733
|
else:
|
|
@@ -679,10 +735,11 @@ class _MindsporeFunctionExecutor:
|
|
|
679
735
|
else:
|
|
680
736
|
if isinstance(self.obj, ms.nn.Cell):
|
|
681
737
|
self._graph_executor.set_weights_values(self.obj.parameters_dict())
|
|
682
|
-
is_compile = self._graph_executor.compile(self.obj, compile_args, kwargs, phase
|
|
738
|
+
is_compile = self._graph_executor.compile(self.obj, compile_args, kwargs, phase)
|
|
683
739
|
|
|
684
740
|
if not is_compile:
|
|
685
741
|
raise RuntimeError("Executor compile failed.")
|
|
742
|
+
set_parameter_hook_updated(False)
|
|
686
743
|
ms_compile_cache.add(phase)
|
|
687
744
|
|
|
688
745
|
return phase
|
|
@@ -704,7 +761,7 @@ class _MindsporeFunctionExecutor:
|
|
|
704
761
|
else:
|
|
705
762
|
key_id = str(id(self.obj)) + str(self._create_time)
|
|
706
763
|
|
|
707
|
-
if _pynative_executor.
|
|
764
|
+
if _pynative_executor.requires_grad():
|
|
708
765
|
key_id = key_id + ".grad"
|
|
709
766
|
return key_id
|
|
710
767
|
|
|
@@ -714,9 +771,9 @@ class _MindsporeFunctionExecutor:
|
|
|
714
771
|
self.fn.__code__.co_firstlineno)
|
|
715
772
|
echo_function_name = "function \"" + self.fn.__name__ + "\" at the file \"" + self.fn.__code__.co_filename \
|
|
716
773
|
+ "\", line " + str(self.fn.__code__.co_firstlineno)
|
|
717
|
-
if _pynative_executor.
|
|
774
|
+
if _pynative_executor.requires_grad():
|
|
718
775
|
generate_name = generate_name + ".grad"
|
|
719
|
-
if
|
|
776
|
+
if self.fn.__name__ == _PYNATIVE_PARALLEL_FUNC_NAME:
|
|
720
777
|
generate_name = generate_name[:generate_name.rfind(str(id(self.fn)))] + str(id(self.shard_parent_obj))
|
|
721
778
|
return generate_name, echo_function_name
|
|
722
779
|
|
|
@@ -777,6 +834,14 @@ class _MindsporeFunctionExecutor:
|
|
|
777
834
|
"""
|
|
778
835
|
return _get_args_for_run(self, args_list, kwargs, self._compile_args)
|
|
779
836
|
|
|
837
|
+
def _get_func_graph_proto(self, obj, exec_id, ir_type="onnx_ir", use_prefix=False, incremental=False):
|
|
838
|
+
"""Get graph proto from pipeline."""
|
|
839
|
+
if use_prefix:
|
|
840
|
+
exec_id = exec_id + '.' + obj.arguments_key
|
|
841
|
+
if self._graph_executor.has_compiled(exec_id) is False:
|
|
842
|
+
return None
|
|
843
|
+
return self._graph_executor.get_func_graph_proto(exec_id, ir_type, incremental)
|
|
844
|
+
|
|
780
845
|
|
|
781
846
|
# The attributes used to identify a given object.
|
|
782
847
|
attr_op = {"__str__": lambda x: x.__str__(),
|
|
@@ -789,6 +854,13 @@ attr_op = {"__str__": lambda x: x.__str__(),
|
|
|
789
854
|
}
|
|
790
855
|
|
|
791
856
|
|
|
857
|
+
def _is_inner_func(func):
|
|
858
|
+
"""Check whether the func is an inner func which needs hash_args parameter."""
|
|
859
|
+
# This is a workaround for inner api, should fix it later.
|
|
860
|
+
inner_func = ["after_shard", "_wrap_container"]
|
|
861
|
+
return func.__name__ in inner_func
|
|
862
|
+
|
|
863
|
+
|
|
792
864
|
def _get_obj_id(input_obj):
|
|
793
865
|
"""Get hash id of single object."""
|
|
794
866
|
obj_id = ".".join(
|
|
@@ -803,50 +875,227 @@ def _get_jit_hash(hash_input):
|
|
|
803
875
|
return _get_obj_id(hash_input)
|
|
804
876
|
|
|
805
877
|
|
|
806
|
-
def
|
|
878
|
+
def _get_hash_obj(options):
|
|
879
|
+
hash_obj = None
|
|
880
|
+
if "hash_args" in options:
|
|
881
|
+
hash_obj = _get_jit_hash(options["hash_args"])
|
|
882
|
+
del options["hash_args"]
|
|
883
|
+
return hash_obj
|
|
884
|
+
|
|
885
|
+
|
|
886
|
+
def _check_option_device(option, device):
|
|
887
|
+
"""Check jit options wiwh device"""
|
|
888
|
+
option_device_cfgs = {
|
|
889
|
+
'disable_format_transform': ['GPU'],
|
|
890
|
+
'exec_order': ['Ascend'],
|
|
891
|
+
'ge_options': ['Ascend'],
|
|
892
|
+
'infer_boost': ['Ascend'],
|
|
893
|
+
}
|
|
894
|
+
if option in option_device_cfgs and device not in option_device_cfgs[option]:
|
|
895
|
+
logger.warning(f"For 'jit(options)', the option '{option}' is only support device in "
|
|
896
|
+
f"'{option_device_cfgs[option]}', but got '{device}', ignore it.")
|
|
897
|
+
|
|
898
|
+
|
|
899
|
+
def _check_option_backend(option, backend):
|
|
900
|
+
"""Check jit options wiwh backend"""
|
|
901
|
+
option_backend_cfgs = {
|
|
902
|
+
'disable_format_transform': ['ms_backend'],
|
|
903
|
+
'exec_order': ['ms_backend'],
|
|
904
|
+
'ge_options': ['GE'],
|
|
905
|
+
'infer_boost': ['ms_backend'],
|
|
906
|
+
}
|
|
907
|
+
if option in option_backend_cfgs and backend not in option_backend_cfgs[option]:
|
|
908
|
+
logger.warning(f"For 'jit(options)', the option '{option}' is only support backend in "
|
|
909
|
+
f"'{option_backend_cfgs[option]}', but got '{backend}', ignore it.")
|
|
910
|
+
|
|
911
|
+
|
|
912
|
+
def _check_disable_format_transform_value(option, disable_format_transform):
|
|
913
|
+
"""check disable_format_transform option value"""
|
|
914
|
+
if not isinstance(disable_format_transform, bool):
|
|
915
|
+
raise TypeError(f"For 'jit(options)', the type of '{option}' must be bool, "
|
|
916
|
+
f"but got {type(disable_format_transform)}.")
|
|
917
|
+
|
|
918
|
+
|
|
919
|
+
def _check_exec_order_value(option, exec_order):
|
|
920
|
+
"""check exec_order option value"""
|
|
921
|
+
if not isinstance(exec_order, str):
|
|
922
|
+
raise TypeError(f"For 'jit(options)', the type of '{option}' must be str, but got {type(exec_order)}.")
|
|
923
|
+
|
|
924
|
+
if exec_order not in ['bfs', 'dfs']:
|
|
925
|
+
raise ValueError(f"For '{option}', the value of '{option}' must be one of "
|
|
926
|
+
f"['bfs', 'dfs'], but got '{exec_order}'.")
|
|
927
|
+
|
|
928
|
+
|
|
929
|
+
def _check_ge_options_value(option, ge_options):
|
|
930
|
+
"""check ge_options option value"""
|
|
931
|
+
if not isinstance(ge_options, dict):
|
|
932
|
+
raise TypeError(f"For 'jit(options)', the type of '{option}' must be dict, but got {type(ge_options)}.")
|
|
933
|
+
|
|
934
|
+
for level, options in ge_options.items():
|
|
935
|
+
if level not in ['global', 'session']:
|
|
936
|
+
raise ValueError(f"For '{option}', the key of '{option}' must be one of "
|
|
937
|
+
f"['global', 'session'], but got '{level}'.")
|
|
938
|
+
|
|
939
|
+
if not isinstance(options, dict):
|
|
940
|
+
raise TypeError(f"For '{option}', the type of {level} options must be dict, "
|
|
941
|
+
f"but got {type(options)}. The error options: {options}.")
|
|
942
|
+
|
|
943
|
+
for key, value in options.items():
|
|
944
|
+
if not isinstance(key, str):
|
|
945
|
+
raise TypeError(f"For '{option}', the type of key and value must be str, "
|
|
946
|
+
f"but got {type(key)}. The error key is {key}.")
|
|
947
|
+
if not isinstance(value, str):
|
|
948
|
+
raise TypeError(f"For '{option}', the type of key and value must be str, "
|
|
949
|
+
f"but got {type(value)}. The error value is {value}")
|
|
950
|
+
|
|
951
|
+
|
|
952
|
+
def _check_infer_boost_value(option, value):
|
|
953
|
+
"""check infer_boost option value"""
|
|
954
|
+
if not isinstance(value, str):
|
|
955
|
+
raise TypeError(f"For 'jit(options)', the type of '{option}' must be str, but got {type(value)}.")
|
|
956
|
+
|
|
957
|
+
if value not in ['on', 'off']:
|
|
958
|
+
raise ValueError(f"For '{option}', the value of '{option}' must be one of ['on', 'off'], but got '{value}'.")
|
|
959
|
+
|
|
960
|
+
|
|
961
|
+
def _check_option_value(option, value):
|
|
962
|
+
"""check jit options wiwh value"""
|
|
963
|
+
option_valuecheck_funcs = {
|
|
964
|
+
'disable_format_transform': _check_disable_format_transform_value,
|
|
965
|
+
'exec_order': _check_exec_order_value,
|
|
966
|
+
'ge_options': _check_ge_options_value,
|
|
967
|
+
'infer_boost': _check_infer_boost_value,
|
|
968
|
+
}
|
|
969
|
+
if option in option_valuecheck_funcs:
|
|
970
|
+
option_valuecheck_funcs[option](option, value)
|
|
971
|
+
else:
|
|
972
|
+
logger.warning(f"For 'jit(options)', the option argument '{option}' is not recognized, please check!"
|
|
973
|
+
f"For detailed usage of 'jit(options)', please refer to the Mindspore official website.")
|
|
974
|
+
|
|
975
|
+
|
|
976
|
+
def _check_options(options, backend):
|
|
977
|
+
"""Check jit options"""
|
|
978
|
+
# check whether there are deprecated parameters in the dict `options`.
|
|
979
|
+
deprecated_args = {'mode': 'capture_mode', 'input_signature': 'dynamic', 'hash_args: ': '',
|
|
980
|
+
'jit_config': 'jit_level, fullgraph or options', 'compile_once': ''}
|
|
981
|
+
for key, value in deprecated_args.items():
|
|
982
|
+
if key in options:
|
|
983
|
+
log = f"For 'jit', the parameter '{key}' has been deprecated."
|
|
984
|
+
if value != '':
|
|
985
|
+
log += f" Please use the parameter '{value}' instead. For more details, please refer to " \
|
|
986
|
+
f"https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.jit.html."
|
|
987
|
+
logger.warning(log)
|
|
988
|
+
del options[key]
|
|
989
|
+
|
|
990
|
+
# check options' device, backend and value
|
|
991
|
+
for option, value in options.items():
|
|
992
|
+
_check_option_backend(option, backend)
|
|
993
|
+
_check_option_value(option, value)
|
|
994
|
+
|
|
995
|
+
|
|
996
|
+
def jit(
|
|
997
|
+
function: Optional[Callable] = None,
|
|
998
|
+
*,
|
|
999
|
+
capture_mode: str = "ast",
|
|
1000
|
+
jit_level: str = "O0",
|
|
1001
|
+
dynamic: int = 0,
|
|
1002
|
+
fullgraph: bool = False,
|
|
1003
|
+
backend: str = "",
|
|
1004
|
+
**options):
|
|
807
1005
|
"""
|
|
808
1006
|
Create a callable MindSpore graph from a Python function.
|
|
809
1007
|
|
|
810
1008
|
This allows the MindSpore runtime to apply optimizations based on graph.
|
|
811
1009
|
|
|
812
1010
|
Note:
|
|
813
|
-
-
|
|
814
|
-
|
|
815
|
-
-
|
|
816
|
-
|
|
817
|
-
|
|
818
|
-
decorated with @jit(mode=“PIJit”) are not supported,
|
|
819
|
-
and the decoration @jit(mode=“PIJit”) is considered invalid.
|
|
1011
|
+
- It is not supported to run a function with decoration @jit(capture_mode=“bytecode”)
|
|
1012
|
+
in static graph mode, in which case the decoration @jit(capture_mode=“bytecode”) is considered invalid.
|
|
1013
|
+
- Calls to functions with decorated @jit(capture_mode=“bytecode”) inside functions
|
|
1014
|
+
decorated with @jit(capture_mode=“ast”) are not supported,
|
|
1015
|
+
and the decoration @jit(capture_mode=“bytecode”) is considered invalid.
|
|
820
1016
|
|
|
821
1017
|
Args:
|
|
822
|
-
|
|
823
|
-
|
|
824
|
-
|
|
825
|
-
|
|
826
|
-
|
|
827
|
-
|
|
828
|
-
|
|
829
|
-
|
|
830
|
-
|
|
831
|
-
|
|
832
|
-
|
|
833
|
-
|
|
834
|
-
|
|
835
|
-
|
|
836
|
-
|
|
837
|
-
|
|
838
|
-
|
|
839
|
-
|
|
840
|
-
|
|
841
|
-
|
|
842
|
-
|
|
843
|
-
|
|
844
|
-
|
|
845
|
-
|
|
846
|
-
|
|
847
|
-
|
|
848
|
-
|
|
849
|
-
|
|
1018
|
+
function (Function, optional): The Python function that will be run as a graph. Default: ``None``.
|
|
1019
|
+
|
|
1020
|
+
Keyword Args:
|
|
1021
|
+
capture_mode (str, optional): The method to create a callable MindSpore graph. The value of capture_mode
|
|
1022
|
+
should be ``ast`` , ``bytecode`` or ``trace`` . Default: ``ast`` .
|
|
1023
|
+
|
|
1024
|
+
- `ast <https://www.mindspore.cn/tutorials/en/master/compile/static_graph.html>`_ :
|
|
1025
|
+
Parse Python ast to build graph.
|
|
1026
|
+
- `bytecode` :
|
|
1027
|
+
Parse Python bytecode to build graph at runtime. This is an experimental prototype that is subject to
|
|
1028
|
+
change and/or deletion.
|
|
1029
|
+
- `trace` : Trace the execution of Python code to build graph. This is an experimental prototype that is
|
|
1030
|
+
subject to change and/or deletion.
|
|
1031
|
+
|
|
1032
|
+
jit_level (str, optional): Used to control the compilation optimization level. Currently is only effective
|
|
1033
|
+
with default backend. The value of jit_level should be ``O0`` or ``O1`` . Default: ``O0`` .
|
|
1034
|
+
|
|
1035
|
+
- `O0`: Except for optimizations that may affect functionality, all other optimizations are turned off.
|
|
1036
|
+
- `O1`: Using commonly used optimizations and automatic operator fusion optimizations. This optimization
|
|
1037
|
+
level is experimental and is being improved.
|
|
1038
|
+
|
|
1039
|
+
dynamic (int, optional): Whether dynamic shape compilation should be performed. Default: ``0``. The value range
|
|
1040
|
+
is as follows:
|
|
1041
|
+
|
|
1042
|
+
- `0`: Do not perform dynamic shape compilation.
|
|
1043
|
+
- `1`: Enable dynamic shape compilation and automatically detect shape changes.
|
|
1044
|
+
|
|
1045
|
+
fullgraph (bool, optional): Whether to capture the entire function into graph. If False, jit attempts to
|
|
1046
|
+
be compatible with all Python syntax in the function as much as possible. If True, we require that the
|
|
1047
|
+
entire function can be captured into graph. If this is not possible (that is, if there is Python syntax
|
|
1048
|
+
not supported), then it will raise an exception. This currently only applies when capture_mode is ast.
|
|
1049
|
+
Default: ``False``.
|
|
1050
|
+
backend (str, optional): The compilation backend to be used. If this parameter is not set, the framework will
|
|
1051
|
+
use ``GE`` backend for Atlas training series products and ``ms_backend`` backend for others including Atlas
|
|
1052
|
+
A2 training series products by default.
|
|
1053
|
+
|
|
1054
|
+
- `ms_backend`: Adopt KernelByKernel execution mode.
|
|
1055
|
+
- `GE`: Adopt Sink execution mode. The whole model will be sinked to device to execute, only applicable to
|
|
1056
|
+
the top cell of model. And only can be used in Ascend platform.
|
|
1057
|
+
|
|
1058
|
+
**options (dict): A dictionary of options to pass to the compilation backend.
|
|
1059
|
+
|
|
1060
|
+
Some options are device specific, see the below table for details:
|
|
1061
|
+
|
|
1062
|
+
+---------------------------+---------------------------+-------------------------+
|
|
1063
|
+
| Option Parameters | Hardware Platform Support | Backend Support |
|
|
1064
|
+
+===========================+===========================+=========================+
|
|
1065
|
+
| disable_format_transform | GPU | ms_backend |
|
|
1066
|
+
+---------------------------+---------------------------+-------------------------+
|
|
1067
|
+
| exec_order | Ascend | ms_backend |
|
|
1068
|
+
+---------------------------+---------------------------+-------------------------+
|
|
1069
|
+
| ge_options | Ascend | GE |
|
|
1070
|
+
+---------------------------+---------------------------+-------------------------+
|
|
1071
|
+
| infer_boost | Ascend | ms_backend |
|
|
1072
|
+
+---------------------------+---------------------------+-------------------------+
|
|
1073
|
+
|
|
1074
|
+
- disable_format_transform (bool, optional): Whether to disable the automatic format transform function
|
|
1075
|
+
from NCHW to NHWC. When the network training performance of fp16 is worse than fp32,
|
|
1076
|
+
`disable_format_transform` can be set to ``True`` to try to improve training performance.
|
|
1077
|
+
Default: ``False`` .
|
|
1078
|
+
- exec_order (str, optional): Set the sorting method for operator execution, currently only two sorting
|
|
1079
|
+
methods are supported: ``bfs`` and ``dfs`` . Default: ``bfs`` .
|
|
1080
|
+
|
|
1081
|
+
- `bfs`: The default sorting method, breadth priority, good communication masking, relatively good
|
|
1082
|
+
performance.
|
|
1083
|
+
- `dfs`: An optional sorting method, depth-first sorting. The performance is relatively worse than that
|
|
1084
|
+
of bfs execution order, but it occupies less memory. It is recommended to try dfs in scenarios where
|
|
1085
|
+
other execution orders run out of memory (OOM).
|
|
1086
|
+
|
|
1087
|
+
- ge_options (dict): Set options for ge backend. The options are divided into two categories: global,
|
|
1088
|
+
and session. This is an experimental prototype that is subject to change and/or deletion.
|
|
1089
|
+
For detailed information, please refer to `Ascend community <https://www.hiascend.com/document/detail/zh/canncommercial/80RC3/apiref/ascendgraphapi/atlasgeapi_07_0146.html>`_ .
|
|
1090
|
+
|
|
1091
|
+
- global (dict): Set global options.
|
|
1092
|
+
- session (dict): Set session options.
|
|
1093
|
+
|
|
1094
|
+
- infer_boost (str, optional): Used to control the inference mode. Default: ``off``, which means
|
|
1095
|
+
the inference mode is disabled. The range is as follows:
|
|
1096
|
+
|
|
1097
|
+
- `on`: Enable inference mode, get better infer performance.
|
|
1098
|
+
- `off`: Disable inference mode, use forward for inference. The performance is poor.
|
|
850
1099
|
|
|
851
1100
|
Returns:
|
|
852
1101
|
Function, if `fn` is not None, returns a callable function that will execute the compiled function; If `fn` is
|
|
@@ -865,12 +1114,12 @@ def jit(fn=None, mode="PSJit", input_signature=None, hash_args=None, jit_config=
|
|
|
865
1114
|
>>> x = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
|
|
866
1115
|
>>> y = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
|
|
867
1116
|
...
|
|
868
|
-
>>> # create a callable MindSpore graph by calling
|
|
1117
|
+
>>> # create a callable MindSpore graph by calling jit
|
|
869
1118
|
>>> def tensor_add(x, y):
|
|
870
1119
|
... z = x + y
|
|
871
1120
|
... return z
|
|
872
1121
|
...
|
|
873
|
-
>>> tensor_add_graph = jit(
|
|
1122
|
+
>>> tensor_add_graph = jit(function=tensor_add)
|
|
874
1123
|
>>> out = tensor_add_graph(x, y)
|
|
875
1124
|
...
|
|
876
1125
|
>>> # create a callable MindSpore graph through decorator @jit
|
|
@@ -881,180 +1130,70 @@ def jit(fn=None, mode="PSJit", input_signature=None, hash_args=None, jit_config=
|
|
|
881
1130
|
...
|
|
882
1131
|
>>> out = tensor_add_with_dec(x, y)
|
|
883
1132
|
...
|
|
884
|
-
>>> # create a callable MindSpore graph
|
|
885
|
-
>>> @jit(
|
|
886
|
-
...
|
|
887
|
-
... def tensor_add_with_sig(x, y):
|
|
888
|
-
... z = x + y
|
|
889
|
-
... return z
|
|
890
|
-
...
|
|
891
|
-
>>> out = tensor_add_with_sig(x, y)
|
|
892
|
-
...
|
|
893
|
-
>>> @jit(input_signature={"y": Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))})
|
|
894
|
-
... def tensor_add_with_sig_1(x, y):
|
|
1133
|
+
>>> # create a callable MindSpore graph and capture the entire function into the graph
|
|
1134
|
+
>>> @jit(fullgraph=True)
|
|
1135
|
+
... def tensor_add_fullgraph(x, y):
|
|
895
1136
|
... z = x + y
|
|
896
1137
|
... return z
|
|
897
1138
|
...
|
|
898
|
-
>>>
|
|
899
|
-
...
|
|
900
|
-
... # Set hash_args as fn, otherwise cache of compiled closure_fn will not be reused.
|
|
901
|
-
... # While fn differs during calling again, recompilation will be triggered.
|
|
902
|
-
>>> def func(x):
|
|
903
|
-
... return ops.exp(x)
|
|
904
|
-
...
|
|
905
|
-
>>> def closure_fn(x, fn):
|
|
906
|
-
... @jit(hash_args=fn)
|
|
907
|
-
... def inner_fn(a):
|
|
908
|
-
... return fn(a)
|
|
909
|
-
... return inner_fn(x)
|
|
910
|
-
...
|
|
911
|
-
>>> inputs = Tensor(np.ones([10, 10, 10]).astype(np.float32))
|
|
912
|
-
>>> for i in range(10):
|
|
913
|
-
... closure_fn(inputs, func)
|
|
914
|
-
...
|
|
915
|
-
... # Set compile_once = True, otherwise the train_step will be compiled again.
|
|
916
|
-
>>> def train(x):
|
|
917
|
-
... @jit(compile_once = True)
|
|
918
|
-
... def train_step(x):
|
|
919
|
-
... return ops.exp(x)
|
|
920
|
-
... for i in range(10):
|
|
921
|
-
... train_step(x)
|
|
922
|
-
...
|
|
923
|
-
>>> inputs = Tensor(np.ones([10, 10, 10]).astype(np.float32))
|
|
924
|
-
>>> for i in range(10):
|
|
925
|
-
... train(inputs)
|
|
1139
|
+
>>> out = tensor_add_fullgraph(x, y)
|
|
926
1140
|
"""
|
|
927
1141
|
|
|
928
|
-
|
|
929
|
-
|
|
930
|
-
|
|
931
|
-
|
|
932
|
-
|
|
933
|
-
|
|
934
|
-
|
|
935
|
-
|
|
936
|
-
|
|
1142
|
+
capture_mode = Validator.check_string(capture_mode, ["ast", "bytecode", "trace"], "capture_mode", "jit")
|
|
1143
|
+
jit_level = Validator.check_string(jit_level, ["O0", "O1"], "jit_level", "jit")
|
|
1144
|
+
dynamic = Validator.check_int_range(dynamic, 0, 1, Validator.INC_BOTH, "dynamic", "jit")
|
|
1145
|
+
fullgraph = Validator.check_bool(fullgraph, "fullgraph", "jit")
|
|
1146
|
+
if backend == "":
|
|
1147
|
+
backend = "GE" if MSContext.get_instance().get_ascend_soc_version() == "ascend910" else "ms_backend"
|
|
1148
|
+
backend = Validator.check_string(backend, ["ms_backend", "GE"], "backend", "jit")
|
|
1149
|
+
jit_syntax_level = "LAX" if fullgraph is False else "STRICT"
|
|
1150
|
+
hash_obj = _get_hash_obj(options)
|
|
1151
|
+
_check_options(options, backend)
|
|
1152
|
+
options_str = json.dumps(options)
|
|
1153
|
+
infer_boost = options['infer_boost'] if 'infer_boost' in options else "off"
|
|
1154
|
+
exc_mode = options['exc_mode'] if 'exc_mode' in options else "auto"
|
|
1155
|
+
jit_config = JitConfig(jit_level=jit_level, exc_mode=exc_mode, jit_syntax_level=jit_syntax_level,
|
|
1156
|
+
infer_boost=infer_boost, backend=backend, options=options_str)
|
|
1157
|
+
|
|
1158
|
+
def wrap_func(func):
|
|
1159
|
+
nonlocal hash_obj
|
|
1160
|
+
if hash_obj is None or not _is_inner_func(func):
|
|
937
1161
|
hash_obj = int(time.time() * 1e9)
|
|
938
1162
|
|
|
939
|
-
dyn_args = _process_dyn_args(func, input_signature)
|
|
940
|
-
|
|
941
1163
|
@wraps(func)
|
|
942
1164
|
def staging_specialize(*args, **kwargs):
|
|
943
1165
|
if os.getenv("MS_JIT") == '0':
|
|
944
1166
|
return func(*args, **kwargs)
|
|
945
1167
|
|
|
946
1168
|
args, kwargs = _handle_func_args(func, *args, **kwargs)
|
|
947
|
-
|
|
948
1169
|
process_obj = None
|
|
949
1170
|
if args and not isinstance(args[0], PythonTensor) and hasattr(args[0], func.__name__):
|
|
950
1171
|
process_obj = args[0]
|
|
951
|
-
# only the function or cell instance wrapped by shard will fall into this branch
|
|
952
|
-
if _is_pynative_parallel() and func.__name__ == _PYNATIVE_PARALLEL_FUNC_NAME:
|
|
953
|
-
process_obj = hash_args
|
|
954
1172
|
# Handle auto mixed precision strategy.
|
|
955
1173
|
if not hasattr(func, "amp_strategy"):
|
|
956
1174
|
if isinstance(func, types.MethodType):
|
|
957
1175
|
setattr(func.__func__, "amp_strategy", get_curr_amp_strategy())
|
|
958
1176
|
else:
|
|
959
1177
|
setattr(func, "amp_strategy", get_curr_amp_strategy())
|
|
960
|
-
|
|
1178
|
+
|
|
1179
|
+
ms_function_executor = _JitExecutor(func, hash_obj, None, process_obj, jit_config, dynamic)
|
|
1180
|
+
out = ms_function_executor(*args, **kwargs)
|
|
961
1181
|
return out
|
|
962
1182
|
|
|
963
1183
|
return staging_specialize
|
|
964
1184
|
|
|
965
|
-
|
|
966
|
-
|
|
967
|
-
|
|
1185
|
+
if capture_mode == "bytecode":
|
|
1186
|
+
wrap_func = PIJitCaptureContext(jit_config)
|
|
1187
|
+
elif capture_mode == "trace":
|
|
1188
|
+
if function is not None:
|
|
1189
|
+
return _jit_trace(function)
|
|
1190
|
+
return _jit_trace
|
|
968
1191
|
|
|
969
|
-
if
|
|
970
|
-
return wrap_func(
|
|
1192
|
+
if function is not None:
|
|
1193
|
+
return wrap_func(function)
|
|
971
1194
|
return wrap_func
|
|
972
1195
|
|
|
973
1196
|
|
|
974
|
-
def ms_function(fn=None, input_signature=None, hash_args=None, jit_config=None):
|
|
975
|
-
"""
|
|
976
|
-
Create a callable MindSpore graph from a Python function.
|
|
977
|
-
|
|
978
|
-
This allows the MindSpore runtime to apply optimizations based on graph.
|
|
979
|
-
|
|
980
|
-
Note:
|
|
981
|
-
- `ms_function` will be deprecated and removed in a future version. Please use :func:`mindspore.jit` instead.
|
|
982
|
-
- If `input_signature` is specified, each input of `fn` must be a Tensor. And the input arguments for `fn`
|
|
983
|
-
will not accept `**kwargs`.
|
|
984
|
-
|
|
985
|
-
Args:
|
|
986
|
-
fn (Function): The Python function that will be run as a graph. Default: ``None`` .
|
|
987
|
-
input_signature (Tensor): The Tensor which describes the input arguments. The shape and dtype of the Tensor
|
|
988
|
-
will be supplied to this function. The shape and dtype of actual inputs of `fn` should
|
|
989
|
-
keep the same as input_signature. Otherwise, TypeError will be raised. Default: ``None`` .
|
|
990
|
-
hash_args (Union[Object, List or Tuple of Objects]): The local free variables used inside `fn`,
|
|
991
|
-
like functions or objects of class defined outside `fn`. Calling `fn` again with change of `hash_args`
|
|
992
|
-
will trigger recompilation. Default: ``None`` .
|
|
993
|
-
jit_config (JitConfig): Jit config for compile. Default: ``None`` .
|
|
994
|
-
|
|
995
|
-
Returns:
|
|
996
|
-
Function, if `fn` is not None, returns a callable function that will execute the compiled function; If `fn` is
|
|
997
|
-
None, returns a decorator and when this decorator invokes with a single `fn` argument, the callable function is
|
|
998
|
-
equal to the case when `fn` is not None.
|
|
999
|
-
|
|
1000
|
-
Supported Platforms:
|
|
1001
|
-
``Ascend`` ``GPU`` ``CPU``
|
|
1002
|
-
|
|
1003
|
-
Examples:
|
|
1004
|
-
>>> import numpy as np
|
|
1005
|
-
>>> from mindspore import Tensor
|
|
1006
|
-
>>> from mindspore import ops
|
|
1007
|
-
>>> from mindspore import ms_function
|
|
1008
|
-
...
|
|
1009
|
-
>>> x = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
|
|
1010
|
-
>>> y = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
|
|
1011
|
-
...
|
|
1012
|
-
>>> # create a callable MindSpore graph by calling ms_function
|
|
1013
|
-
>>> def tensor_add(x, y):
|
|
1014
|
-
... z = x + y
|
|
1015
|
-
... return z
|
|
1016
|
-
...
|
|
1017
|
-
>>> tensor_add_graph = ms_function(fn=tensor_add)
|
|
1018
|
-
>>> out = tensor_add_graph(x, y)
|
|
1019
|
-
...
|
|
1020
|
-
>>> # create a callable MindSpore graph through decorator @ms_function
|
|
1021
|
-
>>> @ms_function
|
|
1022
|
-
... def tensor_add_with_dec(x, y):
|
|
1023
|
-
... z = x + y
|
|
1024
|
-
... return z
|
|
1025
|
-
...
|
|
1026
|
-
>>> out = tensor_add_with_dec(x, y)
|
|
1027
|
-
...
|
|
1028
|
-
>>> # create a callable MindSpore graph through decorator @ms_function with input_signature parameter
|
|
1029
|
-
>>> @ms_function(input_signature=(Tensor(np.ones([1, 1, 3, 3]).astype(np.float32)),
|
|
1030
|
-
... Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))))
|
|
1031
|
-
... def tensor_add_with_sig(x, y):
|
|
1032
|
-
... z = x + y
|
|
1033
|
-
... return z
|
|
1034
|
-
...
|
|
1035
|
-
>>> out = tensor_add_with_sig(x, y)
|
|
1036
|
-
...
|
|
1037
|
-
... # Set hash_args as fn, otherwise cache of compiled `closure_fn` will not be reused.
|
|
1038
|
-
... # While fn differs during calling again, recompilation will be triggered.
|
|
1039
|
-
>>> def func(x):
|
|
1040
|
-
... return ops.exp(x)
|
|
1041
|
-
...
|
|
1042
|
-
>>> def closure_fn(x, fn):
|
|
1043
|
-
... @ms_function(hash_args=fn)
|
|
1044
|
-
... def inner_fn(a):
|
|
1045
|
-
... return fn(a)
|
|
1046
|
-
... return inner_fn(x)
|
|
1047
|
-
...
|
|
1048
|
-
>>> inputs = Tensor(np.ones([10, 10, 10]).astype(np.float32))
|
|
1049
|
-
>>> for i in range(10):
|
|
1050
|
-
... closure_fn(inputs, func)
|
|
1051
|
-
"""
|
|
1052
|
-
|
|
1053
|
-
logger.warning("'mindspore.ms_function' will be deprecated and removed in a future version. "
|
|
1054
|
-
"Please use 'mindspore.jit' instead.")
|
|
1055
|
-
return jit(fn=fn, input_signature=input_signature, hash_args=hash_args, jit_config=jit_config)
|
|
1056
|
-
|
|
1057
|
-
|
|
1058
1197
|
def _core(fn=None, **flags):
|
|
1059
1198
|
"""
|
|
1060
1199
|
A decorator that adds a flag to the function.
|
|
@@ -1147,69 +1286,6 @@ def _no_recursive(callable_obj):
|
|
|
1147
1286
|
return callable_obj
|
|
1148
1287
|
|
|
1149
1288
|
|
|
1150
|
-
def ms_class(cls):
|
|
1151
|
-
"""
|
|
1152
|
-
Class decorator for user-defined classes.
|
|
1153
|
-
|
|
1154
|
-
This allows MindSpore to identify user-defined classes and thus obtain their attributes and methods.
|
|
1155
|
-
|
|
1156
|
-
Note:
|
|
1157
|
-
`ms_class` will be deprecated and removed in a future version. Please use :func:`mindspore.jit_class` instead.
|
|
1158
|
-
|
|
1159
|
-
Args:
|
|
1160
|
-
cls (Class): User-defined class.
|
|
1161
|
-
|
|
1162
|
-
Returns:
|
|
1163
|
-
Class.
|
|
1164
|
-
|
|
1165
|
-
Raises:
|
|
1166
|
-
TypeError: If ms_class is used for non-class types or nn.Cell.
|
|
1167
|
-
AttributeError: If the private attributes or magic methods of the class decorated with ms_class is called.
|
|
1168
|
-
|
|
1169
|
-
Supported Platforms:
|
|
1170
|
-
``Ascend`` ``GPU`` ``CPU``
|
|
1171
|
-
|
|
1172
|
-
Examples:
|
|
1173
|
-
>>> import mindspore.nn as nn
|
|
1174
|
-
>>> from mindspore import ms_class
|
|
1175
|
-
...
|
|
1176
|
-
>>> @ms_class
|
|
1177
|
-
... class UserDefinedNet:
|
|
1178
|
-
... def __init__(self):
|
|
1179
|
-
... self.value = 10
|
|
1180
|
-
...
|
|
1181
|
-
... def func(self, x):
|
|
1182
|
-
... return 2 * x
|
|
1183
|
-
...
|
|
1184
|
-
>>> class Net(nn.Cell):
|
|
1185
|
-
... def __init__(self):
|
|
1186
|
-
... super(Net, self).__init__()
|
|
1187
|
-
... self.net = UserDefinedNet()
|
|
1188
|
-
...
|
|
1189
|
-
... def construct(self, x):
|
|
1190
|
-
... out = self.net.value + self.net.func(x)
|
|
1191
|
-
... return out
|
|
1192
|
-
...
|
|
1193
|
-
>>> net = Net()
|
|
1194
|
-
>>> out = net(5)
|
|
1195
|
-
>>> print(out)
|
|
1196
|
-
20
|
|
1197
|
-
"""
|
|
1198
|
-
|
|
1199
|
-
logger.warning("'mindspore.ms_class' will be deprecated and removed in a future version. "
|
|
1200
|
-
"Please use 'mindspore.jit_class' instead.")
|
|
1201
|
-
|
|
1202
|
-
# Check if cls is of type class.
|
|
1203
|
-
if not inspect.isclass(cls):
|
|
1204
|
-
raise TypeError(f'Decorator ms_class can only be used for class type, but got {cls}.')
|
|
1205
|
-
# Check if cls is nn.Cell.
|
|
1206
|
-
if issubclass(cls, ms.nn.Cell):
|
|
1207
|
-
raise TypeError(f"Decorator ms_class is used for user-defined classes and cannot be used for nn.Cell: {cls}.")
|
|
1208
|
-
logger.info(f'Found ms_class: {cls}.')
|
|
1209
|
-
setattr(cls, '__ms_class__', True)
|
|
1210
|
-
return cls
|
|
1211
|
-
|
|
1212
|
-
|
|
1213
1289
|
def jit_class(cls):
|
|
1214
1290
|
"""
|
|
1215
1291
|
Class decorator for user-defined classes.
|
|
@@ -1266,28 +1342,6 @@ def jit_class(cls):
|
|
|
1266
1342
|
return cls
|
|
1267
1343
|
|
|
1268
1344
|
|
|
1269
|
-
def set_adapter_config(config):
|
|
1270
|
-
"""
|
|
1271
|
-
Register configuration information for MSAdapter.
|
|
1272
|
-
|
|
1273
|
-
Args:
|
|
1274
|
-
config (dict): Configuration information.
|
|
1275
|
-
"""
|
|
1276
|
-
if not isinstance(config, dict):
|
|
1277
|
-
raise TypeError(f"The input argument of 'set_adapter_config' should be a dict, but got {config}.")
|
|
1278
|
-
for key, value in config.items():
|
|
1279
|
-
if key == "Tensor":
|
|
1280
|
-
ms_adapter_registry.register_tensor(value)
|
|
1281
|
-
elif key == "Parameter":
|
|
1282
|
-
ms_adapter_registry.register_parameter(value)
|
|
1283
|
-
elif key == "convert_object_map":
|
|
1284
|
-
ms_adapter_registry.register_convert_map(value)
|
|
1285
|
-
elif key == "convert_adapter_tensor_map":
|
|
1286
|
-
ms_adapter_registry.register_convert_adapter_tensor_map(value)
|
|
1287
|
-
else:
|
|
1288
|
-
raise ValueError(f"Unsupported key in adapter config: {key}")
|
|
1289
|
-
|
|
1290
|
-
|
|
1291
1345
|
def _function_forbid_reuse(func):
|
|
1292
1346
|
if not inspect.isfunction(func):
|
|
1293
1347
|
raise TypeError(f'Decorator _function_forbid_reuse can only be used for function type, but got {func}.')
|
|
@@ -1351,8 +1405,6 @@ class _no_grad(contextlib.ContextDecorator):
|
|
|
1351
1405
|
self.prev_state = False
|
|
1352
1406
|
|
|
1353
1407
|
def __enter__(self):
|
|
1354
|
-
if context.get_context("mode") == context.GRAPH_MODE:
|
|
1355
|
-
raise RuntimeError("For no_grad feature, currently only support Pynative mode, but got Graph mode.")
|
|
1356
1408
|
self.prev_state = _pynative_executor.enable_grad()
|
|
1357
1409
|
_pynative_executor.set_enable_grad(False)
|
|
1358
1410
|
|
|
@@ -1481,7 +1533,24 @@ class _PyNativeExecutor:
|
|
|
1481
1533
|
Return:
|
|
1482
1534
|
None.
|
|
1483
1535
|
"""
|
|
1484
|
-
return self._executor.grad(grad, obj, weights, grad_position, *args)
|
|
1536
|
+
return self._executor.grad(grad, obj, weights, grad_position, False, *args)
|
|
1537
|
+
|
|
1538
|
+
def grad_aux(self, obj, grad, weights, grad_position, *args):
|
|
1539
|
+
"""
|
|
1540
|
+
Run grad graph with aux
|
|
1541
|
+
|
|
1542
|
+
Args:
|
|
1543
|
+
obj (Function/Cell): The function or cell instance.
|
|
1544
|
+
grad (GradOperation): The gradoperation object.
|
|
1545
|
+
weights (ParameterTuple): The weights of cell instance.
|
|
1546
|
+
grad_position (Union(int, tuple[int])): If int, get the gradient with respect to single input.
|
|
1547
|
+
If tuple, get the gradients with respect to selected inputs. 'grad_position' begins with 0. Default: 0.
|
|
1548
|
+
args (tuple): Function or cell input arguments.
|
|
1549
|
+
|
|
1550
|
+
Return:
|
|
1551
|
+
None.
|
|
1552
|
+
"""
|
|
1553
|
+
return self._executor.grad(grad, obj, weights, grad_position, True, *args)
|
|
1485
1554
|
|
|
1486
1555
|
def clear_res(self):
|
|
1487
1556
|
"""
|
|
@@ -1501,18 +1570,18 @@ class _PyNativeExecutor:
|
|
|
1501
1570
|
"""
|
|
1502
1571
|
self._executor.sync()
|
|
1503
1572
|
|
|
1504
|
-
def grad_jit(self,
|
|
1573
|
+
def grad_jit(self, *args):
|
|
1505
1574
|
"""
|
|
1506
1575
|
Building grad graph decorated by jit.
|
|
1507
1576
|
|
|
1508
1577
|
Args:
|
|
1509
|
-
output (tuple): The function or cell decorated by jit output object.
|
|
1510
1578
|
args (tuple): Function or cell decorated by jit input arguments.
|
|
1511
1579
|
|
|
1512
1580
|
Return:
|
|
1513
|
-
|
|
1581
|
+
output: The output object of function or cell decorated by jit.
|
|
1514
1582
|
"""
|
|
1515
|
-
|
|
1583
|
+
output = self._executor.grad_jit(*args)
|
|
1584
|
+
return output
|
|
1516
1585
|
|
|
1517
1586
|
def call_custom_bprop(self, obj, output, *args, **kwargs):
|
|
1518
1587
|
"""
|
|
@@ -1617,6 +1686,15 @@ class _PyNativeExecutor:
|
|
|
1617
1686
|
"""
|
|
1618
1687
|
self._executor.set_is_run_recompute(status)
|
|
1619
1688
|
|
|
1689
|
+
def high_order(self):
|
|
1690
|
+
"""
|
|
1691
|
+
Is high order of current scene, this is a inner interface.
|
|
1692
|
+
|
|
1693
|
+
Return:
|
|
1694
|
+
Bool.
|
|
1695
|
+
"""
|
|
1696
|
+
return self._executor.high_order()
|
|
1697
|
+
|
|
1620
1698
|
def set_cell_use_dynamic_shape_process(self, flag):
|
|
1621
1699
|
"""
|
|
1622
1700
|
Set the dynamic shape flag of eval process.
|
|
@@ -1699,7 +1777,6 @@ class _CellGraphExecutor:
|
|
|
1699
1777
|
# create needed graph by lazy mode
|
|
1700
1778
|
self.is_init = False
|
|
1701
1779
|
self.enable_tuple_broaden = False
|
|
1702
|
-
self.obfuscate_config = None # used for model's dynamic obfuscation
|
|
1703
1780
|
self._graph_executor = GraphExecutor_.get_instance()
|
|
1704
1781
|
self._graph_executor.set_py_exe_path(sys.executable)
|
|
1705
1782
|
self._graph_executor.set_kernel_build_server_dir(os.path.split(kernel_build_server.__file__)[0] + os.sep)
|
|
@@ -1791,6 +1868,7 @@ class _CellGraphExecutor:
|
|
|
1791
1868
|
Str, the full phase of the cell.
|
|
1792
1869
|
Bool, if the graph has been compiled before, return False, else return True.
|
|
1793
1870
|
"""
|
|
1871
|
+
_init_auto_parallel_context(obj)
|
|
1794
1872
|
obj.__parse_method__ = 'construct'
|
|
1795
1873
|
if not hasattr(obj, obj.__parse_method__):
|
|
1796
1874
|
raise AttributeError(
|
|
@@ -1803,8 +1881,12 @@ class _CellGraphExecutor:
|
|
|
1803
1881
|
self.enable_tuple_broaden = obj.enable_tuple_broaden
|
|
1804
1882
|
logger.debug(f"Convert the network: {do_convert}.")
|
|
1805
1883
|
self._graph_executor.set_enable_tuple_broaden(self.enable_tuple_broaden)
|
|
1884
|
+
|
|
1806
1885
|
key = self._graph_executor.generate_arguments_key(obj, args, kwargs, self.enable_tuple_broaden)
|
|
1807
1886
|
obj.arguments_key = str(key)
|
|
1887
|
+
|
|
1888
|
+
obj.arguments_key = obj.arguments_key + "." + _get_hook_key(*args, **kwargs)
|
|
1889
|
+
|
|
1808
1890
|
# When exist parameter in the top graph inputs, need check if the parameter object has changed.
|
|
1809
1891
|
parameter_ids = _get_parameter_ids(args, kwargs)
|
|
1810
1892
|
if parameter_ids != "":
|
|
@@ -1814,11 +1896,12 @@ class _CellGraphExecutor:
|
|
|
1814
1896
|
obj.phase_cache[raw_phase] = phase
|
|
1815
1897
|
update_auto_dynamic_shape_phase(args, key_id, phase)
|
|
1816
1898
|
obj.current_phase = phase
|
|
1817
|
-
if phase in obj.compile_cache and self.has_compiled(phase):
|
|
1899
|
+
if phase in obj.compile_cache and self.has_compiled(phase) and not parameter_hook_updated():
|
|
1818
1900
|
logger.debug("%r graph has existed.", phase)
|
|
1819
1901
|
# Release resource should be released when CompileInner won't be executed, such as cur_convert_input_
|
|
1820
1902
|
# generated in generate_arguments_key.
|
|
1821
1903
|
self._graph_executor.clear_compile_arguments_resource()
|
|
1904
|
+
_clear_auto_parallel_context(obj)
|
|
1822
1905
|
return phase, False
|
|
1823
1906
|
|
|
1824
1907
|
full_function_name = obj.__class__.__name__ + '.' + str(obj.instance_count) + '.' + str(id(type(obj)))
|
|
@@ -1836,10 +1919,12 @@ class _CellGraphExecutor:
|
|
|
1836
1919
|
else:
|
|
1837
1920
|
jit_config_dict = JitConfig().jit_config_dict
|
|
1838
1921
|
self._graph_executor.set_jit_config(jit_config_dict)
|
|
1839
|
-
|
|
1922
|
+
gc.collect()
|
|
1923
|
+
result = self._graph_executor.compile(obj, args, kwargs, phase)
|
|
1840
1924
|
obj.compile_cache.add(phase)
|
|
1841
1925
|
if not result:
|
|
1842
1926
|
raise RuntimeError("Executor compile failed.")
|
|
1927
|
+
set_parameter_hook_updated(False)
|
|
1843
1928
|
graph = self._graph_executor.get_func_graph(phase)
|
|
1844
1929
|
|
|
1845
1930
|
if graph is None:
|
|
@@ -1856,6 +1941,7 @@ class _CellGraphExecutor:
|
|
|
1856
1941
|
self._build_data_graph(obj, phase)
|
|
1857
1942
|
elif BROADCAST_PHASE not in phase and _get_parameter_broadcast():
|
|
1858
1943
|
_parameter_broadcast(obj)
|
|
1944
|
+
_clear_auto_parallel_context(obj)
|
|
1859
1945
|
return phase, True
|
|
1860
1946
|
|
|
1861
1947
|
def _update_param_node_default_input(self, phase, replace):
|
|
@@ -1875,7 +1961,7 @@ class _CellGraphExecutor:
|
|
|
1875
1961
|
return self._graph_executor.get_allreduce_fusion(real_phase)
|
|
1876
1962
|
|
|
1877
1963
|
def __call__(self, obj, *args, phase='predict'):
|
|
1878
|
-
if context.get_context("precompile_only") or _is_role_sched():
|
|
1964
|
+
if context.get_context("precompile_only") or os.getenv('MS_DEV_PRECOMPILE_ONLY') == '1' or _is_role_sched():
|
|
1879
1965
|
return None
|
|
1880
1966
|
return self.run(obj, *args, phase=phase)
|
|
1881
1967
|
|
|
@@ -1935,25 +2021,12 @@ class _CellGraphExecutor:
|
|
|
1935
2021
|
"""Clear the memory resource of a network."""
|
|
1936
2022
|
self._graph_executor.del_net_res(obj, net_id)
|
|
1937
2023
|
|
|
1938
|
-
def _get_branch_control_input(self):
|
|
1939
|
-
if ('obf_ratio' not in self.obfuscate_config.keys()) or (
|
|
1940
|
-
'obf_random_seed' not in self.obfuscate_config.keys()):
|
|
1941
|
-
raise ValueError("'obf_ratio' and 'obf_random_seed' must be in obfuscate_config.")
|
|
1942
|
-
obf_random_seed = self.obfuscate_config.get('obf_random_seed')
|
|
1943
|
-
if obf_random_seed == 0:
|
|
1944
|
-
branch_control_input = 0
|
|
1945
|
-
else:
|
|
1946
|
-
branch_control_input = _generate_branch_control_input(obf_random_seed)
|
|
1947
|
-
return branch_control_input
|
|
1948
|
-
|
|
1949
2024
|
def _get_func_graph(self, obj, exec_id, use_prefix=False):
|
|
1950
2025
|
"""Get func graph from pipeline."""
|
|
1951
2026
|
if use_prefix:
|
|
1952
2027
|
exec_id = exec_id + '.' + obj.arguments_key
|
|
1953
2028
|
if self._graph_executor.has_compiled(exec_id) is False:
|
|
1954
2029
|
return None
|
|
1955
|
-
if self.obfuscate_config is not None:
|
|
1956
|
-
raise ValueError('For get func graph, obfuscate_config is currently not supported now.')
|
|
1957
2030
|
return self._graph_executor.get_func_graph(exec_id)
|
|
1958
2031
|
|
|
1959
2032
|
def _get_func_graph_proto(self, obj, exec_id, ir_type="onnx_ir", use_prefix=False, incremental=False):
|
|
@@ -1962,11 +2035,6 @@ class _CellGraphExecutor:
|
|
|
1962
2035
|
exec_id = exec_id + '.' + obj.arguments_key
|
|
1963
2036
|
if self._graph_executor.has_compiled(exec_id) is False:
|
|
1964
2037
|
return None
|
|
1965
|
-
if self.obfuscate_config is not None:
|
|
1966
|
-
branch_control_input = self._get_branch_control_input()
|
|
1967
|
-
return self._graph_executor.get_obfuscate_func_graph_proto(exec_id, incremental,
|
|
1968
|
-
self.obfuscate_config['obf_ratio'],
|
|
1969
|
-
branch_control_input)
|
|
1970
2038
|
return self._graph_executor.get_func_graph_proto(exec_id, ir_type, incremental)
|
|
1971
2039
|
|
|
1972
2040
|
def get_optimize_graph_proto(self, obj):
|
|
@@ -2004,6 +2072,8 @@ def ms_memory_recycle():
|
|
|
2004
2072
|
"""
|
|
2005
2073
|
if ms_compile_cache:
|
|
2006
2074
|
_cell_graph_executor.del_net_res(None, ms_compile_cache)
|
|
2075
|
+
if os.getenv('MS_DEV_JIT_PIPELINE') != '0':
|
|
2076
|
+
JitExecutor_.get_instance().del_net_res(None, ms_compile_cache)
|
|
2007
2077
|
ms_compile_cache.clear()
|
|
2008
2078
|
for cell_cache in cells_compile_cache.values():
|
|
2009
2079
|
if cell_cache:
|
|
@@ -2012,28 +2082,22 @@ def ms_memory_recycle():
|
|
|
2012
2082
|
_ms_memory_recycle()
|
|
2013
2083
|
|
|
2014
2084
|
|
|
2015
|
-
def
|
|
2016
|
-
"""
|
|
2017
|
-
|
|
2018
|
-
|
|
2019
|
-
|
|
2020
|
-
|
|
2021
|
-
|
|
2022
|
-
|
|
2023
|
-
|
|
2024
|
-
|
|
2025
|
-
|
|
2026
|
-
|
|
2027
|
-
|
|
2028
|
-
|
|
2029
|
-
|
|
2030
|
-
|
|
2031
|
-
hex_base = 16
|
|
2032
|
-
for item in sha_result:
|
|
2033
|
-
if int(item, hex_base) > 0:
|
|
2034
|
-
branch_control_input *= int(item, hex_base)
|
|
2035
|
-
branch_control_input %= int_max
|
|
2036
|
-
return branch_control_input
|
|
2085
|
+
def set_recursion_limit(recursion_limit=1000):
|
|
2086
|
+
"""
|
|
2087
|
+
Specify the recursion depth limit of function call before compiling graph.
|
|
2088
|
+
It needs to be call when the nested function call is too deep or the number of sub graphs is too large.
|
|
2089
|
+
If recursion_limit is set larger than before, the system max stack depth should be set larger too,
|
|
2090
|
+
otherwise a `core dumped` exception may be raised because of system stack overflow.
|
|
2091
|
+
|
|
2092
|
+
Args:
|
|
2093
|
+
recursion_limit (int, optional): The recursion depth limit. Must be a positive integer. Default: ``1000`` .
|
|
2094
|
+
|
|
2095
|
+
Examples:
|
|
2096
|
+
>>> import mindspore as ms
|
|
2097
|
+
>>> ms.set_recursion_limit(10000)
|
|
2098
|
+
"""
|
|
2099
|
+
recursion_limit = Validator.check_positive_int(recursion_limit)
|
|
2100
|
+
GraphExecutor_.get_instance().set_max_call_depth(recursion_limit)
|
|
2037
2101
|
|
|
2038
2102
|
|
|
2039
2103
|
def _bind_device_context():
|
|
@@ -2058,4 +2122,4 @@ def flops_collection(phase='train'):
|
|
|
2058
2122
|
_cell_graph_executor = _CellGraphExecutor()
|
|
2059
2123
|
_pynative_executor = _PyNativeExecutor()
|
|
2060
2124
|
|
|
2061
|
-
__all__ = ['
|
|
2125
|
+
__all__ = ['ms_memory_recycle', 'jit', 'jit_class', 'flops_collection']
|