mindspore 2.5.0__cp311-cp311-win_amd64.whl → 2.6.0__cp311-cp311-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/__init__.py +6 -4
- mindspore/_c_dataengine.cp311-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp311-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp311-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +3 -0
- mindspore/_checkparam.py +3 -33
- mindspore/_deprecated/__init__.py +17 -0
- mindspore/_deprecated/jit.py +198 -0
- mindspore/_extends/builtin_operations.py +1 -1
- mindspore/_extends/parse/__init__.py +6 -7
- mindspore/_extends/parse/compile_config.py +19 -0
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +22 -3
- mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
- mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
- mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
- mindspore/_extends/parse/parser.py +25 -194
- mindspore/_extends/parse/resources.py +1 -5
- mindspore/_extends/parse/standard_method.py +109 -75
- mindspore/_extends/pijit/__init__.py +2 -2
- mindspore/_extends/pijit/pijit_func_white_list.py +16 -11
- mindspore/_extends/pijit/tensor_func_list.py +27 -0
- mindspore/_extends/utils.py +1 -1
- mindspore/amp.py +4 -4
- mindspore/atlprov.dll +0 -0
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/__init__.py +2 -2
- mindspore/boost/base.py +3 -7
- mindspore/boost/boost_cell_wrapper.py +2 -2
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/common/__init__.py +4 -3
- mindspore/common/_grad_function.py +56 -0
- mindspore/common/_pijit_context.py +14 -5
- mindspore/common/_register_for_tensor.py +1 -1
- mindspore/common/_stub_tensor.py +5 -10
- mindspore/common/_tensor_cpp_method.py +1 -1
- mindspore/common/_tensor_docs.py +2014 -3386
- mindspore/common/api.py +386 -355
- mindspore/common/auto_dynamic_shape.py +41 -44
- mindspore/common/dtype.py +5 -2
- mindspore/common/dump.py +7 -5
- mindspore/common/file_system.py +3 -0
- mindspore/common/generator.py +3 -0
- mindspore/common/hook_handle.py +5 -3
- mindspore/common/initializer.py +10 -6
- mindspore/common/jit_begin_end.py +94 -0
- mindspore/common/jit_config.py +6 -1
- mindspore/common/jit_context.py +76 -0
- mindspore/common/jit_trace.py +378 -0
- mindspore/common/lazy_inline.py +2 -2
- mindspore/common/mutable.py +5 -4
- mindspore/common/parameter.py +106 -39
- mindspore/common/seed.py +2 -2
- mindspore/common/sparse_tensor.py +23 -17
- mindspore/common/tensor.py +332 -714
- mindspore/communication/__init__.py +7 -5
- mindspore/communication/_comm_helper.py +47 -2
- mindspore/communication/comm_func.py +70 -53
- mindspore/communication/management.py +83 -17
- mindspore/context.py +228 -571
- mindspore/dataset/__init__.py +44 -20
- mindspore/dataset/audio/__init__.py +2 -8
- mindspore/dataset/audio/transforms.py +3 -17
- mindspore/dataset/core/config.py +3 -3
- mindspore/dataset/engine/cache_client.py +1 -1
- mindspore/dataset/engine/datasets.py +102 -120
- mindspore/dataset/engine/datasets_audio.py +22 -22
- mindspore/dataset/engine/datasets_standard_format.py +43 -24
- mindspore/dataset/engine/datasets_text.py +78 -85
- mindspore/dataset/engine/datasets_user_defined.py +109 -77
- mindspore/dataset/engine/datasets_vision.py +111 -108
- mindspore/dataset/engine/iterators.py +5 -3
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
- mindspore/dataset/engine/samplers.py +279 -57
- mindspore/dataset/engine/serializer_deserializer.py +2 -1
- mindspore/dataset/engine/validators.py +10 -0
- mindspore/dataset/text/__init__.py +7 -6
- mindspore/dataset/text/transforms.py +6 -5
- mindspore/dataset/text/utils.py +3 -3
- mindspore/dataset/transforms/__init__.py +0 -9
- mindspore/dataset/transforms/transforms.py +3 -3
- mindspore/dataset/utils/browse_dataset.py +1 -1
- mindspore/dataset/vision/__init__.py +2 -9
- mindspore/dataset/vision/transforms.py +202 -158
- mindspore/dataset/vision/utils.py +7 -5
- mindspore/device_context/ascend/op_debug.py +60 -1
- mindspore/device_context/ascend/op_tuning.py +0 -4
- mindspore/device_manager.py +39 -3
- mindspore/dnnl.dll +0 -0
- mindspore/dpcmi.dll +0 -0
- mindspore/experimental/es/embedding_service.py +35 -27
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +0 -2
- mindspore/experimental/map_parameter.py +4 -4
- mindspore/experimental/optim/adadelta.py +22 -26
- mindspore/experimental/optim/adagrad.py +4 -4
- mindspore/experimental/optim/adam.py +4 -0
- mindspore/experimental/optim/adamax.py +4 -4
- mindspore/experimental/optim/adamw.py +4 -0
- mindspore/experimental/optim/asgd.py +1 -1
- mindspore/experimental/optim/lr_scheduler.py +40 -22
- mindspore/experimental/optim/radam.py +5 -5
- mindspore/experimental/optim/rprop.py +1 -1
- mindspore/experimental/optim/sgd.py +1 -1
- mindspore/hal/contiguous_tensors_handle.py +6 -10
- mindspore/hal/device.py +55 -81
- mindspore/hal/event.py +38 -55
- mindspore/hal/memory.py +115 -147
- mindspore/hal/stream.py +81 -125
- mindspore/include/dataset/constants.h +7 -4
- mindspore/include/dataset/execute.h +2 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +40 -2
- mindspore/mindrecord/__init__.py +20 -7
- mindspore/mindspore_backend_common.dll +0 -0
- mindspore/mindspore_backend_manager.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_dump.dll +0 -0
- mindspore/mindspore_frontend.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_memory_pool.dll +0 -0
- mindspore/mindspore_ms_backend.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
- mindspore/mindspore_ops_kernel_common.dll +0 -0
- mindspore/mindspore_profiler.dll +0 -0
- mindspore/mindspore_pyboost.dll +0 -0
- mindspore/mindspore_pynative.dll +0 -0
- mindspore/mindspore_res_manager.dll +0 -0
- mindspore/mindspore_runtime_pipeline.dll +0 -0
- mindspore/mint/__init__.py +133 -702
- mindspore/mint/distributed/__init__.py +5 -1
- mindspore/mint/distributed/distributed.py +198 -113
- mindspore/mint/linalg/__init__.py +2 -0
- mindspore/mint/nn/__init__.py +280 -18
- mindspore/mint/nn/functional.py +282 -64
- mindspore/mint/nn/layer/__init__.py +4 -0
- mindspore/mint/nn/layer/_functions.py +7 -3
- mindspore/mint/nn/layer/activation.py +120 -13
- mindspore/mint/nn/layer/conv.py +234 -28
- mindspore/mint/nn/layer/normalization.py +15 -16
- mindspore/mint/nn/layer/padding.py +1 -1
- mindspore/mint/nn/layer/pooling.py +66 -1
- mindspore/mint/optim/__init__.py +2 -1
- mindspore/mint/optim/sgd.py +171 -0
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/nn/__init__.py +4 -1
- mindspore/nn/cell.py +1253 -179
- mindspore/nn/layer/activation.py +23 -21
- mindspore/nn/layer/basic.py +22 -16
- mindspore/nn/layer/container.py +1 -1
- mindspore/nn/layer/conv.py +53 -42
- mindspore/nn/layer/embedding.py +9 -8
- mindspore/nn/layer/normalization.py +48 -42
- mindspore/nn/layer/pooling.py +75 -31
- mindspore/nn/layer/transformer.py +11 -10
- mindspore/nn/learning_rate_schedule.py +4 -2
- mindspore/nn/loss/loss.py +27 -19
- mindspore/nn/optim/ada_grad.py +6 -5
- mindspore/nn/optim/adadelta.py +9 -7
- mindspore/nn/optim/adafactor.py +1 -1
- mindspore/nn/optim/adam.py +18 -14
- mindspore/nn/optim/adamax.py +8 -7
- mindspore/nn/optim/adasum.py +5 -5
- mindspore/nn/optim/asgd.py +3 -1
- mindspore/nn/optim/ftrl.py +11 -9
- mindspore/nn/optim/lamb.py +1 -1
- mindspore/nn/optim/lazyadam.py +12 -10
- mindspore/nn/optim/momentum.py +7 -6
- mindspore/nn/optim/optimizer.py +2 -2
- mindspore/nn/optim/proximal_ada_grad.py +12 -10
- mindspore/nn/optim/rmsprop.py +13 -12
- mindspore/nn/optim/rprop.py +9 -7
- mindspore/nn/optim/sgd.py +9 -6
- mindspore/nn/optim/tft_wrapper.py +5 -2
- mindspore/nn/probability/bijector/bijector.py +17 -11
- mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
- mindspore/nn/probability/bijector/invert.py +2 -2
- mindspore/nn/probability/bijector/scalar_affine.py +3 -3
- mindspore/nn/probability/bijector/softplus.py +3 -2
- mindspore/nn/probability/distribution/beta.py +3 -3
- mindspore/nn/probability/distribution/categorical.py +1 -1
- mindspore/nn/probability/distribution/cauchy.py +4 -2
- mindspore/nn/probability/distribution/exponential.py +6 -7
- mindspore/nn/probability/distribution/gamma.py +2 -2
- mindspore/nn/probability/distribution/gumbel.py +2 -2
- mindspore/nn/probability/distribution/half_normal.py +5 -3
- mindspore/nn/probability/distribution/logistic.py +5 -3
- mindspore/nn/probability/distribution/poisson.py +1 -1
- mindspore/nn/probability/distribution/uniform.py +5 -3
- mindspore/nn/reinforcement/_tensors_queue.py +1 -1
- mindspore/nn/reinforcement/tensor_array.py +1 -1
- mindspore/nn/wrap/__init__.py +6 -6
- mindspore/nn/wrap/cell_wrapper.py +178 -117
- mindspore/nn/wrap/grad_reducer.py +45 -36
- mindspore/nn/wrap/loss_scale.py +3 -3
- mindspore/numpy/array_creations.py +3 -3
- mindspore/numpy/array_ops.py +1 -1
- mindspore/numpy/utils.py +1 -2
- mindspore/numpy/utils_const.py +1 -2
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +3 -2
- mindspore/ops/_grad_experimental/grad_comm_ops.py +18 -3
- mindspore/ops/_grad_experimental/grad_debug_ops.py +8 -1
- mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
- mindspore/ops/_register_for_op.py +0 -11
- mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
- mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -4
- mindspore/ops/_vmap/vmap_array_ops.py +32 -6
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +2 -1
- mindspore/ops/_vmap/vmap_math_ops.py +4 -7
- mindspore/ops/_vmap/vmap_nn_ops.py +9 -8
- mindspore/ops/auto_generate/__init__.py +4 -3
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +127 -52
- mindspore/ops/auto_generate/gen_extend_func.py +286 -208
- mindspore/ops/auto_generate/gen_ops_def.py +2783 -2335
- mindspore/ops/auto_generate/gen_ops_prim.py +8992 -2686
- mindspore/ops/auto_generate/pyboost_inner_prim.py +106 -76
- mindspore/ops/composite/__init__.py +2 -1
- mindspore/ops/composite/base.py +19 -24
- mindspore/ops/composite/math_ops.py +6 -16
- mindspore/ops/composite/multitype_ops/__init__.py +5 -2
- mindspore/ops/composite/multitype_ops/_compile_utils.py +4 -5
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
- mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
- mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
- mindspore/ops/composite/multitype_ops/pow_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
- mindspore/ops/function/__init__.py +28 -2
- mindspore/ops/function/_add_attr_func.py +58 -0
- mindspore/ops/function/array_func.py +1631 -2347
- mindspore/ops/function/clip_func.py +38 -45
- mindspore/ops/function/debug_func.py +36 -44
- mindspore/ops/function/grad/__init__.py +1 -0
- mindspore/ops/function/grad/grad_func.py +104 -71
- mindspore/ops/function/image_func.py +1 -1
- mindspore/ops/function/linalg_func.py +46 -78
- mindspore/ops/function/math_func.py +3024 -3855
- mindspore/ops/function/nn_func.py +678 -274
- mindspore/ops/function/other_func.py +159 -1
- mindspore/ops/function/parameter_func.py +17 -30
- mindspore/ops/function/random_func.py +216 -361
- mindspore/ops/function/reshard_func.py +4 -70
- mindspore/ops/function/sparse_func.py +3 -3
- mindspore/ops/function/sparse_unary_func.py +5 -5
- mindspore/ops/function/spectral_func.py +25 -58
- mindspore/ops/function/vmap_func.py +26 -18
- mindspore/ops/functional.py +8 -5
- mindspore/ops/functional_overload.py +655 -4
- mindspore/ops/op_info_register.py +32 -244
- mindspore/ops/operations/__init__.py +21 -14
- mindspore/ops/operations/_custom_ops_utils.py +235 -0
- mindspore/ops/operations/_grad_ops.py +1 -10
- mindspore/ops/operations/_inner_ops.py +5 -76
- mindspore/ops/operations/_ms_kernel.py +4 -10
- mindspore/ops/operations/_rl_inner_ops.py +1 -1
- mindspore/ops/operations/_scalar_ops.py +3 -2
- mindspore/ops/operations/_sequence_ops.py +1 -1
- mindspore/ops/operations/_tensor_array.py +1 -1
- mindspore/ops/operations/array_ops.py +39 -24
- mindspore/ops/operations/comm_ops.py +150 -107
- mindspore/ops/operations/custom_ops.py +287 -32
- mindspore/ops/operations/debug_ops.py +119 -16
- mindspore/ops/operations/inner_ops.py +1 -1
- mindspore/ops/operations/linalg_ops.py +1 -58
- mindspore/ops/operations/manually_defined/_inner.py +1 -1
- mindspore/ops/operations/manually_defined/ops_def.py +746 -79
- mindspore/ops/operations/math_ops.py +21 -18
- mindspore/ops/operations/nn_ops.py +67 -224
- mindspore/ops/operations/other_ops.py +62 -9
- mindspore/ops/operations/random_ops.py +13 -7
- mindspore/ops/operations/reshard_ops.py +1 -1
- mindspore/ops/operations/sparse_ops.py +2 -2
- mindspore/ops/primitive.py +43 -32
- mindspore/ops/tensor_method.py +243 -17
- mindspore/ops_generate/__init__.py +0 -5
- mindspore/ops_generate/aclnn/__init__.py +0 -0
- mindspore/ops_generate/{aclnn_kernel_register_auto_cc_generator.py → aclnn/aclnn_kernel_register_auto_cc_generator.py} +43 -18
- mindspore/ops_generate/{gen_aclnn_implement.py → aclnn/gen_aclnn_implement.py} +49 -51
- mindspore/ops_generate/api/__init__.py +0 -0
- mindspore/ops_generate/{add_tensor_docs_generator.py → api/add_tensor_docs_generator.py} +9 -7
- mindspore/ops_generate/{cpp_create_prim_instance_helper_generator.py → api/cpp_create_prim_instance_helper_generator.py} +6 -9
- mindspore/ops_generate/{functional_map_cpp_generator.py → api/functional_map_cpp_generator.py} +25 -12
- mindspore/ops_generate/{functional_overload_py_generator.py → api/functional_overload_py_generator.py} +8 -6
- mindspore/ops_generate/{functions_cc_generator.py → api/functions_cc_generator.py} +14 -10
- mindspore/ops_generate/api/gen_api.py +103 -0
- mindspore/ops_generate/{op_api_proto.py → api/op_api_proto.py} +98 -69
- mindspore/ops_generate/{tensor_func_reg_cpp_generator.py → api/tensor_func_reg_cpp_generator.py} +82 -43
- mindspore/ops_generate/common/__init__.py +0 -0
- mindspore/ops_generate/common/gen_constants.py +91 -0
- mindspore/ops_generate/{gen_utils.py → common/gen_utils.py} +72 -19
- mindspore/ops_generate/{op_proto.py → common/op_proto.py} +64 -1
- mindspore/ops_generate/{template.py → common/template.py} +96 -84
- mindspore/ops_generate/gen_ops.py +23 -325
- mindspore/ops_generate/op_def/__init__.py +0 -0
- mindspore/ops_generate/op_def/gen_op_def.py +90 -0
- mindspore/ops_generate/{lite_ops_cpp_generator.py → op_def/lite_ops_cpp_generator.py} +47 -11
- mindspore/ops_generate/{ops_def_cc_generator.py → op_def/ops_def_cc_generator.py} +18 -10
- mindspore/ops_generate/{ops_def_h_generator.py → op_def/ops_def_h_generator.py} +5 -5
- mindspore/ops_generate/{ops_name_h_generator.py → op_def/ops_name_h_generator.py} +30 -15
- mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
- mindspore/ops_generate/op_def_py/__init__.py +0 -0
- mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
- mindspore/ops_generate/{op_def_py_generator.py → op_def_py/op_def_py_generator.py} +6 -5
- mindspore/ops_generate/{op_prim_py_generator.py → op_def_py/op_prim_py_generator.py} +24 -15
- mindspore/ops_generate/pyboost/__init__.py +0 -0
- mindspore/ops_generate/{auto_grad_impl_cc_generator.py → pyboost/auto_grad_impl_cc_generator.py} +11 -7
- mindspore/ops_generate/{auto_grad_reg_cc_generator.py → pyboost/auto_grad_reg_cc_generator.py} +7 -7
- mindspore/ops_generate/{gen_pyboost_func.py → pyboost/gen_pyboost_func.py} +40 -16
- mindspore/ops_generate/{op_template_parser.py → pyboost/op_template_parser.py} +105 -24
- mindspore/ops_generate/{pyboost_functions_cpp_generator.py → pyboost/pyboost_functions_cpp_generator.py} +55 -18
- mindspore/ops_generate/{pyboost_functions_h_generator.py → pyboost/pyboost_functions_h_generator.py} +42 -10
- mindspore/ops_generate/{pyboost_functions_py_generator.py → pyboost/pyboost_functions_py_generator.py} +6 -6
- mindspore/ops_generate/{pyboost_grad_function_cpp_generator.py → pyboost/pyboost_grad_function_cpp_generator.py} +11 -10
- mindspore/ops_generate/{pyboost_inner_prim_generator.py → pyboost/pyboost_inner_prim_generator.py} +8 -7
- mindspore/ops_generate/{pyboost_native_grad_functions_generator.py → pyboost/pyboost_native_grad_functions_generator.py} +14 -10
- mindspore/ops_generate/{pyboost_op_cpp_code_generator.py → pyboost/pyboost_op_cpp_code_generator.py} +140 -53
- mindspore/ops_generate/{pyboost_overload_functions_cpp_generator.py → pyboost/pyboost_overload_functions_cpp_generator.py} +28 -15
- mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +88 -4
- mindspore/ops_generate/resources/__init__.py +0 -0
- mindspore/ops_generate/resources/resource_list.py +30 -0
- mindspore/ops_generate/resources/resource_loader.py +36 -0
- mindspore/ops_generate/resources/resource_manager.py +64 -0
- mindspore/ops_generate/resources/yaml_loader.py +88 -0
- mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
- mindspore/parallel/__init__.py +6 -2
- mindspore/parallel/_auto_parallel_context.py +140 -12
- mindspore/parallel/_cell_wrapper.py +132 -15
- mindspore/parallel/_parallel_serialization.py +95 -4
- mindspore/parallel/_ps_context.py +1 -1
- mindspore/parallel/_recovery_context.py +7 -2
- mindspore/parallel/_tensor.py +142 -18
- mindspore/parallel/_utils.py +198 -25
- mindspore/parallel/algo_parameter_config.py +3 -3
- mindspore/parallel/auto_parallel.py +732 -0
- mindspore/parallel/checkpoint_convert.py +159 -0
- mindspore/parallel/checkpoint_transform.py +658 -37
- mindspore/parallel/cluster/process_entity/_api.py +151 -19
- mindspore/parallel/cluster/run.py +1 -1
- mindspore/parallel/function/__init__.py +24 -0
- mindspore/parallel/function/reshard_func.py +258 -0
- mindspore/parallel/nn/__init__.py +25 -0
- mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
- mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
- mindspore/parallel/parameter_broadcast.py +24 -13
- mindspore/parallel/shard.py +137 -62
- mindspore/parallel/transform_safetensors.py +288 -95
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/__init__.py +9 -5
- mindspore/profiler/analysis/parser/ascend_cann_parser.py +6 -2
- mindspore/profiler/analysis/parser/ms_framework_parser.py +4 -4
- mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +7 -4
- mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +25 -0
- mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +3 -3
- mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +241 -86
- mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +41 -2
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +33 -35
- mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +7 -0
- mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +8 -3
- mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +141 -30
- mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +5 -6
- mindspore/profiler/common/ascend_msprof_exporter.py +5 -4
- mindspore/profiler/common/constant.py +12 -0
- mindspore/profiler/common/msprof_cmd_tool.py +42 -23
- mindspore/profiler/common/path_manager.py +24 -0
- mindspore/profiler/common/profiler_context.py +26 -2
- mindspore/profiler/common/profiler_meta_data.py +74 -0
- mindspore/profiler/common/profiler_parameters.py +59 -18
- mindspore/profiler/common/profiler_path_manager.py +66 -7
- mindspore/profiler/dynamic_profiler.py +112 -79
- mindspore/profiler/envprofiler.py +26 -1
- mindspore/profiler/experimental_config.py +197 -0
- mindspore/profiler/mstx.py +57 -14
- mindspore/profiler/platform/npu_profiler.py +33 -7
- mindspore/profiler/profiler.py +541 -45
- mindspore/profiler/profiler_action_controller.py +1 -1
- mindspore/profiler/profiler_interface.py +4 -0
- mindspore/profiler/schedule.py +57 -22
- mindspore/rewrite/api/node.py +15 -13
- mindspore/rewrite/api/symbol_tree.py +1 -1
- mindspore/run_check/_check_version.py +25 -14
- mindspore/run_check/run_check.py +1 -1
- mindspore/runtime/__init__.py +2 -2
- mindspore/runtime/executor.py +40 -11
- mindspore/runtime/memory.py +37 -13
- mindspore/safeguard/rewrite_obfuscation.py +12 -9
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +8 -8
- mindspore/train/_utils.py +43 -9
- mindspore/train/amp.py +1 -1
- mindspore/train/callback/__init__.py +2 -2
- mindspore/train/callback/_callback.py +2 -16
- mindspore/train/callback/_checkpoint.py +24 -40
- mindspore/train/callback/_cluster_monitor.py +14 -18
- mindspore/train/callback/_flops_collector.py +2 -3
- mindspore/train/callback/_history.py +7 -4
- mindspore/train/callback/_lambda_callback.py +2 -2
- mindspore/train/callback/_landscape.py +0 -3
- mindspore/train/callback/_loss_monitor.py +2 -1
- mindspore/train/callback/_on_request_exit.py +6 -5
- mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
- mindspore/train/callback/_summary_collector.py +8 -13
- mindspore/train/callback/_time_monitor.py +2 -1
- mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +204 -105
- mindspore/train/data_sink.py +25 -2
- mindspore/train/dataset_helper.py +4 -5
- mindspore/train/loss_scale_manager.py +8 -7
- mindspore/train/metrics/accuracy.py +3 -3
- mindspore/train/metrics/confusion_matrix.py +9 -9
- mindspore/train/metrics/error.py +3 -3
- mindspore/train/metrics/hausdorff_distance.py +4 -4
- mindspore/train/metrics/mean_surface_distance.py +3 -3
- mindspore/train/metrics/metric.py +0 -12
- mindspore/train/metrics/occlusion_sensitivity.py +4 -2
- mindspore/train/metrics/precision.py +8 -6
- mindspore/train/metrics/recall.py +9 -9
- mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
- mindspore/train/mind_ir_pb2.py +19 -12
- mindspore/train/model.py +262 -127
- mindspore/train/serialization.py +246 -988
- mindspore/train/summary/_summary_adapter.py +2 -2
- mindspore/train/summary/summary_record.py +1 -1
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +3 -2
- mindspore/utils/dryrun.py +4 -2
- mindspore/utils/hooks.py +81 -0
- mindspore/utils/runtime_execution_order_check.py +2 -0
- mindspore/utils/utils.py +138 -4
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.5.0.dist-info → mindspore-2.6.0.dist-info}/METADATA +2 -1
- {mindspore-2.5.0.dist-info → mindspore-2.6.0.dist-info}/RECORD +485 -440
- mindspore/_install_custom.py +0 -43
- mindspore/common/_register_for_adapter.py +0 -74
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
- mindspore/ops/auto_generate/gen_arg_handler.py +0 -136
- mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
- mindspore/ops_generate/gen_constants.py +0 -190
- mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
- mindspore/ops_generate/ops_primitive_h_generator.py +0 -81
- /mindspore/ops_generate/{base_generator.py → common/base_generator.py} +0 -0
- {mindspore-2.5.0.dist-info → mindspore-2.6.0.dist-info}/WHEEL +0 -0
- {mindspore-2.5.0.dist-info → mindspore-2.6.0.dist-info}/entry_points.txt +0 -0
- {mindspore-2.5.0.dist-info → mindspore-2.6.0.dist-info}/top_level.txt +0 -0
mindspore/common/api.py
CHANGED
|
@@ -17,6 +17,7 @@
|
|
|
17
17
|
"""Providing interface methods."""
|
|
18
18
|
from __future__ import absolute_import
|
|
19
19
|
|
|
20
|
+
import gc
|
|
20
21
|
import types
|
|
21
22
|
import sys
|
|
22
23
|
import os
|
|
@@ -24,11 +25,11 @@ import time
|
|
|
24
25
|
import ast
|
|
25
26
|
import inspect
|
|
26
27
|
import importlib
|
|
27
|
-
import hashlib
|
|
28
28
|
import contextlib
|
|
29
|
+
import json
|
|
29
30
|
from collections import OrderedDict, namedtuple
|
|
30
31
|
from functools import wraps
|
|
31
|
-
import
|
|
32
|
+
from typing import Optional, Callable
|
|
32
33
|
import mindspore as ms
|
|
33
34
|
from mindspore import context
|
|
34
35
|
from mindspore import log as logger
|
|
@@ -39,21 +40,23 @@ from mindspore.common.sparse_tensor import CSRTensor as PythonCSRTensor
|
|
|
39
40
|
from mindspore.common.sparse_tensor import COOTensor as PythonCOOTensor
|
|
40
41
|
from mindspore.common.sparse_tensor import RowTensor as PythonRowTensor
|
|
41
42
|
from mindspore._c_expression.amp import get_curr_amp_strategy
|
|
42
|
-
from mindspore._c_expression import GraphExecutor_,
|
|
43
|
+
from mindspore._c_expression import GraphExecutor_, JitExecutor_, CSRTensor, RowTensor, COOTensor, \
|
|
43
44
|
PyNativeExecutor_, verify_inputs_signature, init_exec_dataset, _set_dataset_mode_config, init_pipeline, \
|
|
44
|
-
_ms_memory_recycle, _bind_device_ctx, StubNode
|
|
45
|
+
_run_jit_pipeline, _ms_memory_recycle, _bind_device_ctx, StubNode, MSContext, TensorPy as Tensor
|
|
45
46
|
from mindspore.parallel._ps_context import _is_role_sched
|
|
46
|
-
from mindspore.parallel._utils import _check_full_batch, _get_parameter_broadcast,
|
|
47
|
-
|
|
47
|
+
from mindspore.parallel._utils import _check_full_batch, _get_parameter_broadcast, _is_in_auto_parallel_mode, \
|
|
48
|
+
_is_parallel_mode
|
|
48
49
|
from mindspore import _checkparam as Validator
|
|
49
50
|
from mindspore._checkparam import is_stub_tensor
|
|
50
51
|
from mindspore.common._utils import is_shape_unknown
|
|
51
52
|
from mindspore.common.mutable import mutable, _check_element_type
|
|
52
|
-
from mindspore.common._register_for_adapter import ms_adapter_registry
|
|
53
53
|
from mindspore.common.auto_dynamic_shape import get_auto_dynamic_shape_args, update_auto_dynamic_shape_phase, \
|
|
54
54
|
get_auto_dynamic_shape_args_with_check_input_signature, update_auto_dynamic_shape_phase_with_check_input_signature
|
|
55
55
|
from mindspore.common._pijit_context import PIJitCaptureContext
|
|
56
56
|
from mindspore.common.parameter import Parameter, set_parameter_hook_updated, parameter_hook_updated
|
|
57
|
+
from mindspore.common.jit_context import jit_context
|
|
58
|
+
from mindspore.common.jit_trace import _jit_trace
|
|
59
|
+
from mindspore.parallel._utils import _init_auto_parallel_context, _clear_auto_parallel_context
|
|
57
60
|
|
|
58
61
|
# Store ms_function class compiled pipeline cache.
|
|
59
62
|
ms_compile_cache = set()
|
|
@@ -107,8 +110,7 @@ def _check_recompile(obj, compile_args, kwargs, full_function_name, create_time,
|
|
|
107
110
|
logger.info(f"The {echo_function_name} has been compiled again. "
|
|
108
111
|
f"{tips} ")
|
|
109
112
|
else:
|
|
110
|
-
tips = "Try to
|
|
111
|
-
"or @jit(compile_once=True) to reduce the compile time. " \
|
|
113
|
+
tips = "Try to reuse the function object decorated by @jit to reduce the compile time. " \
|
|
112
114
|
"For more details, get instructions about `jit` at " \
|
|
113
115
|
"https://www.mindspore.cn/search?inputValue=jit."
|
|
114
116
|
logger.warning(f"The {echo_function_name} has been compiled again. "
|
|
@@ -120,14 +122,6 @@ def _check_recompile(obj, compile_args, kwargs, full_function_name, create_time,
|
|
|
120
122
|
function_phases[full_function_name].add(create_time)
|
|
121
123
|
|
|
122
124
|
|
|
123
|
-
def _ms_adapter_tensor_as_parameter_output(data):
|
|
124
|
-
"""Check whether the data is an output from a parameter which is a ms_adapter tensor.
|
|
125
|
-
Pylint: disable=unidiomatic-typecheck.
|
|
126
|
-
"""
|
|
127
|
-
return ms_adapter_registry.is_registered and isinstance(data, ms_adapter_registry.tensor) \
|
|
128
|
-
and hasattr(data, "__ms_parameter_output__") and getattr(data, "__ms_parameter_output__")
|
|
129
|
-
|
|
130
|
-
|
|
131
125
|
def _convert_python_data(data):
|
|
132
126
|
"""
|
|
133
127
|
Convert C++ data to python.
|
|
@@ -138,18 +132,8 @@ def _convert_python_data(data):
|
|
|
138
132
|
Returns:
|
|
139
133
|
data, a data convert C++ to python
|
|
140
134
|
"""
|
|
141
|
-
if isinstance(data,
|
|
142
|
-
return
|
|
143
|
-
if _ms_adapter_tensor_as_parameter_output(data) and hasattr(data, "tensor"):
|
|
144
|
-
return data.tensor
|
|
145
|
-
if isinstance(data, Tensor) and not isinstance(data, PythonTensor):
|
|
146
|
-
return PythonTensor(data, internal=True)
|
|
147
|
-
if isinstance(data, CSRTensor) and not isinstance(data, PythonCSRTensor):
|
|
148
|
-
return PythonCSRTensor(csr_tensor=data)
|
|
149
|
-
if isinstance(data, COOTensor) and not isinstance(data, PythonCOOTensor):
|
|
150
|
-
return PythonCOOTensor(coo_tensor=data)
|
|
151
|
-
if isinstance(data, RowTensor) and not isinstance(data, PythonRowTensor):
|
|
152
|
-
return PythonRowTensor(row_tensor=data)
|
|
135
|
+
if isinstance(data, PythonTensor):
|
|
136
|
+
return data
|
|
153
137
|
if isinstance(data, StubNode):
|
|
154
138
|
return ms.common._stub_tensor._convert_stub(data)
|
|
155
139
|
if data.__class__ is tuple:
|
|
@@ -160,6 +144,12 @@ def _convert_python_data(data):
|
|
|
160
144
|
fields = data_dict.keys()
|
|
161
145
|
return namedtuple(type_name, fields)(**_convert_python_data(data_dict))
|
|
162
146
|
return tuple(_convert_python_data(x) for x in data)
|
|
147
|
+
if isinstance(data, CSRTensor) and not isinstance(data, PythonCSRTensor):
|
|
148
|
+
return PythonCSRTensor(csr_tensor=data)
|
|
149
|
+
if isinstance(data, COOTensor) and not isinstance(data, PythonCOOTensor):
|
|
150
|
+
return PythonCOOTensor(coo_tensor=data)
|
|
151
|
+
if isinstance(data, RowTensor) and not isinstance(data, PythonRowTensor):
|
|
152
|
+
return PythonRowTensor(row_tensor=data)
|
|
163
153
|
if data.__class__ is list:
|
|
164
154
|
# Keep list object not change for inplace operation.
|
|
165
155
|
for i in range(len(data)):
|
|
@@ -578,11 +568,11 @@ def _get_hook_key(*args, **kwargs):
|
|
|
578
568
|
return hook_key
|
|
579
569
|
|
|
580
570
|
|
|
581
|
-
class
|
|
571
|
+
class _JitExecutor:
|
|
582
572
|
"""
|
|
583
573
|
Represents a function compiled by graph compiler.
|
|
584
574
|
|
|
585
|
-
|
|
575
|
+
_JitExecutor will compile the original function for every combination
|
|
586
576
|
of argument types and shapes it is given (as well as their values, optionally).
|
|
587
577
|
|
|
588
578
|
Args:
|
|
@@ -596,7 +586,7 @@ class _MindsporeFunctionExecutor:
|
|
|
596
586
|
The result of pipeline running in graph mode.
|
|
597
587
|
"""
|
|
598
588
|
|
|
599
|
-
def __init__(self, fn, ms_create_time, input_signature=None, obj=None, jit_config=None):
|
|
589
|
+
def __init__(self, fn, ms_create_time, input_signature=None, obj=None, jit_config=None, dynamic=0):
|
|
600
590
|
init_pipeline()
|
|
601
591
|
if not isinstance(fn, (types.FunctionType, types.MethodType)):
|
|
602
592
|
raise RuntimeError('fn {} is not function or method'.format(fn))
|
|
@@ -608,13 +598,61 @@ class _MindsporeFunctionExecutor:
|
|
|
608
598
|
self.obj = obj
|
|
609
599
|
self.shard_parent_obj = obj
|
|
610
600
|
self.enable_tuple_broaden = False
|
|
611
|
-
|
|
601
|
+
if _run_jit_pipeline():
|
|
602
|
+
self._graph_executor = JitExecutor_.get_instance()
|
|
603
|
+
else:
|
|
604
|
+
self._graph_executor = GraphExecutor_.get_instance()
|
|
612
605
|
self._create_time = ms_create_time
|
|
613
606
|
self._compile_args = None
|
|
607
|
+
self._enable_auto_dynamic = dynamic == 1
|
|
614
608
|
self.jit_config_dict = jit_config.jit_config_dict if jit_config else None
|
|
615
609
|
|
|
610
|
+
def _predict(self, *args, **kwargs):
|
|
611
|
+
"""Dedicated routine for predict."""
|
|
612
|
+
if not hasattr(self.obj, "phase"):
|
|
613
|
+
return False, None
|
|
614
|
+
|
|
615
|
+
predict_vailid_phase = {"prefill", 'increment'}
|
|
616
|
+
predict_phase = self.obj.phase
|
|
617
|
+
if predict_phase not in predict_vailid_phase:
|
|
618
|
+
return False, None
|
|
619
|
+
|
|
620
|
+
args_list = args
|
|
621
|
+
if self.obj is not None:
|
|
622
|
+
args_list = args_list[1:]
|
|
623
|
+
|
|
624
|
+
if predict_phase not in self.obj.phase_cache:
|
|
625
|
+
try:
|
|
626
|
+
predict_phase = self.compile(self.fn.__name__, *args_list, **kwargs)
|
|
627
|
+
except Exception as err:
|
|
628
|
+
_pynative_executor.clear_res()
|
|
629
|
+
raise err
|
|
630
|
+
else: # get compiled args to generate run args by _generate_run_args
|
|
631
|
+
compile_args = self._generate_compile_args(args_list)
|
|
632
|
+
key_id = self._get_key_id()
|
|
633
|
+
compile_args = get_auto_dynamic_shape_args_with_check_input_signature(
|
|
634
|
+
compile_args,
|
|
635
|
+
key_id,
|
|
636
|
+
self.input_signature,
|
|
637
|
+
self._enable_auto_dynamic
|
|
638
|
+
)
|
|
639
|
+
self._compile_args = compile_args
|
|
640
|
+
|
|
641
|
+
new_inputs = self._generate_run_args(args_list, kwargs)
|
|
642
|
+
output = self._graph_executor(
|
|
643
|
+
tuple(new_inputs),
|
|
644
|
+
self.obj.phase_cache[self.obj.phase]
|
|
645
|
+
)
|
|
646
|
+
res = _convert_python_data(output)
|
|
647
|
+
return True, res
|
|
648
|
+
|
|
616
649
|
@_wrap_func
|
|
617
650
|
def __call__(self, *args, **kwargs):
|
|
651
|
+
predict, res = self._predict(*args, **kwargs)
|
|
652
|
+
if predict:
|
|
653
|
+
return res
|
|
654
|
+
if jit_context() and jit_context().is_nested():
|
|
655
|
+
return jit_context().run_graph("", None, *())
|
|
618
656
|
args_list = args
|
|
619
657
|
if self.obj is not None:
|
|
620
658
|
args_list = args_list[1:]
|
|
@@ -634,10 +672,14 @@ class _MindsporeFunctionExecutor:
|
|
|
634
672
|
return None
|
|
635
673
|
|
|
636
674
|
new_inputs = self._generate_run_args(args_list, kwargs)
|
|
637
|
-
if context.get_context("mode") == context.PYNATIVE_MODE:
|
|
675
|
+
if context.get_context("mode") == context.PYNATIVE_MODE and not jit_context():
|
|
638
676
|
output = _pynative_executor.grad_jit(*new_inputs)
|
|
639
677
|
else:
|
|
640
678
|
output = self._graph_executor(tuple(new_inputs), phase)
|
|
679
|
+
if jit_context():
|
|
680
|
+
if is_stub_tensor(output):
|
|
681
|
+
output = output.stub_sync()
|
|
682
|
+
return jit_context().run_graph(phase, output, *tuple(new_inputs))
|
|
641
683
|
|
|
642
684
|
return output
|
|
643
685
|
|
|
@@ -653,7 +695,8 @@ class _MindsporeFunctionExecutor:
|
|
|
653
695
|
compile_args = self._generate_compile_args(args)
|
|
654
696
|
key_id = self._get_key_id()
|
|
655
697
|
compile_args = get_auto_dynamic_shape_args_with_check_input_signature(compile_args, key_id,
|
|
656
|
-
self.input_signature
|
|
698
|
+
self.input_signature,
|
|
699
|
+
self._enable_auto_dynamic)
|
|
657
700
|
|
|
658
701
|
# Add mutable for compile_args for two scene:
|
|
659
702
|
# 1) Origin args is mutable.
|
|
@@ -673,7 +716,7 @@ class _MindsporeFunctionExecutor:
|
|
|
673
716
|
f'`{self.fn.__module__}`')
|
|
674
717
|
self.obj.__parse_method__ = method_name
|
|
675
718
|
if isinstance(self.obj, ms.nn.Cell):
|
|
676
|
-
generate_name = generate_name + '.' + str(self.obj.create_time)
|
|
719
|
+
generate_name = generate_name + '.' + str(self.obj.create_time) + self.obj.phase
|
|
677
720
|
create_time = str(self.obj.create_time)
|
|
678
721
|
else:
|
|
679
722
|
generate_name = generate_name + '.' + str(self._create_time)
|
|
@@ -704,7 +747,7 @@ class _MindsporeFunctionExecutor:
|
|
|
704
747
|
|
|
705
748
|
update_auto_dynamic_shape_phase_with_check_input_signature(compile_args, key_id, phase, self.input_signature)
|
|
706
749
|
|
|
707
|
-
if phase in ms_compile_cache and not parameter_hook_updated():
|
|
750
|
+
if phase in ms_compile_cache and self._graph_executor.has_compiled(phase) and not parameter_hook_updated():
|
|
708
751
|
# Release resource should be released when CompileInner won't be executed, such as cur_convert_input_
|
|
709
752
|
# generated in generate_arguments_key.
|
|
710
753
|
self._graph_executor.clear_compile_arguments_resource()
|
|
@@ -726,7 +769,7 @@ class _MindsporeFunctionExecutor:
|
|
|
726
769
|
setattr(self.fn.__func__, "__jit_function__", True)
|
|
727
770
|
else:
|
|
728
771
|
setattr(self.fn, "__jit_function__", True)
|
|
729
|
-
is_compile = self._graph_executor.compile(self.fn, compile_args, kwargs, phase
|
|
772
|
+
is_compile = self._graph_executor.compile(self.fn, compile_args, kwargs, phase)
|
|
730
773
|
if isinstance(self.fn, types.MethodType):
|
|
731
774
|
delattr(self.fn.__func__, "__jit_function__")
|
|
732
775
|
else:
|
|
@@ -734,12 +777,14 @@ class _MindsporeFunctionExecutor:
|
|
|
734
777
|
else:
|
|
735
778
|
if isinstance(self.obj, ms.nn.Cell):
|
|
736
779
|
self._graph_executor.set_weights_values(self.obj.parameters_dict())
|
|
737
|
-
is_compile = self._graph_executor.compile(self.obj, compile_args, kwargs, phase
|
|
780
|
+
is_compile = self._graph_executor.compile(self.obj, compile_args, kwargs, phase)
|
|
738
781
|
|
|
739
782
|
if not is_compile:
|
|
740
783
|
raise RuntimeError("Executor compile failed.")
|
|
741
784
|
set_parameter_hook_updated(False)
|
|
742
785
|
ms_compile_cache.add(phase)
|
|
786
|
+
if hasattr(self.obj, "phase"):
|
|
787
|
+
self.obj.phase_cache[self.obj.phase] = phase
|
|
743
788
|
|
|
744
789
|
return phase
|
|
745
790
|
|
|
@@ -760,7 +805,7 @@ class _MindsporeFunctionExecutor:
|
|
|
760
805
|
else:
|
|
761
806
|
key_id = str(id(self.obj)) + str(self._create_time)
|
|
762
807
|
|
|
763
|
-
if _pynative_executor.
|
|
808
|
+
if _pynative_executor.requires_grad():
|
|
764
809
|
key_id = key_id + ".grad"
|
|
765
810
|
return key_id
|
|
766
811
|
|
|
@@ -770,9 +815,9 @@ class _MindsporeFunctionExecutor:
|
|
|
770
815
|
self.fn.__code__.co_firstlineno)
|
|
771
816
|
echo_function_name = "function \"" + self.fn.__name__ + "\" at the file \"" + self.fn.__code__.co_filename \
|
|
772
817
|
+ "\", line " + str(self.fn.__code__.co_firstlineno)
|
|
773
|
-
if _pynative_executor.
|
|
818
|
+
if _pynative_executor.requires_grad():
|
|
774
819
|
generate_name = generate_name + ".grad"
|
|
775
|
-
if
|
|
820
|
+
if self.fn.__name__ == _PYNATIVE_PARALLEL_FUNC_NAME:
|
|
776
821
|
generate_name = generate_name[:generate_name.rfind(str(id(self.fn)))] + str(id(self.shard_parent_obj))
|
|
777
822
|
return generate_name, echo_function_name
|
|
778
823
|
|
|
@@ -833,6 +878,14 @@ class _MindsporeFunctionExecutor:
|
|
|
833
878
|
"""
|
|
834
879
|
return _get_args_for_run(self, args_list, kwargs, self._compile_args)
|
|
835
880
|
|
|
881
|
+
def _get_func_graph_proto(self, obj, exec_id, ir_type="onnx_ir", use_prefix=False, incremental=False):
|
|
882
|
+
"""Get graph proto from pipeline."""
|
|
883
|
+
if use_prefix:
|
|
884
|
+
exec_id = exec_id + '.' + obj.arguments_key
|
|
885
|
+
if self._graph_executor.has_compiled(exec_id) is False:
|
|
886
|
+
return None
|
|
887
|
+
return self._graph_executor.get_func_graph_proto(exec_id, ir_type, incremental)
|
|
888
|
+
|
|
836
889
|
|
|
837
890
|
# The attributes used to identify a given object.
|
|
838
891
|
attr_op = {"__str__": lambda x: x.__str__(),
|
|
@@ -845,6 +898,13 @@ attr_op = {"__str__": lambda x: x.__str__(),
|
|
|
845
898
|
}
|
|
846
899
|
|
|
847
900
|
|
|
901
|
+
def _is_inner_func(func):
|
|
902
|
+
"""Check whether the func is an inner func which needs hash_args parameter."""
|
|
903
|
+
# This is a workaround for inner api, should fix it later.
|
|
904
|
+
inner_func = ["after_shard", "_wrap_container"]
|
|
905
|
+
return func.__name__ in inner_func
|
|
906
|
+
|
|
907
|
+
|
|
848
908
|
def _get_obj_id(input_obj):
|
|
849
909
|
"""Get hash id of single object."""
|
|
850
910
|
obj_id = ".".join(
|
|
@@ -859,50 +919,227 @@ def _get_jit_hash(hash_input):
|
|
|
859
919
|
return _get_obj_id(hash_input)
|
|
860
920
|
|
|
861
921
|
|
|
862
|
-
def
|
|
922
|
+
def _get_hash_obj(options):
|
|
923
|
+
hash_obj = None
|
|
924
|
+
if "hash_args" in options:
|
|
925
|
+
hash_obj = _get_jit_hash(options["hash_args"])
|
|
926
|
+
del options["hash_args"]
|
|
927
|
+
return hash_obj
|
|
928
|
+
|
|
929
|
+
|
|
930
|
+
def _check_option_device(option, device):
|
|
931
|
+
"""Check jit options wiwh device"""
|
|
932
|
+
option_device_cfgs = {
|
|
933
|
+
'disable_format_transform': ['GPU'],
|
|
934
|
+
'exec_order': ['Ascend'],
|
|
935
|
+
'ge_options': ['Ascend'],
|
|
936
|
+
'infer_boost': ['Ascend'],
|
|
937
|
+
}
|
|
938
|
+
if option in option_device_cfgs and device not in option_device_cfgs[option]:
|
|
939
|
+
logger.warning(f"For 'jit(options)', the option '{option}' is only support device in "
|
|
940
|
+
f"'{option_device_cfgs[option]}', but got '{device}', ignore it.")
|
|
941
|
+
|
|
942
|
+
|
|
943
|
+
def _check_option_backend(option, backend):
|
|
944
|
+
"""Check jit options wiwh backend"""
|
|
945
|
+
option_backend_cfgs = {
|
|
946
|
+
'disable_format_transform': ['ms_backend'],
|
|
947
|
+
'exec_order': ['ms_backend'],
|
|
948
|
+
'ge_options': ['GE'],
|
|
949
|
+
'infer_boost': ['ms_backend'],
|
|
950
|
+
}
|
|
951
|
+
if option in option_backend_cfgs and backend not in option_backend_cfgs[option]:
|
|
952
|
+
logger.warning(f"For 'jit(options)', the option '{option}' is only support backend in "
|
|
953
|
+
f"'{option_backend_cfgs[option]}', but got '{backend}', ignore it.")
|
|
954
|
+
|
|
955
|
+
|
|
956
|
+
def _check_disable_format_transform_value(option, disable_format_transform):
|
|
957
|
+
"""check disable_format_transform option value"""
|
|
958
|
+
if not isinstance(disable_format_transform, bool):
|
|
959
|
+
raise TypeError(f"For 'jit(options)', the type of '{option}' must be bool, "
|
|
960
|
+
f"but got {type(disable_format_transform)}.")
|
|
961
|
+
|
|
962
|
+
|
|
963
|
+
def _check_exec_order_value(option, exec_order):
|
|
964
|
+
"""check exec_order option value"""
|
|
965
|
+
if not isinstance(exec_order, str):
|
|
966
|
+
raise TypeError(f"For 'jit(options)', the type of '{option}' must be str, but got {type(exec_order)}.")
|
|
967
|
+
|
|
968
|
+
if exec_order not in ['bfs', 'dfs']:
|
|
969
|
+
raise ValueError(f"For '{option}', the value of '{option}' must be one of "
|
|
970
|
+
f"['bfs', 'dfs'], but got '{exec_order}'.")
|
|
971
|
+
|
|
972
|
+
|
|
973
|
+
def _check_ge_options_value(option, ge_options):
|
|
974
|
+
"""check ge_options option value"""
|
|
975
|
+
if not isinstance(ge_options, dict):
|
|
976
|
+
raise TypeError(f"For 'jit(options)', the type of '{option}' must be dict, but got {type(ge_options)}.")
|
|
977
|
+
|
|
978
|
+
for level, options in ge_options.items():
|
|
979
|
+
if level not in ['global', 'session']:
|
|
980
|
+
raise ValueError(f"For '{option}', the key of '{option}' must be one of "
|
|
981
|
+
f"['global', 'session'], but got '{level}'.")
|
|
982
|
+
|
|
983
|
+
if not isinstance(options, dict):
|
|
984
|
+
raise TypeError(f"For '{option}', the type of {level} options must be dict, "
|
|
985
|
+
f"but got {type(options)}. The error options: {options}.")
|
|
986
|
+
|
|
987
|
+
for key, value in options.items():
|
|
988
|
+
if not isinstance(key, str):
|
|
989
|
+
raise TypeError(f"For '{option}', the type of key and value must be str, "
|
|
990
|
+
f"but got {type(key)}. The error key is {key}.")
|
|
991
|
+
if not isinstance(value, str):
|
|
992
|
+
raise TypeError(f"For '{option}', the type of key and value must be str, "
|
|
993
|
+
f"but got {type(value)}. The error value is {value}")
|
|
994
|
+
|
|
995
|
+
|
|
996
|
+
def _check_infer_boost_value(option, value):
|
|
997
|
+
"""check infer_boost option value"""
|
|
998
|
+
if not isinstance(value, str):
|
|
999
|
+
raise TypeError(f"For 'jit(options)', the type of '{option}' must be str, but got {type(value)}.")
|
|
1000
|
+
|
|
1001
|
+
if value not in ['on', 'off']:
|
|
1002
|
+
raise ValueError(f"For '{option}', the value of '{option}' must be one of ['on', 'off'], but got '{value}'.")
|
|
1003
|
+
|
|
1004
|
+
|
|
1005
|
+
def _check_option_value(option, value):
|
|
1006
|
+
"""check jit options wiwh value"""
|
|
1007
|
+
option_valuecheck_funcs = {
|
|
1008
|
+
'disable_format_transform': _check_disable_format_transform_value,
|
|
1009
|
+
'exec_order': _check_exec_order_value,
|
|
1010
|
+
'ge_options': _check_ge_options_value,
|
|
1011
|
+
'infer_boost': _check_infer_boost_value,
|
|
1012
|
+
}
|
|
1013
|
+
if option in option_valuecheck_funcs:
|
|
1014
|
+
option_valuecheck_funcs[option](option, value)
|
|
1015
|
+
else:
|
|
1016
|
+
logger.warning(f"For 'jit(options)', the option argument '{option}' is not recognized, please check!"
|
|
1017
|
+
f"For detailed usage of 'jit(options)', please refer to the Mindspore official website.")
|
|
1018
|
+
|
|
1019
|
+
|
|
1020
|
+
def _check_options(options, backend):
|
|
1021
|
+
"""Check jit options"""
|
|
1022
|
+
# check whether there are deprecated parameters in the dict `options`.
|
|
1023
|
+
deprecated_args = {'mode': 'capture_mode', 'input_signature': 'dynamic', 'hash_args: ': '',
|
|
1024
|
+
'jit_config': 'jit_level, fullgraph or options', 'compile_once': ''}
|
|
1025
|
+
for key, value in deprecated_args.items():
|
|
1026
|
+
if key in options:
|
|
1027
|
+
log = f"For 'jit', the parameter '{key}' has been deprecated."
|
|
1028
|
+
if value != '':
|
|
1029
|
+
log += f" Please use the parameter '{value}' instead. For more details, please refer to " \
|
|
1030
|
+
f"https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.jit.html."
|
|
1031
|
+
logger.warning(log)
|
|
1032
|
+
del options[key]
|
|
1033
|
+
|
|
1034
|
+
# check options' device, backend and value
|
|
1035
|
+
for option, value in options.items():
|
|
1036
|
+
_check_option_backend(option, backend)
|
|
1037
|
+
_check_option_value(option, value)
|
|
1038
|
+
|
|
1039
|
+
|
|
1040
|
+
def jit(
|
|
1041
|
+
function: Optional[Callable] = None,
|
|
1042
|
+
*,
|
|
1043
|
+
capture_mode: str = "ast",
|
|
1044
|
+
jit_level: str = "O0",
|
|
1045
|
+
dynamic: int = 0,
|
|
1046
|
+
fullgraph: bool = False,
|
|
1047
|
+
backend: str = "",
|
|
1048
|
+
**options):
|
|
863
1049
|
"""
|
|
864
1050
|
Create a callable MindSpore graph from a Python function.
|
|
865
1051
|
|
|
866
1052
|
This allows the MindSpore runtime to apply optimizations based on graph.
|
|
867
1053
|
|
|
868
1054
|
Note:
|
|
869
|
-
-
|
|
870
|
-
|
|
871
|
-
-
|
|
872
|
-
|
|
873
|
-
|
|
874
|
-
decorated with @jit(mode=“PIJit”) are not supported,
|
|
875
|
-
and the decoration @jit(mode=“PIJit”) is considered invalid.
|
|
1055
|
+
- It is not supported to run a function with decoration @jit(capture_mode=“bytecode”)
|
|
1056
|
+
in static graph mode, in which case the decoration @jit(capture_mode=“bytecode”) is considered invalid.
|
|
1057
|
+
- Calls to functions with decorated @jit(capture_mode=“bytecode”) inside functions
|
|
1058
|
+
decorated with @jit(capture_mode=“ast”) are not supported,
|
|
1059
|
+
and the decoration @jit(capture_mode=“bytecode”) is considered invalid.
|
|
876
1060
|
|
|
877
1061
|
Args:
|
|
878
|
-
|
|
879
|
-
|
|
880
|
-
|
|
881
|
-
|
|
882
|
-
|
|
883
|
-
|
|
884
|
-
|
|
885
|
-
|
|
886
|
-
|
|
887
|
-
|
|
888
|
-
|
|
889
|
-
|
|
890
|
-
|
|
891
|
-
|
|
892
|
-
|
|
893
|
-
|
|
894
|
-
|
|
895
|
-
|
|
896
|
-
|
|
897
|
-
|
|
898
|
-
|
|
899
|
-
|
|
900
|
-
|
|
901
|
-
|
|
902
|
-
|
|
903
|
-
|
|
904
|
-
|
|
905
|
-
|
|
1062
|
+
function (Function, optional): The Python function that will be run as a graph. Default: ``None``.
|
|
1063
|
+
|
|
1064
|
+
Keyword Args:
|
|
1065
|
+
capture_mode (str, optional): The method to create a callable MindSpore graph. The value of capture_mode
|
|
1066
|
+
should be ``ast`` , ``bytecode`` or ``trace`` . Default: ``ast`` .
|
|
1067
|
+
|
|
1068
|
+
- `ast <https://www.mindspore.cn/docs/en/r2.5.0/model_train/program_form/static_graph.html>`_ :
|
|
1069
|
+
Parse Python ast to build graph.
|
|
1070
|
+
- `bytecode <https://www.mindspore.cn/docs/en/r2.5.0/model_train/program_form/pynative.html#pijit>`_ :
|
|
1071
|
+
Parse Python bytecode to build graph at runtime. This is an experimental prototype that is subject to
|
|
1072
|
+
change and/or deletion.
|
|
1073
|
+
- `trace` : Trace the execution of Python code to build graph. This is an experimental prototype that is
|
|
1074
|
+
subject to change and/or deletion.
|
|
1075
|
+
|
|
1076
|
+
jit_level (str, optional): Used to control the compilation optimization level. Currently is only effective
|
|
1077
|
+
with default backend. The value of jit_level should be ``O0`` or ``O1`` . Default: ``O0`` .
|
|
1078
|
+
|
|
1079
|
+
- `O0`: Except for optimizations that may affect functionality, all other optimizations are turned off.
|
|
1080
|
+
- `O1`: Using commonly used optimizations and automatic operator fusion optimizations. This optimization
|
|
1081
|
+
level is experimental and is being improved.
|
|
1082
|
+
|
|
1083
|
+
dynamic (int, optional): Whether dynamic shape compilation should be performed. Default: ``0``. The value range
|
|
1084
|
+
is as follows:
|
|
1085
|
+
|
|
1086
|
+
- `0`: Do not perform dynamic shape compilation.
|
|
1087
|
+
- `1`: Enable dynamic shape compilation and automatically detect shape changes.
|
|
1088
|
+
|
|
1089
|
+
fullgraph (bool, optional): Whether to capture the entire function into graph. If False, jit attempts to
|
|
1090
|
+
be compatible with all Python syntax in the function as much as possible. If True, we require that the
|
|
1091
|
+
entire function can be captured into graph. If this is not possible (that is, if there is Python syntax
|
|
1092
|
+
not supported), then it will raise an exception. This currently only applies when capture_mode is ast.
|
|
1093
|
+
Default: ``False``.
|
|
1094
|
+
backend (str, optional): The compilation backend to be used. If this parameter is not set, the framework will
|
|
1095
|
+
use ``GE`` backend for Atlas training series products and ``ms_backend`` backend for others including Atlas
|
|
1096
|
+
A2 training series products by default.
|
|
1097
|
+
|
|
1098
|
+
- `ms_backend`: Adopt KernelByKernel execution mode.
|
|
1099
|
+
- `GE`: Adopt Sink execution mode. The whole model will be sinked to device to execute, only applicable to
|
|
1100
|
+
the top cell of model. And only can be used in Ascend platform.
|
|
1101
|
+
|
|
1102
|
+
**options (dict): A dictionary of options to pass to the compilation backend.
|
|
1103
|
+
|
|
1104
|
+
Some options are device specific, see the below table for details:
|
|
1105
|
+
|
|
1106
|
+
+---------------------------+---------------------------+-------------------------+
|
|
1107
|
+
| Option Parameters | Hardware Platform Support | Backend Support |
|
|
1108
|
+
+===========================+===========================+=========================+
|
|
1109
|
+
| disable_format_transform | GPU | ms_backend |
|
|
1110
|
+
+---------------------------+---------------------------+-------------------------+
|
|
1111
|
+
| exec_order | Ascend | ms_backend |
|
|
1112
|
+
+---------------------------+---------------------------+-------------------------+
|
|
1113
|
+
| ge_options | Ascend | GE |
|
|
1114
|
+
+---------------------------+---------------------------+-------------------------+
|
|
1115
|
+
| infer_boost | Ascend | ms_backend |
|
|
1116
|
+
+---------------------------+---------------------------+-------------------------+
|
|
1117
|
+
|
|
1118
|
+
- disable_format_transform (bool, optional): Whether to disable the automatic format transform function
|
|
1119
|
+
from NCHW to NHWC. When the network training performance of fp16 is worse than fp32,
|
|
1120
|
+
`disable_format_transform` can be set to ``True`` to try to improve training performance.
|
|
1121
|
+
Default: ``False`` .
|
|
1122
|
+
- exec_order (str, optional): Set the sorting method for operator execution, currently only two sorting
|
|
1123
|
+
methods are supported: ``bfs`` and ``dfs`` . Default: ``bfs`` .
|
|
1124
|
+
|
|
1125
|
+
- `bfs`: The default sorting method, breadth priority, good communication masking, relatively good
|
|
1126
|
+
performance.
|
|
1127
|
+
- `dfs`: An optional sorting method, depth-first sorting. The performance is relatively worse than that
|
|
1128
|
+
of bfs execution order, but it occupies less memory. It is recommended to try dfs in scenarios where
|
|
1129
|
+
other execution orders run out of memory (OOM).
|
|
1130
|
+
|
|
1131
|
+
- ge_options (dict): Set options for ge backend. The options are divided into two categories: global,
|
|
1132
|
+
and session. This is an experimental prototype that is subject to change and/or deletion.
|
|
1133
|
+
For detailed information, please refer to `Ascend community <https://www.hiascend.com/document/detail/zh/canncommercial/80RC3/apiref/ascendgraphapi/atlasgeapi_07_0146.html>`_ .
|
|
1134
|
+
|
|
1135
|
+
- global (dict): Set global options.
|
|
1136
|
+
- session (dict): Set session options.
|
|
1137
|
+
|
|
1138
|
+
- infer_boost (str, optional): Used to control the inference mode. Default: ``off``, which means
|
|
1139
|
+
the inference mode is disabled. The range is as follows:
|
|
1140
|
+
|
|
1141
|
+
- `on`: Enable inference mode, get better infer performance.
|
|
1142
|
+
- `off`: Disable inference mode, use forward for inference. The performance is poor.
|
|
906
1143
|
|
|
907
1144
|
Returns:
|
|
908
1145
|
Function, if `fn` is not None, returns a callable function that will execute the compiled function; If `fn` is
|
|
@@ -921,12 +1158,12 @@ def jit(fn=None, mode="PSJit", input_signature=None, hash_args=None, jit_config=
|
|
|
921
1158
|
>>> x = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
|
|
922
1159
|
>>> y = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
|
|
923
1160
|
...
|
|
924
|
-
>>> # create a callable MindSpore graph by calling
|
|
1161
|
+
>>> # create a callable MindSpore graph by calling jit
|
|
925
1162
|
>>> def tensor_add(x, y):
|
|
926
1163
|
... z = x + y
|
|
927
1164
|
... return z
|
|
928
1165
|
...
|
|
929
|
-
>>> tensor_add_graph = jit(
|
|
1166
|
+
>>> tensor_add_graph = jit(function=tensor_add)
|
|
930
1167
|
>>> out = tensor_add_graph(x, y)
|
|
931
1168
|
...
|
|
932
1169
|
>>> # create a callable MindSpore graph through decorator @jit
|
|
@@ -937,180 +1174,70 @@ def jit(fn=None, mode="PSJit", input_signature=None, hash_args=None, jit_config=
|
|
|
937
1174
|
...
|
|
938
1175
|
>>> out = tensor_add_with_dec(x, y)
|
|
939
1176
|
...
|
|
940
|
-
>>> # create a callable MindSpore graph
|
|
941
|
-
>>> @jit(
|
|
942
|
-
...
|
|
943
|
-
... def tensor_add_with_sig(x, y):
|
|
1177
|
+
>>> # create a callable MindSpore graph and capture the entire function into the graph
|
|
1178
|
+
>>> @jit(fullgraph=True)
|
|
1179
|
+
... def tensor_add_fullgraph(x, y):
|
|
944
1180
|
... z = x + y
|
|
945
1181
|
... return z
|
|
946
1182
|
...
|
|
947
|
-
>>> out =
|
|
948
|
-
...
|
|
949
|
-
>>> @jit(input_signature={"y": Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))})
|
|
950
|
-
... def tensor_add_with_sig_1(x, y):
|
|
951
|
-
... z = x + y
|
|
952
|
-
... return z
|
|
953
|
-
...
|
|
954
|
-
>>> out1 = tensor_add_with_sig_1(x, y)
|
|
955
|
-
...
|
|
956
|
-
... # Set hash_args as fn, otherwise cache of compiled closure_fn will not be reused.
|
|
957
|
-
... # While fn differs during calling again, recompilation will be triggered.
|
|
958
|
-
>>> def func(x):
|
|
959
|
-
... return ops.exp(x)
|
|
960
|
-
...
|
|
961
|
-
>>> def closure_fn(x, fn):
|
|
962
|
-
... @jit(hash_args=fn)
|
|
963
|
-
... def inner_fn(a):
|
|
964
|
-
... return fn(a)
|
|
965
|
-
... return inner_fn(x)
|
|
966
|
-
...
|
|
967
|
-
>>> inputs = Tensor(np.ones([10, 10, 10]).astype(np.float32))
|
|
968
|
-
>>> for i in range(10):
|
|
969
|
-
... closure_fn(inputs, func)
|
|
970
|
-
...
|
|
971
|
-
... # Set compile_once = True, otherwise the train_step will be compiled again.
|
|
972
|
-
>>> def train(x):
|
|
973
|
-
... @jit(compile_once = True)
|
|
974
|
-
... def train_step(x):
|
|
975
|
-
... return ops.exp(x)
|
|
976
|
-
... for i in range(10):
|
|
977
|
-
... train_step(x)
|
|
978
|
-
...
|
|
979
|
-
>>> inputs = Tensor(np.ones([10, 10, 10]).astype(np.float32))
|
|
980
|
-
>>> for i in range(10):
|
|
981
|
-
... train(inputs)
|
|
1183
|
+
>>> out = tensor_add_fullgraph(x, y)
|
|
982
1184
|
"""
|
|
983
1185
|
|
|
984
|
-
|
|
985
|
-
|
|
986
|
-
|
|
987
|
-
|
|
988
|
-
|
|
989
|
-
|
|
990
|
-
|
|
991
|
-
|
|
992
|
-
|
|
1186
|
+
capture_mode = Validator.check_string(capture_mode, ["ast", "bytecode", "trace"], "capture_mode", "jit")
|
|
1187
|
+
jit_level = Validator.check_string(jit_level, ["O0", "O1"], "jit_level", "jit")
|
|
1188
|
+
dynamic = Validator.check_int_range(dynamic, 0, 1, Validator.INC_BOTH, "dynamic", "jit")
|
|
1189
|
+
fullgraph = Validator.check_bool(fullgraph, "fullgraph", "jit")
|
|
1190
|
+
if backend == "":
|
|
1191
|
+
backend = "GE" if MSContext.get_instance().get_ascend_soc_version() == "ascend910" else "ms_backend"
|
|
1192
|
+
backend = Validator.check_string(backend, ["ms_backend", "GE"], "backend", "jit")
|
|
1193
|
+
jit_syntax_level = "LAX" if fullgraph is False else "STRICT"
|
|
1194
|
+
hash_obj = _get_hash_obj(options)
|
|
1195
|
+
_check_options(options, backend)
|
|
1196
|
+
options_str = json.dumps(options)
|
|
1197
|
+
infer_boost = options['infer_boost'] if 'infer_boost' in options else "off"
|
|
1198
|
+
exc_mode = options['exc_mode'] if 'exc_mode' in options else "auto"
|
|
1199
|
+
jit_config = JitConfig(jit_level=jit_level, exc_mode=exc_mode, jit_syntax_level=jit_syntax_level,
|
|
1200
|
+
infer_boost=infer_boost, backend=backend, options=options_str)
|
|
1201
|
+
|
|
1202
|
+
def wrap_func(func):
|
|
1203
|
+
nonlocal hash_obj
|
|
1204
|
+
if hash_obj is None or not _is_inner_func(func):
|
|
993
1205
|
hash_obj = int(time.time() * 1e9)
|
|
994
1206
|
|
|
995
|
-
dyn_args = _process_dyn_args(func, input_signature)
|
|
996
|
-
|
|
997
1207
|
@wraps(func)
|
|
998
1208
|
def staging_specialize(*args, **kwargs):
|
|
999
1209
|
if os.getenv("MS_JIT") == '0':
|
|
1000
1210
|
return func(*args, **kwargs)
|
|
1001
1211
|
|
|
1002
1212
|
args, kwargs = _handle_func_args(func, *args, **kwargs)
|
|
1003
|
-
|
|
1004
1213
|
process_obj = None
|
|
1005
1214
|
if args and not isinstance(args[0], PythonTensor) and hasattr(args[0], func.__name__):
|
|
1006
1215
|
process_obj = args[0]
|
|
1007
|
-
# only the function or cell instance wrapped by shard will fall into this branch
|
|
1008
|
-
if _is_pynative_parallel() and func.__name__ == _PYNATIVE_PARALLEL_FUNC_NAME:
|
|
1009
|
-
process_obj = hash_args
|
|
1010
1216
|
# Handle auto mixed precision strategy.
|
|
1011
1217
|
if not hasattr(func, "amp_strategy"):
|
|
1012
1218
|
if isinstance(func, types.MethodType):
|
|
1013
1219
|
setattr(func.__func__, "amp_strategy", get_curr_amp_strategy())
|
|
1014
1220
|
else:
|
|
1015
1221
|
setattr(func, "amp_strategy", get_curr_amp_strategy())
|
|
1016
|
-
|
|
1222
|
+
|
|
1223
|
+
ms_function_executor = _JitExecutor(func, hash_obj, None, process_obj, jit_config, dynamic)
|
|
1224
|
+
out = ms_function_executor(*args, **kwargs)
|
|
1017
1225
|
return out
|
|
1018
1226
|
|
|
1019
1227
|
return staging_specialize
|
|
1020
1228
|
|
|
1021
|
-
|
|
1022
|
-
|
|
1023
|
-
|
|
1229
|
+
if capture_mode == "bytecode":
|
|
1230
|
+
wrap_func = PIJitCaptureContext(jit_config)
|
|
1231
|
+
elif capture_mode == "trace":
|
|
1232
|
+
if function is not None:
|
|
1233
|
+
return _jit_trace(function)
|
|
1234
|
+
return _jit_trace
|
|
1024
1235
|
|
|
1025
|
-
if
|
|
1026
|
-
return wrap_func(
|
|
1236
|
+
if function is not None:
|
|
1237
|
+
return wrap_func(function)
|
|
1027
1238
|
return wrap_func
|
|
1028
1239
|
|
|
1029
1240
|
|
|
1030
|
-
def ms_function(fn=None, input_signature=None, hash_args=None, jit_config=None):
|
|
1031
|
-
"""
|
|
1032
|
-
Create a callable MindSpore graph from a Python function.
|
|
1033
|
-
|
|
1034
|
-
This allows the MindSpore runtime to apply optimizations based on graph.
|
|
1035
|
-
|
|
1036
|
-
Note:
|
|
1037
|
-
- `ms_function` will be deprecated and removed in a future version. Please use :func:`mindspore.jit` instead.
|
|
1038
|
-
- If `input_signature` is specified, each input of `fn` must be a Tensor. And the input arguments for `fn`
|
|
1039
|
-
will not accept `**kwargs`.
|
|
1040
|
-
|
|
1041
|
-
Args:
|
|
1042
|
-
fn (Function): The Python function that will be run as a graph. Default: ``None`` .
|
|
1043
|
-
input_signature (Tensor): The Tensor which describes the input arguments. The shape and dtype of the Tensor
|
|
1044
|
-
will be supplied to this function. The shape and dtype of actual inputs of `fn` should
|
|
1045
|
-
keep the same as input_signature. Otherwise, TypeError will be raised. Default: ``None`` .
|
|
1046
|
-
hash_args (Union[Object, List or Tuple of Objects]): The local free variables used inside `fn`,
|
|
1047
|
-
like functions or objects of class defined outside `fn`. Calling `fn` again with change of `hash_args`
|
|
1048
|
-
will trigger recompilation. Default: ``None`` .
|
|
1049
|
-
jit_config (JitConfig): Jit config for compile. Default: ``None`` .
|
|
1050
|
-
|
|
1051
|
-
Returns:
|
|
1052
|
-
Function, if `fn` is not None, returns a callable function that will execute the compiled function; If `fn` is
|
|
1053
|
-
None, returns a decorator and when this decorator invokes with a single `fn` argument, the callable function is
|
|
1054
|
-
equal to the case when `fn` is not None.
|
|
1055
|
-
|
|
1056
|
-
Supported Platforms:
|
|
1057
|
-
``Ascend`` ``GPU`` ``CPU``
|
|
1058
|
-
|
|
1059
|
-
Examples:
|
|
1060
|
-
>>> import numpy as np
|
|
1061
|
-
>>> from mindspore import Tensor
|
|
1062
|
-
>>> from mindspore import ops
|
|
1063
|
-
>>> from mindspore import ms_function
|
|
1064
|
-
...
|
|
1065
|
-
>>> x = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
|
|
1066
|
-
>>> y = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
|
|
1067
|
-
...
|
|
1068
|
-
>>> # create a callable MindSpore graph by calling ms_function
|
|
1069
|
-
>>> def tensor_add(x, y):
|
|
1070
|
-
... z = x + y
|
|
1071
|
-
... return z
|
|
1072
|
-
...
|
|
1073
|
-
>>> tensor_add_graph = ms_function(fn=tensor_add)
|
|
1074
|
-
>>> out = tensor_add_graph(x, y)
|
|
1075
|
-
...
|
|
1076
|
-
>>> # create a callable MindSpore graph through decorator @ms_function
|
|
1077
|
-
>>> @ms_function
|
|
1078
|
-
... def tensor_add_with_dec(x, y):
|
|
1079
|
-
... z = x + y
|
|
1080
|
-
... return z
|
|
1081
|
-
...
|
|
1082
|
-
>>> out = tensor_add_with_dec(x, y)
|
|
1083
|
-
...
|
|
1084
|
-
>>> # create a callable MindSpore graph through decorator @ms_function with input_signature parameter
|
|
1085
|
-
>>> @ms_function(input_signature=(Tensor(np.ones([1, 1, 3, 3]).astype(np.float32)),
|
|
1086
|
-
... Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))))
|
|
1087
|
-
... def tensor_add_with_sig(x, y):
|
|
1088
|
-
... z = x + y
|
|
1089
|
-
... return z
|
|
1090
|
-
...
|
|
1091
|
-
>>> out = tensor_add_with_sig(x, y)
|
|
1092
|
-
...
|
|
1093
|
-
... # Set hash_args as fn, otherwise cache of compiled `closure_fn` will not be reused.
|
|
1094
|
-
... # While fn differs during calling again, recompilation will be triggered.
|
|
1095
|
-
>>> def func(x):
|
|
1096
|
-
... return ops.exp(x)
|
|
1097
|
-
...
|
|
1098
|
-
>>> def closure_fn(x, fn):
|
|
1099
|
-
... @ms_function(hash_args=fn)
|
|
1100
|
-
... def inner_fn(a):
|
|
1101
|
-
... return fn(a)
|
|
1102
|
-
... return inner_fn(x)
|
|
1103
|
-
...
|
|
1104
|
-
>>> inputs = Tensor(np.ones([10, 10, 10]).astype(np.float32))
|
|
1105
|
-
>>> for i in range(10):
|
|
1106
|
-
... closure_fn(inputs, func)
|
|
1107
|
-
"""
|
|
1108
|
-
|
|
1109
|
-
logger.warning("'mindspore.ms_function' will be deprecated and removed in a future version. "
|
|
1110
|
-
"Please use 'mindspore.jit' instead.")
|
|
1111
|
-
return jit(fn=fn, input_signature=input_signature, hash_args=hash_args, jit_config=jit_config)
|
|
1112
|
-
|
|
1113
|
-
|
|
1114
1241
|
def _core(fn=None, **flags):
|
|
1115
1242
|
"""
|
|
1116
1243
|
A decorator that adds a flag to the function.
|
|
@@ -1203,69 +1330,6 @@ def _no_recursive(callable_obj):
|
|
|
1203
1330
|
return callable_obj
|
|
1204
1331
|
|
|
1205
1332
|
|
|
1206
|
-
def ms_class(cls):
|
|
1207
|
-
"""
|
|
1208
|
-
Class decorator for user-defined classes.
|
|
1209
|
-
|
|
1210
|
-
This allows MindSpore to identify user-defined classes and thus obtain their attributes and methods.
|
|
1211
|
-
|
|
1212
|
-
Note:
|
|
1213
|
-
`ms_class` will be deprecated and removed in a future version. Please use :func:`mindspore.jit_class` instead.
|
|
1214
|
-
|
|
1215
|
-
Args:
|
|
1216
|
-
cls (Class): User-defined class.
|
|
1217
|
-
|
|
1218
|
-
Returns:
|
|
1219
|
-
Class.
|
|
1220
|
-
|
|
1221
|
-
Raises:
|
|
1222
|
-
TypeError: If ms_class is used for non-class types or nn.Cell.
|
|
1223
|
-
AttributeError: If the private attributes or magic methods of the class decorated with ms_class is called.
|
|
1224
|
-
|
|
1225
|
-
Supported Platforms:
|
|
1226
|
-
``Ascend`` ``GPU`` ``CPU``
|
|
1227
|
-
|
|
1228
|
-
Examples:
|
|
1229
|
-
>>> import mindspore.nn as nn
|
|
1230
|
-
>>> from mindspore import ms_class
|
|
1231
|
-
...
|
|
1232
|
-
>>> @ms_class
|
|
1233
|
-
... class UserDefinedNet:
|
|
1234
|
-
... def __init__(self):
|
|
1235
|
-
... self.value = 10
|
|
1236
|
-
...
|
|
1237
|
-
... def func(self, x):
|
|
1238
|
-
... return 2 * x
|
|
1239
|
-
...
|
|
1240
|
-
>>> class Net(nn.Cell):
|
|
1241
|
-
... def __init__(self):
|
|
1242
|
-
... super(Net, self).__init__()
|
|
1243
|
-
... self.net = UserDefinedNet()
|
|
1244
|
-
...
|
|
1245
|
-
... def construct(self, x):
|
|
1246
|
-
... out = self.net.value + self.net.func(x)
|
|
1247
|
-
... return out
|
|
1248
|
-
...
|
|
1249
|
-
>>> net = Net()
|
|
1250
|
-
>>> out = net(5)
|
|
1251
|
-
>>> print(out)
|
|
1252
|
-
20
|
|
1253
|
-
"""
|
|
1254
|
-
|
|
1255
|
-
logger.warning("'mindspore.ms_class' will be deprecated and removed in a future version. "
|
|
1256
|
-
"Please use 'mindspore.jit_class' instead.")
|
|
1257
|
-
|
|
1258
|
-
# Check if cls is of type class.
|
|
1259
|
-
if not inspect.isclass(cls):
|
|
1260
|
-
raise TypeError(f'Decorator ms_class can only be used for class type, but got {cls}.')
|
|
1261
|
-
# Check if cls is nn.Cell.
|
|
1262
|
-
if issubclass(cls, ms.nn.Cell):
|
|
1263
|
-
raise TypeError(f"Decorator ms_class is used for user-defined classes and cannot be used for nn.Cell: {cls}.")
|
|
1264
|
-
logger.info(f'Found ms_class: {cls}.')
|
|
1265
|
-
setattr(cls, '__ms_class__', True)
|
|
1266
|
-
return cls
|
|
1267
|
-
|
|
1268
|
-
|
|
1269
1333
|
def jit_class(cls):
|
|
1270
1334
|
"""
|
|
1271
1335
|
Class decorator for user-defined classes.
|
|
@@ -1322,28 +1386,6 @@ def jit_class(cls):
|
|
|
1322
1386
|
return cls
|
|
1323
1387
|
|
|
1324
1388
|
|
|
1325
|
-
def set_adapter_config(config):
|
|
1326
|
-
"""
|
|
1327
|
-
Register configuration information for MSAdapter.
|
|
1328
|
-
|
|
1329
|
-
Args:
|
|
1330
|
-
config (dict): Configuration information.
|
|
1331
|
-
"""
|
|
1332
|
-
if not isinstance(config, dict):
|
|
1333
|
-
raise TypeError(f"The input argument of 'set_adapter_config' should be a dict, but got {config}.")
|
|
1334
|
-
for key, value in config.items():
|
|
1335
|
-
if key == "Tensor":
|
|
1336
|
-
ms_adapter_registry.register_tensor(value)
|
|
1337
|
-
elif key == "Parameter":
|
|
1338
|
-
ms_adapter_registry.register_parameter(value)
|
|
1339
|
-
elif key == "convert_object_map":
|
|
1340
|
-
ms_adapter_registry.register_convert_map(value)
|
|
1341
|
-
elif key == "convert_adapter_tensor_map":
|
|
1342
|
-
ms_adapter_registry.register_convert_adapter_tensor_map(value)
|
|
1343
|
-
else:
|
|
1344
|
-
raise ValueError(f"Unsupported key in adapter config: {key}")
|
|
1345
|
-
|
|
1346
|
-
|
|
1347
1389
|
def _function_forbid_reuse(func):
|
|
1348
1390
|
if not inspect.isfunction(func):
|
|
1349
1391
|
raise TypeError(f'Decorator _function_forbid_reuse can only be used for function type, but got {func}.')
|
|
@@ -1535,7 +1577,24 @@ class _PyNativeExecutor:
|
|
|
1535
1577
|
Return:
|
|
1536
1578
|
None.
|
|
1537
1579
|
"""
|
|
1538
|
-
return self._executor.grad(grad, obj, weights, grad_position, *args)
|
|
1580
|
+
return self._executor.grad(grad, obj, weights, grad_position, False, *args)
|
|
1581
|
+
|
|
1582
|
+
def grad_aux(self, obj, grad, weights, grad_position, *args):
|
|
1583
|
+
"""
|
|
1584
|
+
Run grad graph with aux
|
|
1585
|
+
|
|
1586
|
+
Args:
|
|
1587
|
+
obj (Function/Cell): The function or cell instance.
|
|
1588
|
+
grad (GradOperation): The gradoperation object.
|
|
1589
|
+
weights (ParameterTuple): The weights of cell instance.
|
|
1590
|
+
grad_position (Union(int, tuple[int])): If int, get the gradient with respect to single input.
|
|
1591
|
+
If tuple, get the gradients with respect to selected inputs. 'grad_position' begins with 0. Default: 0.
|
|
1592
|
+
args (tuple): Function or cell input arguments.
|
|
1593
|
+
|
|
1594
|
+
Return:
|
|
1595
|
+
None.
|
|
1596
|
+
"""
|
|
1597
|
+
return self._executor.grad(grad, obj, weights, grad_position, True, *args)
|
|
1539
1598
|
|
|
1540
1599
|
def clear_res(self):
|
|
1541
1600
|
"""
|
|
@@ -1671,6 +1730,15 @@ class _PyNativeExecutor:
|
|
|
1671
1730
|
"""
|
|
1672
1731
|
self._executor.set_is_run_recompute(status)
|
|
1673
1732
|
|
|
1733
|
+
def high_order(self):
|
|
1734
|
+
"""
|
|
1735
|
+
Is high order of current scene, this is a inner interface.
|
|
1736
|
+
|
|
1737
|
+
Return:
|
|
1738
|
+
Bool.
|
|
1739
|
+
"""
|
|
1740
|
+
return self._executor.high_order()
|
|
1741
|
+
|
|
1674
1742
|
def set_cell_use_dynamic_shape_process(self, flag):
|
|
1675
1743
|
"""
|
|
1676
1744
|
Set the dynamic shape flag of eval process.
|
|
@@ -1753,7 +1821,6 @@ class _CellGraphExecutor:
|
|
|
1753
1821
|
# create needed graph by lazy mode
|
|
1754
1822
|
self.is_init = False
|
|
1755
1823
|
self.enable_tuple_broaden = False
|
|
1756
|
-
self.obfuscate_config = None # used for model's dynamic obfuscation
|
|
1757
1824
|
self._graph_executor = GraphExecutor_.get_instance()
|
|
1758
1825
|
self._graph_executor.set_py_exe_path(sys.executable)
|
|
1759
1826
|
self._graph_executor.set_kernel_build_server_dir(os.path.split(kernel_build_server.__file__)[0] + os.sep)
|
|
@@ -1845,6 +1912,7 @@ class _CellGraphExecutor:
|
|
|
1845
1912
|
Str, the full phase of the cell.
|
|
1846
1913
|
Bool, if the graph has been compiled before, return False, else return True.
|
|
1847
1914
|
"""
|
|
1915
|
+
_init_auto_parallel_context(obj)
|
|
1848
1916
|
obj.__parse_method__ = 'construct'
|
|
1849
1917
|
if not hasattr(obj, obj.__parse_method__):
|
|
1850
1918
|
raise AttributeError(
|
|
@@ -1877,6 +1945,7 @@ class _CellGraphExecutor:
|
|
|
1877
1945
|
# Release resource should be released when CompileInner won't be executed, such as cur_convert_input_
|
|
1878
1946
|
# generated in generate_arguments_key.
|
|
1879
1947
|
self._graph_executor.clear_compile_arguments_resource()
|
|
1948
|
+
_clear_auto_parallel_context(obj)
|
|
1880
1949
|
return phase, False
|
|
1881
1950
|
|
|
1882
1951
|
full_function_name = obj.__class__.__name__ + '.' + str(obj.instance_count) + '.' + str(id(type(obj)))
|
|
@@ -1894,7 +1963,8 @@ class _CellGraphExecutor:
|
|
|
1894
1963
|
else:
|
|
1895
1964
|
jit_config_dict = JitConfig().jit_config_dict
|
|
1896
1965
|
self._graph_executor.set_jit_config(jit_config_dict)
|
|
1897
|
-
|
|
1966
|
+
gc.collect()
|
|
1967
|
+
result = self._graph_executor.compile(obj, args, kwargs, phase)
|
|
1898
1968
|
obj.compile_cache.add(phase)
|
|
1899
1969
|
if not result:
|
|
1900
1970
|
raise RuntimeError("Executor compile failed.")
|
|
@@ -1915,6 +1985,7 @@ class _CellGraphExecutor:
|
|
|
1915
1985
|
self._build_data_graph(obj, phase)
|
|
1916
1986
|
elif BROADCAST_PHASE not in phase and _get_parameter_broadcast():
|
|
1917
1987
|
_parameter_broadcast(obj)
|
|
1988
|
+
_clear_auto_parallel_context(obj)
|
|
1918
1989
|
return phase, True
|
|
1919
1990
|
|
|
1920
1991
|
def _update_param_node_default_input(self, phase, replace):
|
|
@@ -1994,25 +2065,12 @@ class _CellGraphExecutor:
|
|
|
1994
2065
|
"""Clear the memory resource of a network."""
|
|
1995
2066
|
self._graph_executor.del_net_res(obj, net_id)
|
|
1996
2067
|
|
|
1997
|
-
def _get_branch_control_input(self):
|
|
1998
|
-
if ('obf_ratio' not in self.obfuscate_config.keys()) or (
|
|
1999
|
-
'obf_random_seed' not in self.obfuscate_config.keys()):
|
|
2000
|
-
raise ValueError("'obf_ratio' and 'obf_random_seed' must be in obfuscate_config.")
|
|
2001
|
-
obf_random_seed = self.obfuscate_config.get('obf_random_seed')
|
|
2002
|
-
if obf_random_seed == 0:
|
|
2003
|
-
branch_control_input = 0
|
|
2004
|
-
else:
|
|
2005
|
-
branch_control_input = _generate_branch_control_input(obf_random_seed)
|
|
2006
|
-
return branch_control_input
|
|
2007
|
-
|
|
2008
2068
|
def _get_func_graph(self, obj, exec_id, use_prefix=False):
|
|
2009
2069
|
"""Get func graph from pipeline."""
|
|
2010
2070
|
if use_prefix:
|
|
2011
2071
|
exec_id = exec_id + '.' + obj.arguments_key
|
|
2012
2072
|
if self._graph_executor.has_compiled(exec_id) is False:
|
|
2013
2073
|
return None
|
|
2014
|
-
if self.obfuscate_config is not None:
|
|
2015
|
-
raise ValueError('For get func graph, obfuscate_config is currently not supported now.')
|
|
2016
2074
|
return self._graph_executor.get_func_graph(exec_id)
|
|
2017
2075
|
|
|
2018
2076
|
def _get_func_graph_proto(self, obj, exec_id, ir_type="onnx_ir", use_prefix=False, incremental=False):
|
|
@@ -2021,11 +2079,6 @@ class _CellGraphExecutor:
|
|
|
2021
2079
|
exec_id = exec_id + '.' + obj.arguments_key
|
|
2022
2080
|
if self._graph_executor.has_compiled(exec_id) is False:
|
|
2023
2081
|
return None
|
|
2024
|
-
if self.obfuscate_config is not None:
|
|
2025
|
-
branch_control_input = self._get_branch_control_input()
|
|
2026
|
-
return self._graph_executor.get_obfuscate_func_graph_proto(exec_id, incremental,
|
|
2027
|
-
self.obfuscate_config['obf_ratio'],
|
|
2028
|
-
branch_control_input)
|
|
2029
2082
|
return self._graph_executor.get_func_graph_proto(exec_id, ir_type, incremental)
|
|
2030
2083
|
|
|
2031
2084
|
def get_optimize_graph_proto(self, obj):
|
|
@@ -2063,6 +2116,8 @@ def ms_memory_recycle():
|
|
|
2063
2116
|
"""
|
|
2064
2117
|
if ms_compile_cache:
|
|
2065
2118
|
_cell_graph_executor.del_net_res(None, ms_compile_cache)
|
|
2119
|
+
if os.getenv('MS_DEV_JIT_PIPELINE') != '0':
|
|
2120
|
+
JitExecutor_.get_instance().del_net_res(None, ms_compile_cache)
|
|
2066
2121
|
ms_compile_cache.clear()
|
|
2067
2122
|
for cell_cache in cells_compile_cache.values():
|
|
2068
2123
|
if cell_cache:
|
|
@@ -2089,30 +2144,6 @@ def set_recursion_limit(recursion_limit=1000):
|
|
|
2089
2144
|
GraphExecutor_.get_instance().set_max_call_depth(recursion_limit)
|
|
2090
2145
|
|
|
2091
2146
|
|
|
2092
|
-
def _generate_branch_control_input(obf_random_seed):
|
|
2093
|
-
"""Generate append network input for dynamic obfuscation in random seed mode."""
|
|
2094
|
-
seed_max = 2 ** 32 - 1
|
|
2095
|
-
int_max = 2 ** 31 - 1
|
|
2096
|
-
np.random.seed(obf_random_seed % seed_max)
|
|
2097
|
-
# generate a string as hash function inputs
|
|
2098
|
-
word_repo = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + "abcdefghigklmnopqrstuvwxyz" + "0123456789"
|
|
2099
|
-
repo_len = len(word_repo)
|
|
2100
|
-
sha_string = ''
|
|
2101
|
-
string_len = 1024 * 1024
|
|
2102
|
-
for _ in range(string_len):
|
|
2103
|
-
rand_index = np.random.randint(0, repo_len)
|
|
2104
|
-
sha_string += word_repo[rand_index]
|
|
2105
|
-
# get hash result
|
|
2106
|
-
sha_result = hashlib.sha256(sha_string.encode('utf-8')).hexdigest() # len is 64
|
|
2107
|
-
branch_control_input = 1
|
|
2108
|
-
hex_base = 16
|
|
2109
|
-
for item in sha_result:
|
|
2110
|
-
if int(item, hex_base) > 0:
|
|
2111
|
-
branch_control_input *= int(item, hex_base)
|
|
2112
|
-
branch_control_input %= int_max
|
|
2113
|
-
return branch_control_input
|
|
2114
|
-
|
|
2115
|
-
|
|
2116
2147
|
def _bind_device_context():
|
|
2117
2148
|
"""Bind device context to current thread"""
|
|
2118
2149
|
_bind_device_ctx()
|
|
@@ -2135,4 +2166,4 @@ def flops_collection(phase='train'):
|
|
|
2135
2166
|
_cell_graph_executor = _CellGraphExecutor()
|
|
2136
2167
|
_pynative_executor = _PyNativeExecutor()
|
|
2137
2168
|
|
|
2138
|
-
__all__ = ['
|
|
2169
|
+
__all__ = ['ms_memory_recycle', 'jit', 'jit_class', 'flops_collection']
|