mindspore 2.5.0__cp311-cp311-win_amd64.whl → 2.6.0__cp311-cp311-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/__init__.py +6 -4
- mindspore/_c_dataengine.cp311-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp311-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp311-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +3 -0
- mindspore/_checkparam.py +3 -33
- mindspore/_deprecated/__init__.py +17 -0
- mindspore/_deprecated/jit.py +198 -0
- mindspore/_extends/builtin_operations.py +1 -1
- mindspore/_extends/parse/__init__.py +6 -7
- mindspore/_extends/parse/compile_config.py +19 -0
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +22 -3
- mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
- mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
- mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
- mindspore/_extends/parse/parser.py +25 -194
- mindspore/_extends/parse/resources.py +1 -5
- mindspore/_extends/parse/standard_method.py +109 -75
- mindspore/_extends/pijit/__init__.py +2 -2
- mindspore/_extends/pijit/pijit_func_white_list.py +16 -11
- mindspore/_extends/pijit/tensor_func_list.py +27 -0
- mindspore/_extends/utils.py +1 -1
- mindspore/amp.py +4 -4
- mindspore/atlprov.dll +0 -0
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/__init__.py +2 -2
- mindspore/boost/base.py +3 -7
- mindspore/boost/boost_cell_wrapper.py +2 -2
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/common/__init__.py +4 -3
- mindspore/common/_grad_function.py +56 -0
- mindspore/common/_pijit_context.py +14 -5
- mindspore/common/_register_for_tensor.py +1 -1
- mindspore/common/_stub_tensor.py +5 -10
- mindspore/common/_tensor_cpp_method.py +1 -1
- mindspore/common/_tensor_docs.py +2014 -3386
- mindspore/common/api.py +386 -355
- mindspore/common/auto_dynamic_shape.py +41 -44
- mindspore/common/dtype.py +5 -2
- mindspore/common/dump.py +7 -5
- mindspore/common/file_system.py +3 -0
- mindspore/common/generator.py +3 -0
- mindspore/common/hook_handle.py +5 -3
- mindspore/common/initializer.py +10 -6
- mindspore/common/jit_begin_end.py +94 -0
- mindspore/common/jit_config.py +6 -1
- mindspore/common/jit_context.py +76 -0
- mindspore/common/jit_trace.py +378 -0
- mindspore/common/lazy_inline.py +2 -2
- mindspore/common/mutable.py +5 -4
- mindspore/common/parameter.py +106 -39
- mindspore/common/seed.py +2 -2
- mindspore/common/sparse_tensor.py +23 -17
- mindspore/common/tensor.py +332 -714
- mindspore/communication/__init__.py +7 -5
- mindspore/communication/_comm_helper.py +47 -2
- mindspore/communication/comm_func.py +70 -53
- mindspore/communication/management.py +83 -17
- mindspore/context.py +228 -571
- mindspore/dataset/__init__.py +44 -20
- mindspore/dataset/audio/__init__.py +2 -8
- mindspore/dataset/audio/transforms.py +3 -17
- mindspore/dataset/core/config.py +3 -3
- mindspore/dataset/engine/cache_client.py +1 -1
- mindspore/dataset/engine/datasets.py +102 -120
- mindspore/dataset/engine/datasets_audio.py +22 -22
- mindspore/dataset/engine/datasets_standard_format.py +43 -24
- mindspore/dataset/engine/datasets_text.py +78 -85
- mindspore/dataset/engine/datasets_user_defined.py +109 -77
- mindspore/dataset/engine/datasets_vision.py +111 -108
- mindspore/dataset/engine/iterators.py +5 -3
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
- mindspore/dataset/engine/samplers.py +279 -57
- mindspore/dataset/engine/serializer_deserializer.py +2 -1
- mindspore/dataset/engine/validators.py +10 -0
- mindspore/dataset/text/__init__.py +7 -6
- mindspore/dataset/text/transforms.py +6 -5
- mindspore/dataset/text/utils.py +3 -3
- mindspore/dataset/transforms/__init__.py +0 -9
- mindspore/dataset/transforms/transforms.py +3 -3
- mindspore/dataset/utils/browse_dataset.py +1 -1
- mindspore/dataset/vision/__init__.py +2 -9
- mindspore/dataset/vision/transforms.py +202 -158
- mindspore/dataset/vision/utils.py +7 -5
- mindspore/device_context/ascend/op_debug.py +60 -1
- mindspore/device_context/ascend/op_tuning.py +0 -4
- mindspore/device_manager.py +39 -3
- mindspore/dnnl.dll +0 -0
- mindspore/dpcmi.dll +0 -0
- mindspore/experimental/es/embedding_service.py +35 -27
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +0 -2
- mindspore/experimental/map_parameter.py +4 -4
- mindspore/experimental/optim/adadelta.py +22 -26
- mindspore/experimental/optim/adagrad.py +4 -4
- mindspore/experimental/optim/adam.py +4 -0
- mindspore/experimental/optim/adamax.py +4 -4
- mindspore/experimental/optim/adamw.py +4 -0
- mindspore/experimental/optim/asgd.py +1 -1
- mindspore/experimental/optim/lr_scheduler.py +40 -22
- mindspore/experimental/optim/radam.py +5 -5
- mindspore/experimental/optim/rprop.py +1 -1
- mindspore/experimental/optim/sgd.py +1 -1
- mindspore/hal/contiguous_tensors_handle.py +6 -10
- mindspore/hal/device.py +55 -81
- mindspore/hal/event.py +38 -55
- mindspore/hal/memory.py +115 -147
- mindspore/hal/stream.py +81 -125
- mindspore/include/dataset/constants.h +7 -4
- mindspore/include/dataset/execute.h +2 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +40 -2
- mindspore/mindrecord/__init__.py +20 -7
- mindspore/mindspore_backend_common.dll +0 -0
- mindspore/mindspore_backend_manager.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_dump.dll +0 -0
- mindspore/mindspore_frontend.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_memory_pool.dll +0 -0
- mindspore/mindspore_ms_backend.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
- mindspore/mindspore_ops_kernel_common.dll +0 -0
- mindspore/mindspore_profiler.dll +0 -0
- mindspore/mindspore_pyboost.dll +0 -0
- mindspore/mindspore_pynative.dll +0 -0
- mindspore/mindspore_res_manager.dll +0 -0
- mindspore/mindspore_runtime_pipeline.dll +0 -0
- mindspore/mint/__init__.py +133 -702
- mindspore/mint/distributed/__init__.py +5 -1
- mindspore/mint/distributed/distributed.py +198 -113
- mindspore/mint/linalg/__init__.py +2 -0
- mindspore/mint/nn/__init__.py +280 -18
- mindspore/mint/nn/functional.py +282 -64
- mindspore/mint/nn/layer/__init__.py +4 -0
- mindspore/mint/nn/layer/_functions.py +7 -3
- mindspore/mint/nn/layer/activation.py +120 -13
- mindspore/mint/nn/layer/conv.py +234 -28
- mindspore/mint/nn/layer/normalization.py +15 -16
- mindspore/mint/nn/layer/padding.py +1 -1
- mindspore/mint/nn/layer/pooling.py +66 -1
- mindspore/mint/optim/__init__.py +2 -1
- mindspore/mint/optim/sgd.py +171 -0
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/nn/__init__.py +4 -1
- mindspore/nn/cell.py +1253 -179
- mindspore/nn/layer/activation.py +23 -21
- mindspore/nn/layer/basic.py +22 -16
- mindspore/nn/layer/container.py +1 -1
- mindspore/nn/layer/conv.py +53 -42
- mindspore/nn/layer/embedding.py +9 -8
- mindspore/nn/layer/normalization.py +48 -42
- mindspore/nn/layer/pooling.py +75 -31
- mindspore/nn/layer/transformer.py +11 -10
- mindspore/nn/learning_rate_schedule.py +4 -2
- mindspore/nn/loss/loss.py +27 -19
- mindspore/nn/optim/ada_grad.py +6 -5
- mindspore/nn/optim/adadelta.py +9 -7
- mindspore/nn/optim/adafactor.py +1 -1
- mindspore/nn/optim/adam.py +18 -14
- mindspore/nn/optim/adamax.py +8 -7
- mindspore/nn/optim/adasum.py +5 -5
- mindspore/nn/optim/asgd.py +3 -1
- mindspore/nn/optim/ftrl.py +11 -9
- mindspore/nn/optim/lamb.py +1 -1
- mindspore/nn/optim/lazyadam.py +12 -10
- mindspore/nn/optim/momentum.py +7 -6
- mindspore/nn/optim/optimizer.py +2 -2
- mindspore/nn/optim/proximal_ada_grad.py +12 -10
- mindspore/nn/optim/rmsprop.py +13 -12
- mindspore/nn/optim/rprop.py +9 -7
- mindspore/nn/optim/sgd.py +9 -6
- mindspore/nn/optim/tft_wrapper.py +5 -2
- mindspore/nn/probability/bijector/bijector.py +17 -11
- mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
- mindspore/nn/probability/bijector/invert.py +2 -2
- mindspore/nn/probability/bijector/scalar_affine.py +3 -3
- mindspore/nn/probability/bijector/softplus.py +3 -2
- mindspore/nn/probability/distribution/beta.py +3 -3
- mindspore/nn/probability/distribution/categorical.py +1 -1
- mindspore/nn/probability/distribution/cauchy.py +4 -2
- mindspore/nn/probability/distribution/exponential.py +6 -7
- mindspore/nn/probability/distribution/gamma.py +2 -2
- mindspore/nn/probability/distribution/gumbel.py +2 -2
- mindspore/nn/probability/distribution/half_normal.py +5 -3
- mindspore/nn/probability/distribution/logistic.py +5 -3
- mindspore/nn/probability/distribution/poisson.py +1 -1
- mindspore/nn/probability/distribution/uniform.py +5 -3
- mindspore/nn/reinforcement/_tensors_queue.py +1 -1
- mindspore/nn/reinforcement/tensor_array.py +1 -1
- mindspore/nn/wrap/__init__.py +6 -6
- mindspore/nn/wrap/cell_wrapper.py +178 -117
- mindspore/nn/wrap/grad_reducer.py +45 -36
- mindspore/nn/wrap/loss_scale.py +3 -3
- mindspore/numpy/array_creations.py +3 -3
- mindspore/numpy/array_ops.py +1 -1
- mindspore/numpy/utils.py +1 -2
- mindspore/numpy/utils_const.py +1 -2
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +3 -2
- mindspore/ops/_grad_experimental/grad_comm_ops.py +18 -3
- mindspore/ops/_grad_experimental/grad_debug_ops.py +8 -1
- mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
- mindspore/ops/_register_for_op.py +0 -11
- mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
- mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -4
- mindspore/ops/_vmap/vmap_array_ops.py +32 -6
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +2 -1
- mindspore/ops/_vmap/vmap_math_ops.py +4 -7
- mindspore/ops/_vmap/vmap_nn_ops.py +9 -8
- mindspore/ops/auto_generate/__init__.py +4 -3
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +127 -52
- mindspore/ops/auto_generate/gen_extend_func.py +286 -208
- mindspore/ops/auto_generate/gen_ops_def.py +2783 -2335
- mindspore/ops/auto_generate/gen_ops_prim.py +8992 -2686
- mindspore/ops/auto_generate/pyboost_inner_prim.py +106 -76
- mindspore/ops/composite/__init__.py +2 -1
- mindspore/ops/composite/base.py +19 -24
- mindspore/ops/composite/math_ops.py +6 -16
- mindspore/ops/composite/multitype_ops/__init__.py +5 -2
- mindspore/ops/composite/multitype_ops/_compile_utils.py +4 -5
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
- mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
- mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
- mindspore/ops/composite/multitype_ops/pow_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
- mindspore/ops/function/__init__.py +28 -2
- mindspore/ops/function/_add_attr_func.py +58 -0
- mindspore/ops/function/array_func.py +1631 -2347
- mindspore/ops/function/clip_func.py +38 -45
- mindspore/ops/function/debug_func.py +36 -44
- mindspore/ops/function/grad/__init__.py +1 -0
- mindspore/ops/function/grad/grad_func.py +104 -71
- mindspore/ops/function/image_func.py +1 -1
- mindspore/ops/function/linalg_func.py +46 -78
- mindspore/ops/function/math_func.py +3024 -3855
- mindspore/ops/function/nn_func.py +678 -274
- mindspore/ops/function/other_func.py +159 -1
- mindspore/ops/function/parameter_func.py +17 -30
- mindspore/ops/function/random_func.py +216 -361
- mindspore/ops/function/reshard_func.py +4 -70
- mindspore/ops/function/sparse_func.py +3 -3
- mindspore/ops/function/sparse_unary_func.py +5 -5
- mindspore/ops/function/spectral_func.py +25 -58
- mindspore/ops/function/vmap_func.py +26 -18
- mindspore/ops/functional.py +8 -5
- mindspore/ops/functional_overload.py +655 -4
- mindspore/ops/op_info_register.py +32 -244
- mindspore/ops/operations/__init__.py +21 -14
- mindspore/ops/operations/_custom_ops_utils.py +235 -0
- mindspore/ops/operations/_grad_ops.py +1 -10
- mindspore/ops/operations/_inner_ops.py +5 -76
- mindspore/ops/operations/_ms_kernel.py +4 -10
- mindspore/ops/operations/_rl_inner_ops.py +1 -1
- mindspore/ops/operations/_scalar_ops.py +3 -2
- mindspore/ops/operations/_sequence_ops.py +1 -1
- mindspore/ops/operations/_tensor_array.py +1 -1
- mindspore/ops/operations/array_ops.py +39 -24
- mindspore/ops/operations/comm_ops.py +150 -107
- mindspore/ops/operations/custom_ops.py +287 -32
- mindspore/ops/operations/debug_ops.py +119 -16
- mindspore/ops/operations/inner_ops.py +1 -1
- mindspore/ops/operations/linalg_ops.py +1 -58
- mindspore/ops/operations/manually_defined/_inner.py +1 -1
- mindspore/ops/operations/manually_defined/ops_def.py +746 -79
- mindspore/ops/operations/math_ops.py +21 -18
- mindspore/ops/operations/nn_ops.py +67 -224
- mindspore/ops/operations/other_ops.py +62 -9
- mindspore/ops/operations/random_ops.py +13 -7
- mindspore/ops/operations/reshard_ops.py +1 -1
- mindspore/ops/operations/sparse_ops.py +2 -2
- mindspore/ops/primitive.py +43 -32
- mindspore/ops/tensor_method.py +243 -17
- mindspore/ops_generate/__init__.py +0 -5
- mindspore/ops_generate/aclnn/__init__.py +0 -0
- mindspore/ops_generate/{aclnn_kernel_register_auto_cc_generator.py → aclnn/aclnn_kernel_register_auto_cc_generator.py} +43 -18
- mindspore/ops_generate/{gen_aclnn_implement.py → aclnn/gen_aclnn_implement.py} +49 -51
- mindspore/ops_generate/api/__init__.py +0 -0
- mindspore/ops_generate/{add_tensor_docs_generator.py → api/add_tensor_docs_generator.py} +9 -7
- mindspore/ops_generate/{cpp_create_prim_instance_helper_generator.py → api/cpp_create_prim_instance_helper_generator.py} +6 -9
- mindspore/ops_generate/{functional_map_cpp_generator.py → api/functional_map_cpp_generator.py} +25 -12
- mindspore/ops_generate/{functional_overload_py_generator.py → api/functional_overload_py_generator.py} +8 -6
- mindspore/ops_generate/{functions_cc_generator.py → api/functions_cc_generator.py} +14 -10
- mindspore/ops_generate/api/gen_api.py +103 -0
- mindspore/ops_generate/{op_api_proto.py → api/op_api_proto.py} +98 -69
- mindspore/ops_generate/{tensor_func_reg_cpp_generator.py → api/tensor_func_reg_cpp_generator.py} +82 -43
- mindspore/ops_generate/common/__init__.py +0 -0
- mindspore/ops_generate/common/gen_constants.py +91 -0
- mindspore/ops_generate/{gen_utils.py → common/gen_utils.py} +72 -19
- mindspore/ops_generate/{op_proto.py → common/op_proto.py} +64 -1
- mindspore/ops_generate/{template.py → common/template.py} +96 -84
- mindspore/ops_generate/gen_ops.py +23 -325
- mindspore/ops_generate/op_def/__init__.py +0 -0
- mindspore/ops_generate/op_def/gen_op_def.py +90 -0
- mindspore/ops_generate/{lite_ops_cpp_generator.py → op_def/lite_ops_cpp_generator.py} +47 -11
- mindspore/ops_generate/{ops_def_cc_generator.py → op_def/ops_def_cc_generator.py} +18 -10
- mindspore/ops_generate/{ops_def_h_generator.py → op_def/ops_def_h_generator.py} +5 -5
- mindspore/ops_generate/{ops_name_h_generator.py → op_def/ops_name_h_generator.py} +30 -15
- mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
- mindspore/ops_generate/op_def_py/__init__.py +0 -0
- mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
- mindspore/ops_generate/{op_def_py_generator.py → op_def_py/op_def_py_generator.py} +6 -5
- mindspore/ops_generate/{op_prim_py_generator.py → op_def_py/op_prim_py_generator.py} +24 -15
- mindspore/ops_generate/pyboost/__init__.py +0 -0
- mindspore/ops_generate/{auto_grad_impl_cc_generator.py → pyboost/auto_grad_impl_cc_generator.py} +11 -7
- mindspore/ops_generate/{auto_grad_reg_cc_generator.py → pyboost/auto_grad_reg_cc_generator.py} +7 -7
- mindspore/ops_generate/{gen_pyboost_func.py → pyboost/gen_pyboost_func.py} +40 -16
- mindspore/ops_generate/{op_template_parser.py → pyboost/op_template_parser.py} +105 -24
- mindspore/ops_generate/{pyboost_functions_cpp_generator.py → pyboost/pyboost_functions_cpp_generator.py} +55 -18
- mindspore/ops_generate/{pyboost_functions_h_generator.py → pyboost/pyboost_functions_h_generator.py} +42 -10
- mindspore/ops_generate/{pyboost_functions_py_generator.py → pyboost/pyboost_functions_py_generator.py} +6 -6
- mindspore/ops_generate/{pyboost_grad_function_cpp_generator.py → pyboost/pyboost_grad_function_cpp_generator.py} +11 -10
- mindspore/ops_generate/{pyboost_inner_prim_generator.py → pyboost/pyboost_inner_prim_generator.py} +8 -7
- mindspore/ops_generate/{pyboost_native_grad_functions_generator.py → pyboost/pyboost_native_grad_functions_generator.py} +14 -10
- mindspore/ops_generate/{pyboost_op_cpp_code_generator.py → pyboost/pyboost_op_cpp_code_generator.py} +140 -53
- mindspore/ops_generate/{pyboost_overload_functions_cpp_generator.py → pyboost/pyboost_overload_functions_cpp_generator.py} +28 -15
- mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +88 -4
- mindspore/ops_generate/resources/__init__.py +0 -0
- mindspore/ops_generate/resources/resource_list.py +30 -0
- mindspore/ops_generate/resources/resource_loader.py +36 -0
- mindspore/ops_generate/resources/resource_manager.py +64 -0
- mindspore/ops_generate/resources/yaml_loader.py +88 -0
- mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
- mindspore/parallel/__init__.py +6 -2
- mindspore/parallel/_auto_parallel_context.py +140 -12
- mindspore/parallel/_cell_wrapper.py +132 -15
- mindspore/parallel/_parallel_serialization.py +95 -4
- mindspore/parallel/_ps_context.py +1 -1
- mindspore/parallel/_recovery_context.py +7 -2
- mindspore/parallel/_tensor.py +142 -18
- mindspore/parallel/_utils.py +198 -25
- mindspore/parallel/algo_parameter_config.py +3 -3
- mindspore/parallel/auto_parallel.py +732 -0
- mindspore/parallel/checkpoint_convert.py +159 -0
- mindspore/parallel/checkpoint_transform.py +658 -37
- mindspore/parallel/cluster/process_entity/_api.py +151 -19
- mindspore/parallel/cluster/run.py +1 -1
- mindspore/parallel/function/__init__.py +24 -0
- mindspore/parallel/function/reshard_func.py +258 -0
- mindspore/parallel/nn/__init__.py +25 -0
- mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
- mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
- mindspore/parallel/parameter_broadcast.py +24 -13
- mindspore/parallel/shard.py +137 -62
- mindspore/parallel/transform_safetensors.py +288 -95
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/__init__.py +9 -5
- mindspore/profiler/analysis/parser/ascend_cann_parser.py +6 -2
- mindspore/profiler/analysis/parser/ms_framework_parser.py +4 -4
- mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +7 -4
- mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +25 -0
- mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +3 -3
- mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +241 -86
- mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +41 -2
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +33 -35
- mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +7 -0
- mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +8 -3
- mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +141 -30
- mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +5 -6
- mindspore/profiler/common/ascend_msprof_exporter.py +5 -4
- mindspore/profiler/common/constant.py +12 -0
- mindspore/profiler/common/msprof_cmd_tool.py +42 -23
- mindspore/profiler/common/path_manager.py +24 -0
- mindspore/profiler/common/profiler_context.py +26 -2
- mindspore/profiler/common/profiler_meta_data.py +74 -0
- mindspore/profiler/common/profiler_parameters.py +59 -18
- mindspore/profiler/common/profiler_path_manager.py +66 -7
- mindspore/profiler/dynamic_profiler.py +112 -79
- mindspore/profiler/envprofiler.py +26 -1
- mindspore/profiler/experimental_config.py +197 -0
- mindspore/profiler/mstx.py +57 -14
- mindspore/profiler/platform/npu_profiler.py +33 -7
- mindspore/profiler/profiler.py +541 -45
- mindspore/profiler/profiler_action_controller.py +1 -1
- mindspore/profiler/profiler_interface.py +4 -0
- mindspore/profiler/schedule.py +57 -22
- mindspore/rewrite/api/node.py +15 -13
- mindspore/rewrite/api/symbol_tree.py +1 -1
- mindspore/run_check/_check_version.py +25 -14
- mindspore/run_check/run_check.py +1 -1
- mindspore/runtime/__init__.py +2 -2
- mindspore/runtime/executor.py +40 -11
- mindspore/runtime/memory.py +37 -13
- mindspore/safeguard/rewrite_obfuscation.py +12 -9
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +8 -8
- mindspore/train/_utils.py +43 -9
- mindspore/train/amp.py +1 -1
- mindspore/train/callback/__init__.py +2 -2
- mindspore/train/callback/_callback.py +2 -16
- mindspore/train/callback/_checkpoint.py +24 -40
- mindspore/train/callback/_cluster_monitor.py +14 -18
- mindspore/train/callback/_flops_collector.py +2 -3
- mindspore/train/callback/_history.py +7 -4
- mindspore/train/callback/_lambda_callback.py +2 -2
- mindspore/train/callback/_landscape.py +0 -3
- mindspore/train/callback/_loss_monitor.py +2 -1
- mindspore/train/callback/_on_request_exit.py +6 -5
- mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
- mindspore/train/callback/_summary_collector.py +8 -13
- mindspore/train/callback/_time_monitor.py +2 -1
- mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +204 -105
- mindspore/train/data_sink.py +25 -2
- mindspore/train/dataset_helper.py +4 -5
- mindspore/train/loss_scale_manager.py +8 -7
- mindspore/train/metrics/accuracy.py +3 -3
- mindspore/train/metrics/confusion_matrix.py +9 -9
- mindspore/train/metrics/error.py +3 -3
- mindspore/train/metrics/hausdorff_distance.py +4 -4
- mindspore/train/metrics/mean_surface_distance.py +3 -3
- mindspore/train/metrics/metric.py +0 -12
- mindspore/train/metrics/occlusion_sensitivity.py +4 -2
- mindspore/train/metrics/precision.py +8 -6
- mindspore/train/metrics/recall.py +9 -9
- mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
- mindspore/train/mind_ir_pb2.py +19 -12
- mindspore/train/model.py +262 -127
- mindspore/train/serialization.py +246 -988
- mindspore/train/summary/_summary_adapter.py +2 -2
- mindspore/train/summary/summary_record.py +1 -1
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +3 -2
- mindspore/utils/dryrun.py +4 -2
- mindspore/utils/hooks.py +81 -0
- mindspore/utils/runtime_execution_order_check.py +2 -0
- mindspore/utils/utils.py +138 -4
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.5.0.dist-info → mindspore-2.6.0.dist-info}/METADATA +2 -1
- {mindspore-2.5.0.dist-info → mindspore-2.6.0.dist-info}/RECORD +485 -440
- mindspore/_install_custom.py +0 -43
- mindspore/common/_register_for_adapter.py +0 -74
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
- mindspore/ops/auto_generate/gen_arg_handler.py +0 -136
- mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
- mindspore/ops_generate/gen_constants.py +0 -190
- mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
- mindspore/ops_generate/ops_primitive_h_generator.py +0 -81
- /mindspore/ops_generate/{base_generator.py → common/base_generator.py} +0 -0
- {mindspore-2.5.0.dist-info → mindspore-2.6.0.dist-info}/WHEEL +0 -0
- {mindspore-2.5.0.dist-info → mindspore-2.6.0.dist-info}/entry_points.txt +0 -0
- {mindspore-2.5.0.dist-info → mindspore-2.6.0.dist-info}/top_level.txt +0 -0
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright 2021-
|
|
1
|
+
# Copyright 2021-2025 Huawei Technologies Co., Ltd
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -21,6 +21,7 @@ import re
|
|
|
21
21
|
import ast
|
|
22
22
|
import hashlib
|
|
23
23
|
import stat
|
|
24
|
+
import copy
|
|
24
25
|
import inspect
|
|
25
26
|
import importlib
|
|
26
27
|
import platform
|
|
@@ -29,7 +30,6 @@ import numpy as np
|
|
|
29
30
|
import mindspore as ms
|
|
30
31
|
from mindspore._c_expression import Oplib, typing
|
|
31
32
|
from mindspore._c_expression import pyboost_custom_ext
|
|
32
|
-
from mindspore.common._stub_tensor import _convert_stub
|
|
33
33
|
from mindspore import context
|
|
34
34
|
from mindspore.common import Tensor
|
|
35
35
|
from mindspore.common import dtype as mstype
|
|
@@ -40,6 +40,7 @@ from mindspore.communication.management import get_rank, GlobalComm
|
|
|
40
40
|
from ._ms_kernel import determine_variable_usage
|
|
41
41
|
from ._custom_grad import autodiff_bprop
|
|
42
42
|
from ._pyfunc_registry import add_pyfunc
|
|
43
|
+
from ._custom_ops_utils import ExtensionBuilder
|
|
43
44
|
|
|
44
45
|
if platform.system() != "Windows":
|
|
45
46
|
import fcntl
|
|
@@ -109,12 +110,19 @@ def _compile_aot(file):
|
|
|
109
110
|
func_path = cache_path + file_name + ".so"
|
|
110
111
|
include_file = "{} -I{}".format(include_file, file[:file.rindex('/')])
|
|
111
112
|
|
|
113
|
+
if context.get_context("device_target") == "Ascend":
|
|
114
|
+
ascend_cann_path = os.getenv("ASCEND_OPP_PATH").split('opp')[0]
|
|
115
|
+
ascend_include = os.path.join(ascend_cann_path, "include")
|
|
116
|
+
include_file = "{} -I{}".format(include_file, ascend_include)
|
|
117
|
+
|
|
118
|
+
include_file = include_file.split(" ")
|
|
112
119
|
if func_path not in Custom.compiled_bin:
|
|
113
120
|
Custom.compiled_bin.append(func_path)
|
|
114
121
|
|
|
115
122
|
if file.endswith("cpp") or file.endswith("cc"):
|
|
116
123
|
cmd = ["g++", "-std=c++17", "--shared", "-fPIC", "-D_GLIBCXX_USE_CXX11_ABI=0"]
|
|
117
|
-
cmd +=
|
|
124
|
+
cmd += include_file
|
|
125
|
+
cmd += ["-o", func_path, file]
|
|
118
126
|
elif file.endswith("cu"):
|
|
119
127
|
cmd = ["nvcc"]
|
|
120
128
|
cmd += ["--shared", "-Xcompiler", "-fPIC", "-O3", "-gencode", "arch=compute_70, code=sm_70"]
|
|
@@ -141,12 +149,13 @@ def _compile_aot(file):
|
|
|
141
149
|
logger.warning("The current version of nvcc, V{}.{}.{}, might have unfixed issues with std string, "
|
|
142
150
|
"which will lead to errors in aot custom op with attrs."
|
|
143
151
|
"The version higher than V10.1.168 is recommended".format(v_major, v_mid, v_minor))
|
|
144
|
-
cmd +=
|
|
152
|
+
cmd += include_file
|
|
153
|
+
cmd += ["-o", func_path, file]
|
|
145
154
|
else:
|
|
146
155
|
raise ValueError("The source file must be a cc/cpp/cu file, but get: {}".format(file))
|
|
147
156
|
|
|
148
157
|
proc = subprocess.Popen(
|
|
149
|
-
cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
|
|
158
|
+
cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, shell=False)
|
|
150
159
|
|
|
151
160
|
(out, _) = proc.communicate(timeout=30)
|
|
152
161
|
|
|
@@ -183,10 +192,16 @@ class _CustomExt(ops.PrimitiveWithInfer):
|
|
|
183
192
|
|
|
184
193
|
infer_value = None
|
|
185
194
|
if infer_shape is None:
|
|
186
|
-
logger.
|
|
187
|
-
|
|
188
|
-
|
|
195
|
+
logger.debug("'out_shape' is None. Add a placeholder instead. "
|
|
196
|
+
"A CPP version of infer shape function is required "
|
|
197
|
+
"in this case.")
|
|
189
198
|
infer_shape = (1,)
|
|
199
|
+
if infer_dtype is None:
|
|
200
|
+
logger.debug("'out_dtype' is None. Add a placeholder instead. "
|
|
201
|
+
"A CPP version of infer type function is required "
|
|
202
|
+
"in this case.")
|
|
203
|
+
infer_dtype = ms.float16
|
|
204
|
+
|
|
190
205
|
# after all automatic infer information fulfillment, throw error if infer_shape/infer_dtype is still None
|
|
191
206
|
if not isinstance(infer_shape, (tuple, list)):
|
|
192
207
|
raise TypeError("'out_shape' must be one of [tuple, list, function], but got {}".format(type(infer_shape)))
|
|
@@ -215,7 +230,7 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
215
230
|
function if needed. Then these `Custom` objects can be directly used in neural networks.
|
|
216
231
|
Detailed description and introduction of user-defined operators, including correct writing of parameters,
|
|
217
232
|
please refer to `Custom Operators Tutorial
|
|
218
|
-
<https://www.mindspore.cn/
|
|
233
|
+
<https://www.mindspore.cn/tutorials/en/master/custom_program/op_custom.html>`_ .
|
|
219
234
|
|
|
220
235
|
.. warning::
|
|
221
236
|
- This is an experimental API that is subject to change.
|
|
@@ -297,14 +312,17 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
297
312
|
b) Ascend platform.
|
|
298
313
|
Before using Custom operators on the Ascend platform, users must first develop custom operators
|
|
299
314
|
based on Ascend C and compile them. The complete development and usage process can refer to the
|
|
300
|
-
tutorial `AOT-Type Custom Operators(Ascend)
|
|
315
|
+
tutorial `AOT-Type Custom Operators(Ascend)
|
|
316
|
+
<https://www.mindspore.cn/tutorials/en/master/custom_program/operation/op_custom_ascendc.html>`_.
|
|
301
317
|
By passing the name of the operator through the input parameter `func`, there are two usage methods
|
|
302
|
-
based on the implementation of the infer
|
|
303
|
-
|
|
304
|
-
- Python infer: If the operator's infer
|
|
305
|
-
function is passed through the `out_shape` parameter,
|
|
306
|
-
|
|
307
|
-
|
|
318
|
+
based on the implementation of the infer function:
|
|
319
|
+
|
|
320
|
+
- Python infer: If the operator's infer function is implemented in Python, that is, the infer shape
|
|
321
|
+
function is passed through the `out_shape` parameter, and the infer type is passed throuht the
|
|
322
|
+
`out_dtype`, then the `func` should be specified as the operator name, for example,
|
|
323
|
+
`func="CustomName"`.
|
|
324
|
+
- C++ infer: If the operator's infer function is implemented through C++, then pass the path of the
|
|
325
|
+
infer function implementation file in `func` and separate the operator name with `:`,
|
|
308
326
|
for example: `func="add_custom_infer.cc:AddCustom"` .
|
|
309
327
|
|
|
310
328
|
2. for "julia":
|
|
@@ -425,6 +443,8 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
425
443
|
self._func_compile_attrs = {}
|
|
426
444
|
self._is_ms_kernel = False
|
|
427
445
|
self.out_shape = out_shape
|
|
446
|
+
self.out_dtype = out_dtype
|
|
447
|
+
self.is_ascend_c = context.get_context("device_target") == "Ascend" and self.func_type == "aot"
|
|
428
448
|
|
|
429
449
|
self._check_platform()
|
|
430
450
|
self._check_func()
|
|
@@ -440,15 +460,17 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
440
460
|
add_pyfunc(func_id, self.func)
|
|
441
461
|
self.add_prim_attr("fn_id", func_id)
|
|
442
462
|
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
self.
|
|
463
|
+
self.set_infer_flag()
|
|
464
|
+
|
|
465
|
+
self.multi_output = (reg_info is not None and (len(reg_info.get("outputs", [])) > 1))
|
|
466
|
+
self.add_prim_attr("multi_output", self.multi_output)
|
|
467
|
+
|
|
446
468
|
self.bprop = bprop
|
|
447
469
|
self.fake_output = False
|
|
448
470
|
self.single_scalar_output = False
|
|
449
|
-
if not self.out_dtype:
|
|
471
|
+
if not self.out_dtype and not self.func_type == "pyfunc":
|
|
450
472
|
self.fake_output = True
|
|
451
|
-
elif not self.out_shape:
|
|
473
|
+
elif not self.out_shape and self.func_type == "pyfunc":
|
|
452
474
|
self.single_scalar_output = True
|
|
453
475
|
self.add_prim_attr("fake_output", self.fake_output)
|
|
454
476
|
self.add_prim_attr("single_scalar_output", self.single_scalar_output)
|
|
@@ -464,13 +486,28 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
464
486
|
|
|
465
487
|
self.add_prim_attr("func_type", self.func_type)
|
|
466
488
|
self._update_attr()
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
self.
|
|
489
|
+
|
|
490
|
+
if self.is_ascend_c:
|
|
491
|
+
self.set_inputs_type(reg_info)
|
|
492
|
+
self.custom_pyboost = _CustomExt(self.func, self.out_shape, self.out_dtype, self.bprop)
|
|
471
493
|
for key, value in super().get_attr_dict().items():
|
|
472
494
|
self.custom_pyboost.add_prim_attr(key, value)
|
|
473
495
|
|
|
496
|
+
def set_infer_flag(self):
|
|
497
|
+
"""set cpp infer attr"""
|
|
498
|
+
if self.out_shape is None and self.func_type == "aot":
|
|
499
|
+
self.add_prim_attr("cpp_infer_shape", True)
|
|
500
|
+
if self.out_dtype is None and self.func_type == "aot":
|
|
501
|
+
self.add_prim_attr("cpp_infer_type", True)
|
|
502
|
+
|
|
503
|
+
def set_inputs_type(self, reg_info):
|
|
504
|
+
"""set custom_inputs_type attr"""
|
|
505
|
+
if not self.is_ascend_c or not reg_info.get('attr'):
|
|
506
|
+
return
|
|
507
|
+
inputs_type = ["tensor"] * len(reg_info.get("inputs", [])) + \
|
|
508
|
+
[attr.get("type") for attr in reg_info.get("attr", [])]
|
|
509
|
+
self.add_prim_attr("custom_inputs_type", inputs_type)
|
|
510
|
+
|
|
474
511
|
def __infer__(self, *args):
|
|
475
512
|
if callable(self.out_shape):
|
|
476
513
|
infer_shape = self.out_shape(*(x["shape"] for x in args))
|
|
@@ -510,10 +547,15 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
510
547
|
infer_dtype = mstype.int32
|
|
511
548
|
if self.func_type == "aot":
|
|
512
549
|
if infer_shape is None:
|
|
513
|
-
logger.
|
|
514
|
-
|
|
515
|
-
|
|
550
|
+
logger.debug("{}, 'out_shape' is None. Add a placeholder instead. "
|
|
551
|
+
"A CPP version of infer shape function is required "
|
|
552
|
+
"in this case.".format(self.log_prefix))
|
|
516
553
|
infer_shape = (1,)
|
|
554
|
+
if infer_dtype is None:
|
|
555
|
+
logger.debug("{}, 'out_dtype' is None. Add a placeholder instead. "
|
|
556
|
+
"A CPP version of infer type function is required "
|
|
557
|
+
"in this case.".format(self.log_prefix))
|
|
558
|
+
infer_dtype = ms.float16
|
|
517
559
|
# after all automatic infer information fulfillment, throw error if infer_shape/infer_dtype is still None
|
|
518
560
|
if not isinstance(infer_shape, (tuple, list)):
|
|
519
561
|
raise TypeError("{}, 'out_shape' must be one of [tuple, list, function], but got {}"
|
|
@@ -757,6 +799,26 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
757
799
|
if isinstance(item, dict) and item.get("value") is not None:
|
|
758
800
|
self.add_prim_attr(item[KEY_NAME], item["value"])
|
|
759
801
|
|
|
802
|
+
def _convert_attr_to_input(self, ori_reg_info):
|
|
803
|
+
"""convert attr to input"""
|
|
804
|
+
if not self.is_ascend_c or not ori_reg_info.get("attr"):
|
|
805
|
+
return ori_reg_info
|
|
806
|
+
|
|
807
|
+
reg_info = copy.deepcopy(ori_reg_info)
|
|
808
|
+
start_index = len(reg_info.get("inputs", []))
|
|
809
|
+
for i, attr_item in enumerate(reg_info.get("attr", [])):
|
|
810
|
+
new_input = {
|
|
811
|
+
'index': start_index + i,
|
|
812
|
+
'name': attr_item['name'],
|
|
813
|
+
'paramType': attr_item['paramType']}
|
|
814
|
+
reg_info['inputs'].append(new_input)
|
|
815
|
+
for dtype_format_item in reg_info.get("dtype_format", []):
|
|
816
|
+
new_dtype_format_item = list(dtype_format_item)
|
|
817
|
+
new_dtype_format_item.insert(start_index + i, DataType.None_None)
|
|
818
|
+
reg_info['dtype_format'][reg_info['dtype_format'].index(dtype_format_item)] = new_dtype_format_item
|
|
819
|
+
reg_info['attr'] = []
|
|
820
|
+
return reg_info
|
|
821
|
+
|
|
760
822
|
def _register_info(self, info):
|
|
761
823
|
"""Register reg_info."""
|
|
762
824
|
reg_info = info
|
|
@@ -787,14 +849,15 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
787
849
|
continue
|
|
788
850
|
# Register
|
|
789
851
|
reg_info = self._reformat_reg_info(reg_info, target)
|
|
790
|
-
|
|
852
|
+
new_reg_info = self._convert_attr_to_input(reg_info)
|
|
853
|
+
reg_info_str = json.dumps(new_reg_info)
|
|
791
854
|
op_lib = Oplib()
|
|
792
855
|
if not op_lib.reg_op(reg_info_str, self.imply_path):
|
|
793
856
|
raise ValueError("{}, the registration information is registered failed. Use 'CustomRegOp' to "
|
|
794
857
|
"generate the registration information, then pass it to 'reg_info' or use "
|
|
795
858
|
"'custom_info_register' to bind it to 'func' if 'func' is a function."
|
|
796
859
|
.format(self.log_prefix))
|
|
797
|
-
self._save_attr(
|
|
860
|
+
self._save_attr(new_reg_info)
|
|
798
861
|
self._save_register_status(target)
|
|
799
862
|
|
|
800
863
|
def _get_expanded_list(self, data):
|
|
@@ -1078,10 +1141,202 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
1078
1141
|
return infer_shape, infer_dtype, infer_value
|
|
1079
1142
|
|
|
1080
1143
|
def __call__(self, *args):
|
|
1081
|
-
if self.
|
|
1082
|
-
|
|
1144
|
+
if self.is_ascend_c:
|
|
1145
|
+
res = pyboost_custom_ext(self.custom_pyboost, [args])
|
|
1146
|
+
return res if self.multi_output else res[0]
|
|
1083
1147
|
should_elim, output = self.check_elim(*args)
|
|
1084
1148
|
if should_elim:
|
|
1085
1149
|
return output
|
|
1086
1150
|
# pylint: disable=protected-access
|
|
1087
1151
|
return ops.primitive._run_op(self, self.name, args)
|
|
1152
|
+
|
|
1153
|
+
|
|
1154
|
+
class CustomOpBuilder:
|
|
1155
|
+
r"""
|
|
1156
|
+
CustomOpBuilder is used to initialize and configure custom operators for MindSpore.
|
|
1157
|
+
Users can define and load custom operator modules through this class and apply them to the network.
|
|
1158
|
+
|
|
1159
|
+
In most cases, users only need to provide the source files and additional compilation options in the constructor
|
|
1160
|
+
and call the `load` method to complete the compilation and loading of the operator.
|
|
1161
|
+
If users have specific customization requirements, they can inherit this class and override certain methods.
|
|
1162
|
+
It is important to note that if methods are overridden, some parameters passed to the constructor may be ignored.
|
|
1163
|
+
|
|
1164
|
+
.. warning::
|
|
1165
|
+
This is an experimental API that is subject to change.
|
|
1166
|
+
|
|
1167
|
+
Args:
|
|
1168
|
+
name (str): The unique name of the custom operator module, used to identify the operator.
|
|
1169
|
+
sources (Union[str, list[str]]): The source file(s) of the custom operator. It can be a single file path or
|
|
1170
|
+
a list of file paths.
|
|
1171
|
+
backend (str, optional): The target backend for the operator, such as "CPU" or "Ascend". Default: ``None``.
|
|
1172
|
+
include_paths (list[str], optional): Additionally included paths needed during compilation. Default: ``None``.
|
|
1173
|
+
cflags (str, optional): Extra C++ compiler flags to be used during compilation. Default: ``None``.
|
|
1174
|
+
ldflags (str, optional): Extra linker flags to be used during linking. Default: ``None``.
|
|
1175
|
+
kwargs (dict, optional): Additional keyword arguments for future extensions or specific custom requirements.
|
|
1176
|
+
|
|
1177
|
+
.. note::
|
|
1178
|
+
- If the `backend` argument is provided, additional default flags will be automatically added to
|
|
1179
|
+
the compilation and linking steps to support the operator's target backend. The default options
|
|
1180
|
+
can be referenced in the implementation of the `get_cflags` and `get_ldflags` methods in the `CustomOpBuilder
|
|
1181
|
+
<https://gitee.com/mindspore/mindspore/blob/master/mindspore/python/mindspore/ops/operations/custom_ops.py>`_.
|
|
1182
|
+
- The `sources` argument must point to valid source files for the custom operator.
|
|
1183
|
+
|
|
1184
|
+
Supported Platforms:
|
|
1185
|
+
``Ascend`` ``CPU``
|
|
1186
|
+
|
|
1187
|
+
Examples:
|
|
1188
|
+
>>> from mindspore import ops
|
|
1189
|
+
>>> builder = ops.CustomOpBuilder(
|
|
1190
|
+
... name="custom_op_cpu",
|
|
1191
|
+
... sources="custom_ops_impl/pybind_op_cpu.cpp",
|
|
1192
|
+
... backend="CPU"
|
|
1193
|
+
... )
|
|
1194
|
+
>>> my_ops = builder.load()
|
|
1195
|
+
"""
|
|
1196
|
+
_mindspore_path = None
|
|
1197
|
+
_loaded_ops = {}
|
|
1198
|
+
_ms_code_base = None
|
|
1199
|
+
|
|
1200
|
+
def __init__(self, name, sources, backend=None, include_paths=None, cflags=None, ldflags=None, **kwargs):
|
|
1201
|
+
self.name = name
|
|
1202
|
+
self.source = sources
|
|
1203
|
+
self.backend = backend
|
|
1204
|
+
self.include_paths = include_paths
|
|
1205
|
+
self.cflags = cflags
|
|
1206
|
+
self.ldflags = ldflags
|
|
1207
|
+
self.build_dir = kwargs.get("build_dir")
|
|
1208
|
+
if CustomOpBuilder._mindspore_path is None:
|
|
1209
|
+
CustomOpBuilder._mindspore_path = os.path.dirname(os.path.abspath(ms.__file__))
|
|
1210
|
+
CustomOpBuilder._ms_code_base = os.path.join(CustomOpBuilder._mindspore_path, "include")
|
|
1211
|
+
if self.backend == "Ascend":
|
|
1212
|
+
self.ascend_cann_path = os.getenv("ASCEND_OPP_PATH").split('opp')[0]
|
|
1213
|
+
|
|
1214
|
+
def get_sources(self):
|
|
1215
|
+
"""
|
|
1216
|
+
Get the source files for the custom operator.
|
|
1217
|
+
|
|
1218
|
+
Returns:
|
|
1219
|
+
str or list[str], The source file(s) for the operator.
|
|
1220
|
+
"""
|
|
1221
|
+
return self.source
|
|
1222
|
+
|
|
1223
|
+
def get_include_paths(self):
|
|
1224
|
+
"""
|
|
1225
|
+
Get the include paths required for compiling the custom operator.
|
|
1226
|
+
|
|
1227
|
+
Returns:
|
|
1228
|
+
list[str], A list of include paths.
|
|
1229
|
+
"""
|
|
1230
|
+
include_list = self.include_paths if self.include_paths is not None else []
|
|
1231
|
+
include_list.append(CustomOpBuilder._mindspore_path)
|
|
1232
|
+
include_list.append(os.path.join(CustomOpBuilder._mindspore_path, "include"))
|
|
1233
|
+
include_list.append(os.path.join(CustomOpBuilder._mindspore_path, "include/third_party"))
|
|
1234
|
+
include_list.append(os.path.join(CustomOpBuilder._mindspore_path, "include/third_party/robin_hood_hashing"))
|
|
1235
|
+
include_list.append(os.path.join(CustomOpBuilder._mindspore_path, "include/third_party/securec/include"))
|
|
1236
|
+
|
|
1237
|
+
if self.backend == "Ascend":
|
|
1238
|
+
include_list.append(os.path.join(self.ascend_cann_path, "include"))
|
|
1239
|
+
include_list += self._get_ms_inner_includes()
|
|
1240
|
+
return include_list
|
|
1241
|
+
|
|
1242
|
+
def _get_ms_inner_includes(self):
|
|
1243
|
+
"""include paths for inner module interface."""
|
|
1244
|
+
ms_inner_code_base = os.path.join(CustomOpBuilder._mindspore_path, "include", "mindspore")
|
|
1245
|
+
include_list = []
|
|
1246
|
+
include_list.append(ms_inner_code_base + "/core/include")
|
|
1247
|
+
include_list.append(ms_inner_code_base + "/core/mindrt/include")
|
|
1248
|
+
include_list.append(ms_inner_code_base + "/core/mindrt")
|
|
1249
|
+
include_list.append(ms_inner_code_base + "/ops")
|
|
1250
|
+
include_list.append(ms_inner_code_base + "/ops/kernel/include")
|
|
1251
|
+
include_list.append(ms_inner_code_base + "/ccsrc")
|
|
1252
|
+
include_list.append(ms_inner_code_base + "/ccsrc/include")
|
|
1253
|
+
include_list.append(ms_inner_code_base + "/ccsrc/minddata/mindrecord/include")
|
|
1254
|
+
return include_list
|
|
1255
|
+
|
|
1256
|
+
def get_cflags(self):
|
|
1257
|
+
"""
|
|
1258
|
+
Get the C++ compiler flags for building the custom operator.
|
|
1259
|
+
|
|
1260
|
+
Returns:
|
|
1261
|
+
list[str], A list of C++ compiler flags.
|
|
1262
|
+
"""
|
|
1263
|
+
flags = ['-fstack-protector-all', '-fPIC', '-pie']
|
|
1264
|
+
flags += ['-DENABLE_FAST_HASH_TABLE=1']
|
|
1265
|
+
if self.backend == "Ascend":
|
|
1266
|
+
flags.append('-DCUSTOM_ASCEND_OP')
|
|
1267
|
+
if self.cflags is not None:
|
|
1268
|
+
flags.append(self.cflags)
|
|
1269
|
+
return flags
|
|
1270
|
+
|
|
1271
|
+
def get_ldflags(self):
|
|
1272
|
+
"""
|
|
1273
|
+
Get the linker flags for building the custom operator.
|
|
1274
|
+
|
|
1275
|
+
Returns:
|
|
1276
|
+
list[str], A list of linker flags.
|
|
1277
|
+
"""
|
|
1278
|
+
flags = ['-Wl,-z,relro,-z,now,-z,noexecstack', '-Wl,--disable-new-dtags,--rpath', '-s']
|
|
1279
|
+
flags += [
|
|
1280
|
+
'-L' + os.path.abspath(os.path.join(CustomOpBuilder._mindspore_path, 'lib')),
|
|
1281
|
+
'-lmindspore_core',
|
|
1282
|
+
'-lmindspore_ms_backend',
|
|
1283
|
+
'-lmindspore_pynative'
|
|
1284
|
+
]
|
|
1285
|
+
if self.backend == "Ascend":
|
|
1286
|
+
flags.append('-L' + os.path.abspath(os.path.join(CustomOpBuilder._mindspore_path, 'lib/plugin')))
|
|
1287
|
+
flags.append('-L' + os.path.abspath(os.path.join(self.ascend_cann_path, "lib64")))
|
|
1288
|
+
flags.append('-lascendcl')
|
|
1289
|
+
flags.append('-l:libmindspore_ascend.so.2')
|
|
1290
|
+
if self.ldflags is not None:
|
|
1291
|
+
flags.append(self.ldflags)
|
|
1292
|
+
return flags
|
|
1293
|
+
|
|
1294
|
+
def build(self):
|
|
1295
|
+
"""
|
|
1296
|
+
Build the custom operator module.
|
|
1297
|
+
|
|
1298
|
+
This method generates a dynamic library file for the custom operator based on the provided source files,
|
|
1299
|
+
include paths, compilation flags, and link flags.
|
|
1300
|
+
|
|
1301
|
+
Returns:
|
|
1302
|
+
str, The path to the compiled module.
|
|
1303
|
+
"""
|
|
1304
|
+
return ExtensionBuilder(self._get_build_directory()).build(
|
|
1305
|
+
module_name=self.name,
|
|
1306
|
+
sources=self.get_sources(),
|
|
1307
|
+
extra_include_paths=self.get_include_paths(),
|
|
1308
|
+
extra_cflags=self.get_cflags(),
|
|
1309
|
+
extra_ldflags=self.get_ldflags())
|
|
1310
|
+
|
|
1311
|
+
def load(self):
|
|
1312
|
+
"""
|
|
1313
|
+
Build and load the custom operator module.
|
|
1314
|
+
|
|
1315
|
+
Returns:
|
|
1316
|
+
Module, The loaded custom operator module.
|
|
1317
|
+
"""
|
|
1318
|
+
if self.name in CustomOpBuilder._loaded_ops:
|
|
1319
|
+
return CustomOpBuilder._loaded_ops[self.name]
|
|
1320
|
+
module_path = self.build()
|
|
1321
|
+
mod = self._import_module(module_path)
|
|
1322
|
+
CustomOpBuilder._loaded_ops[self.name] = mod
|
|
1323
|
+
return mod
|
|
1324
|
+
|
|
1325
|
+
def _import_module(self, module_path):
|
|
1326
|
+
"""Import module from library."""
|
|
1327
|
+
spec = importlib.util.spec_from_file_location(self.name, module_path)
|
|
1328
|
+
module = importlib.util.module_from_spec(spec)
|
|
1329
|
+
spec.loader.exec_module(module)
|
|
1330
|
+
return module
|
|
1331
|
+
|
|
1332
|
+
def _get_build_directory(self):
|
|
1333
|
+
"""Get build directory."""
|
|
1334
|
+
if self.build_dir is None:
|
|
1335
|
+
build_root = os.path.realpath(os.getenv('MS_COMPILER_CACHE_PATH', "./kernel_meta"))
|
|
1336
|
+
self.build_dir = os.path.join(build_root, self.name)
|
|
1337
|
+
else:
|
|
1338
|
+
self.build_dir = os.path.realpath(self.build_dir)
|
|
1339
|
+
logger.info(f'Build {self.name} in directory {self.build_dir}')
|
|
1340
|
+
if not os.path.exists(self.build_dir):
|
|
1341
|
+
os.makedirs(self.build_dir, exist_ok=True)
|
|
1342
|
+
return self.build_dir
|
|
@@ -19,12 +19,13 @@ from pathlib import Path
|
|
|
19
19
|
import numpy as np
|
|
20
20
|
from mindspore import log as logger
|
|
21
21
|
from mindspore._c_expression import security, HookType
|
|
22
|
-
from mindspore._c_expression import
|
|
22
|
+
from mindspore._c_expression import TensorPy as Tensor_
|
|
23
23
|
from mindspore._c_expression import _tensordump_process_file
|
|
24
24
|
from mindspore import _checkparam as validator
|
|
25
25
|
from mindspore.common import dtype as mstype
|
|
26
26
|
from mindspore.common.parameter import Parameter
|
|
27
27
|
from mindspore.common.tensor import Tensor
|
|
28
|
+
from mindspore.common.jit_context import jit_context
|
|
28
29
|
from mindspore.ops.primitive import prim_attr_register, Primitive, PrimitiveWithInfer
|
|
29
30
|
from mindspore._checkparam import check_hook_fn
|
|
30
31
|
from mindspore.ops import operations as P
|
|
@@ -62,9 +63,7 @@ class ScalarSummary(Primitive):
|
|
|
62
63
|
"""
|
|
63
64
|
This operator will put a scalar to a summary file with protocol buffer format.
|
|
64
65
|
It must be used with :class:`mindspore.SummaryRecord` or :class:`mindspore.SummaryCollector`,
|
|
65
|
-
which specify the directory of the summary file.
|
|
66
|
-
be loaded and shown by MindInsight, see `MindInsight documents <https://www.mindspore.cn/
|
|
67
|
-
mindinsight/docs/en/master/index.html>`_ for details.
|
|
66
|
+
which specify the directory of the summary file.
|
|
68
67
|
In Ascend platform with graph mode, the environment variables `MS_DUMP_SLICE_SIZE` and `MS_DUMP_WAIT_TIME`
|
|
69
68
|
can be set to solve operator execution failure when calling this operator intensively.
|
|
70
69
|
|
|
@@ -122,11 +121,9 @@ class ScalarSummary(Primitive):
|
|
|
122
121
|
class ImageSummary(Primitive):
|
|
123
122
|
"""
|
|
124
123
|
This operator will put an image tensor to a summary file with protocol buffer format. It must be used with
|
|
125
|
-
SummaryRecord or SummaryCollector, which specify the directory of the summary file.
|
|
126
|
-
be loaded and shown by MindInsight, see `MindInsight documents <https://www.mindspore.cn/
|
|
127
|
-
mindinsight/docs/en/master/index.html>`_ for details.
|
|
124
|
+
SummaryRecord or SummaryCollector, which specify the directory of the summary file.
|
|
128
125
|
In Ascend platform with graph mode, the environment variables `MS_DUMP_SLICE_SIZE` and `MS_DUMP_WAIT_TIME`
|
|
129
|
-
can be set to solve
|
|
126
|
+
can be set to solve execution failure when calling this operator intensively.
|
|
130
127
|
|
|
131
128
|
Inputs:
|
|
132
129
|
- **name** (str) - The name of the input variable, it must not be an empty string.
|
|
@@ -175,9 +172,7 @@ class ImageSummary(Primitive):
|
|
|
175
172
|
class TensorSummary(Primitive):
|
|
176
173
|
"""
|
|
177
174
|
This operator will put a tensor to a summary file with protocol buffer format. It must be used with SummaryRecord
|
|
178
|
-
or SummaryCollector, which specify the directory of the summary file.
|
|
179
|
-
be loaded and shown by MindInsight, see `MindInsight documents <https://www.mindspore.cn/
|
|
180
|
-
mindinsight/docs/en/master/index.html>`_ for details.
|
|
175
|
+
or SummaryCollector, which specify the directory of the summary file.
|
|
181
176
|
In Ascend platform with graph mode, the environment variables `MS_DUMP_SLICE_SIZE` and `MS_DUMP_WAIT_TIME`
|
|
182
177
|
can be set to solve operator execution failure when calling this operator intensively.
|
|
183
178
|
|
|
@@ -236,7 +231,7 @@ class TensorDump(Primitive):
|
|
|
236
231
|
Save the Tensor as an npy file in numpy format.
|
|
237
232
|
|
|
238
233
|
.. warning::
|
|
239
|
-
|
|
234
|
+
The parameter input_output will no longer support the value 'all'.
|
|
240
235
|
|
|
241
236
|
.. note::
|
|
242
237
|
In Ascend platform with graph mode, the environment variables `MS_DUMP_SLICE_SIZE` and `MS_DUMP_WAIT_TIME`
|
|
@@ -344,8 +339,6 @@ class HistogramSummary(Primitive):
|
|
|
344
339
|
"""
|
|
345
340
|
This operator will calculate the histogram of a tensor and put it to a summary file with protocol buffer format.
|
|
346
341
|
It must be used with SummaryRecord or SummaryCollector, which specify the directory of the summary file.
|
|
347
|
-
The summary file can be loaded and shown by MindInsight, see `MindInsight documents <https://www.mindspore.cn/
|
|
348
|
-
mindinsight/docs/en/master/index.html>`_ for details.
|
|
349
342
|
In Ascend platform with graph mode, the environment variables `MS_DUMP_SLICE_SIZE` and `MS_DUMP_WAIT_TIME`
|
|
350
343
|
can be set to solve operator execution failure when calling this operator intensively.
|
|
351
344
|
|
|
@@ -403,6 +396,10 @@ class InsertGradientOf(Primitive):
|
|
|
403
396
|
"""
|
|
404
397
|
Attaches callback to the graph node that will be invoked on the node's gradient.
|
|
405
398
|
|
|
399
|
+
.. warning::
|
|
400
|
+
In the callback, exercise caution when using side-effect operators,
|
|
401
|
+
such as the TensorDump operator, as current support is incomplete.
|
|
402
|
+
|
|
406
403
|
Args:
|
|
407
404
|
f (Function): MindSpore's Function. Callback function.
|
|
408
405
|
|
|
@@ -466,6 +463,107 @@ class InsertGradientOf(Primitive):
|
|
|
466
463
|
self.f = f
|
|
467
464
|
|
|
468
465
|
|
|
466
|
+
class Morph(PrimitiveWithInfer):
|
|
467
|
+
"""
|
|
468
|
+
The `Morph` Primitive is used to encapsulate a user-defined function `fn`, allowing it to be used as a custom
|
|
469
|
+
Primitive.
|
|
470
|
+
The primary application scenario of the `Morph` Primitive is in the auto-parallel case after `GRAPH_MODE` mode,
|
|
471
|
+
where collective communication operators are used within the user-defined `fn` to implement custom parallel
|
|
472
|
+
computation logic, especially in scenarios where `fn` involves dynamic shapes.
|
|
473
|
+
When the `Morph` Primitive is applied to inputs, it is actually the encapsulated user-defined function `fn` that is
|
|
474
|
+
applied to the inputs.
|
|
475
|
+
The main difference between the `Morph` Primitive and :func:`mindspore.ops.Custom` is that the former is expanded
|
|
476
|
+
and replaced by the user-defined `fn` before automatic differentiation, so there is no need to implement a backward
|
|
477
|
+
function.
|
|
478
|
+
|
|
479
|
+
.. note::
|
|
480
|
+
- This primitive is only supported in GRAPH_MODE.
|
|
481
|
+
- `fn` must satisfy the syntax constraints of the graph mode.
|
|
482
|
+
- Users do not need to implement a custom backward function.
|
|
483
|
+
- `vararg`, `kwarg`, `kwonlyargs` and free variables are not supported in user-defined function.
|
|
484
|
+
|
|
485
|
+
Args:
|
|
486
|
+
fn (Function): Mindspore's function, user-defined function.
|
|
487
|
+
infer_shape (Function): Mindspore's function, user-defined infer_shape function.
|
|
488
|
+
infer_dtype (Function): Mindspore's function, user-defined infer_dtype function.
|
|
489
|
+
|
|
490
|
+
Inputs:
|
|
491
|
+
The inputs of user-defined `fn`.
|
|
492
|
+
|
|
493
|
+
Outputs:
|
|
494
|
+
The outputs of user-defined `fn`.
|
|
495
|
+
|
|
496
|
+
Raises:
|
|
497
|
+
RuntimeError: if not used in GRAPH_MODE.
|
|
498
|
+
|
|
499
|
+
Examples:
|
|
500
|
+
>>> import numpy as np
|
|
501
|
+
>>> import mindspore as ms
|
|
502
|
+
>>> from mindspore import context, nn, ops, Tensor, Parameter
|
|
503
|
+
>>>
|
|
504
|
+
>>> np_weight0 = np.array([1.0, 2.0, 3.0])
|
|
505
|
+
>>> np_weight1 = np.array([4.0, 5.0, 6.0])
|
|
506
|
+
>>> np_input_x = np.array([7.0, 8.0, 9.0])
|
|
507
|
+
>>>
|
|
508
|
+
>>> def infer_dtype(args):
|
|
509
|
+
... return args
|
|
510
|
+
>>>
|
|
511
|
+
>>> def infer_shape(args):
|
|
512
|
+
... return args
|
|
513
|
+
>>>
|
|
514
|
+
>>> def mul_by(*args):
|
|
515
|
+
... def inner(x):
|
|
516
|
+
... return args[0] * x
|
|
517
|
+
... return inner
|
|
518
|
+
>>>
|
|
519
|
+
>>> NUMBER_100 = 100
|
|
520
|
+
>>> class MorphNet(nn.Cell):
|
|
521
|
+
... def __init__(self):
|
|
522
|
+
... super(MorphNet, self).__init__()
|
|
523
|
+
... self.weight0 = Parameter(Tensor(np_weight0, ms.float32), name="weight0")
|
|
524
|
+
... self.weight1 = Parameter(Tensor(np_weight1, ms.float32), name="weight1")
|
|
525
|
+
... self.mul_by_100 = ops.Morph(mul_by(NUMBER_100), infer_shape, infer_dtype)
|
|
526
|
+
... def construct(self, x):
|
|
527
|
+
... a = x * self.weight0
|
|
528
|
+
... b = self.mul_by_100(a)
|
|
529
|
+
... out = b * self.weight1
|
|
530
|
+
... return out
|
|
531
|
+
>>>
|
|
532
|
+
>>> context.set_context(mode=context.GRAPH_MODE)
|
|
533
|
+
>>> input_x = Tensor(np_input_x, ms.float32)
|
|
534
|
+
>>> net = MorphNet()
|
|
535
|
+
>>> grad_op = ops.GradOperation(get_all=True, get_by_list=True)
|
|
536
|
+
>>> grad_net = grad_op(net, net.trainable_params())
|
|
537
|
+
>>> bwd_out = grad_net(input_x)
|
|
538
|
+
>>> x_grad = bwd_out[0][0].asnumpy()
|
|
539
|
+
>>> weight0_grad = bwd_out[1][0].asnumpy()
|
|
540
|
+
>>> weight1_grad = bwd_out[1][1].asnumpy()
|
|
541
|
+
>>> print("x_grad", x_grad)
|
|
542
|
+
>>> print("weight0_grad", weight0_grad)
|
|
543
|
+
>>> print("weight1_grad", weight1_grad)
|
|
544
|
+
x_grad [ 400. 1000. 1800.]
|
|
545
|
+
weight0_grad [2800. 4000. 5400.]
|
|
546
|
+
weight1_grad [ 700. 1600. 2700.]
|
|
547
|
+
"""
|
|
548
|
+
@prim_attr_register
|
|
549
|
+
def __init__(self, fn, infer_shape, infer_dtype):
|
|
550
|
+
self.add_prim_attr('side_effect_backprop', True)
|
|
551
|
+
self.add_prim_attr('side_effect_mem', True)
|
|
552
|
+
self.add_prim_attr('side_effect_io', True)
|
|
553
|
+
self.add_prim_attr('__metamorphosis__', fn)
|
|
554
|
+
self._infer_shape = infer_shape
|
|
555
|
+
self._infer_dtype = infer_dtype
|
|
556
|
+
|
|
557
|
+
def infer_shape(self, *args):
|
|
558
|
+
return self._infer_shape(*args)
|
|
559
|
+
|
|
560
|
+
def infer_dtype(self, *args):
|
|
561
|
+
return self._infer_dtype(*args)
|
|
562
|
+
|
|
563
|
+
def __call__(self, *args):
|
|
564
|
+
raise RuntimeError("Morph is only supported in GRAPH_MODE.")
|
|
565
|
+
|
|
566
|
+
|
|
469
567
|
class HookBackward(PrimitiveWithInfer):
|
|
470
568
|
"""
|
|
471
569
|
This operation is used as a tag to hook gradient in intermediate variables. Note that this function
|
|
@@ -527,8 +625,7 @@ class HookBackward(PrimitiveWithInfer):
|
|
|
527
625
|
def __init__(self, hook_fn, cell_id=""):
|
|
528
626
|
"""Initialize HookBackward."""
|
|
529
627
|
super(HookBackward, self).__init__(self.__class__.__name__)
|
|
530
|
-
|
|
531
|
-
return
|
|
628
|
+
check_hook_fn(hook_fn)
|
|
532
629
|
if cell_id != "":
|
|
533
630
|
logger.warning(f"The args 'cell_id' of HookBackward will be removed in a future version. If the value of "
|
|
534
631
|
f"'cell_id' is set, the hook function will not work.")
|
|
@@ -600,6 +697,9 @@ class Print(Primitive):
|
|
|
600
697
|
self.add_prim_attr("side_effect_io", True)
|
|
601
698
|
|
|
602
699
|
def __call__(self, *args):
|
|
700
|
+
# Add for jit context.
|
|
701
|
+
if jit_context() and jit_context().compiled:
|
|
702
|
+
return
|
|
603
703
|
for arg in args:
|
|
604
704
|
if isinstance(arg, Parameter):
|
|
605
705
|
print(Tensor_.__repr__(arg))
|
|
@@ -607,6 +707,9 @@ class Print(Primitive):
|
|
|
607
707
|
print(arg.__repr__())
|
|
608
708
|
else:
|
|
609
709
|
print(arg)
|
|
710
|
+
# Add for jit context.
|
|
711
|
+
if jit_context():
|
|
712
|
+
jit_context().run_op(self, None, *args)
|
|
610
713
|
|
|
611
714
|
|
|
612
715
|
class Assert(PrimitiveWithInfer):
|