mindspore 2.7.0__cp311-cp311-win_amd64.whl → 2.7.1__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (290) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +4 -1
  3. mindspore/_c_dataengine.cp311-win_amd64.pyd +0 -0
  4. mindspore/_c_expression.cp311-win_amd64.pyd +0 -0
  5. mindspore/_c_mindrecord.cp311-win_amd64.pyd +0 -0
  6. mindspore/_extends/parse/compile_config.py +24 -1
  7. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +6 -2
  8. mindspore/_extends/parse/resources.py +1 -1
  9. mindspore/_extends/parse/standard_method.py +8 -1
  10. mindspore/_extends/parse/trope.py +2 -1
  11. mindspore/_extends/pijit/pijit_func_white_list.py +7 -22
  12. mindspore/avcodec-59.dll +0 -0
  13. mindspore/avdevice-59.dll +0 -0
  14. mindspore/avfilter-8.dll +0 -0
  15. mindspore/avformat-59.dll +0 -0
  16. mindspore/avutil-57.dll +0 -0
  17. mindspore/boost/base.py +29 -2
  18. mindspore/common/_decorator.py +3 -2
  19. mindspore/common/_grad_function.py +3 -1
  20. mindspore/common/_tensor_cpp_method.py +1 -1
  21. mindspore/common/_tensor_docs.py +275 -64
  22. mindspore/common/_utils.py +0 -44
  23. mindspore/common/api.py +285 -35
  24. mindspore/common/dump.py +7 -108
  25. mindspore/common/dynamic_shape/auto_dynamic_shape.py +1 -3
  26. mindspore/common/hook_handle.py +60 -0
  27. mindspore/common/jit_config.py +5 -1
  28. mindspore/common/jit_trace.py +27 -12
  29. mindspore/common/lazy_inline.py +5 -3
  30. mindspore/common/parameter.py +13 -107
  31. mindspore/common/recompute.py +4 -11
  32. mindspore/common/tensor.py +16 -169
  33. mindspore/communication/_comm_helper.py +11 -1
  34. mindspore/communication/comm_func.py +138 -4
  35. mindspore/communication/management.py +85 -1
  36. mindspore/config/op_info.config +0 -15
  37. mindspore/context.py +5 -85
  38. mindspore/dataset/engine/datasets.py +8 -4
  39. mindspore/dataset/engine/datasets_vision.py +1 -1
  40. mindspore/dataset/engine/validators.py +1 -15
  41. mindspore/dnnl.dll +0 -0
  42. mindspore/{experimental/llm_boost/ascend_native → graph}/__init__.py +7 -7
  43. mindspore/graph/custom_pass.py +55 -0
  44. mindspore/include/dataset/execute.h +2 -2
  45. mindspore/jpeg62.dll +0 -0
  46. mindspore/mindrecord/__init__.py +3 -3
  47. mindspore/mindrecord/common/exceptions.py +1 -0
  48. mindspore/mindrecord/config.py +1 -1
  49. mindspore/{parallel/mpi → mindrecord/core}/__init__.py +4 -1
  50. mindspore/mindrecord/{shardheader.py → core/shardheader.py} +2 -1
  51. mindspore/mindrecord/{shardindexgenerator.py → core/shardindexgenerator.py} +1 -1
  52. mindspore/mindrecord/{shardreader.py → core/shardreader.py} +2 -1
  53. mindspore/mindrecord/{shardsegment.py → core/shardsegment.py} +2 -2
  54. mindspore/mindrecord/{shardutils.py → core/shardutils.py} +1 -1
  55. mindspore/mindrecord/{shardwriter.py → core/shardwriter.py} +1 -1
  56. mindspore/mindrecord/filereader.py +4 -4
  57. mindspore/mindrecord/filewriter.py +5 -5
  58. mindspore/mindrecord/mindpage.py +2 -2
  59. mindspore/mindrecord/tools/cifar10.py +1 -1
  60. mindspore/mindrecord/tools/cifar100.py +1 -1
  61. mindspore/mindrecord/tools/cifar100_to_mr.py +1 -1
  62. mindspore/mindrecord/tools/cifar10_to_mr.py +1 -1
  63. mindspore/mindrecord/tools/csv_to_mr.py +1 -1
  64. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  65. mindspore/mindrecord/tools/mnist_to_mr.py +1 -1
  66. mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -1
  67. mindspore/mindspore_backend_common.dll +0 -0
  68. mindspore/mindspore_backend_manager.dll +0 -0
  69. mindspore/mindspore_cluster.dll +0 -0
  70. mindspore/mindspore_common.dll +0 -0
  71. mindspore/mindspore_core.dll +0 -0
  72. mindspore/mindspore_cpu.dll +0 -0
  73. mindspore/mindspore_dump.dll +0 -0
  74. mindspore/mindspore_frontend.dll +0 -0
  75. mindspore/mindspore_glog.dll +0 -0
  76. mindspore/mindspore_hardware_abstract.dll +0 -0
  77. mindspore/mindspore_memory_pool.dll +0 -0
  78. mindspore/mindspore_ms_backend.dll +0 -0
  79. mindspore/mindspore_ops.dll +0 -0
  80. mindspore/{mindspore_ops_host.dll → mindspore_ops_cpu.dll} +0 -0
  81. mindspore/mindspore_profiler.dll +0 -0
  82. mindspore/mindspore_pyboost.dll +0 -0
  83. mindspore/mindspore_pynative.dll +0 -0
  84. mindspore/mindspore_runtime_pipeline.dll +0 -0
  85. mindspore/mindspore_runtime_utils.dll +0 -0
  86. mindspore/mindspore_tools.dll +0 -0
  87. mindspore/mint/__init__.py +15 -10
  88. mindspore/mint/distributed/distributed.py +182 -62
  89. mindspore/mint/nn/__init__.py +2 -16
  90. mindspore/mint/nn/functional.py +4 -110
  91. mindspore/mint/nn/layer/__init__.py +0 -2
  92. mindspore/mint/nn/layer/activation.py +0 -6
  93. mindspore/mint/nn/layer/basic.py +0 -47
  94. mindspore/mint/nn/layer/conv.py +4 -4
  95. mindspore/mint/nn/layer/normalization.py +8 -13
  96. mindspore/mint/nn/layer/pooling.py +0 -4
  97. mindspore/nn/__init__.py +1 -3
  98. mindspore/nn/cell.py +16 -66
  99. mindspore/nn/layer/basic.py +49 -1
  100. mindspore/nn/layer/container.py +16 -0
  101. mindspore/nn/layer/embedding.py +4 -169
  102. mindspore/nn/layer/normalization.py +2 -1
  103. mindspore/nn/layer/thor_layer.py +4 -85
  104. mindspore/nn/optim/ada_grad.py +0 -1
  105. mindspore/nn/optim/adafactor.py +0 -1
  106. mindspore/nn/optim/adam.py +31 -124
  107. mindspore/nn/optim/adamax.py +0 -1
  108. mindspore/nn/optim/asgd.py +0 -1
  109. mindspore/nn/optim/ftrl.py +8 -102
  110. mindspore/nn/optim/lamb.py +0 -1
  111. mindspore/nn/optim/lars.py +0 -3
  112. mindspore/nn/optim/lazyadam.py +25 -218
  113. mindspore/nn/optim/momentum.py +5 -43
  114. mindspore/nn/optim/optimizer.py +6 -55
  115. mindspore/nn/optim/proximal_ada_grad.py +0 -1
  116. mindspore/nn/optim/rmsprop.py +0 -1
  117. mindspore/nn/optim/rprop.py +0 -1
  118. mindspore/nn/optim/sgd.py +0 -1
  119. mindspore/nn/optim/tft_wrapper.py +0 -1
  120. mindspore/nn/optim/thor.py +0 -2
  121. mindspore/nn/probability/bijector/bijector.py +7 -8
  122. mindspore/nn/probability/bijector/gumbel_cdf.py +2 -2
  123. mindspore/nn/probability/bijector/power_transform.py +20 -21
  124. mindspore/nn/probability/bijector/scalar_affine.py +5 -5
  125. mindspore/nn/probability/bijector/softplus.py +13 -14
  126. mindspore/nn/wrap/grad_reducer.py +4 -74
  127. mindspore/numpy/array_creations.py +2 -2
  128. mindspore/numpy/fft.py +9 -9
  129. mindspore/{nn/reinforcement → onnx}/__init__.py +5 -8
  130. mindspore/onnx/onnx_export.py +137 -0
  131. mindspore/opencv_core4110.dll +0 -0
  132. mindspore/opencv_imgcodecs4110.dll +0 -0
  133. mindspore/{opencv_imgproc452.dll → opencv_imgproc4110.dll} +0 -0
  134. mindspore/ops/__init__.py +2 -0
  135. mindspore/ops/_grad_experimental/grad_comm_ops.py +38 -2
  136. mindspore/ops/_op_impl/aicpu/__init__.py +0 -10
  137. mindspore/ops/_op_impl/cpu/__init__.py +0 -5
  138. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +16 -22
  139. mindspore/ops/auto_generate/gen_extend_func.py +2 -7
  140. mindspore/ops/auto_generate/gen_ops_def.py +98 -141
  141. mindspore/ops/auto_generate/gen_ops_prim.py +12708 -12686
  142. mindspore/ops/communication.py +97 -0
  143. mindspore/ops/composite/__init__.py +5 -2
  144. mindspore/ops/composite/base.py +15 -1
  145. mindspore/ops/composite/multitype_ops/__init__.py +3 -1
  146. mindspore/ops/composite/multitype_ops/_compile_utils.py +150 -8
  147. mindspore/ops/composite/multitype_ops/add_impl.py +7 -0
  148. mindspore/ops/composite/multitype_ops/mod_impl.py +27 -0
  149. mindspore/ops/function/__init__.py +1 -0
  150. mindspore/ops/function/array_func.py +14 -12
  151. mindspore/ops/function/comm_func.py +3883 -0
  152. mindspore/ops/function/debug_func.py +3 -4
  153. mindspore/ops/function/math_func.py +45 -54
  154. mindspore/ops/function/nn_func.py +75 -294
  155. mindspore/ops/function/random_func.py +9 -18
  156. mindspore/ops/functional.py +2 -0
  157. mindspore/ops/functional_overload.py +354 -18
  158. mindspore/ops/operations/__init__.py +2 -5
  159. mindspore/ops/operations/_custom_ops_utils.py +7 -9
  160. mindspore/ops/operations/_inner_ops.py +1 -38
  161. mindspore/ops/operations/_rl_inner_ops.py +0 -933
  162. mindspore/ops/operations/array_ops.py +1 -0
  163. mindspore/ops/operations/comm_ops.py +94 -2
  164. mindspore/ops/operations/custom_ops.py +228 -19
  165. mindspore/ops/operations/debug_ops.py +27 -29
  166. mindspore/ops/operations/manually_defined/ops_def.py +27 -306
  167. mindspore/ops/operations/nn_ops.py +2 -2
  168. mindspore/ops/operations/sparse_ops.py +0 -83
  169. mindspore/ops/primitive.py +1 -17
  170. mindspore/ops/tensor_method.py +72 -3
  171. mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +5 -5
  172. mindspore/ops_generate/aclnn/gen_aclnn_implement.py +8 -8
  173. mindspore/ops_generate/api/functions_cc_generator.py +53 -4
  174. mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +25 -11
  175. mindspore/ops_generate/common/gen_constants.py +11 -10
  176. mindspore/ops_generate/common/op_proto.py +18 -1
  177. mindspore/ops_generate/common/template.py +102 -245
  178. mindspore/ops_generate/common/template_utils.py +212 -0
  179. mindspore/ops_generate/gen_custom_ops.py +69 -0
  180. mindspore/ops_generate/op_def/ops_def_cc_generator.py +78 -7
  181. mindspore/ops_generate/op_def_py/base_op_prim_py_generator.py +360 -0
  182. mindspore/ops_generate/op_def_py/custom_op_prim_py_generator.py +140 -0
  183. mindspore/ops_generate/op_def_py/op_def_py_generator.py +54 -7
  184. mindspore/ops_generate/op_def_py/op_prim_py_generator.py +5 -312
  185. mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +74 -17
  186. mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +22 -5
  187. mindspore/ops_generate/pyboost/op_template_parser.py +3 -2
  188. mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +21 -5
  189. mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +2 -2
  190. mindspore/ops_generate/pyboost/pyboost_functions_impl_cpp_generator.py +30 -10
  191. mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +10 -3
  192. mindspore/ops_generate/pyboost/pyboost_internal_kernel_info_adapter_generator.py +1 -1
  193. mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +19 -9
  194. mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +71 -28
  195. mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +10 -9
  196. mindspore/ops_generate/pyboost/pyboost_utils.py +27 -16
  197. mindspore/ops_generate/resources/yaml_loader.py +13 -0
  198. mindspore/ops_generate/tensor_py_cc_generator.py +2 -2
  199. mindspore/parallel/_cell_wrapper.py +1 -1
  200. mindspore/parallel/_parallel_serialization.py +1 -4
  201. mindspore/parallel/_utils.py +29 -6
  202. mindspore/parallel/checkpoint_transform.py +18 -2
  203. mindspore/parallel/cluster/process_entity/_api.py +24 -32
  204. mindspore/parallel/cluster/process_entity/_utils.py +9 -5
  205. mindspore/{experimental/llm_boost/atb → parallel/distributed}/__init__.py +21 -23
  206. mindspore/parallel/distributed/distributed_data_parallel.py +393 -0
  207. mindspore/parallel/distributed/flatten_grad_buffer.py +295 -0
  208. mindspore/parallel/strategy.py +336 -0
  209. mindspore/parallel/transform_safetensors.py +117 -16
  210. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +3 -0
  211. mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +1 -1
  212. mindspore/profiler/common/constant.py +5 -0
  213. mindspore/profiler/common/file_manager.py +9 -0
  214. mindspore/profiler/common/msprof_cmd_tool.py +38 -2
  215. mindspore/profiler/common/path_manager.py +56 -24
  216. mindspore/profiler/common/profiler_context.py +2 -12
  217. mindspore/profiler/common/profiler_info.py +3 -3
  218. mindspore/profiler/common/profiler_path_manager.py +13 -0
  219. mindspore/profiler/common/util.py +30 -3
  220. mindspore/profiler/experimental_config.py +2 -1
  221. mindspore/profiler/platform/npu_profiler.py +33 -6
  222. mindspore/run_check/_check_version.py +108 -24
  223. mindspore/runtime/__init__.py +3 -2
  224. mindspore/runtime/executor.py +11 -3
  225. mindspore/runtime/memory.py +112 -0
  226. mindspore/swresample-4.dll +0 -0
  227. mindspore/swscale-6.dll +0 -0
  228. mindspore/tinyxml2.dll +0 -0
  229. mindspore/{experimental/llm_boost → tools}/__init__.py +5 -5
  230. mindspore/tools/data_dump.py +130 -0
  231. mindspore/tools/sdc_detect.py +91 -0
  232. mindspore/tools/stress_detect.py +63 -0
  233. mindspore/train/__init__.py +6 -6
  234. mindspore/train/_utils.py +5 -18
  235. mindspore/train/amp.py +6 -4
  236. mindspore/train/callback/_checkpoint.py +0 -9
  237. mindspore/train/callback/_train_fault_tolerance.py +69 -18
  238. mindspore/train/data_sink.py +1 -5
  239. mindspore/train/model.py +38 -211
  240. mindspore/train/serialization.py +126 -387
  241. mindspore/turbojpeg.dll +0 -0
  242. mindspore/utils/__init__.py +6 -3
  243. mindspore/utils/dlpack.py +92 -0
  244. mindspore/utils/dryrun.py +1 -1
  245. mindspore/utils/runtime_execution_order_check.py +10 -0
  246. mindspore/utils/sdc_detect.py +14 -12
  247. mindspore/utils/stress_detect.py +43 -0
  248. mindspore/utils/utils.py +144 -8
  249. mindspore/version.py +1 -1
  250. {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/METADATA +3 -2
  251. {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/RECORD +254 -267
  252. mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +0 -210
  253. mindspore/experimental/llm_boost/ascend_native/llm_boost.py +0 -52
  254. mindspore/experimental/llm_boost/atb/boost_base.py +0 -385
  255. mindspore/experimental/llm_boost/atb/llama_boost.py +0 -137
  256. mindspore/experimental/llm_boost/atb/qwen_boost.py +0 -124
  257. mindspore/experimental/llm_boost/register.py +0 -130
  258. mindspore/experimental/llm_boost/utils.py +0 -31
  259. mindspore/include/OWNERS +0 -7
  260. mindspore/mindspore_cpu_res_manager.dll +0 -0
  261. mindspore/mindspore_ops_kernel_common.dll +0 -0
  262. mindspore/mindspore_res_manager.dll +0 -0
  263. mindspore/nn/optim/_dist_optimizer_registry.py +0 -111
  264. mindspore/nn/reinforcement/_batch_read_write.py +0 -142
  265. mindspore/nn/reinforcement/_tensors_queue.py +0 -152
  266. mindspore/nn/reinforcement/tensor_array.py +0 -145
  267. mindspore/opencv_core452.dll +0 -0
  268. mindspore/opencv_imgcodecs452.dll +0 -0
  269. mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +0 -113
  270. mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +0 -96
  271. mindspore/ops/_op_impl/aicpu/sparse_cross.py +0 -42
  272. mindspore/ops/_op_impl/cpu/buffer_append.py +0 -28
  273. mindspore/ops/_op_impl/cpu/buffer_get.py +0 -28
  274. mindspore/ops/_op_impl/cpu/buffer_sample.py +0 -28
  275. mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +0 -42
  276. mindspore/ops/operations/_tensor_array.py +0 -359
  277. mindspore/ops/operations/rl_ops.py +0 -288
  278. mindspore/parallel/_offload_context.py +0 -275
  279. mindspore/parallel/_recovery_context.py +0 -115
  280. mindspore/parallel/_transformer/__init__.py +0 -35
  281. mindspore/parallel/_transformer/layers.py +0 -765
  282. mindspore/parallel/_transformer/loss.py +0 -251
  283. mindspore/parallel/_transformer/moe.py +0 -693
  284. mindspore/parallel/_transformer/op_parallel_config.py +0 -222
  285. mindspore/parallel/_transformer/transformer.py +0 -3124
  286. mindspore/parallel/mpi/_mpi_config.py +0 -116
  287. mindspore/train/memory_profiling_pb2.py +0 -298
  288. {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/WHEEL +0 -0
  289. {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/entry_points.txt +0 -0
  290. {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/top_level.txt +0 -0
@@ -1157,6 +1157,7 @@ class FillV2(PrimitiveWithCheck):
1157
1157
  init_func = Zero()
1158
1158
  init_func.__enable_zero_dim__ = True
1159
1159
  out = Tensor(shape=dims, dtype=x.dtype, init=init_func)
1160
+ out.init_data()
1160
1161
  return out
1161
1162
  return Tensor(np.full(dims, x.asnumpy()))
1162
1163
 
@@ -1940,7 +1940,7 @@ class BatchISendIRecv(PrimitiveWithInfer):
1940
1940
 
1941
1941
 
1942
1942
  class AlltoAllV(PrimitiveWithInfer):
1943
- """
1943
+ r"""
1944
1944
  AllToAllV which support uneven scatter and gather compared with AllToAll.
1945
1945
 
1946
1946
  Note:
@@ -2001,7 +2001,7 @@ class AlltoAllV(PrimitiveWithInfer):
2001
2001
  ... send_tensor = Tensor([0, 1, 2.])
2002
2002
  ... send_numel_list = [1, 2]
2003
2003
  ... recv_numel_list = [1, 2]
2004
- >>> elif rank == 1:
2004
+ ... elif rank == 1:
2005
2005
  ... send_tensor = Tensor([3, 4, 5.])
2006
2006
  ... send_numel_list = [2, 1]
2007
2007
  ... recv_numel_list = [2, 1]
@@ -2013,6 +2013,10 @@ class AlltoAllV(PrimitiveWithInfer):
2013
2013
  rank 1:
2014
2014
  [1. 2. 5]
2015
2015
 
2016
+ Tutorial Examples:
2017
+ - `Distributed Set Communication Primitives - AlltoAllV
2018
+ <https://www.mindspore.cn/docs/en/master/api_python/samples/ops/communicate_ops.html#alltoallv>`_
2019
+
2016
2020
  """
2017
2021
 
2018
2022
  @prim_attr_register
@@ -2024,6 +2028,94 @@ class AlltoAllV(PrimitiveWithInfer):
2024
2028
  self.add_prim_attr('block_size', self.block_size)
2025
2029
 
2026
2030
 
2031
+ class AlltoAllVC(PrimitiveWithInfer):
2032
+ r"""
2033
+ AllToAllVC passes in the sending and receiving parameters of all ranks through the input parameter
2034
+ `send_count_matrix`. Compared to AllToAllV, AllToAllVC does not require the aggregation of all rank
2035
+ sending and receiving parameters, thus offering superior performance.
2036
+
2037
+ Note:
2038
+ Only one-dimensional input is supported; the input data must be flattened into a one-dimensional
2039
+ array before using this interface.
2040
+
2041
+ Args:
2042
+ group (str, optional): The communication group to work on. Default: ``GlobalComm.WORLD_COMM_GROUP``, which
2043
+ means ``"hccl_world_group"`` in Ascend.
2044
+ block_size (int, optional): The basic units for splitting and gathering numel by `send_count_matrix`.
2045
+ Default: ``1``.
2046
+ transpose (bool, optional): Indicates whether the `send_count_matrix` needs to undergo a transpose
2047
+ operation, this parameter is used in reverse calculation scenarios. Default: ``False``.
2048
+
2049
+ Inputs:
2050
+ - **input_x** (Tensor) - flatten tensor to scatter. The shape of tensor is :math:`(x_1)`.
2051
+ - **send_count_matrix** (Union[list[int], Tensor]) - The sending and receiving parameters of
2052
+ all ranks, :math:`\text{send_count_matrix}[i*\text{rank_size}+j]` represents the amount of data sent by
2053
+ rank i to rank j, and the basic unit is the number of bytes of Tensor's dtype. Among them, `rank_size`
2054
+ indicates the size of the communication group.
2055
+
2056
+ Outputs:
2057
+ Tensor. Flattened and concatenated tensor gather from remote ranks.
2058
+ If gather result is empty, it will return a Tensor with shape `()`, and value has no actual meaning.
2059
+
2060
+ Supported Platforms:
2061
+ ``Ascend``
2062
+
2063
+ Examples:
2064
+ .. note::
2065
+ Before running the following examples, you need to configure the communication environment variables.
2066
+
2067
+ For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
2068
+ without any third-party or configuration file dependencies.
2069
+
2070
+ Please see the `msrun start up
2071
+ <https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
2072
+ for more details.
2073
+
2074
+ This example should be run with 2 devices.
2075
+
2076
+ >>> from mindspore.ops import AlltoAllVC
2077
+ >>> import mindspore.nn as nn
2078
+ >>> from mindspore.communication import init, get_rank
2079
+ >>> from mindspore import Tensor
2080
+ >>>
2081
+ >>> init()
2082
+ >>> rank = get_rank()
2083
+ >>> class Net(nn.Cell):
2084
+ ... def __init__(self):
2085
+ ... super(Net, self).__init__()
2086
+ ... self.all_to_all_v_c = AlltoAllVC()
2087
+ ...
2088
+ ... def construct(self, x, send_count_matrix):
2089
+ ... return self.all_to_all_v_c(x, send_count_matrix)
2090
+ >>> send_count_matrix = Tensor([[0, 3], [3, 0]])
2091
+ >>> send_tensor = Tensor([0, 1, 2.]) * rank
2092
+ >>> net = Net()
2093
+ >>> output = net(send_tensor, send_count_matrix)
2094
+ >>> print(output)
2095
+ rank 0:
2096
+ [0. 1. 2]
2097
+ rank 1:
2098
+ [0. 0. 0]
2099
+
2100
+ Tutorial Examples:
2101
+ - `Distributed Set Communication Primitives - AlltoAllVC
2102
+ <https://www.mindspore.cn/docs/en/master/api_python/samples/ops/communicate_ops.html#alltoallvc>`_
2103
+
2104
+ """
2105
+
2106
+ @prim_attr_register
2107
+ def __init__(self, group=GlobalComm.WORLD_COMM_GROUP, block_size=1, transpose=False):
2108
+ self.group = GlobalComm.WORLD_COMM_GROUP if group is None else _get_group(group)
2109
+ self.rank_size = get_group_size(self.group)
2110
+ self.add_prim_attr('rank_size', self.rank_size)
2111
+ self.add_prim_attr('group', self.group)
2112
+ self.rank_id = get_rank(_get_group(self.group))
2113
+ self.add_prim_attr('rank_id', self.rank_id)
2114
+ validator.check_value_type("block_size", block_size, [int], self.name)
2115
+ self.add_prim_attr('block_size', self.block_size)
2116
+ self.add_prim_attr('transpose', self.transpose)
2117
+
2118
+
2027
2119
  class AllGatherV(PrimitiveWithInfer):
2028
2120
  """
2029
2121
  Gathers uneven tensors from the specified communication group and returns the tensor which is all gathered.
@@ -18,6 +18,7 @@ from __future__ import absolute_import
18
18
  import json
19
19
  import os
20
20
  import re
21
+ import sys
21
22
  import ast
22
23
  import hashlib
23
24
  import stat
@@ -26,6 +27,7 @@ import inspect
26
27
  import importlib
27
28
  import platform
28
29
  import subprocess
30
+ import shutil
29
31
  import numpy as np
30
32
  import mindspore as ms
31
33
  from mindspore._c_expression import Oplib, typing
@@ -37,6 +39,7 @@ from mindspore.ops import DataType
37
39
  from mindspore import log as logger
38
40
  from mindspore import ops
39
41
  from mindspore.communication.management import get_rank, GlobalComm
42
+ from mindspore import _checkparam as validator
40
43
  from ._ms_kernel import determine_variable_usage
41
44
  from ._custom_grad import autodiff_bprop
42
45
  from ._pyfunc_registry import add_pyfunc
@@ -1185,6 +1188,54 @@ class Custom(ops.PrimitiveWithInfer):
1185
1188
  return ops.primitive._run_op(self, self.name, args)
1186
1189
 
1187
1190
 
1191
+ class _MultiSoProxy:
1192
+ """
1193
+ A thin wrapper that transparently multiplexes attribute access between a
1194
+ pure-Python fallback module and an optional compiled shared-object (SO)
1195
+ module, honoring MindSpore’s current execution mode (GRAPH vs. PYNATIVE).
1196
+ """
1197
+
1198
+ def __init__(self, func_module, so_module):
1199
+ """
1200
+ Args:
1201
+ func_module (module or None): Python module to serve as the fallback implementation source.
1202
+ May be ``None`` if no Python fallback is available.
1203
+ so_module (module): Compiled shared-object module that provides
1204
+ optimized kernels accessible only in ``PYNATIVE_MODE``.
1205
+ """
1206
+ self.func_module = func_module
1207
+ self.so_module = so_module
1208
+
1209
+ def __getattr__(self, name: str):
1210
+ """
1211
+ Intercepts every attribute lookup and resolves it against the wrapped
1212
+ modules according to the documented precedence rules.
1213
+
1214
+ Args:
1215
+ name (str): Name of the custom operation being requested.
1216
+
1217
+ Returns:
1218
+ Any: The requested callable or attribute from either ``func_module`` or ``so_module``.
1219
+
1220
+ Raises:
1221
+ AttributeError: If the attribute is not found in any applicable module or
1222
+ is incompatible with the current execution mode.
1223
+ """
1224
+ if self.func_module is not None and hasattr(self.func_module, name):
1225
+ return getattr(self.func_module, name)
1226
+ if context.get_context("mode") == ms.PYNATIVE_MODE:
1227
+ if hasattr(self.so_module, name):
1228
+ return getattr(self.so_module, name)
1229
+ raise AttributeError(
1230
+ f"Custom op '{name}' is neither in func_module nor in so_module."
1231
+ )
1232
+
1233
+ raise AttributeError(
1234
+ f"Custom op '{name}' does not support GRAPH mode. "
1235
+ f"Please register it for GRAPH mode or switch to PYNATIVE mode."
1236
+ )
1237
+
1238
+
1188
1239
  class CustomOpBuilder:
1189
1240
  r"""
1190
1241
  CustomOpBuilder is used to initialize and configure custom operators for MindSpore.
@@ -1200,10 +1251,11 @@ class CustomOpBuilder:
1200
1251
 
1201
1252
  Args:
1202
1253
  name (str): The unique name of the custom operator module, used to identify the operator.
1203
- sources (Union[str, list[str]]): The source file(s) of the custom operator. It can be a single file path or
1204
- a list of file paths.
1254
+ sources (Union[list[str], tuple[str], str]): The source file(s) of the custom operator. It can be a single
1255
+ file path or a list of file paths.
1205
1256
  backend (str, optional): The target backend for the operator, such as "CPU" or "Ascend". Default: ``None``.
1206
- include_paths (list[str], optional): Additionally included paths needed during compilation. Default: ``None``.
1257
+ include_paths (Union[list[str], tuple[str], str], optional): Additionally included paths needed during
1258
+ compilation. Default: ``None``.
1207
1259
  cflags (str, optional): Extra C++ compiler flags to be used during compilation. Default: ``None``.
1208
1260
  ldflags (str, optional): Extra linker flags to be used during linking. Default: ``None``.
1209
1261
  kwargs (dict, optional): Additional keyword arguments for future extensions or specific custom requirements.
@@ -1217,6 +1269,17 @@ class CustomOpBuilder:
1217
1269
  - enable_atb (bool, optional): Whether to call ATB (Ascend Transformer Boost) operator. If set to ``True``,
1218
1270
  the `backend` must be ``Ascend`` or left empty. Default: ``False``.
1219
1271
 
1272
+ - enable_asdsip (bool, optional): Whether to call ASDSIP (Ascend SiP Boost) operator. If set to ``True``,
1273
+ the `backend` must be ``Ascend`` or left empty. Default: ``False``.
1274
+
1275
+ - op_def (Union[list[str], tuple[str], str], optional): Path(s) to the operator definition
1276
+ file(s) (YAML format). When using custom operators in graph mode, this parameter is mandatory.
1277
+ It can be a single file path string or a list of file path strings. Default: ``None``.
1278
+
1279
+ - op_doc (Union[list[str], tuple[str], str], optional): Path(s) to the operator documentation
1280
+ file(s) (YAML format). This parameter is optional and used to provide additional documentation
1281
+ for the operator. It can be a single file path string or a list of file path strings. Default: ``None``.
1282
+
1220
1283
  .. note::
1221
1284
  - If the `backend` argument is provided, additional default flags will be automatically added to
1222
1285
  the compilation and linking steps to support the operator's target backend. The default options
@@ -1239,21 +1302,20 @@ class CustomOpBuilder:
1239
1302
  _loaded_ops = {}
1240
1303
 
1241
1304
  def __init__(self, name, sources, backend=None, include_paths=None, cflags=None, ldflags=None, **kwargs):
1242
- self.name = name
1243
- self.source = sources
1244
- self.backend = backend
1245
- self.include_paths = include_paths
1246
- self.cflags = cflags
1247
- self.ldflags = ldflags
1248
- self.build_dir = kwargs.get("build_dir")
1249
- self.enable_atb = kwargs.get("enable_atb", False)
1250
- self.debug_mode = kwargs.get("debug_mode", False)
1305
+ self._check_and_get_args(name, sources, backend, include_paths, cflags, ldflags, **kwargs)
1306
+
1251
1307
  self._ms_path = os.path.dirname(os.path.abspath(ms.__file__))
1308
+ self.auto_generate = self.name + "_auto_generate"
1252
1309
  if self.enable_atb:
1253
1310
  if backend is not None and backend != "Ascend":
1254
1311
  raise ValueError("For 'CustomOpBuilder', when 'enable_atb' is set to True, the 'backend' must be "
1255
1312
  f"'Ascend' (or left implicit), but got '{backend}'")
1256
1313
  self.backend = "Ascend"
1314
+ if self.enable_asdsip:
1315
+ if backend is not None and backend != "Ascend":
1316
+ raise ValueError("For 'CustomOpBuilder', when 'enable_asdsip' is set to True, the 'backend' must be "
1317
+ f"'Ascend' (or left implicit), but got '{backend}'")
1318
+ self.backend = "Ascend"
1257
1319
  if self.backend == "Ascend":
1258
1320
  ascend_opp_path = os.getenv("ASCEND_OPP_PATH")
1259
1321
  if not ascend_opp_path:
@@ -1265,6 +1327,115 @@ class CustomOpBuilder:
1265
1327
  if not self.atb_home_path:
1266
1328
  raise ValueError("Environment variable 'ATB_HOME_PATH' must be set when 'enable_atb' is True.")
1267
1329
 
1330
+ def _check_and_get_args(self, name, sources, backend=None, include_paths=None,
1331
+ cflags=None, ldflags=None, **kwargs):
1332
+ """
1333
+ Validate and normalize all arguments to meet custom-op build requirements.
1334
+ """
1335
+
1336
+ def _check_str_or_list_str(key, val):
1337
+ if val is None:
1338
+ return val
1339
+ if isinstance(val, str):
1340
+ val = [val]
1341
+ val = validator.check_value_type(key, val, [list, tuple])
1342
+ val = list(val)
1343
+ validator.check_element_type_of_iterable(key, val, [str])
1344
+ return val
1345
+
1346
+ self.name = validator.check_value_type("name", name, [str])
1347
+ self.source = _check_str_or_list_str("sources", sources)
1348
+ self.backend = validator.check_value_type("backend", backend, [str, type(None)])
1349
+ if self.backend is not None and self.backend not in {"CPU", "Ascend"}:
1350
+ raise ValueError(
1351
+ f"For 'backend', only 'CPU' or 'Ascend' are allowed, but got '{self.backend}'.")
1352
+
1353
+ self.include_paths = _check_str_or_list_str("include_paths", include_paths)
1354
+
1355
+ self.cflags = validator.check_value_type("cflags", cflags, [str, type(None)])
1356
+ self.ldflags = validator.check_value_type("ldflags", ldflags, [str, type(None)])
1357
+
1358
+ self.build_dir = validator.check_value_type("build_dir",
1359
+ kwargs.get("build_dir"),
1360
+ [str, type(None)])
1361
+
1362
+ self.debug_mode = validator.check_bool(kwargs.get("debug_mode", False), "debug_mode")
1363
+ self.enable_asdsip = validator.check_bool(kwargs.get("enable_asdsip", False), "enable_asdsip")
1364
+ self.yaml = _check_str_or_list_str("op_def", kwargs.get("op_def"))
1365
+ self.doc = _check_str_or_list_str("op_doc", kwargs.get("op_doc"))
1366
+
1367
+ self.enable_atb = validator.check_bool(kwargs.get("enable_atb", False))
1368
+
1369
+ def _generate_custom_op_def(self, module: str, input_path: str, doc_path: str, output_path: str) -> None:
1370
+ """Call gen_custom_ops.py to generate custom operator definition"""
1371
+ file_path = os.path.join(self._ms_path, "ops_generate/gen_custom_ops.py")
1372
+ cmd = [
1373
+ sys.executable,
1374
+ file_path,
1375
+ "-i", input_path,
1376
+ "-o", output_path,
1377
+ "-m", module,
1378
+ "-d", doc_path
1379
+ ]
1380
+
1381
+ try:
1382
+ subprocess.run(
1383
+ cmd,
1384
+ check=True,
1385
+ text=True,
1386
+ capture_output=True
1387
+ )
1388
+ except subprocess.CalledProcessError as exc:
1389
+ raise RuntimeError(
1390
+ f"gen_custom_op.py failed with exit code {exc.returncode}.\n"
1391
+ f"stdout: {exc.stdout}\n"
1392
+ f"stderr: {exc.stderr}"
1393
+ ) from None
1394
+
1395
+ def _get_op_def(self):
1396
+ """
1397
+ Generate C++ operator-definition source files from one or more YAML specification files.
1398
+ """
1399
+ if self.yaml is None:
1400
+ return []
1401
+
1402
+ if self.doc is None:
1403
+ logger.info("Missing required 'doc': no YAML document was provided.")
1404
+
1405
+ build_path = self._get_build_directory()
1406
+ yaml_path = os.path.join(build_path, "yaml")
1407
+ op_def_path = os.path.join(build_path, self.auto_generate)
1408
+ if os.path.exists(op_def_path):
1409
+ shutil.rmtree(op_def_path)
1410
+ os.makedirs(op_def_path, exist_ok=True)
1411
+
1412
+ def copy_files(yaml_files, dest_path):
1413
+ if os.path.exists(dest_path):
1414
+ shutil.rmtree(dest_path)
1415
+ os.makedirs(dest_path, exist_ok=True)
1416
+ for file_path in yaml_files:
1417
+ if not os.path.isfile(file_path):
1418
+ raise FileNotFoundError(f"File not found: {file_path}")
1419
+
1420
+ filename = os.path.basename(file_path)
1421
+ file_ext = os.path.splitext(filename)[1].lower()
1422
+ if file_ext not in ('.yaml', '.yml'):
1423
+ raise ValueError(f"Invalid file extension: {file_ext} for {filename}")
1424
+
1425
+ _dest_path = os.path.join(dest_path, filename)
1426
+ shutil.copy2(file_path, _dest_path)
1427
+
1428
+ yaml_files = [self.yaml] if isinstance(self.yaml, str) else self.yaml
1429
+ copy_files(yaml_files, yaml_path)
1430
+ doc_path = ""
1431
+ if self.doc is not None:
1432
+ doc_path = os.path.join(build_path, "doc")
1433
+ doc_files = [self.doc] if isinstance(self.doc, str) else self.doc
1434
+ copy_files(doc_files, doc_path)
1435
+
1436
+ self._generate_custom_op_def(self.name, yaml_path, doc_path, op_def_path)
1437
+ return [os.path.join(op_def_path, "gen_custom_ops_def.cc")]
1438
+
1268
1439
  def get_sources(self):
1269
1440
  """
1270
1441
  Get the source files for the custom operator.
@@ -1272,7 +1443,8 @@ class CustomOpBuilder:
1272
1443
  Returns:
1273
1444
  str or list[str], The source file(s) for the operator.
1274
1445
  """
1275
- return self.source
1446
+ self.source = [self.source] if isinstance(self.source, str) else self.source
1447
+ return self.source + self._get_op_def()
1276
1448
 
1277
1449
  def get_include_paths(self):
1278
1450
  """
@@ -1299,6 +1471,7 @@ class CustomOpBuilder:
1299
1471
  """include paths for inner module interface."""
1300
1472
  ms_inner_path = os.path.join(self._ms_path, "include", "mindspore")
1301
1473
  include_list = []
1474
+ include_list.append(os.path.join(ms_inner_path, "include"))
1302
1475
  include_list.append(os.path.join(ms_inner_path, "core", "include"))
1303
1476
  include_list.append(os.path.join(ms_inner_path, "core", "mindrt", "include"))
1304
1477
  include_list.append(os.path.join(ms_inner_path, "core", "mindrt"))
@@ -1320,10 +1493,14 @@ class CustomOpBuilder:
1320
1493
  flags += ['-std=c++17', '-fstack-protector-all', '-fPIC', '-pie']
1321
1494
  if self.debug_mode:
1322
1495
  flags.append('-g')
1496
+ else:
1497
+ flags.append('-O2')
1323
1498
  if self.backend == "Ascend":
1324
1499
  flags.append('-DCUSTOM_ASCEND_OP')
1325
1500
  if self.enable_atb:
1326
1501
  flags.append('-DCUSTOM_ENABLE_ATB')
1502
+ if self.enable_asdsip:
1503
+ flags.append('-DCUSTOM_ENABLE_ASDSIP')
1327
1504
  if self.cflags is not None:
1328
1505
  flags.append(self.cflags)
1329
1506
  return flags
@@ -1344,18 +1521,23 @@ class CustomOpBuilder:
1344
1521
  '-lmindspore_core',
1345
1522
  '-lmindspore_ms_backend',
1346
1523
  '-lmindspore_pynative',
1347
- '-lmindspore_extension'
1524
+ '-lmindspore_pyboost'
1348
1525
  ]
1349
1526
  if self.backend == "Ascend":
1350
- flags.append(f"-L{os.path.abspath(os.path.join(self._ms_path, 'lib', 'plugin'))}")
1351
1527
  flags.append(f"-L{os.path.abspath(os.path.join(self.ascend_cann_path, 'lib64'))}")
1352
1528
  flags.append('-lascendcl')
1529
+ plugin_path = os.path.abspath(os.path.join(self._ms_path, 'lib', 'plugin'))
1530
+ flags.append(f"-L{plugin_path}")
1531
+ flags.append(f"-L{os.path.join(plugin_path, 'ascend')}")
1353
1532
  flags.append('-l:libmindspore_ascend.so.2')
1533
+ flags.append('-lmindspore_extension_ascend_aclnn')
1354
1534
  if self.enable_atb:
1355
- flags.append(f"-L{os.path.abspath(os.path.join(self._ms_path, 'lib', 'plugin', 'ascend'))}")
1356
1535
  flags.append('-lmindspore_extension_ascend_atb')
1357
1536
  flags.append(f"-L{os.path.abspath(os.path.join(self.atb_home_path, 'lib'))}")
1358
1537
  flags.append('-latb')
1538
+ if self.enable_asdsip:
1539
+ flags.append(f"-L{os.path.abspath(os.path.join(self._ms_path, 'lib', 'plugin', 'ascend'))}")
1540
+ flags.append('-lmindspore_extension_ascend_asdsip')
1359
1541
  if self.ldflags is not None:
1360
1542
  flags.append(self.ldflags)
1361
1543
  return flags
@@ -1386,15 +1568,42 @@ class CustomOpBuilder:
1386
1568
  """
1387
1569
  if self.name in CustomOpBuilder._loaded_ops:
1388
1570
  return CustomOpBuilder._loaded_ops[self.name]
1571
+
1389
1572
  module_path = self.build()
1390
- mod = self._import_module(module_path)
1573
+ so_module = CustomOpBuilder._import_module(module_path)
1574
+ func_module = None
1575
+ if self.yaml is not None:
1576
+ module_path = os.path.join(self.build_dir, self.auto_generate, "gen_ops_def.py")
1577
+ sys.path.append(os.path.join(self.build_dir, self.auto_generate))
1578
+ sys.path.append(os.path.join(self.build_dir))
1579
+ func_module = self._import_module(module_path, True)
1580
+ mod = _MultiSoProxy(func_module, so_module)
1391
1581
  CustomOpBuilder._loaded_ops[self.name] = mod
1392
1582
  return mod
1393
1583
 
1394
- def _import_module(self, module_path):
1584
+ @staticmethod
1585
+ def _import_module(module_path, is_yaml_build=False):
1395
1586
  """Import module from library."""
1396
- spec = importlib.util.spec_from_file_location(self.name, module_path)
1587
+ module_path = os.path.abspath(module_path)
1588
+ module_dir = os.path.dirname(module_path)
1589
+ module_name = os.path.splitext(os.path.basename(module_path))[0]
1590
+
1591
+ if is_yaml_build:
1592
+ package_name = os.path.basename(module_dir)
1593
+ if module_dir not in sys.path:
1594
+ sys.path.append(module_dir)
1595
+
1596
+ if package_name not in sys.modules:
1597
+ pkg_spec = importlib.machinery.ModuleSpec(package_name, None, is_package=True)
1598
+ pkg = importlib.util.module_from_spec(pkg_spec)
1599
+ pkg.__path__ = [module_dir]
1600
+ sys.modules[package_name] = pkg
1601
+
1602
+ module_name = f"{package_name}.{module_name}"
1603
+
1604
+ spec = importlib.util.spec_from_file_location(module_name, module_path)
1397
1605
  module = importlib.util.module_from_spec(spec)
1606
+ sys.modules[module_name] = module
1398
1607
  spec.loader.exec_module(module)
1399
1608
  return module
1400
1609
 
@@ -13,14 +13,11 @@
13
13
  # limitations under the License.
14
14
  # ============================================================================
15
15
  """debug_ops"""
16
- import stat
17
- from pathlib import Path
18
-
19
- import numpy as np
16
+ import inspect
20
17
  from mindspore import log as logger
21
18
  from mindspore._c_expression import security, HookType
22
19
  from mindspore._c_expression import TensorPy as Tensor_
23
- from mindspore._c_expression import _tensordump_process_file
20
+ from mindspore._c_expression import _tensordump_exec
24
21
  from mindspore import _checkparam as validator
25
22
  from mindspore.common import dtype as mstype
26
23
  from mindspore.common.parameter import Parameter
@@ -314,26 +311,12 @@ class TensorDump(Primitive):
314
311
  self.add_prim_attr("side_effect_io", True)
315
312
  self.add_prim_attr("channel_name", "ms_tensor_dump")
316
313
 
317
- def _save_file(self, file, data):
318
- file = Path(file)
319
- if file.exists():
320
- file.chmod(stat.S_IWUSR)
321
- np.save(file, data)
322
- file.chmod(stat.S_IRUSR)
323
-
324
314
  def __call__(self, file, input_x):
325
315
  validator.check_value_type('file', file, [str], self.__class__.__name__)
326
316
  if not file:
327
317
  raise ValueError("For 'TensorDump', the input argument[file] cannot be an empty string.")
328
318
  validator.check_value_type('input_x', input_x, [Tensor], self.__class__.__name__)
329
-
330
- dtype = input_x.dtype
331
- file = _tensordump_process_file(file, str(dtype))
332
- if not file:
333
- return
334
- if dtype == mstype.bfloat16:
335
- input_x = P.Cast()(input_x, mstype.float32)
336
- self._save_file(file, input_x.asnumpy())
319
+ _tensordump_exec(file, input_x)
337
320
 
338
321
 
339
322
  class HistogramSummary(Primitive):
@@ -529,14 +512,15 @@ class Morph(PrimitiveWithInfer):
529
512
 
530
513
  .. note::
531
514
  - This primitive is only supported in GRAPH_MODE.
532
- - `fn` must satisfy the syntax constraints of the graph mode.
533
- - Users do not need to implement a custom backward function.
515
+ - A user-defined bprop (by argument: `bprop_fn`) is allowed for `Morph`.
516
+ - `fn` and `bprop_fn` must satisfy the syntax constraints of the graph mode.
534
517
  - `vararg`, `kwarg`, `kwonlyargs` and free variables are not supported in user-defined function.
535
518
 
536
519
  Args:
537
- fn (Function): Mindspore's function, user-defined function.
538
- infer_shape (Function): Mindspore's function, user-defined infer_shape function.
539
- infer_dtype (Function): Mindspore's function, user-defined infer_dtype function.
520
+ fn (Function): MindSpore's function, user-defined function.
521
+ infer_shape (Function): MindSpore's function, user-defined infer_shape function.
522
+ infer_dtype (Function): MindSpore's function, user-defined infer_dtype function.
523
+ bprop_fn (Function, optional): MindSpore's function, user-defined bprop function, default: ``None``.
540
524
 
541
525
  Inputs:
542
526
  The inputs of user-defined `fn`.
@@ -590,21 +574,35 @@ class Morph(PrimitiveWithInfer):
590
574
  >>> weight0_grad = bwd_out[1][0].asnumpy()
591
575
  >>> weight1_grad = bwd_out[1][1].asnumpy()
592
576
  >>> print("x_grad", x_grad)
593
- >>> print("weight0_grad", weight0_grad)
594
- >>> print("weight1_grad", weight1_grad)
595
577
  x_grad [ 400. 1000. 1800.]
578
+ >>> print("weight0_grad", weight0_grad)
596
579
  weight0_grad [2800. 4000. 5400.]
580
+ >>> print("weight1_grad", weight1_grad)
597
581
  weight1_grad [ 700. 1600. 2700.]
598
582
  """
599
583
  @prim_attr_register
600
- def __init__(self, fn, infer_shape, infer_dtype):
584
+ def __init__(self, fn, infer_shape, infer_dtype, bprop_fn=None):
601
585
  self.add_prim_attr('side_effect_backprop', True)
602
586
  self.add_prim_attr('side_effect_mem', True)
603
587
  self.add_prim_attr('side_effect_io', True)
604
- self.add_prim_attr('__metamorphosis__', fn)
605
588
  self._infer_shape = infer_shape
606
589
  self._infer_dtype = infer_dtype
607
590
 
591
+ self.add_prim_attr('__metamorphosis__', True)
592
+ self.__morph_fn__ = fn
593
+ self.__morph_bprop_fn__ = None
594
+ if bprop_fn:
595
+ self._check_fn_supported(fn)
596
+ self.__morph_bprop_fn__ = bprop_fn
597
+
598
+ def _check_fn_supported(self, fn):
599
+ fn_sig = inspect.signature(fn)
600
+ for param in fn_sig.parameters.values():
601
+ if not (param.kind is inspect.Parameter.POSITIONAL_OR_KEYWORD and param.default is inspect.Parameter.empty):
602
+ raise ValueError(f"When use `bprop` in Morph, Morph `fn` only support positional or keyword parameters "
603
+ f"with default value is empty, but got param '{param.name}' "
604
+ f"of kind '{param.kind.name}' with default value '{param.default}'.")
605
+
608
606
  def infer_shape(self, *args):
609
607
  return self._infer_shape(*args)
610
608