mindspore 2.7.0rc1__cp311-cp311-win_amd64.whl → 2.7.1__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (370) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +5 -2
  3. mindspore/_c_dataengine.cp311-win_amd64.pyd +0 -0
  4. mindspore/_c_expression.cp311-win_amd64.pyd +0 -0
  5. mindspore/_c_mindrecord.cp311-win_amd64.pyd +0 -0
  6. mindspore/_checkparam.py +2 -2
  7. mindspore/_extends/builtin_operations.py +3 -3
  8. mindspore/_extends/parallel_compile/akg_compiler/custom.py +1109 -0
  9. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
  10. mindspore/_extends/parse/__init__.py +3 -3
  11. mindspore/_extends/parse/compile_config.py +24 -1
  12. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +6 -3
  13. mindspore/_extends/parse/parser.py +28 -22
  14. mindspore/_extends/parse/resources.py +1 -1
  15. mindspore/_extends/parse/standard_method.py +23 -2
  16. mindspore/_extends/parse/trope.py +2 -1
  17. mindspore/_extends/pijit/pijit_func_white_list.py +9 -27
  18. mindspore/amp.py +0 -18
  19. mindspore/avcodec-59.dll +0 -0
  20. mindspore/avdevice-59.dll +0 -0
  21. mindspore/avfilter-8.dll +0 -0
  22. mindspore/avformat-59.dll +0 -0
  23. mindspore/avutil-57.dll +0 -0
  24. mindspore/boost/base.py +29 -2
  25. mindspore/common/__init__.py +18 -12
  26. mindspore/common/_decorator.py +3 -2
  27. mindspore/common/_grad_function.py +3 -1
  28. mindspore/common/_tensor_cpp_method.py +1 -1
  29. mindspore/common/_tensor_docs.py +371 -96
  30. mindspore/common/_utils.py +7 -43
  31. mindspore/common/api.py +434 -135
  32. mindspore/common/dtype.py +98 -57
  33. mindspore/common/dump.py +7 -108
  34. mindspore/common/dynamic_shape/__init__.py +0 -0
  35. mindspore/common/{auto_dynamic_shape.py → dynamic_shape/auto_dynamic_shape.py} +15 -23
  36. mindspore/common/dynamic_shape/enable_dynamic.py +197 -0
  37. mindspore/common/file_system.py +59 -9
  38. mindspore/common/hook_handle.py +82 -3
  39. mindspore/common/jit_config.py +5 -1
  40. mindspore/common/jit_trace.py +27 -12
  41. mindspore/common/lazy_inline.py +5 -3
  42. mindspore/common/np_dtype.py +3 -3
  43. mindspore/common/parameter.py +17 -127
  44. mindspore/common/recompute.py +4 -13
  45. mindspore/common/tensor.py +50 -217
  46. mindspore/communication/_comm_helper.py +11 -1
  47. mindspore/communication/comm_func.py +138 -4
  48. mindspore/communication/management.py +85 -1
  49. mindspore/config/op_info.config +0 -15
  50. mindspore/context.py +20 -106
  51. mindspore/dataset/__init__.py +1 -1
  52. mindspore/dataset/audio/transforms.py +1 -1
  53. mindspore/dataset/core/config.py +35 -1
  54. mindspore/dataset/engine/datasets.py +338 -319
  55. mindspore/dataset/engine/datasets_user_defined.py +38 -22
  56. mindspore/dataset/engine/datasets_vision.py +1 -1
  57. mindspore/dataset/engine/validators.py +1 -15
  58. mindspore/dataset/transforms/c_transforms.py +2 -2
  59. mindspore/dataset/transforms/transforms.py +3 -3
  60. mindspore/dataset/vision/__init__.py +1 -1
  61. mindspore/dataset/vision/py_transforms.py +8 -8
  62. mindspore/dataset/vision/transforms.py +17 -5
  63. mindspore/dataset/vision/utils.py +632 -21
  64. mindspore/device_context/ascend/op_tuning.py +35 -1
  65. mindspore/dnnl.dll +0 -0
  66. mindspore/{profiler/common/validator → graph}/__init__.py +9 -1
  67. mindspore/graph/custom_pass.py +55 -0
  68. mindspore/include/api/cell.h +28 -4
  69. mindspore/include/api/cfg.h +24 -7
  70. mindspore/include/api/context.h +1 -0
  71. mindspore/include/api/delegate.h +0 -2
  72. mindspore/include/api/dual_abi_helper.h +100 -19
  73. mindspore/include/api/graph.h +14 -1
  74. mindspore/include/api/kernel.h +16 -3
  75. mindspore/include/api/kernel_api.h +9 -1
  76. mindspore/include/api/metrics/accuracy.h +9 -0
  77. mindspore/include/api/model.h +5 -1
  78. mindspore/include/api/model_group.h +4 -0
  79. mindspore/include/api/model_parallel_runner.h +2 -0
  80. mindspore/include/api/status.h +48 -10
  81. mindspore/include/api/types.h +6 -1
  82. mindspore/include/dataset/constants.h +9 -0
  83. mindspore/include/dataset/execute.h +2 -2
  84. mindspore/jpeg62.dll +0 -0
  85. mindspore/mindrecord/__init__.py +3 -3
  86. mindspore/mindrecord/common/exceptions.py +1 -0
  87. mindspore/mindrecord/config.py +1 -1
  88. mindspore/{parallel/mpi → mindrecord/core}/__init__.py +4 -1
  89. mindspore/mindrecord/{shardheader.py → core/shardheader.py} +2 -1
  90. mindspore/mindrecord/{shardindexgenerator.py → core/shardindexgenerator.py} +1 -1
  91. mindspore/mindrecord/{shardreader.py → core/shardreader.py} +2 -1
  92. mindspore/mindrecord/{shardsegment.py → core/shardsegment.py} +2 -2
  93. mindspore/mindrecord/{shardutils.py → core/shardutils.py} +1 -1
  94. mindspore/mindrecord/{shardwriter.py → core/shardwriter.py} +1 -1
  95. mindspore/mindrecord/filereader.py +4 -4
  96. mindspore/mindrecord/filewriter.py +5 -5
  97. mindspore/mindrecord/mindpage.py +2 -2
  98. mindspore/mindrecord/tools/cifar10.py +4 -3
  99. mindspore/mindrecord/tools/cifar100.py +1 -1
  100. mindspore/mindrecord/tools/cifar100_to_mr.py +1 -1
  101. mindspore/mindrecord/tools/cifar10_to_mr.py +6 -6
  102. mindspore/mindrecord/tools/csv_to_mr.py +1 -1
  103. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  104. mindspore/mindrecord/tools/mnist_to_mr.py +1 -1
  105. mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -1
  106. mindspore/mindspore_backend_common.dll +0 -0
  107. mindspore/mindspore_backend_manager.dll +0 -0
  108. mindspore/mindspore_cluster.dll +0 -0
  109. mindspore/mindspore_common.dll +0 -0
  110. mindspore/mindspore_core.dll +0 -0
  111. mindspore/mindspore_cpu.dll +0 -0
  112. mindspore/mindspore_dump.dll +0 -0
  113. mindspore/mindspore_frontend.dll +0 -0
  114. mindspore/mindspore_glog.dll +0 -0
  115. mindspore/mindspore_hardware_abstract.dll +0 -0
  116. mindspore/mindspore_memory_pool.dll +0 -0
  117. mindspore/mindspore_ms_backend.dll +0 -0
  118. mindspore/mindspore_ops.dll +0 -0
  119. mindspore/{mindspore_ops_host.dll → mindspore_ops_cpu.dll} +0 -0
  120. mindspore/mindspore_profiler.dll +0 -0
  121. mindspore/mindspore_pyboost.dll +0 -0
  122. mindspore/mindspore_pynative.dll +0 -0
  123. mindspore/mindspore_runtime_pipeline.dll +0 -0
  124. mindspore/mindspore_runtime_utils.dll +0 -0
  125. mindspore/mindspore_tools.dll +0 -0
  126. mindspore/mint/__init__.py +15 -10
  127. mindspore/mint/distributed/__init__.py +4 -0
  128. mindspore/mint/distributed/distributed.py +392 -69
  129. mindspore/mint/nn/__init__.py +2 -16
  130. mindspore/mint/nn/functional.py +4 -110
  131. mindspore/mint/nn/layer/__init__.py +0 -2
  132. mindspore/mint/nn/layer/_functions.py +1 -2
  133. mindspore/mint/nn/layer/activation.py +0 -6
  134. mindspore/mint/nn/layer/basic.py +0 -47
  135. mindspore/mint/nn/layer/conv.py +10 -10
  136. mindspore/mint/nn/layer/normalization.py +11 -16
  137. mindspore/mint/nn/layer/pooling.py +0 -4
  138. mindspore/nn/__init__.py +1 -3
  139. mindspore/nn/cell.py +231 -239
  140. mindspore/nn/layer/activation.py +4 -2
  141. mindspore/nn/layer/basic.py +56 -14
  142. mindspore/nn/layer/container.py +16 -0
  143. mindspore/nn/layer/embedding.py +4 -169
  144. mindspore/nn/layer/image.py +1 -1
  145. mindspore/nn/layer/normalization.py +2 -1
  146. mindspore/nn/layer/thor_layer.py +4 -85
  147. mindspore/nn/optim/ada_grad.py +0 -1
  148. mindspore/nn/optim/adafactor.py +0 -1
  149. mindspore/nn/optim/adam.py +32 -127
  150. mindspore/nn/optim/adamax.py +0 -1
  151. mindspore/nn/optim/asgd.py +0 -1
  152. mindspore/nn/optim/ftrl.py +8 -102
  153. mindspore/nn/optim/lamb.py +1 -4
  154. mindspore/nn/optim/lars.py +0 -3
  155. mindspore/nn/optim/lazyadam.py +25 -218
  156. mindspore/nn/optim/momentum.py +5 -43
  157. mindspore/nn/optim/optimizer.py +6 -55
  158. mindspore/nn/optim/proximal_ada_grad.py +0 -1
  159. mindspore/nn/optim/rmsprop.py +0 -1
  160. mindspore/nn/optim/rprop.py +0 -1
  161. mindspore/nn/optim/sgd.py +0 -1
  162. mindspore/nn/optim/tft_wrapper.py +2 -4
  163. mindspore/nn/optim/thor.py +0 -2
  164. mindspore/nn/probability/bijector/bijector.py +7 -8
  165. mindspore/nn/probability/bijector/gumbel_cdf.py +2 -2
  166. mindspore/nn/probability/bijector/power_transform.py +20 -21
  167. mindspore/nn/probability/bijector/scalar_affine.py +5 -5
  168. mindspore/nn/probability/bijector/softplus.py +13 -14
  169. mindspore/nn/probability/distribution/_utils/utils.py +2 -2
  170. mindspore/nn/wrap/cell_wrapper.py +39 -5
  171. mindspore/nn/wrap/grad_reducer.py +4 -89
  172. mindspore/numpy/array_creations.py +4 -4
  173. mindspore/numpy/fft.py +9 -9
  174. mindspore/numpy/utils_const.py +1 -1
  175. mindspore/{nn/reinforcement → onnx}/__init__.py +5 -8
  176. mindspore/onnx/onnx_export.py +137 -0
  177. mindspore/opencv_core4110.dll +0 -0
  178. mindspore/opencv_imgcodecs4110.dll +0 -0
  179. mindspore/{opencv_imgproc452.dll → opencv_imgproc4110.dll} +0 -0
  180. mindspore/ops/__init__.py +2 -0
  181. mindspore/ops/_grad_experimental/grad_comm_ops.py +38 -2
  182. mindspore/ops/_grad_experimental/grad_inner_ops.py +0 -9
  183. mindspore/ops/_op_impl/aicpu/__init__.py +0 -10
  184. mindspore/ops/_op_impl/cpu/__init__.py +1 -5
  185. mindspore/ops/_op_impl/cpu/{buffer_append.py → joinedstr_op.py} +8 -8
  186. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +28 -24
  187. mindspore/ops/auto_generate/gen_extend_func.py +6 -11
  188. mindspore/ops/auto_generate/gen_ops_def.py +385 -154
  189. mindspore/ops/auto_generate/gen_ops_prim.py +5676 -5167
  190. mindspore/ops/communication.py +97 -0
  191. mindspore/ops/composite/__init__.py +5 -2
  192. mindspore/ops/composite/base.py +16 -2
  193. mindspore/ops/composite/multitype_ops/__init__.py +3 -1
  194. mindspore/ops/composite/multitype_ops/_compile_utils.py +150 -8
  195. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -1
  196. mindspore/ops/composite/multitype_ops/add_impl.py +7 -0
  197. mindspore/ops/composite/multitype_ops/mod_impl.py +27 -0
  198. mindspore/ops/function/__init__.py +2 -0
  199. mindspore/ops/function/array_func.py +24 -18
  200. mindspore/ops/function/comm_func.py +3883 -0
  201. mindspore/ops/function/debug_func.py +7 -6
  202. mindspore/ops/function/grad/grad_func.py +4 -12
  203. mindspore/ops/function/math_func.py +89 -86
  204. mindspore/ops/function/nn_func.py +92 -313
  205. mindspore/ops/function/random_func.py +9 -18
  206. mindspore/ops/functional.py +4 -1
  207. mindspore/ops/functional_overload.py +377 -30
  208. mindspore/ops/operations/__init__.py +2 -5
  209. mindspore/ops/operations/_custom_ops_utils.py +7 -9
  210. mindspore/ops/operations/_inner_ops.py +12 -50
  211. mindspore/ops/operations/_rl_inner_ops.py +0 -933
  212. mindspore/ops/operations/array_ops.py +5 -50
  213. mindspore/ops/operations/comm_ops.py +95 -17
  214. mindspore/ops/operations/custom_ops.py +237 -22
  215. mindspore/ops/operations/debug_ops.py +33 -35
  216. mindspore/ops/operations/manually_defined/ops_def.py +39 -318
  217. mindspore/ops/operations/math_ops.py +5 -5
  218. mindspore/ops/operations/nn_ops.py +3 -3
  219. mindspore/ops/operations/sparse_ops.py +0 -83
  220. mindspore/ops/primitive.py +4 -27
  221. mindspore/ops/tensor_method.py +88 -10
  222. mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +5 -5
  223. mindspore/ops_generate/aclnn/gen_aclnn_implement.py +8 -8
  224. mindspore/ops_generate/api/functions_cc_generator.py +53 -4
  225. mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +25 -11
  226. mindspore/ops_generate/common/gen_constants.py +11 -10
  227. mindspore/ops_generate/common/op_proto.py +18 -1
  228. mindspore/ops_generate/common/template.py +102 -245
  229. mindspore/ops_generate/common/template_utils.py +212 -0
  230. mindspore/ops_generate/gen_custom_ops.py +69 -0
  231. mindspore/ops_generate/op_def/ops_def_cc_generator.py +78 -7
  232. mindspore/ops_generate/op_def_py/base_op_prim_py_generator.py +360 -0
  233. mindspore/ops_generate/op_def_py/custom_op_prim_py_generator.py +140 -0
  234. mindspore/ops_generate/op_def_py/op_def_py_generator.py +54 -7
  235. mindspore/ops_generate/op_def_py/op_prim_py_generator.py +5 -312
  236. mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +74 -17
  237. mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +22 -5
  238. mindspore/ops_generate/pyboost/gen_pyboost_func.py +0 -16
  239. mindspore/ops_generate/pyboost/op_template_parser.py +3 -2
  240. mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +21 -5
  241. mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +2 -2
  242. mindspore/ops_generate/pyboost/pyboost_functions_impl_cpp_generator.py +30 -10
  243. mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +10 -3
  244. mindspore/ops_generate/pyboost/pyboost_internal_kernel_info_adapter_generator.py +1 -1
  245. mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +19 -9
  246. mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +71 -28
  247. mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +10 -9
  248. mindspore/ops_generate/pyboost/pyboost_utils.py +27 -16
  249. mindspore/ops_generate/resources/yaml_loader.py +13 -0
  250. mindspore/ops_generate/tensor_py_cc_generator.py +2 -2
  251. mindspore/parallel/_auto_parallel_context.py +5 -15
  252. mindspore/parallel/_cell_wrapper.py +1 -1
  253. mindspore/parallel/_parallel_serialization.py +4 -6
  254. mindspore/parallel/_ps_context.py +2 -2
  255. mindspore/parallel/_utils.py +34 -17
  256. mindspore/parallel/auto_parallel.py +23 -9
  257. mindspore/parallel/checkpoint_transform.py +20 -2
  258. mindspore/parallel/cluster/process_entity/_api.py +28 -33
  259. mindspore/parallel/cluster/process_entity/_utils.py +9 -5
  260. mindspore/parallel/cluster/run.py +5 -3
  261. mindspore/{experimental/llm_boost/ascend_native → parallel/distributed}/__init__.py +21 -22
  262. mindspore/parallel/distributed/distributed_data_parallel.py +393 -0
  263. mindspore/parallel/distributed/flatten_grad_buffer.py +295 -0
  264. mindspore/parallel/function/reshard_func.py +6 -5
  265. mindspore/parallel/nn/parallel_cell_wrapper.py +40 -3
  266. mindspore/parallel/nn/parallel_grad_reducer.py +0 -8
  267. mindspore/parallel/shard.py +7 -21
  268. mindspore/parallel/strategy.py +336 -0
  269. mindspore/parallel/transform_safetensors.py +127 -20
  270. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +13 -9
  271. mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +1 -1
  272. mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +1 -1
  273. mindspore/profiler/common/constant.py +5 -0
  274. mindspore/profiler/common/file_manager.py +9 -0
  275. mindspore/profiler/common/msprof_cmd_tool.py +40 -4
  276. mindspore/profiler/common/path_manager.py +65 -24
  277. mindspore/profiler/common/profiler_context.py +27 -14
  278. mindspore/profiler/common/profiler_info.py +3 -3
  279. mindspore/profiler/common/profiler_meta_data.py +1 -0
  280. mindspore/profiler/common/profiler_op_analyse.py +10 -6
  281. mindspore/profiler/common/profiler_path_manager.py +13 -0
  282. mindspore/profiler/common/util.py +30 -3
  283. mindspore/profiler/dynamic_profiler.py +91 -46
  284. mindspore/profiler/envprofiler.py +30 -5
  285. mindspore/profiler/experimental_config.py +18 -2
  286. mindspore/profiler/platform/cpu_profiler.py +10 -4
  287. mindspore/profiler/platform/npu_profiler.py +34 -7
  288. mindspore/profiler/profiler.py +193 -145
  289. mindspore/profiler/profiler_action_controller.py +1 -1
  290. mindspore/profiler/profiler_interface.py +2 -2
  291. mindspore/rewrite/symbol_tree/symbol_tree.py +1 -1
  292. mindspore/run_check/_check_version.py +108 -24
  293. mindspore/runtime/__init__.py +9 -6
  294. mindspore/runtime/executor.py +35 -0
  295. mindspore/runtime/memory.py +113 -0
  296. mindspore/runtime/thread_bind_core.py +1 -1
  297. mindspore/swresample-4.dll +0 -0
  298. mindspore/swscale-6.dll +0 -0
  299. mindspore/tinyxml2.dll +0 -0
  300. mindspore/{experimental/llm_boost → tools}/__init__.py +5 -5
  301. mindspore/tools/data_dump.py +130 -0
  302. mindspore/tools/sdc_detect.py +91 -0
  303. mindspore/tools/stress_detect.py +63 -0
  304. mindspore/train/__init__.py +6 -6
  305. mindspore/train/_utils.py +8 -21
  306. mindspore/train/amp.py +6 -7
  307. mindspore/train/callback/_callback.py +2 -1
  308. mindspore/train/callback/_checkpoint.py +1 -17
  309. mindspore/train/callback/_flops_collector.py +10 -6
  310. mindspore/train/callback/_train_fault_tolerance.py +72 -25
  311. mindspore/train/data_sink.py +5 -9
  312. mindspore/train/dataset_helper.py +5 -5
  313. mindspore/train/model.py +41 -230
  314. mindspore/train/serialization.py +160 -401
  315. mindspore/train/train_thor/model_thor.py +2 -2
  316. mindspore/turbojpeg.dll +0 -0
  317. mindspore/utils/__init__.py +6 -3
  318. mindspore/utils/dlpack.py +92 -0
  319. mindspore/utils/dryrun.py +1 -1
  320. mindspore/utils/runtime_execution_order_check.py +10 -0
  321. mindspore/utils/sdc_detect.py +14 -12
  322. mindspore/utils/stress_detect.py +43 -0
  323. mindspore/utils/utils.py +152 -16
  324. mindspore/version.py +1 -1
  325. {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/METADATA +3 -2
  326. {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/RECORD +330 -344
  327. mindspore/_extends/remote/kernel_build_server_ascend.py +0 -75
  328. mindspore/communication/_hccl_management.py +0 -297
  329. mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +0 -207
  330. mindspore/experimental/llm_boost/ascend_native/llm_boost.py +0 -52
  331. mindspore/experimental/llm_boost/atb/__init__.py +0 -23
  332. mindspore/experimental/llm_boost/atb/boost_base.py +0 -385
  333. mindspore/experimental/llm_boost/atb/llama_boost.py +0 -137
  334. mindspore/experimental/llm_boost/atb/qwen_boost.py +0 -124
  335. mindspore/experimental/llm_boost/register.py +0 -130
  336. mindspore/experimental/llm_boost/utils.py +0 -31
  337. mindspore/include/OWNERS +0 -7
  338. mindspore/mindspore_cpu_res_manager.dll +0 -0
  339. mindspore/mindspore_ops_kernel_common.dll +0 -0
  340. mindspore/mindspore_res_manager.dll +0 -0
  341. mindspore/nn/optim/_dist_optimizer_registry.py +0 -111
  342. mindspore/nn/reinforcement/_batch_read_write.py +0 -142
  343. mindspore/nn/reinforcement/_tensors_queue.py +0 -152
  344. mindspore/nn/reinforcement/tensor_array.py +0 -145
  345. mindspore/opencv_core452.dll +0 -0
  346. mindspore/opencv_imgcodecs452.dll +0 -0
  347. mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +0 -113
  348. mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +0 -96
  349. mindspore/ops/_op_impl/aicpu/sparse_cross.py +0 -42
  350. mindspore/ops/_op_impl/cpu/buffer_get.py +0 -28
  351. mindspore/ops/_op_impl/cpu/buffer_sample.py +0 -28
  352. mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +0 -42
  353. mindspore/ops/operations/_tensor_array.py +0 -359
  354. mindspore/ops/operations/rl_ops.py +0 -288
  355. mindspore/parallel/_offload_context.py +0 -275
  356. mindspore/parallel/_recovery_context.py +0 -115
  357. mindspore/parallel/_transformer/__init__.py +0 -35
  358. mindspore/parallel/_transformer/layers.py +0 -765
  359. mindspore/parallel/_transformer/loss.py +0 -251
  360. mindspore/parallel/_transformer/moe.py +0 -693
  361. mindspore/parallel/_transformer/op_parallel_config.py +0 -222
  362. mindspore/parallel/_transformer/transformer.py +0 -3124
  363. mindspore/parallel/mpi/_mpi_config.py +0 -116
  364. mindspore/profiler/common/validator/validate_path.py +0 -84
  365. mindspore/train/memory_profiling_pb2.py +0 -298
  366. mindspore/utils/hooks.py +0 -81
  367. /mindspore/common/{_auto_dynamic.py → dynamic_shape/_auto_dynamic.py} +0 -0
  368. {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/WHEEL +0 -0
  369. {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/entry_points.txt +0 -0
  370. {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/top_level.txt +0 -0
@@ -20,14 +20,14 @@ import os
20
20
 
21
21
  import common.gen_constants as K
22
22
  import common.gen_utils as gen_utils
23
- import common.template as template
24
- from common.base_generator import BaseGenerator
23
+ import common.template_utils as template
25
24
  from common.op_proto import OpProto
26
- from common.template import Template
25
+ from common.template_utils import Template
27
26
  from pyboost import pyboost_utils
27
+ from op_def_py.base_op_prim_py_generator import BaseOpPrimPyGenerator, _generate_arg_handler, generate_py_op_deprecated
28
28
 
29
29
 
30
- class OpPrimPyGenerator(BaseGenerator):
30
+ class OpPrimPyGenerator(BaseOpPrimPyGenerator):
31
31
  """
32
32
  Generates Python code for primitive operators based on provided specifications.
33
33
  """
@@ -87,7 +87,7 @@ class OpPrimPyGenerator(BaseGenerator):
87
87
 
88
88
  pyboost_import_header = self.generate_pyboost_import_header(op_protos)
89
89
  res_str = template.PY_LICENSE_STR + \
90
- template.OPS_PY_PRIM_HEADER + pyboost_import_header + gen_py
90
+ template.OPS_PY_PRIM_HEADER + pyboost_import_header + gen_py
91
91
 
92
92
  save_path = os.path.join(work_path, K.PY_AUTO_GEN_PATH)
93
93
  file_name = f"{file_pre}_ops_prim.py"
@@ -111,113 +111,6 @@ class OpPrimPyGenerator(BaseGenerator):
111
111
  pyboost_import_header += header
112
112
  return pyboost_import_header
113
113
 
114
- def _process_args(self, op_proto: OpProto):
115
- """
116
- Processes operator arguments to categorize them for code generation.
117
-
118
- Args:
119
- op_proto (OpProto): The operator prototype.
120
-
121
- Returns:
122
- tuple: A tuple containing processed arguments.
123
- """
124
- inputs_name = []
125
- args_name = []
126
- args_assign = []
127
- inputs_default = {}
128
- init_args_with_default = []
129
- args_handlers = {}
130
-
131
- for arg in op_proto.op_args:
132
- # step1: get args infos:
133
- if arg.is_prim_init:
134
- # step1.1: get args name:
135
- args_name.append(arg.arg_name)
136
- # step1.2: get args assign with default value:
137
- if arg.default is not None:
138
- init_args_with_default.append(f"""{arg.arg_name}={arg.default}""")
139
- else:
140
- init_args_with_default.append(f"""{arg.arg_name}""")
141
-
142
- # step1.3: get args set prim arg expression:
143
- assign_str = self._get_assign_str_by_type_it(op_proto.op_class.name, arg)
144
- if arg.arg_handler:
145
- assign_str = (
146
- f' self._set_prim_arg_with_handler('
147
- f'"{arg.arg_name}", {assign_str}, {arg.arg_handler})'
148
- )
149
- else:
150
- assign_str = f""" self._set_prim_arg("{arg.arg_name}", {assign_str})"""
151
- args_assign.append(assign_str)
152
- # step2: get inputs infos:
153
- else:
154
- # step2.1: get inputs name:
155
- inputs_name.append(arg.arg_name)
156
-
157
- # step2.2: get default value of inputs:
158
- if arg.default is not None:
159
- inputs_default[arg.arg_name] = arg.default
160
-
161
- # step2.3: get args_handler functions for inputs
162
- if arg.arg_handler:
163
- args_handlers[arg.arg_name] = arg.arg_handler
164
-
165
- return inputs_name, inputs_default, args_name, args_assign, init_args_with_default, args_handlers
166
-
167
- def _get_assign_str_by_type_it(self, class_name, arg):
168
- """
169
- Generates assignment string with type casting.
170
-
171
- Args:
172
- class_name (str): The name of the class.
173
- arg (OpArg): The operator argument.
174
-
175
- Returns:
176
- str: A string representing the assignment.
177
- """
178
- assign_str = ""
179
- type_cast = arg.type_cast
180
- if type_cast:
181
- assign_str += f"type_it('{class_name}', '{arg.arg_name}', {arg.arg_name}, "
182
- if len(type_cast) == 1:
183
- assign_str += gen_utils.get_type_str(type_cast[0]) + ', '
184
- else:
185
- assign_str += '(' + ', '.join(gen_utils.get_type_str(ct) for ct in type_cast) + '), '
186
- assign_str += gen_utils.get_type_str(arg.arg_dtype) + ')'
187
- else:
188
- assign_str = arg.arg_name
189
- return assign_str
190
-
191
- def _generate_class_desc(self, op_proto: OpProto, input_args, init_args, doc_dic):
192
- """
193
- Generates a class description based on the operator prototype.
194
-
195
- Args:
196
- op_proto (OpProto): The operator prototype.
197
- input_args (list): List of input argument names.
198
- init_args (list): List of initialization argument names.
199
- doc_dic (dict): Documentation dictionary.
200
-
201
- Returns:
202
- str: A string containing the class description.
203
- """
204
- if op_proto.op_function and op_proto.op_function.disable:
205
- # if function disabled, function name is equal to operator_name
206
- return gen_utils.get_op_description(op_proto.op_name, doc_dic)
207
-
208
- # If function is a released API, refer to the function doc.
209
- init_args_str = ", ".join(init_args)
210
- input_args_str = ", ".join(input_args)
211
- args_str = ", ".join(input_args + init_args)
212
-
213
- description_template = Template(template.PRIMITIVE_CLASS_DESC)
214
- description_str = description_template.replace(class_name=op_proto.op_class.name,
215
- init_args_str=init_args_str,
216
- input_args_str=input_args_str,
217
- func_name=op_proto.op_function.name,
218
- args_str=args_str)
219
- return description_str
220
-
221
114
  def _generate_init_code(self, args_assign, init_args_with_default, op_proto: OpProto):
222
115
  """
223
116
  Generates the __init__ method code for the operator primitive class.
@@ -242,50 +135,6 @@ class OpPrimPyGenerator(BaseGenerator):
242
135
  init_code_str += f"\n"
243
136
  return init_code_str
244
137
 
245
- def _get_init_code(self, init_code, op_proto: OpProto):
246
- """
247
- Generates additional initialization code for the operator primitive class.
248
-
249
- Args:
250
- init_code (str): Existing initialization code.
251
- op_proto (OpProto): The operator prototype.
252
-
253
- Returns:
254
- str: A string containing additional initialization code.
255
- """
256
- labels_dic = op_proto.op_labels
257
- if labels_dic:
258
- if init_code:
259
- init_code += "\n"
260
- init_code += "\n".join([f""" self.add_prim_attr("{k}", {v})""" for k, v in labels_dic.items()])
261
-
262
- return init_code if init_code else f""" pass"""
263
-
264
- def _generate_call_code(self, args_handlers, init_args, inputs_args, inputs_default, op_proto: OpProto):
265
- """
266
- Generates the __call__ method code for the operator primitive class.
267
-
268
- Args:
269
- args_handlers (dict): Dictionary of argument handlers.
270
- init_args (list): List of initialization argument names.
271
- inputs_args (list): List of input argument names.
272
- inputs_default (dict): Dictionary of default input values.
273
- op_proto (OpProto): The operator prototype.
274
-
275
- Returns:
276
- str: A string containing the __call__ method code.
277
- """
278
- call_code_str = ""
279
- call_args = []
280
- for name in inputs_args:
281
- call_args.append(f"{name}={inputs_default[name]}" if name in inputs_default else name)
282
- call_method_args_str = ", ".join(call_args)
283
- call_method_body_str = self._get_call_method_body_str(args_handlers, init_args, inputs_args, inputs_default,
284
- op_proto)
285
- call_code_str += f""" def __call__(self, {call_method_args_str}):"""
286
- call_code_str += f"""{call_method_body_str}"""
287
- return call_code_str
288
-
289
138
  def _get_call_method_body_str(self, args_handlers, init_args, inputs_args, inputs_default, op_proto: OpProto):
290
139
  """
291
140
  Generates the body of the __call__ method.
@@ -334,159 +183,3 @@ class OpPrimPyGenerator(BaseGenerator):
334
183
  call_method_body_str += f"""
335
184
  return super().__call__({call_args_list_str})\n"""
336
185
  return call_method_body_str
337
-
338
- def _generate_py_op_signature(self, op_proto: OpProto, args_name, args_default):
339
- """
340
- Generates the __mindspore_signature__ for the operator.
341
-
342
- Args:
343
- op_proto (OpProto): The operator prototype.
344
- args_name (list): List of argument names.
345
- args_default (dict): Dictionary of default argument values.
346
-
347
- Returns:
348
- str: A string containing the __mindspore_signature__ code.
349
- """
350
- op_name = op_proto.op_name
351
- args_signature = op_proto.op_args_signature
352
-
353
- if args_signature is None and not args_default:
354
- return ''
355
-
356
- signature_code = f"""\n __mindspore_signature__ = """
357
-
358
- # Init rw.
359
- read_list, ref_list, write_list = gen_utils.init_args_signature_rw(args_signature)
360
- _check_signature_arg_valid(op_name, write_list, args_name)
361
- _check_signature_arg_valid(op_name, read_list, args_name)
362
- _check_signature_arg_valid(op_name, ref_list, args_name)
363
-
364
- # Init dtype group.
365
- same_dtype_groups, dtype_count = gen_utils.get_same_dtype_groups(args_signature, args_name)
366
- _check_signature_arg_valid(op_name, list(same_dtype_groups.keys()), args_name)
367
-
368
- # Only one dtype_group is set.
369
- if dtype_count == 1 and not any([write_list, read_list, ref_list, args_default]):
370
- signature_code += '('
371
- for _ in range(len(args_name) - 1):
372
- signature_code += 'sig.sig_dtype.T, '
373
- signature_code += 'sig.sig_dtype.T)\n'
374
- return signature_code
375
-
376
- # Set sig.make_sig.
377
- signature_code += f""" (\n"""
378
- for arg_name in args_name:
379
- signature_code += f""" sig.make_sig('{arg_name}'"""
380
- signature_code += signature_get_rw_label(arg_name, write_list, read_list, ref_list)
381
- if arg_name in same_dtype_groups:
382
- signature_code += f""", """ + signature_get_dtype_label(same_dtype_groups[arg_name])
383
- if arg_name in args_default:
384
- signature_code += f""", default=""" + str(args_default[arg_name])
385
- signature_code += f"""),\n"""
386
- signature_code += f""" )\n"""
387
- return signature_code
388
-
389
-
390
- def _check_signature_arg_valid(op_name, sig_arg_names, args_names):
391
- """
392
- Validates that all signature arguments are present in the list of argument names.
393
-
394
- Args:
395
- op_name (str): The name of the operator.
396
- sig_arg_names (list): List of signature argument names.
397
- args_names (list): List of actual argument names.
398
-
399
- Raises:
400
- ValueError: If a signature argument is not found in the list of argument names.
401
- """
402
- for sig_arg_name in sig_arg_names:
403
- if sig_arg_name not in args_names:
404
- raise ValueError(f"Op {op_name} has no input arg named '{sig_arg_name}'!")
405
-
406
-
407
- def signature_get_dtype_label(index):
408
- """
409
- Generates the label for the data type in the signature.
410
-
411
- Args:
412
- index (int): The index of the data type.
413
-
414
- Returns:
415
- str: The label string for the data type.
416
- """
417
- dtype_index = ''
418
- if index > 0:
419
- dtype_index = f"""{index}"""
420
- return f"""dtype=sig.sig_dtype.T{dtype_index}"""
421
-
422
-
423
- def signature_get_rw_label(arg_name, write_list, read_list, ref_list):
424
- """
425
- Determines the read-write label for an argument in the signature.
426
-
427
- Args:
428
- arg_name (str): The name of the argument.
429
- write_list (list): List of arguments that are writable.
430
- read_list (list): List of arguments that are readable.
431
- ref_list (list): List of arguments that are references.
432
-
433
- Returns:
434
- str: The read-write label for the argument.
435
- """
436
- for rw_arg_name in write_list:
437
- if rw_arg_name == arg_name:
438
- return ', sig.sig_rw.RW_WRITE'
439
- for read_arg_name in read_list:
440
- if read_arg_name == arg_name:
441
- return ', sig.sig_rw.RW_READ'
442
- for ref_arg_name in ref_list:
443
- if ref_arg_name == arg_name:
444
- return ', sig.sig_rw.RW_REF'
445
- return ''
446
-
447
-
448
- def generate_py_op_deprecated(deprecated):
449
- """
450
- Generates the deprecated decorator for an operator.
451
-
452
- Args:
453
- deprecated (dict): The deprecation information.
454
-
455
- Returns:
456
- str: A string containing the deprecated decorator.
457
- """
458
- if deprecated is None:
459
- return ''
460
- version = deprecated.get("version")
461
- if version is None:
462
- raise ValueError("The version of deprecated can't be None.")
463
- substitute = deprecated.get("substitute")
464
- if substitute is None:
465
- raise ValueError("The substitute of deprecated can't be None.")
466
- use_substitute = deprecated.get("use_substitute")
467
- if use_substitute is None:
468
- raise ValueError("The use_substitute of deprecated can't be None.")
469
- if use_substitute is not True and use_substitute is not False:
470
- raise ValueError(f"The use_substitute must be True or False, but got {use_substitute}")
471
-
472
- deprecated = f""" @deprecated("{version}", "{substitute}", {use_substitute})\n"""
473
- return deprecated
474
-
475
-
476
- def _generate_arg_handler(class_name, arg, arg_handler, is_optional):
477
- """
478
- Generates the argument handler call for an argument.
479
-
480
- Args:
481
- class_name (str): The name of the class.
482
- arg (str): The name of the argument.
483
- arg_handler (str): The handler function for the argument.
484
- is_optional (bool): Indicates whether the argument is optional.
485
-
486
- Returns:
487
- str: The argument handler call string.
488
- """
489
- arg_handler_call = f"""{arg_handler}('{class_name}', '{arg}', {arg})"""
490
- if is_optional:
491
- arg_handler_call = f"""{arg} if {arg} is None else {arg_handler_call}"""
492
- return arg_handler_call
@@ -23,7 +23,7 @@ from common.template import Template
23
23
  import common.gen_constants as K
24
24
  from common.gen_utils import save_file
25
25
  from common.base_generator import BaseGenerator
26
- from pyboost.pyboost_utils import is_optional_param, get_input_dtype, is_op_multi_output
26
+ from pyboost.pyboost_utils import is_optional_param, get_input_dtype, is_op_multi_output, get_output_dtype
27
27
 
28
28
 
29
29
  class AutoGradImplGenerator(BaseGenerator):
@@ -38,6 +38,8 @@ class AutoGradImplGenerator(BaseGenerator):
38
38
  self.OP_DEF_INC_HEAD_TEMPLATE = template.OP_DEF_INC_HEAD_TEMPLATE
39
39
  self.AUTO_GRAD_IMPL_CC_TEMPLATE = template.AUTO_GRAD_IMPL_CC_TEMPLATE
40
40
  self.DO_GRAD_FUNCTION_BODY_TEMPLATE = template.DO_GRAD_FUNCTION_BODY_TEMPLATE
41
+ self.DO_VIEW_GRAD_FUNCTION_BODY_TEMPLATE = template.DO_VIEW_GRAD_FUNCTION_BODY_TEMPLATE
42
+ self.DO_VIEW_CUSTOMIZE_GRAD_FUNCTION_BODY_TEMPLATE = template.DO_VIEW_CUSTOMIZE_GRAD_FUNCTION_BODY_TEMPLATE
41
43
  self.auto_grad_reg_template = Template("const_cast<kernel::pyboost::${class_name}GradFunc&>(" + \
42
44
  "kernel::pyboost::AutoGradFactory::Get()." + \
43
45
  "ops_auto_grad_registers().${class_name}GradFuncObj) = " + \
@@ -45,6 +47,9 @@ class AutoGradImplGenerator(BaseGenerator):
45
47
  self.do_grad_op_args_with_type = Template(
46
48
  "const kernel::pyboost::OpPtr &op, ${input_args_with_type}"
47
49
  )
50
+ self.do_grad_view_op_args_with_type = Template(
51
+ "${output_args_with_type}, ${input_args_with_type}"
52
+ )
48
53
 
49
54
  def generate(self, work_path, op_protos):
50
55
  """
@@ -60,8 +65,13 @@ class AutoGradImplGenerator(BaseGenerator):
60
65
  for op_proto in op_protos:
61
66
  if op_proto.op_dispatch is None:
62
67
  continue
68
+ # the backward func of flatten_ext and t_ext are implemented by other view ops, just continue
69
+ if op_proto.op_view and not op_proto.bprop_expander:
70
+ continue
63
71
  auto_grad_reg_list.append(self.auto_grad_reg_template.replace(class_name=op_proto.op_class.name))
64
- do_grad_op_list.append(self._get_single_do_grad_op(op_proto))
72
+ do_single_grad_op_str = self._get_single_do_grad_view_op(op_proto)\
73
+ if op_proto.op_view else self._get_single_do_grad_op(op_proto)
74
+ do_grad_op_list.append(do_single_grad_op_str)
65
75
  ops_inc_head_set.add(self.OP_DEF_INC_HEAD_TEMPLATE.replace(prefix_char=op_proto.op_class.name[0].lower()))
66
76
  pyboost_func_h_str = self.AUTO_GRAD_IMPL_CC_TEMPLATE.replace(do_grad_op=do_grad_op_list,
67
77
  auto_grad_reg=auto_grad_reg_list,
@@ -80,12 +90,11 @@ class AutoGradImplGenerator(BaseGenerator):
80
90
  Returns:
81
91
  str: The generated DoGrad function string.
82
92
  """
83
- input_args_str = self._get_input_args(op_proto, False, False, op_proto.op_view)
84
- input_args_with_optional_str = self._get_input_args(op_proto, False, True, op_proto.op_view)
85
- input_args_with_type_str = self._get_input_args(op_proto, True, False, op_proto.op_view)
93
+ input_args_str = self._get_input_args(op_proto, False, False, False)
94
+ input_args_with_optional_str = self._get_input_args(op_proto, False, True, False)
95
+ input_args_with_type_str = self._get_input_args(op_proto, True, False, False)
86
96
  inner_grad_args_with_type = self._get_input_args(op_proto, True, False, False)
87
97
  multi_output_str = 'Multi' if is_op_multi_output(op_proto.op_returns) else ''
88
- view_arg_str = self._get_view_str(op_proto.op_view, input_args_str)
89
98
  grad_args_with_type_str = self.do_grad_op_args_with_type.replace(input_args_with_type=input_args_with_type_str)
90
99
  inner_grad_args_with_type =\
91
100
  self.do_grad_op_args_with_type.replace(input_args_with_type=inner_grad_args_with_type)
@@ -94,22 +103,62 @@ class AutoGradImplGenerator(BaseGenerator):
94
103
  FALSE = "false"
95
104
  bprop_expander = TRUE if op_proto.bprop_expander else FALSE
96
105
  non_differentiable = TRUE if op_proto.non_differentiable else FALSE
97
- if not op_proto.op_view:
98
- convert_basic_to_value = ''
99
- else:
100
- input_args_with_optional_str, convert_basic_to_value = self._get_convert_str(op_proto,
101
- input_args_with_optional_str)
106
+
102
107
  return self.DO_GRAD_FUNCTION_BODY_TEMPLATE.replace(class_name=op_proto.op_class.name,
103
108
  inner_grad_args_with_type=inner_grad_args_with_type,
104
109
  grad_args_with_type=grad_args_with_type_str,
105
110
  grad_input_args=input_args_str,
106
111
  grad_input_args_with_optional=input_args_with_optional_str,
107
112
  is_multi=multi_output_str,
108
- view_arg=view_arg_str,
109
113
  op_def_name=op_def_name_str,
110
114
  bprop_expander=bprop_expander,
111
- non_differentiable=non_differentiable,
112
- convert_basic_to_value=convert_basic_to_value)
115
+ non_differentiable=non_differentiable)
116
+
117
+ def _get_single_do_grad_view_op(self, op_proto):
118
+ """
119
+ Generate the DoGrad function for a single view operator prototype.
120
+
121
+ Args:
122
+ op_proto: The operator prototype for which the DoGrad function is generated.
123
+
124
+ Returns:
125
+ str: The generated DoGrad function string.
126
+ """
127
+ input_args_str = self._get_input_args(op_proto, False, False, True)
128
+ input_args_with_optional_str = self._get_input_args(op_proto, False, True, True)
129
+ input_args_with_type_str = self._get_input_args(op_proto, True, False, True)
130
+ inner_grad_args_with_type = self._get_input_args(op_proto, True, False, False)
131
+ view_arg_str = self._get_view_str(input_args_str)
132
+ grad_args_with_type_str = self.do_grad_view_op_args_with_type\
133
+ .replace(input_args_with_type=input_args_with_type_str,
134
+ output_args_with_type=self._get_output_arg(op_proto))
135
+ inner_grad_args_with_type =\
136
+ self.do_grad_view_op_args_with_type.replace(output_args_with_type="const ValuePtr &output_value",
137
+ input_args_with_type=inner_grad_args_with_type)
138
+ op_def_name_str = "g" + op_proto.op_class.name
139
+ TRUE = "true"
140
+ FALSE = "false"
141
+ bprop_expander = TRUE if op_proto.bprop_expander else FALSE
142
+ non_differentiable = TRUE if op_proto.non_differentiable else FALSE
143
+ if op_proto.op_name in ["reshape", "expand_dims", "transpose", "slice_ext_view",\
144
+ "select_ext_view", "transpose_ext_view"]:
145
+ do_view_grad_function_body_tpl = self.DO_VIEW_CUSTOMIZE_GRAD_FUNCTION_BODY_TEMPLATE
146
+ convert_basic_to_value = ""
147
+ else:
148
+ do_view_grad_function_body_tpl = self.DO_VIEW_GRAD_FUNCTION_BODY_TEMPLATE
149
+ input_args_with_optional_str, convert_basic_to_value = self._get_convert_str(op_proto,
150
+ input_args_with_optional_str)
151
+ return do_view_grad_function_body_tpl.replace(class_name=op_proto.op_class.name,
152
+ inner_grad_args_with_type=inner_grad_args_with_type,
153
+ grad_args_with_type=grad_args_with_type_str,
154
+ grad_input_args=input_args_str,
155
+ grad_input_args_with_optional=input_args_with_optional_str,
156
+ view_arg=view_arg_str,
157
+ op_def_name=op_def_name_str,
158
+ bprop_expander=bprop_expander,
159
+ non_differentiable=non_differentiable,
160
+ convert_basic_to_value=convert_basic_to_value)
161
+
113
162
 
114
163
  def _get_input_args(self, op_proto, has_type, with_optional, use_basic_type=False):
115
164
  """
@@ -134,6 +183,15 @@ class AutoGradImplGenerator(BaseGenerator):
134
183
  args_list.append(f"{op_arg.arg_name}_tensor")
135
184
  return args_list
136
185
 
186
+ def _get_output_arg(self, op_proto):
187
+ # for view operators, the output is tensor or vector<tensor>
188
+ if len(op_proto.op_returns) != 1:
189
+ raise ValueError(f"the output of {op_proto.op_name} is not tensor, ",
190
+ "tuple[tensor] or list[tensor], which is not not as expected")
191
+ output_dtype = get_output_dtype(op_proto.op_returns[0].arg_dtype)
192
+ output_arg = f"const {output_dtype} &output"
193
+ return output_arg
194
+
137
195
  def _get_convert_str(self, op_proto, args_name):
138
196
  """
139
197
  Get the input convert func for the DoGrad function.
@@ -161,12 +219,11 @@ class AutoGradImplGenerator(BaseGenerator):
161
219
  args_name_list.append(out_arg_name)
162
220
  return args_name_list, convert_funcs
163
221
 
164
- def _get_view_str(self, is_view_op: bool, grad_args: list):
222
+ def _get_view_str(self, grad_args: list):
165
223
  """
166
224
  Get the view argument string for a DoGrad function.
167
225
 
168
226
  Args:
169
- is_view_op (bool): Whether the operator is a view operator.
170
227
  grad_args (list): A list of gradient arguments.
171
228
 
172
229
  Returns:
@@ -174,7 +231,7 @@ class AutoGradImplGenerator(BaseGenerator):
174
231
  """
175
232
  view_arg_str = ''
176
233
  for i, grad_arg in enumerate(grad_args):
177
- if is_view_op and i == 0:
234
+ if i == 0:
178
235
  view_arg_str = ", " + grad_arg
179
236
  break
180
237
  return view_arg_str
@@ -23,7 +23,7 @@ from common.template import Template
23
23
  import common.gen_constants as K
24
24
  from common.gen_utils import save_file
25
25
  from common.base_generator import BaseGenerator
26
- from pyboost.pyboost_utils import is_optional_param, get_input_dtype
26
+ from pyboost.pyboost_utils import is_optional_param, get_input_dtype, get_output_dtype
27
27
 
28
28
 
29
29
  class AutoGradRegHeaderGenerator(BaseGenerator):
@@ -42,6 +42,9 @@ class AutoGradRegHeaderGenerator(BaseGenerator):
42
42
  self.op_grad_func_args_template = Template(
43
43
  "const kernel::pyboost::OpPtr &, ${input_tensor_prt_args}"
44
44
  )
45
+ self.op_view_grad_func_args_template = Template(
46
+ "${output_tensor_prt_args}, ${input_tensor_prt_args}"
47
+ )
45
48
 
46
49
  def generate(self, work_path, op_protos):
47
50
  """
@@ -60,9 +63,13 @@ class AutoGradRegHeaderGenerator(BaseGenerator):
60
63
  continue
61
64
  op_type_enum_list.append(self.op_type_enum_template.replace(class_name=op_proto.op_class.name,
62
65
  enum_val=index))
66
+ # the backward func of flatten_ext and t_ext are implemented by other view ops, just continue
67
+ if op_proto.op_view and not op_proto.bprop_expander:
68
+ continue
63
69
  grad_func_args_with_type_str = self._get_grad_func_args_with_type_str(op_proto)
64
- op_grad_func_list.append(self.op_grad_func_template.replace(class_name=op_proto.op_class.name,
65
- grad_func_args=grad_func_args_with_type_str))
70
+ op_grad_func_list.append(
71
+ self.op_grad_func_template.replace(class_name=op_proto.op_class.name,
72
+ grad_func_args=grad_func_args_with_type_str))
66
73
  op_grad_func_obj_list.append(self.op_grad_func_obj_template.replace(class_name=op_proto.op_class.name))
67
74
  index += 1
68
75
 
@@ -89,5 +96,15 @@ class AutoGradRegHeaderGenerator(BaseGenerator):
89
96
  is_optional = is_optional_param(op_arg)
90
97
  input_dtype = get_input_dtype(op_arg.arg_dtype, is_optional, op_proto.op_view)
91
98
  input_tensor_prt_args_str += f"const {input_dtype} &, "
92
-
93
- return self.op_grad_func_args_template.replace(input_tensor_prt_args=input_tensor_prt_args_str.rstrip(', '))
99
+ input_tensor_prt_args_str = input_tensor_prt_args_str.rstrip(', ')
100
+ if not op_proto.op_view:
101
+ return self.op_grad_func_args_template.replace(input_tensor_prt_args=\
102
+ input_tensor_prt_args_str)
103
+ # for view operators, the output is tensor or vector<tensor>
104
+ if len(op_proto.op_returns) != 1:
105
+ raise ValueError(f"the output of {op_proto.op_name} is not tensor,",
106
+ "tuple[tensor] or list[tensor], which is not not as expected")
107
+ output_dtype = get_output_dtype(op_proto.op_returns[0].arg_dtype)
108
+ output_tensor_prt_args_str = f"const {output_dtype} &"
109
+ return self.op_view_grad_func_args_template.replace(input_tensor_prt_args=input_tensor_prt_args_str,
110
+ output_tensor_prt_args=output_tensor_prt_args_str)
@@ -16,9 +16,6 @@
16
16
  Generate pyboost function from pyboost_op.yaml
17
17
  """
18
18
 
19
- import os
20
- import shutil
21
- import logging
22
19
  from resources.resource_list import ResourceType
23
20
  from common import gen_constants as K
24
21
  from api.functions_cc_generator import FunctionsGenerator, FunctionsHeaderGenerator
@@ -48,18 +45,6 @@ from .auto_grad_impl_cc_generator import AutoGradImplGenerator
48
45
  from .auto_grad_reg_cc_generator import AutoGradRegHeaderGenerator
49
46
 
50
47
 
51
- def clear_old_generated_code(work_path):
52
- """ delete old generated files to prevent compilation failure """
53
- files_to_clear = ['mindspore/ops/kernel/common/pyboost',
54
- 'mindspore/ops/kernel/functions/auto_generate',
55
- 'mindspore/ccsrc/runtime/pynative/op_function']
56
- for f in files_to_clear:
57
- real_path = os.path.join(work_path, f)
58
- if os.path.exists(real_path):
59
- shutil.rmtree(real_path)
60
- logging.warning("rm file %s", real_path)
61
-
62
-
63
48
  def gen_pyboost_code(resource_mgr):
64
49
  """ gen_pyboost_code """
65
50
  work_path = K.WORK_DIR
@@ -67,7 +52,6 @@ def gen_pyboost_code(resource_mgr):
67
52
  doc_yaml_data = resource_mgr.get_resource(ResourceType.OP_DOC_YAML)
68
53
  mint_func_protos = resource_mgr.get_resource(ResourceType.MINT_FUNC_PROTOS)
69
54
  alias_func_mapping = resource_mgr.get_resource(ResourceType.ALIAS_API_MAPPING)
70
- clear_old_generated_code(work_path)
71
55
  call_pyboost_inner_prim_generator(work_path, op_protos)
72
56
  call_pyboost_functions_py_generator(work_path, op_protos, doc_yaml_data)
73
57
  call_pyboost_functions_h_generator(work_path, op_protos)
@@ -47,14 +47,15 @@ class OpTemplateParser:
47
47
  self.op_proto = op_proto
48
48
  self.tensor_arg_handler_prt_template = Template(
49
49
  "parse_args.arg_list_[${idx}] = "
50
- "py::cast((*pynative::${func_str}(\"${func_name}\", \"${op_arg_name}\", "
50
+ "PyLong_FromLong((*pynative::${func_str}(\"${func_name}\", \"${op_arg_name}\", "
51
51
  "parse_args.arg_list_[${idx}]))->value());\n"
52
52
  "parse_args.src_types_[${idx}] = ops::OP_DTYPE::DT_BEGIN;\n"
53
53
  "parse_args.dst_types_[${idx}] = ${new_type};\n"
54
54
  )
55
55
  self.function_arg_handler_prt_template = Template(
56
56
  "parse_args.arg_list_[${idx}] = "
57
- "py::cast((*${func_str}(\"${func_name}\", \"${op_arg_name}\", parse_args.arg_list_[${idx}]))->value());\n"
57
+ "PyLong_FromLong((*${func_str}(\"${func_name}\", \"${op_arg_name}\", "
58
+ "parse_args.arg_list_[${idx}]))->value());\n"
58
59
  "parse_args.src_types_[${idx}] = ops::OP_DTYPE::DT_BEGIN;\n"
59
60
  "parse_args.dst_types_[${idx}] = ${new_type};\n"
60
61
  )