mindspore 2.4.1__cp39-cp39-win_amd64.whl → 2.5.0__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (395) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  3. mindspore/Newtonsoft.Json.dll +0 -0
  4. mindspore/__init__.py +8 -3
  5. mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
  6. mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
  7. mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
  8. mindspore/_checkparam.py +0 -5
  9. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
  10. mindspore/_extends/parse/compile_config.py +64 -0
  11. mindspore/_extends/parse/deprecated/__init__.py +0 -0
  12. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +375 -0
  13. mindspore/_extends/parse/parser.py +23 -5
  14. mindspore/_extends/parse/standard_method.py +123 -27
  15. mindspore/_extends/pijit/pijit_func_white_list.py +1 -1
  16. mindspore/amp.py +7 -1
  17. mindspore/atlprov.dll +0 -0
  18. mindspore/avcodec-59.dll +0 -0
  19. mindspore/avdevice-59.dll +0 -0
  20. mindspore/avfilter-8.dll +0 -0
  21. mindspore/avformat-59.dll +0 -0
  22. mindspore/avutil-57.dll +0 -0
  23. mindspore/boost/boost_cell_wrapper.py +136 -41
  24. mindspore/c1.dll +0 -0
  25. mindspore/c1xx.dll +0 -0
  26. mindspore/c2.dll +0 -0
  27. mindspore/common/__init__.py +3 -1
  28. mindspore/common/_register_for_tensor.py +0 -1
  29. mindspore/common/_stub_tensor.py +25 -4
  30. mindspore/common/_tensor_cpp_method.py +17 -0
  31. mindspore/common/_tensor_docs.py +6132 -0
  32. mindspore/common/api.py +99 -25
  33. mindspore/common/dtype.py +34 -34
  34. mindspore/common/dump.py +2 -1
  35. mindspore/common/file_system.py +8 -1
  36. mindspore/common/generator.py +2 -0
  37. mindspore/common/hook_handle.py +3 -1
  38. mindspore/common/initializer.py +3 -4
  39. mindspore/common/lazy_inline.py +8 -2
  40. mindspore/common/mindir_util.py +10 -2
  41. mindspore/common/parameter.py +30 -27
  42. mindspore/common/tensor.py +713 -1337
  43. mindspore/communication/__init__.py +1 -1
  44. mindspore/communication/_comm_helper.py +10 -0
  45. mindspore/communication/comm_func.py +215 -173
  46. mindspore/communication/management.py +23 -20
  47. mindspore/context.py +292 -193
  48. mindspore/dataset/__init__.py +23 -19
  49. mindspore/dataset/callback/ds_callback.py +2 -1
  50. mindspore/dataset/core/config.py +84 -3
  51. mindspore/dataset/engine/cache_admin.py +3 -3
  52. mindspore/dataset/engine/cache_client.py +5 -4
  53. mindspore/dataset/engine/datasets.py +192 -149
  54. mindspore/dataset/engine/datasets_audio.py +14 -0
  55. mindspore/dataset/engine/datasets_standard_format.py +28 -11
  56. mindspore/dataset/engine/datasets_text.py +38 -1
  57. mindspore/dataset/engine/datasets_user_defined.py +125 -65
  58. mindspore/dataset/engine/datasets_vision.py +81 -8
  59. mindspore/dataset/engine/iterators.py +281 -63
  60. mindspore/dataset/engine/obs/util.py +8 -0
  61. mindspore/dataset/engine/queue.py +40 -0
  62. mindspore/dataset/engine/samplers.py +26 -2
  63. mindspore/dataset/engine/serializer_deserializer.py +1 -1
  64. mindspore/dataset/engine/validators.py +43 -11
  65. mindspore/dataset/transforms/py_transforms_util.py +17 -0
  66. mindspore/dataset/transforms/transforms.py +29 -12
  67. mindspore/dataset/vision/validators.py +1 -2
  68. mindspore/device_context/__init__.py +21 -0
  69. mindspore/device_context/ascend/__init__.py +25 -0
  70. mindspore/device_context/ascend/device.py +72 -0
  71. mindspore/device_context/ascend/op_debug.py +94 -0
  72. mindspore/device_context/ascend/op_precision.py +193 -0
  73. mindspore/device_context/ascend/op_tuning.py +127 -0
  74. mindspore/device_context/cpu/__init__.py +25 -0
  75. mindspore/device_context/cpu/device.py +62 -0
  76. mindspore/device_context/cpu/op_tuning.py +43 -0
  77. mindspore/device_context/gpu/__init__.py +21 -0
  78. mindspore/device_context/gpu/device.py +70 -0
  79. mindspore/device_context/gpu/op_precision.py +67 -0
  80. mindspore/device_context/gpu/op_tuning.py +175 -0
  81. mindspore/device_manager.py +134 -0
  82. mindspore/dnnl.dll +0 -0
  83. mindspore/dpcmi.dll +0 -0
  84. mindspore/experimental/llm_boost/__init__.py +3 -2
  85. mindspore/experimental/llm_boost/ascend_native/__init__.py +22 -0
  86. mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +211 -0
  87. mindspore/experimental/llm_boost/ascend_native/llm_boost.py +52 -0
  88. mindspore/experimental/llm_boost/atb/boost_base.py +239 -64
  89. mindspore/experimental/llm_boost/atb/llama_boost.py +52 -30
  90. mindspore/experimental/llm_boost/atb/qwen_boost.py +47 -24
  91. mindspore/experimental/llm_boost/register.py +1 -0
  92. mindspore/experimental/optim/adadelta.py +26 -22
  93. mindspore/experimental/optim/adam.py +3 -0
  94. mindspore/experimental/optim/lr_scheduler.py +33 -24
  95. mindspore/experimental/optim/radam.py +33 -30
  96. mindspore/hal/device.py +28 -0
  97. mindspore/hal/event.py +17 -0
  98. mindspore/hal/memory.py +94 -3
  99. mindspore/hal/stream.py +91 -6
  100. mindspore/include/api/context.h +1 -2
  101. mindspore/include/dataset/constants.h +2 -2
  102. mindspore/jpeg62.dll +0 -0
  103. mindspore/log.py +12 -0
  104. mindspore/mindrecord/__init__.py +1 -1
  105. mindspore/mindrecord/config.py +17 -316
  106. mindspore/mindrecord/filereader.py +1 -9
  107. mindspore/mindrecord/filewriter.py +5 -15
  108. mindspore/mindrecord/mindpage.py +1 -9
  109. mindspore/mindspore_backend.dll +0 -0
  110. mindspore/mindspore_common.dll +0 -0
  111. mindspore/mindspore_core.dll +0 -0
  112. mindspore/mindspore_glog.dll +0 -0
  113. mindspore/mindspore_ops.dll +0 -0
  114. mindspore/mint/__init__.py +824 -218
  115. mindspore/mint/distributed/__init__.py +66 -4
  116. mindspore/mint/distributed/distributed.py +2594 -44
  117. mindspore/mint/linalg/__init__.py +6 -0
  118. mindspore/mint/nn/__init__.py +473 -14
  119. mindspore/mint/nn/functional.py +486 -11
  120. mindspore/mint/nn/layer/__init__.py +17 -4
  121. mindspore/mint/nn/layer/_functions.py +330 -0
  122. mindspore/mint/nn/layer/activation.py +169 -1
  123. mindspore/mint/nn/layer/basic.py +123 -0
  124. mindspore/mint/nn/layer/conv.py +727 -0
  125. mindspore/mint/nn/layer/normalization.py +215 -19
  126. mindspore/mint/nn/layer/padding.py +797 -0
  127. mindspore/mint/nn/layer/pooling.py +170 -0
  128. mindspore/mint/optim/__init__.py +2 -1
  129. mindspore/mint/optim/adam.py +223 -0
  130. mindspore/mint/optim/adamw.py +26 -19
  131. mindspore/mint/special/__init__.py +2 -1
  132. mindspore/msobj140.dll +0 -0
  133. mindspore/mspdb140.dll +0 -0
  134. mindspore/mspdbcore.dll +0 -0
  135. mindspore/mspdbst.dll +0 -0
  136. mindspore/mspft140.dll +0 -0
  137. mindspore/msvcdis140.dll +0 -0
  138. mindspore/msvcp140_1.dll +0 -0
  139. mindspore/msvcp140_2.dll +0 -0
  140. mindspore/msvcp140_atomic_wait.dll +0 -0
  141. mindspore/msvcp140_codecvt_ids.dll +0 -0
  142. mindspore/multiprocessing/__init__.py +5 -0
  143. mindspore/nn/__init__.py +2 -0
  144. mindspore/nn/cell.py +142 -21
  145. mindspore/nn/dynamic_lr.py +2 -1
  146. mindspore/nn/layer/activation.py +6 -6
  147. mindspore/nn/layer/basic.py +35 -25
  148. mindspore/nn/layer/channel_shuffle.py +3 -3
  149. mindspore/nn/layer/conv.py +3 -0
  150. mindspore/nn/layer/embedding.py +3 -3
  151. mindspore/nn/layer/normalization.py +8 -7
  152. mindspore/nn/layer/padding.py +4 -3
  153. mindspore/nn/layer/pooling.py +55 -23
  154. mindspore/nn/layer/rnn_cells.py +1 -1
  155. mindspore/nn/layer/rnns.py +2 -1
  156. mindspore/nn/layer/timedistributed.py +5 -5
  157. mindspore/nn/layer/transformer.py +48 -26
  158. mindspore/nn/learning_rate_schedule.py +5 -3
  159. mindspore/nn/loss/loss.py +31 -36
  160. mindspore/nn/optim/ada_grad.py +1 -0
  161. mindspore/nn/optim/adadelta.py +2 -2
  162. mindspore/nn/optim/adam.py +1 -1
  163. mindspore/nn/optim/lars.py +1 -4
  164. mindspore/nn/optim/optimizer.py +1 -1
  165. mindspore/nn/optim/rprop.py +2 -2
  166. mindspore/nn/optim/thor.py +2 -1
  167. mindspore/nn/utils/__init__.py +22 -0
  168. mindspore/nn/utils/init.py +73 -0
  169. mindspore/nn/wrap/cell_wrapper.py +4 -6
  170. mindspore/nn/wrap/loss_scale.py +3 -4
  171. mindspore/numpy/array_creations.py +60 -62
  172. mindspore/numpy/array_ops.py +148 -143
  173. mindspore/numpy/logic_ops.py +41 -42
  174. mindspore/numpy/math_ops.py +361 -359
  175. mindspore/numpy/utils.py +16 -16
  176. mindspore/numpy/utils_const.py +4 -4
  177. mindspore/opencv_core452.dll +0 -0
  178. mindspore/opencv_imgcodecs452.dll +0 -0
  179. mindspore/opencv_imgproc452.dll +0 -0
  180. mindspore/ops/__init__.py +2 -1
  181. mindspore/ops/_grad_experimental/grad_comm_ops.py +107 -8
  182. mindspore/ops/_grad_experimental/grad_debug_ops.py +6 -1
  183. mindspore/ops/_grad_experimental/grad_inner_ops.py +9 -0
  184. mindspore/ops/_grad_experimental/grad_math_ops.py +2 -1
  185. mindspore/ops/_op_impl/cpu/__init__.py +1 -0
  186. mindspore/ops/_op_impl/cpu/raise_op.py +28 -0
  187. mindspore/ops/_vmap/vmap_array_ops.py +20 -19
  188. mindspore/ops/_vmap/vmap_base.py +0 -2
  189. mindspore/ops/_vmap/vmap_grad_nn_ops.py +19 -13
  190. mindspore/ops/_vmap/vmap_math_ops.py +11 -9
  191. mindspore/ops/_vmap/vmap_nn_ops.py +20 -34
  192. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +149 -12
  193. mindspore/ops/auto_generate/gen_arg_handler.py +0 -61
  194. mindspore/ops/auto_generate/gen_extend_func.py +554 -60
  195. mindspore/ops/auto_generate/gen_ops_def.py +1621 -115
  196. mindspore/ops/auto_generate/gen_ops_prim.py +8027 -3411
  197. mindspore/ops/auto_generate/pyboost_inner_prim.py +183 -79
  198. mindspore/ops/composite/base.py +1 -1
  199. mindspore/ops/composite/multitype_ops/_compile_utils.py +229 -30
  200. mindspore/ops/composite/multitype_ops/pow_impl.py +0 -29
  201. mindspore/ops/function/__init__.py +12 -0
  202. mindspore/ops/function/array_func.py +561 -159
  203. mindspore/ops/function/clip_func.py +64 -0
  204. mindspore/ops/function/debug_func.py +28 -20
  205. mindspore/ops/function/image_func.py +1 -1
  206. mindspore/ops/function/linalg_func.py +5 -4
  207. mindspore/ops/function/math_func.py +1664 -294
  208. mindspore/ops/function/nn_func.py +988 -317
  209. mindspore/ops/function/parameter_func.py +3 -56
  210. mindspore/ops/function/random_func.py +243 -33
  211. mindspore/ops/function/sparse_unary_func.py +1 -1
  212. mindspore/ops/functional.py +18 -5
  213. mindspore/ops/functional_overload.py +897 -0
  214. mindspore/ops/operations/__init__.py +3 -2
  215. mindspore/ops/operations/_embedding_cache_ops.py +4 -4
  216. mindspore/ops/operations/_grad_ops.py +2 -34
  217. mindspore/ops/operations/_infer_ops.py +2 -1
  218. mindspore/ops/operations/_inner_ops.py +38 -8
  219. mindspore/ops/operations/array_ops.py +45 -303
  220. mindspore/ops/operations/comm_ops.py +23 -17
  221. mindspore/ops/operations/custom_ops.py +7 -49
  222. mindspore/ops/operations/debug_ops.py +42 -47
  223. mindspore/ops/operations/inner_ops.py +6 -4
  224. mindspore/ops/operations/linalg_ops.py +3 -2
  225. mindspore/ops/operations/manually_defined/ops_def.py +185 -104
  226. mindspore/ops/operations/math_ops.py +11 -216
  227. mindspore/ops/operations/nn_ops.py +153 -310
  228. mindspore/ops/primitive.py +23 -21
  229. mindspore/ops/tensor_method.py +1669 -0
  230. mindspore/ops_generate/aclnn_kernel_register_auto_cc_generator.py +110 -0
  231. mindspore/ops_generate/add_tensor_docs_generator.py +54 -0
  232. mindspore/ops_generate/arg_handler.py +0 -61
  233. mindspore/ops_generate/auto_grad_impl_cc_generator.py +135 -0
  234. mindspore/ops_generate/auto_grad_reg_cc_generator.py +93 -0
  235. mindspore/ops_generate/base_generator.py +11 -0
  236. mindspore/ops_generate/cpp_create_prim_instance_helper_generator.py +108 -0
  237. mindspore/ops_generate/functional_map_cpp_generator.py +491 -0
  238. mindspore/ops_generate/functional_overload_py_generator.py +110 -0
  239. mindspore/ops_generate/functions_cc_generator.py +233 -0
  240. mindspore/ops_generate/gen_aclnn_implement.py +110 -114
  241. mindspore/ops_generate/gen_constants.py +157 -3
  242. mindspore/ops_generate/gen_ops.py +245 -990
  243. mindspore/ops_generate/gen_pyboost_func.py +97 -998
  244. mindspore/ops_generate/gen_utils.py +119 -33
  245. mindspore/ops_generate/lite_ops_cpp_generator.py +155 -0
  246. mindspore/ops_generate/op_api_proto.py +206 -0
  247. mindspore/ops_generate/op_def_py_generator.py +131 -0
  248. mindspore/ops_generate/op_prim_py_generator.py +480 -0
  249. mindspore/ops_generate/op_proto.py +373 -108
  250. mindspore/ops_generate/op_template_parser.py +436 -0
  251. mindspore/ops_generate/ops_def_cc_generator.py +288 -0
  252. mindspore/ops_generate/ops_def_h_generator.py +74 -0
  253. mindspore/ops_generate/ops_name_h_generator.py +68 -0
  254. mindspore/ops_generate/ops_primitive_h_generator.py +81 -0
  255. mindspore/ops_generate/pyboost_functions_cpp_generator.py +370 -0
  256. mindspore/ops_generate/pyboost_functions_h_generator.py +68 -0
  257. mindspore/ops_generate/pyboost_functions_py_generator.py +148 -0
  258. mindspore/ops_generate/pyboost_grad_function_cpp_generator.py +154 -0
  259. mindspore/ops_generate/pyboost_inner_prim_generator.py +131 -0
  260. mindspore/ops_generate/pyboost_native_grad_functions_generator.py +268 -0
  261. mindspore/ops_generate/pyboost_op_cpp_code_generator.py +851 -0
  262. mindspore/ops_generate/pyboost_overload_functions_cpp_generator.py +344 -0
  263. mindspore/ops_generate/pyboost_utils.py +92 -33
  264. mindspore/ops_generate/template.py +294 -44
  265. mindspore/ops_generate/tensor_func_reg_cpp_generator.py +422 -0
  266. mindspore/parallel/__init__.py +3 -3
  267. mindspore/parallel/_auto_parallel_context.py +44 -34
  268. mindspore/parallel/_cell_wrapper.py +22 -3
  269. mindspore/parallel/_parallel_serialization.py +13 -2
  270. mindspore/parallel/_utils.py +4 -2
  271. mindspore/parallel/algo_parameter_config.py +1 -1
  272. mindspore/parallel/checkpoint_transform.py +44 -0
  273. mindspore/parallel/cluster/process_entity/_api.py +131 -37
  274. mindspore/parallel/cluster/process_entity/_utils.py +41 -6
  275. mindspore/parallel/cluster/run.py +20 -3
  276. mindspore/parallel/parameter_broadcast.py +1 -1
  277. mindspore/parallel/shard.py +3 -0
  278. mindspore/parallel/transform_safetensors.py +119 -253
  279. mindspore/pgodb140.dll +0 -0
  280. mindspore/pgort140.dll +0 -0
  281. mindspore/profiler/__init__.py +17 -4
  282. mindspore/profiler/analysis/__init__.py +0 -0
  283. mindspore/profiler/analysis/parser/__init__.py +0 -0
  284. mindspore/profiler/analysis/parser/ascend_cann_parser.py +166 -0
  285. mindspore/profiler/analysis/parser/base_parser.py +158 -0
  286. mindspore/profiler/analysis/parser/framework_cann_relation_parser.py +45 -0
  287. mindspore/profiler/analysis/parser/ms_framework_parser.py +142 -0
  288. mindspore/profiler/analysis/parser/ms_minddata_parser.py +145 -0
  289. mindspore/profiler/analysis/parser/timeline_assembly_factory/__init__.py +0 -0
  290. mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +261 -0
  291. mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +40 -0
  292. mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +84 -0
  293. mindspore/profiler/analysis/parser/timeline_creator/__init__.py +0 -0
  294. mindspore/profiler/analysis/parser/timeline_creator/base_timeline_creator.py +44 -0
  295. mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +90 -0
  296. mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +76 -0
  297. mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +103 -0
  298. mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +134 -0
  299. mindspore/profiler/analysis/parser/timeline_event/__init__.py +0 -0
  300. mindspore/profiler/analysis/parser/timeline_event/base_event.py +233 -0
  301. mindspore/profiler/analysis/parser/timeline_event/cpu_op_event.py +47 -0
  302. mindspore/profiler/analysis/parser/timeline_event/flow_event.py +36 -0
  303. mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +260 -0
  304. mindspore/profiler/analysis/parser/timeline_event/msprof_event.py +73 -0
  305. mindspore/profiler/analysis/parser/timeline_event/scope_layer_event.py +53 -0
  306. mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +146 -0
  307. mindspore/profiler/analysis/task_manager.py +131 -0
  308. mindspore/profiler/analysis/time_converter.py +84 -0
  309. mindspore/profiler/analysis/viewer/__init__.py +0 -0
  310. mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +333 -0
  311. mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +87 -0
  312. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +252 -0
  313. mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +313 -0
  314. mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +322 -0
  315. mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +265 -0
  316. mindspore/profiler/analysis/viewer/ascend_timeline_viewer.py +58 -0
  317. mindspore/profiler/analysis/viewer/base_viewer.py +26 -0
  318. mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +97 -0
  319. mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +581 -0
  320. mindspore/profiler/analysis/work_flow.py +73 -0
  321. mindspore/profiler/common/ascend_msprof_exporter.py +138 -0
  322. mindspore/profiler/common/command_executor.py +90 -0
  323. mindspore/profiler/common/constant.py +174 -3
  324. mindspore/profiler/common/file_manager.py +208 -0
  325. mindspore/profiler/common/log.py +130 -0
  326. mindspore/profiler/common/msprof_cmd_tool.py +202 -0
  327. mindspore/profiler/common/path_manager.py +371 -0
  328. mindspore/profiler/common/process_bar.py +168 -0
  329. mindspore/profiler/common/process_pool.py +9 -3
  330. mindspore/profiler/common/profiler_context.py +476 -0
  331. mindspore/profiler/common/profiler_info.py +304 -0
  332. mindspore/profiler/common/profiler_output_path.py +284 -0
  333. mindspore/profiler/common/profiler_parameters.py +210 -0
  334. mindspore/profiler/common/profiler_path_manager.py +120 -0
  335. mindspore/profiler/common/record_function.py +76 -0
  336. mindspore/profiler/common/tlv_decoder.py +76 -0
  337. mindspore/profiler/common/util.py +75 -2
  338. mindspore/profiler/dynamic_profiler.py +270 -37
  339. mindspore/profiler/envprofiler.py +138 -0
  340. mindspore/profiler/mstx.py +199 -0
  341. mindspore/profiler/platform/__init__.py +21 -0
  342. mindspore/profiler/platform/base_profiler.py +40 -0
  343. mindspore/profiler/platform/cpu_profiler.py +124 -0
  344. mindspore/profiler/platform/gpu_profiler.py +74 -0
  345. mindspore/profiler/platform/npu_profiler.py +309 -0
  346. mindspore/profiler/profiler.py +580 -93
  347. mindspore/profiler/profiler_action_controller.py +187 -0
  348. mindspore/profiler/profiler_interface.py +114 -0
  349. mindspore/profiler/schedule.py +208 -0
  350. mindspore/rewrite/api/symbol_tree.py +1 -2
  351. mindspore/run_check/_check_version.py +18 -13
  352. mindspore/runtime/__init__.py +37 -0
  353. mindspore/runtime/device.py +27 -0
  354. mindspore/runtime/event.py +209 -0
  355. mindspore/runtime/executor.py +148 -0
  356. mindspore/runtime/memory.py +392 -0
  357. mindspore/runtime/stream.py +460 -0
  358. mindspore/runtime/thread_bind_core.py +401 -0
  359. mindspore/swresample-4.dll +0 -0
  360. mindspore/swscale-6.dll +0 -0
  361. mindspore/tbbmalloc.dll +0 -0
  362. mindspore/tinyxml2.dll +0 -0
  363. mindspore/train/__init__.py +2 -2
  364. mindspore/train/_utils.py +53 -18
  365. mindspore/train/amp.py +8 -4
  366. mindspore/train/callback/_checkpoint.py +32 -18
  367. mindspore/train/callback/_early_stop.py +1 -1
  368. mindspore/train/callback/_flops_collector.py +105 -69
  369. mindspore/train/callback/_history.py +1 -1
  370. mindspore/train/callback/_summary_collector.py +44 -6
  371. mindspore/train/callback/_tft_register.py +37 -15
  372. mindspore/train/dataset_helper.py +11 -11
  373. mindspore/train/metrics/precision.py +4 -5
  374. mindspore/train/mind_ir_pb2.py +167 -46
  375. mindspore/train/model.py +13 -14
  376. mindspore/train/serialization.py +461 -72
  377. mindspore/train/summary/summary_record.py +1 -2
  378. mindspore/train/train_thor/model_thor.py +1 -1
  379. mindspore/turbojpeg.dll +0 -0
  380. mindspore/utils/__init__.py +4 -2
  381. mindspore/utils/dryrun.py +138 -0
  382. mindspore/utils/runtime_execution_order_check.py +550 -0
  383. mindspore/vcmeta.dll +0 -0
  384. mindspore/vcruntime140.dll +0 -0
  385. mindspore/vcruntime140_1.dll +0 -0
  386. mindspore/version.py +1 -1
  387. {mindspore-2.4.1.dist-info → mindspore-2.5.0.dist-info}/METADATA +3 -4
  388. {mindspore-2.4.1.dist-info → mindspore-2.5.0.dist-info}/RECORD +391 -265
  389. {mindspore-2.4.1.dist-info → mindspore-2.5.0.dist-info}/entry_points.txt +1 -1
  390. mindspore/common/_tensor_overload.py +0 -139
  391. mindspore/mindspore_np_dtype.dll +0 -0
  392. mindspore/profiler/envprofiling.py +0 -254
  393. mindspore/profiler/profiling.py +0 -1926
  394. {mindspore-2.4.1.dist-info → mindspore-2.5.0.dist-info}/WHEEL +0 -0
  395. {mindspore-2.4.1.dist-info → mindspore-2.5.0.dist-info}/top_level.txt +0 -0
@@ -15,1006 +15,127 @@
15
15
  """
16
16
  Generate operator definition from ops.yaml
17
17
  """
18
+ import argparse
19
+ import copy
20
+ import logging
18
21
  import os
19
- import re
20
22
  import shutil
21
23
  import pathlib
22
- import logging
23
- import gen_utils
24
- from gen_utils import (py_licence_str, cc_license_str, check_change_and_replace_file, merge_files,
25
- merge_files_append, safe_load_yaml, convert_dtype_str, write_file)
26
- from pyboost_utils import get_pyboost_name, is_pyboost_enable, AclnnUtils, get_dtypes
27
- import template
28
- from template import CppTemplate
24
+ from gen_utils import (check_change_and_replace_file, merge_files,
25
+ merge_files_append, safe_load_yaml)
26
+ from op_prim_py_generator import OpPrimPyGenerator
27
+ from op_def_py_generator import OpDefPyGenerator
28
+ from aclnn_kernel_register_auto_cc_generator import AclnnKernelRegisterAutoCcGenerator
29
+ from cpp_create_prim_instance_helper_generator import CppCreatePrimInstanceHelperGenerator
30
+ from ops_def_cc_generator import OpsDefCcGenerator
31
+ from ops_def_h_generator import OpsDefHGenerator
32
+ from ops_primitive_h_generator import OpsPrimitiveHGenerator
33
+ from lite_ops_cpp_generator import LiteOpsCcGenerator, LiteOpsHGenerator
34
+ from ops_name_h_generator import OpsNameHGenerator
35
+ from functional_map_cpp_generator import FunctionalMapCppGenerator
36
+ from add_tensor_docs_generator import AddTensorDocsGenerator
37
+ from functional_overload_py_generator import FunctionalOverloadPyGenerator
38
+
39
+ from op_proto import OpProto
40
+ from op_api_proto import load_api_protos_from_yaml
41
+ from tensor_func_reg_cpp_generator import TensorFuncRegCppGenerator
29
42
  from gen_pyboost_func import gen_pyboost_code
30
- from gen_aclnn_implement import gen_aclnn_kernel
31
- import gen_constants as K
32
-
33
-
34
- def _get_op_name(yaml_key, yaml_value):
35
- """
36
- Get op name for python class Primitive or c++ OpDef name.
37
- """
38
- # If has class item, use the specified item.
39
- class_def = yaml_value.get("class")
40
- if class_def is not None:
41
- class_name_specify = class_def.get("name")
42
- if class_name_specify is not None:
43
- return class_name_specify
44
- # Else use the default rule generate class name.
45
- op_name = yaml_key
46
- class_name_normal = ''.join(word.capitalize() for word in op_name.split('_'))
47
- return class_name_normal
48
-
49
-
50
- def _get_op_func_name(yaml_key, yaml_value):
51
- func_def = yaml_value.get('function')
52
- func_name = yaml_key
53
-
54
- if func_def is not None:
55
- item = func_def.get("name")
56
- if item is not None:
57
- func_name = item
58
- return func_name
59
-
60
-
61
- def _auto_generate_class_disabled(yaml_value):
62
- """Check whether class can be auto generated."""
63
- if 'class' not in yaml_value.keys():
64
- return False
65
- class_def = yaml_value.get("class")
66
- if 'disable' not in class_def.keys():
67
- return False
68
- disable_item = class_def.get("disable")
69
- if disable_item is True:
70
- return True
71
- if disable_item is False:
72
- return False
73
- raise TypeError(f"The disable label for class should be True or False, but get {disable_item}.")
74
-
75
-
76
- def _auto_generate_func_disabled(yaml_value):
77
- """Check whether function can be auto generated."""
78
- if 'function' not in yaml_value.keys():
79
- return False
80
- func_def = yaml_value.get('function')
81
- if 'disable' not in func_def.keys():
82
- return False
83
- disable_item = func_def.get("disable")
84
- if disable_item is True:
85
- return True
86
- if disable_item is False:
87
- return False
88
- raise TypeError(f"The disable label for function should be True or False, but get {disable_item}.")
89
-
90
-
91
- def signature_get_rw_label(arg_name, write_list, read_list, ref_list):
92
- """
93
- Generate signature rw code
94
- """
95
- for rw_arg_name in write_list:
96
- if rw_arg_name == arg_name:
97
- return ', sig.sig_rw.RW_WRITE'
98
- for read_arg_name in read_list:
99
- if read_arg_name == arg_name:
100
- return ', sig.sig_rw.RW_READ'
101
- for ref_arg_name in ref_list:
102
- if ref_arg_name == arg_name:
103
- return ', sig.sig_rw.RW_REF'
104
- return ''
105
-
106
-
107
- def signature_get_rw_label_cc(rw_op_name, write_list, read_list, ref_list):
108
- """
109
- Generate cc signature rw code
110
- """
111
- rw_label = 'kRWDefault'
112
- for op in write_list:
113
- if op == rw_op_name:
114
- rw_label = 'kRWWrite'
115
- for op in read_list:
116
- if op == rw_op_name:
117
- rw_label = 'kRWRead'
118
- for op in ref_list:
119
- if op == rw_op_name:
120
- rw_label = 'kRWRef'
121
- return 'SignatureEnumRW::' + rw_label
122
-
123
-
124
- def signature_get_enum_dtype_cc(index):
125
- """
126
- Generate cc enum dtype code
127
- """
128
- enum_type = 'SignatureEnumDType::'
129
- type_map = {0: 'kDType',
130
- 1: 'kDType1',
131
- 2: 'kDType2',
132
- 3: 'kDType3',
133
- 4: 'kDType4',
134
- 5: 'kDType5',
135
- 6: 'kDType6',
136
- 7: 'kDType7',
137
- 8: 'kDType8',
138
- 9: 'kDType9'}
139
- if index in type_map:
140
- return enum_type + type_map[index]
141
- return enum_type + 'kDTypeEmptyDefaultValue'
142
-
143
-
144
- def signature_get_dtype_label(index):
145
- """
146
- Generate signature dtype code
147
- """
148
- dtype_index = ''
149
- if index > 0:
150
- dtype_index = f"""{index}"""
151
- return f"""dtype=sig.sig_dtype.T{dtype_index}"""
152
-
153
-
154
- def get_same_dtype_groups(args_signature, args_name):
155
- """
156
- Get same dtype groups
157
- """
158
- same_dtype_groups = {}
159
- dtype_conut = 0
160
- if args_signature is None:
161
- return same_dtype_groups, dtype_conut
162
-
163
- dtype_group = args_signature.get('dtype_group')
164
- if dtype_group is not None:
165
- args_list = []
166
- match = re.findall(r'\((.*?)\)', dtype_group)
167
- for item in match:
168
- args_list.append(item.replace(' ', '').split(","))
169
- for arg_name in args_name:
170
- if arg_name in same_dtype_groups:
171
- continue
172
- is_match = False
173
- for group in args_list:
174
- if arg_name in group:
175
- is_match = True
176
- for item in group:
177
- same_dtype_groups[item] = dtype_conut
178
- break
179
- if not is_match:
180
- same_dtype_groups[arg_name] = dtype_conut
181
- dtype_conut = dtype_conut + 1
182
- return same_dtype_groups, dtype_conut
183
-
184
-
185
- def generate_py_op_signature(op_name, args_signature, args_name, args_default):
186
- """
187
- Generate __mindspore_signature__
188
- """
189
-
190
- def _check_signature_arg_valid(op_name, sig_arg_names, args_names):
191
- for sig_arg_name in sig_arg_names:
192
- if sig_arg_name not in args_names:
193
- raise ValueError(f"Op {op_name} has no input arg named '{sig_arg_name}'!")
194
-
195
- if args_signature is None and not args_default:
196
- return ''
197
-
198
- signature_code = f""" __mindspore_signature__ = """
199
-
200
- # Init rw.
201
- write_list = []
202
- read_list = []
203
- ref_list = []
204
- if args_signature is not None:
205
- rw_write = args_signature.get('rw_write')
206
- rw_read = args_signature.get('rw_read')
207
- rw_ref = args_signature.get('rw_ref')
208
- if rw_write is not None:
209
- write_list = rw_write.replace(' ', '').split(",")
210
- _check_signature_arg_valid(op_name, write_list, args_name)
211
- if rw_read is not None:
212
- read_list = rw_read.replace(' ', '').split(",")
213
- _check_signature_arg_valid(op_name, read_list, args_name)
214
- if rw_ref is not None:
215
- ref_list = rw_ref.replace(' ', '').split(",")
216
- _check_signature_arg_valid(op_name, ref_list, args_name)
217
- # Init dtype group.
218
- same_dtype_groups, dtype_conut = get_same_dtype_groups(args_signature, args_name)
219
- _check_signature_arg_valid(op_name, list(same_dtype_groups.keys()), args_name)
220
- # Only one dtype_group is set.
221
- if dtype_conut == 1 and not any([write_list, read_list, ref_list, args_default]):
222
- signature_code += '('
223
- for _ in range(len(args_name) - 1):
224
- signature_code += 'sig.sig_dtype.T, '
225
- signature_code += 'sig.sig_dtype.T)\n\n'
226
- return signature_code
227
-
228
- # Set sig.make_sig.
229
- signature_code += f""" (\n"""
230
- for arg_name in args_name:
231
- signature_code += f""" sig.make_sig('{arg_name}'"""
232
- signature_code += signature_get_rw_label(arg_name, write_list, read_list, ref_list)
233
- if arg_name in same_dtype_groups:
234
- signature_code += f""", """ + signature_get_dtype_label(same_dtype_groups[arg_name])
235
- if arg_name in args_default:
236
- signature_code += f""", default=""" + str(args_default[arg_name])
237
- signature_code += f"""),\n"""
238
- signature_code += f""" )\n\n"""
239
- return signature_code
240
-
241
-
242
- def generate_cc_op_signature(args_signature, args_name):
243
- """
244
- generate signatures on in cc file
245
- :param args_signature:
246
- :param args_name:
247
- :return:
248
- """
249
- if args_signature is None:
250
- return ''
251
- signature_code = ''
252
- # Init rw.
253
- write_list = []
254
- read_list = []
255
- ref_list = []
256
- if args_signature is not None:
257
- rw_write = args_signature.get('rw_write')
258
- rw_read = args_signature.get('rw_read')
259
- rw_ref = args_signature.get('rw_ref')
260
- if rw_write is not None:
261
- write_list = rw_write.replace(' ', '').split(",")
262
- if rw_read is not None:
263
- read_list = rw_read.replace(' ', '').split(",")
264
- if rw_ref is not None:
265
- ref_list = rw_ref.replace(' ', '').split(",")
266
- # Init dtype group.
267
- same_dtype_groups, _ = get_same_dtype_groups(args_signature, args_name)
268
- for arg_name in args_name:
269
- enum_rw = signature_get_rw_label_cc(arg_name, write_list, read_list, ref_list)
270
- enum_dtype = signature_get_enum_dtype_cc(same_dtype_groups.get(arg_name))
271
- signature = f"""Signature("{arg_name}", {enum_rw}, \
272
- SignatureEnumKind::kKindPositionalKeyword, nullptr, {enum_dtype}),\n """
273
- signature_code += signature
274
- return signature_code
275
-
276
-
277
- def generate_py_op_deprecated(deprecated):
278
- """
279
- Generate @deprecated
280
- """
281
- if deprecated is None:
282
- return ''
283
- version = deprecated.get("version")
284
- if version is None:
285
- raise ValueError("The version of deprecated can't be None.")
286
- substitute = deprecated.get("substitute")
287
- if substitute is None:
288
- raise ValueError("The substitute of deprecated can't be None.")
289
- use_substitute = deprecated.get("use_substitute")
290
- if use_substitute is None:
291
- raise ValueError("The use_substitute of deprecated can't be None.")
292
- if use_substitute is not True and use_substitute is not False:
293
- raise ValueError(f"The use_substitute must be True or False, but got {use_substitute}")
294
-
295
- deprecated = f""" @deprecated("{version}", "{substitute}", {use_substitute})\n"""
296
- return deprecated
297
-
298
-
299
- def _normalize_func_description_fromat(description):
300
- """
301
- Process description.
302
- """
303
- if not description:
304
- return description
305
- lines = description.split("\n")
306
- if len(lines) == 1:
307
- return description
308
- # Add line indentation to other lines after the first line
309
- for i in range(1, len(lines)):
310
- indent = " " if lines[i] else ""
311
- lines[i] = indent + lines[i]
312
- # Remove trailing blank lines
313
- lines = lines if lines[-1] != "" else lines[:-1]
314
- description = "\n".join(lines)
315
- return description
316
-
317
-
318
- def _get_op_description(operator_name, doc_str):
319
- """
320
- Generate ops api description.
321
- """
322
- if doc_str is None:
323
- print(f"Description is None, op_name: {operator_name}")
324
- return ""
325
- description = doc_str.get(operator_name)
326
- if description is None:
327
- print(f"Description is None, op_name: {operator_name}")
328
- return ""
329
- description = description.get("description")
330
- if description is None:
331
- print(f"Description is None, op_name: {operator_name}")
332
- return ""
333
- return _normalize_func_description_fromat(description)
334
-
335
-
336
- def generate_py_op_func(yaml_data, doc_data):
337
- """
338
- Generate operator python function api.
339
- """
340
- gen_py = ''
341
-
342
- for operator_name, operator_data in yaml_data.items():
343
- if _auto_generate_func_disabled(operator_data):
344
- continue
345
- func_name = _get_op_func_name(operator_name, operator_data)
346
- args = operator_data.get('args')
347
- class_name = _get_op_name(operator_name, operator_data)
348
- func_args = []
349
- prim_init_args = []
350
- prim_call_args = []
351
- for arg_name, arg_info in args.items():
352
- is_prim_init = arg_info.get('prim_init')
353
- has_default = 'default' in arg_info.keys()
354
-
355
- # step1: Process function args.
356
- if not has_default:
357
- func_args.append(f"""{arg_name}""")
358
- else:
359
- default_value = arg_info.get('default')
360
- func_args.append(f"""{arg_name}={default_value}""")
361
-
362
- # step2: Process primitive object init args.
363
- if is_prim_init:
364
- prim_init_args.append(arg_name)
365
-
366
- # step3: Process primitive object call args.
367
- else:
368
- prim_call_args.append(arg_name)
369
- description = _get_op_description(operator_name, doc_data)
370
- function_code = f"""\n
371
- def {func_name}({', '.join(arg for arg in func_args)}):
372
- r\"\"\"
373
- {description}
374
- \"\"\"
375
- {operator_name}_op = _get_cache_prim({class_name})({', '.join(arg_name for arg_name in prim_init_args)})
376
- return {operator_name}_op({', '.join(arg_name for arg_name in prim_call_args)})\n"""
377
-
378
- if not prim_init_args:
379
- if _auto_generate_class_disabled(operator_data):
380
- gen_py += f"""\n{operator_name}_op={class_name}()"""
381
- function_code = f"""\n
382
- def {func_name}({', '.join(arg for arg in func_args)}):
383
- r\"\"\"
384
- {description}
385
- \"\"\"
386
- return {operator_name}_op({', '.join(arg_name for arg_name in prim_call_args)})\n"""
387
- else:
388
- dis = operator_data.get("dispatch")
389
- if dis is not None:
390
- enable_pyboost = dis.get("enable")
391
- if enable_pyboost:
392
- function_code = f"""\n
393
- def {func_name}({', '.join(arg for arg in func_args)}):
394
- r\"\"\"
395
- {description}
396
- \"\"\"
397
- return {operator_name}_impl({', '.join(arg_name for arg_name, _ in args.items())})\n"""
398
- gen_py += function_code
399
-
400
- return gen_py
401
-
402
-
403
- def get_dtype(arg_info):
404
- dtype = arg_info.get('dtype')
405
- # Currently, TypeId is represented by int
406
- if dtype == 'TypeId':
407
- dtype = 'int'
408
- return dtype
409
-
410
-
411
- def process_args(class_name, args):
412
- """
413
- Process arg for yaml, get arg_name, init value, type cast, arg_handler, etc.
414
- """
415
- inputs_name = []
416
- args_name = []
417
- args_assign = []
418
- inputs_default = {}
419
- init_args_with_default = []
420
- args_handlers = {}
421
- for arg_name, arg_info in args.items():
422
- dtype = get_dtype(arg_info)
423
- default_value = arg_info.get('default')
424
- has_default = 'default' in arg_info.keys()
425
- is_prim_init = arg_info.get('prim_init')
426
- arg_handler = arg_info.get('arg_handler')
427
-
428
- # step1: get args infos:
429
- if is_prim_init:
430
- # step1.1: get args name:
431
- args_name.append(arg_name)
432
- # step1.2: get args assign with default value:
433
- if has_default:
434
- init_args_with_default.append(f"""{arg_name}={default_value}""")
435
- else:
436
- init_args_with_default.append(f"""{arg_name}""")
437
-
438
- # step1.3: get args set prim arg expression:
439
- assign_str = gen_utils.get_assign_str_by_type_it(class_name, arg_info, arg_name, dtype)
440
- if arg_handler:
441
- assign_str = f""" self._set_prim_arg_with_handler("{arg_name}", {assign_str}, {arg_handler})"""
442
- else:
443
- assign_str = f""" self._set_prim_arg("{arg_name}", {assign_str})"""
444
- args_assign.append(assign_str)
445
- # step2: get inputs infos:
446
- else:
447
- # step2.1: get inputs name:
448
- inputs_name.append(arg_name)
449
-
450
- # step2.2: get default value of inputs:
451
- if has_default:
452
- inputs_default[arg_name] = default_value
453
-
454
- # step2.3: get args_handler functions for inputs
455
- if arg_handler:
456
- args_handlers[arg_name] = arg_handler
457
-
458
- return inputs_name, inputs_default, args_name, args_assign, init_args_with_default, args_handlers
459
-
460
-
461
- def generate_pyboost_import_header(yaml_data):
462
- """
463
- Generate python primitive
464
- """
465
- pyboost_import_header = ''
466
- import_pyboost = CppTemplate("from mindspore._c_expression import $var\n")
467
- for operator_name, operator_data in yaml_data.items():
468
- is_pyboost = is_pyboost_enable(operator_data)
469
- if is_pyboost:
470
- header = import_pyboost.replace(var=get_pyboost_name(operator_name))
471
- pyboost_import_header += header
472
- return pyboost_import_header
473
-
474
-
475
- def _generate_class_description(class_name, func_name, input_args, init_args, func_disabled, doc_str):
476
- """Generate description for every primitive definition."""
477
- if func_disabled:
478
- # if function disabled, function name is equal to operator_name
479
- description = _get_op_description(func_name, doc_str)
480
- description = f""" r\"\"\"
481
- {description}
482
- \"\"\"
483
- """
484
- return description
485
-
486
- # If function is an released API, refer to the function doc.
487
- description_str = f""" r\"\"\"
488
- .. code-block::
489
-
490
- prim = ops.{class_name}({', '.join(init_args)})
491
- out = prim({', '.join(input_args)})
492
-
493
- is equivalent to
494
-
495
- .. code-block::
496
-
497
- ops.{func_name}({", ".join(input_args + init_args)})
498
-
499
- Refer to :func:`mindspore.ops.{func_name}` for more details.
500
- \"\"\"
501
- """
502
- return description_str
503
-
504
-
505
- def get_init_code(init_code, operator_data):
506
- """
507
- Generate init code for primitive
508
- """
509
- labels = operator_data.get('labels')
510
- if labels is not None:
511
- if init_code != "":
512
- init_code += "\n"
513
- init_code += \
514
- '\n'.join([f""" self.add_prim_attr("{key}", {value})""" for key, value in labels.items()])
515
- if init_code == "":
516
- init_code = f""" pass"""
517
- return init_code
518
-
519
-
520
- def generate_py_primitive(yaml_data, doc_str):
521
- """
522
- Generate python primitive
523
- """
524
-
525
- def _generate_arg_handler(class_name, arg, arg_handler, is_optional):
526
- """Generate arg_handler"""
527
- arg_handler_call = f"""{arg_handler}('{class_name}', '{arg}', {arg})"""
528
- if is_optional:
529
- arg_handler_call = f"""{arg} if {arg} is None else {arg_handler_call}"""
530
- return arg_handler_call
531
-
532
- gen_py = ''
533
- for operator_name, operator_data in yaml_data.items():
534
- if _auto_generate_class_disabled(operator_data):
535
- continue
536
- class_name = _get_op_name(operator_name, operator_data)
537
- func_name = _get_op_func_name(operator_name, operator_data)
538
- pyboost_func_name = get_pyboost_name(operator_name)
539
- args = operator_data.get('args')
540
- inputs_args, inputs_default, init_args, args_assign, init_args_with_default, args_handlers = \
541
- process_args(class_name, args)
542
- init_code = '\n'.join(args_assign)
543
- signature_code = generate_py_op_signature(class_name, operator_data.get('args_signature'), inputs_args,
544
- inputs_default)
545
- deprecated_code = generate_py_op_deprecated(operator_data.get('deprecated'))
546
- init_code = get_init_code(init_code, operator_data)
547
- primitive_code = f"""\n
548
- class {class_name}(Primitive):\n"""
549
- func_disabled = _auto_generate_func_disabled(operator_data)
550
- primitive_code += _generate_class_description(class_name, func_name, inputs_args, init_args, func_disabled,
551
- doc_str)
552
- if signature_code != "":
553
- primitive_code += signature_code
554
- if deprecated_code != "":
555
- primitive_code += deprecated_code
556
- primitive_code += f""" @prim_arg_register
557
- def __init__(self"""
558
- if init_args_with_default:
559
- primitive_code += ", " + f"""{', '.join(init_args_with_default) if init_args_with_default else ''}"""
560
- call_args = []
561
- for name in inputs_args:
562
- call_args.append(f"""{name}={inputs_default[name]}""" if name in inputs_default else name)
563
- primitive_code += f"""):
564
- {init_code}
565
-
566
- def __call__(self, {', '.join(call_args)}):"""
567
- is_pyboost = is_pyboost_enable(operator_data)
568
- if is_pyboost:
569
- primitive_code += f"""
570
- return _convert_stub({pyboost_func_name}(self, ["""
571
- else:
572
- primitive_code += f"""
573
- return super().__call__("""
574
- if inputs_args:
575
- args_with_handler = []
576
- for arg in inputs_args:
577
- if arg in args_handlers:
578
- is_optional = inputs_default.get(arg) == "None"
579
- args_with_handler.append(_generate_arg_handler(class_name, arg, args_handlers[arg], is_optional))
580
- else:
581
- args_with_handler.append(arg)
582
- primitive_code += ', '.join(args_with_handler)
583
-
584
- if init_args:
585
- primitive_code += ', '
586
- primitive_code += ', '.join([f'self.{arg}' for arg in init_args])
587
- if is_pyboost:
588
- primitive_code += """]))"""
589
- else:
590
- primitive_code += """)
591
- """
592
-
593
- gen_py += primitive_code
594
- if not init_args:
595
- prim_op_object = f"""\n
596
- {operator_name}_op={class_name}()
597
- """
598
- gen_py += prim_op_object
599
- return gen_py
600
43
 
44
+ import gen_constants as K
601
45
 
602
- def generate_op_name_opdef(yaml_data):
603
- """
604
- Generate op name
605
- """
606
- op_name_head = f"""
607
- #ifndef MINDSPORE_CORE_OP_NAME_H_
608
- #define MINDSPORE_CORE_OP_NAME_H_
609
46
 
610
- namespace mindspore::ops {{
611
- """
47
+ def generate_ops_prim_file(work_path, op_protos, doc_dict, file_pre):
48
+ generator = OpPrimPyGenerator()
49
+ generator.generate(work_path, op_protos, doc_dict, file_pre)
612
50
 
613
- op_name_end = f"""}} // namespace mindspore::ops
614
51
 
615
- #endif // MINDSPORE_CORE_OP_NAME_H_
616
- """
52
+ def generate_ops_def_file(work_path, os_protos, doc_dict, file_pre):
53
+ generator = OpDefPyGenerator()
54
+ generator.generate(work_path, os_protos, doc_dict, file_pre)
617
55
 
618
- op_name_gen = ''
619
- op_name_gen += op_name_head
620
- for operator_name, operator_data in yaml_data.items():
621
- k_name_op = _get_op_name(operator_name, operator_data)
622
- op_name_gen += f"""constexpr auto kName{k_name_op} = "{k_name_op}";
623
- """
624
56
 
625
- op_name_gen += op_name_end
626
- return op_name_gen
627
-
628
-
629
- def generate_op_prim_opdef(yaml_data):
57
+ def generate_ops_py_files(work_path, op_protos, doc_dict, file_pre):
630
58
  """
631
- Generate primitive c++ definition
59
+ Generate ops python file from yaml.
632
60
  """
633
- ops_prim_head = f"""
634
- #ifndef MINDSPORE_CORE_OPS_GEN_OPS_PRIMITIVE_H_
635
- #define MINDSPORE_CORE_OPS_GEN_OPS_PRIMITIVE_H_
61
+ generate_ops_prim_file(work_path, op_protos, doc_dict, file_pre)
62
+ generate_ops_def_file(work_path, op_protos, doc_dict, file_pre)
63
+ shutil.copy(os.path.join(work_path, K.PY_OPS_GEN_PATH, 'ops_auto_generate_init.txt'),
64
+ os.path.join(work_path, K.PY_AUTO_GEN_PATH, "__init__.py"))
636
65
 
637
- #include <memory>
638
- #include "ir/anf.h"
639
- #include "ir/primitive.h"
640
- #include "{K.MS_OP_DEF_AUTO_GENERATE_PATH}/gen_ops_name.h"
641
- #include "mindapi/base/macros.h"
642
66
 
643
- namespace mindspore::prim {{
644
- """
67
+ def call_ops_def_cc_generator(work_path, op_protos):
68
+ generator = OpsDefCcGenerator()
69
+ generator.generate(work_path, op_protos)
645
70
 
646
- ops_prim_end = f"""}} // namespace mindspore::prim
647
- #endif // MINDSPORE_CORE_OPS_GEN_OPS_PRIMITIVE_H_
648
- """
649
71
 
650
- ops_prim_gen = ''
651
- ops_prim_gen += ops_prim_head
652
- for operator_name, operator_data in yaml_data.items():
653
- k_name_op = _get_op_name(operator_name, operator_data)
654
- ops_prim_gen += f"""GVAR_DEF(PrimitivePtr, kPrim{k_name_op}, std::make_shared<Primitive>(ops::kName{k_name_op}))
655
- """
656
- ops_prim_gen += ops_prim_end
657
- return ops_prim_gen
72
+ def call_ops_def_h_generator(work_path, op_protos):
73
+ generator = OpsDefHGenerator()
74
+ generator.generate(work_path, op_protos)
658
75
 
659
76
 
660
- def generate_lite_ops(yaml_data):
661
- """
662
- Generate BaseOperator parameter set and get func
663
- """
664
- lite_ops_h_head = f"""
665
- #ifndef MINDSPORE_CORE_OPS_GEN_LITE_OPS_H_
666
- #define MINDSPORE_CORE_OPS_GEN_LITE_OPS_H_
77
+ def call_ops_primitive_h_generator(work_path, op_protos):
78
+ generator = OpsPrimitiveHGenerator()
79
+ generator.generate(work_path, op_protos)
667
80
 
668
- #include <vector>
669
- #include "ops/base_operator.h"
670
- #include "{K.OP_DEF_AUTO_GENERATE_PATH}/gen_ops_name.h"
671
81
 
672
- namespace mindspore::ops {{
673
- """
82
+ def call_lite_ops_h_generator(work_path, op_protos):
83
+ h_generator = LiteOpsHGenerator()
84
+ h_generator.generate(work_path, op_protos)
674
85
 
675
- lite_ops_h_end = f"""}} // namespace mindspore::ops
676
- #endif // MINDSPORE_CORE_OPS_GEN_LITE_OPS_H_
677
- """
678
86
 
679
- lite_ops_cc_head = f"""
680
- #include "{K.OP_DEF_AUTO_GENERATE_PATH}/gen_lite_ops.h"
681
- #include "mindapi/helper.h"
682
- #include "ops/primitive_c.h"
683
- #include "ops/base_operator.h"
684
- #include "abstract/abstract_value.h"
87
+ def call_lite_ops_cc_generator(work_path, op_protos):
88
+ generator = LiteOpsCcGenerator()
89
+ generator.generate(work_path, op_protos)
685
90
 
686
- namespace mindspore::ops {{
687
- """
688
91
 
689
- lite_ops_cc_end = f"""}} // namespace mindspore::ops
690
- """
92
+ def call_ops_name_h_generator(work_path, op_protos):
93
+ h_generator = OpsNameHGenerator()
94
+ h_generator.generate(work_path, op_protos)
691
95
 
692
- lite_ops_h_gen = ''
693
- lite_ops_cc_gen = ''
694
-
695
- lite_ops_h_gen += lite_ops_h_head
696
- lite_ops_cc_gen += lite_ops_cc_head
697
- for operator_name, operator_data in yaml_data.items():
698
- op_name = _get_op_name(operator_name, operator_data)
699
- lite_ops_h_gen += f"""class OPS_API {op_name} : public BaseOperator {{
700
- public:
701
- MIND_API_BASE_MEMBER({op_name});
702
- {op_name}() : BaseOperator(kName{op_name}) {{}}\n"""
703
- args = operator_data.get('args')
704
- for _, (arg_name, arg_info) in enumerate(args.items()):
705
- is_prim_init = arg_info.get('prim_init')
706
- if not is_prim_init:
707
- continue
708
96
 
709
- dtype = get_dtype(arg_info)
710
- if dtype == "str":
711
- dtype = "std::string"
712
- if dtype in ("tuple[str]", "list[str]"):
713
- dtype = "std::vector<std::string>"
714
- if dtype in ("tuple[int]", "list[int]"):
715
- dtype = "std::vector<int64_t>"
716
- if dtype in ("tuple[float]", "list[float]"):
717
- dtype = "std::vector<float>"
718
- if dtype in ("tuple[bool]", "list[bool]"):
719
- dtype = "std::vector<bool>"
720
- if dtype == "int":
721
- dtype = "int64_t"
722
- lite_ops_h_gen += f""" void set_{arg_name}(const {dtype} &{arg_name});\n"""
723
- lite_ops_h_gen += f""" {dtype} get_{arg_name}() const;\n"""
724
-
725
- lite_ops_cc_gen += f"""void {op_name}::set_{arg_name}(const {dtype} &{arg_name}) \
726
- {{ (void)this->AddAttr("{arg_name}", api::MakeValue({arg_name})); }}\n\n"""
727
- lite_ops_cc_gen += f"""{dtype} {op_name}::get_{arg_name}() const \
728
- {{ return GetValue<{dtype}>(GetAttr("{arg_name}")); }}\n\n"""
729
-
730
- op_name = _get_op_name(operator_name, operator_data)
731
- lite_ops_cc_gen += f"""REGISTER_PRIMITIVE_C(kName{op_name}, {op_name});\n"""
732
- lite_ops_cc_gen += f"""MIND_API_OPERATOR_IMPL({op_name}, BaseOperator);\n\n"""
733
- lite_ops_h_gen += f"""}};\n\n"""
734
- lite_ops_h_gen += lite_ops_h_end
735
- lite_ops_cc_gen += lite_ops_cc_end
736
- return lite_ops_h_gen, lite_ops_cc_gen
737
-
738
-
739
- def generate_cc_opdef(yaml_data):
97
+ def generate_ops_cc_files(work_path, op_protos, op_protos_with_deprecated):
740
98
  """
741
- Generate c++ OpDef
99
+ Generate ops c++ file from yaml.
742
100
  """
743
- gen_cc_code = f"""\n
744
- namespace mindspore::ops {{"""
745
- gen_include = f"""\n
746
- #include \"{K.MS_OP_DEF_AUTO_GENERATE_PATH}/gen_ops_def.h\""""
747
- gen_include += f"""
748
- #include \"ir/signature.h\""""
749
-
750
- for operator_name, operator_data in yaml_data.items():
751
- args = operator_data.get('args')
752
- class_name = _get_op_name(operator_name, operator_data)
753
- inputs_args, _, _, _, _, _ = process_args(class_name, args)
754
- signature_code = generate_cc_op_signature(operator_data.get('args_signature'), inputs_args)
755
- args = operator_data.get('args')
756
- returns = operator_data.get('returns')
757
- dispatch = operator_data.get("dispatch")
758
- # dispatch not defined in yaml or dispatch.enable==False
759
- if not dispatch or not dispatch.get("enable"):
760
- dispatch = "false"
761
- else:
762
- dispatch = "true"
763
- enable_dispatch_str = f"""{dispatch}"""
764
-
765
- is_view = operator_data.get('view')
766
- if is_view:
767
- is_view_s = "true"
768
- else:
769
- is_view_s = "false"
770
- is_view_str = f"""{is_view_s}"""
771
-
772
- gen_include += f"""\n#include "{K.MS_OPS_FUNC_IMPL_PATH}/{operator_name}.h\""""
773
- cc_index_str = ''
774
- input_args_str = ''
775
- args_dict = {}
776
- for i, (arg_name, arg_info) in enumerate(args.items()):
777
- args_dict[arg_name] = i
778
- cc_index_str += f"""{{"{arg_name}", {i}}},\n"""
779
- dtype = get_dtype(arg_info)
780
- cc_dtype_str = convert_dtype_str(dtype)
781
-
782
- is_prim_init = 1 if arg_info.get('prim_init') else 0
783
- arg_handler = arg_info.get('arg_handler')
784
- arg_handler_str = "" if arg_handler is None else arg_handler
785
-
786
- type_cast = arg_info.get('type_cast')
787
- type_cast_str = "" if type_cast is None else \
788
- ', '.join('DT_' + type.replace('[', '_').replace(']', '').upper() for type in
789
- (ct.strip() for ct in type_cast.split(",")))
790
-
791
- # default: None is regarded as a optional argument.
792
- is_optional_str = "false"
793
- if 'default' in arg_info.keys() and arg_info.get('default') == "None":
794
- is_optional_str = "true"
795
-
796
- input_args_str += f"""\n {{/*.arg_name_=*/"{arg_name}", /*.arg_dtype_=*/{cc_dtype_str}, """ + \
797
- f"""/*.as_init_arg_=*/{is_prim_init}, /*.arg_handler_=*/"{arg_handler_str}", """ + \
798
- f"""/*.cast_dtype_ =*/{{{type_cast_str}}}, /*.is_optional_=*/{is_optional_str}}},"""
799
-
800
- # Process outputs.
801
- return_args_str = ''
802
- for return_name, return_info in returns.items():
803
- return_dtype = return_info.get('dtype')
804
- ref_name = return_info.get('inplace')
805
- ref_index_str = -1 if ref_name is None else args_dict.get(ref_name)
806
- cc_return_type_str = 'DT_' + return_dtype.replace('[', '_').replace(']', '').upper()
807
- return_args_str += f"""{{/*.arg_name_=*/"{return_name}", /*.arg_dtype_=*/{cc_return_type_str},
808
- /*.inplace_input_index_=*/{ref_index_str}}},\n"""
809
-
810
- op_def_cc = template.OP_PROTO_TEMPLATE.replace(class_name=class_name, input_args=input_args_str,
811
- return_args=return_args_str, signatures=signature_code,
812
- indexes=cc_index_str, enable_dispatch=enable_dispatch_str,
813
- is_view=is_view_str)
814
- gen_cc_code += op_def_cc
815
- if is_view:
816
- view_op_def = op_def_cc.replace(class_name, class_name+"View")
817
- gen_cc_code += view_op_def
818
-
819
- cc_opdef_end = f"""\n}} // namespace mindspore::ops\n"""
820
- return gen_include + gen_cc_code + cc_opdef_end
821
-
822
-
823
- ops_py_prim_header = f"""
824
- \"\"\"Operators definition generated by gen_ops.py, includes primitive classes.\"\"\"
825
-
826
- from mindspore.ops.primitive import Primitive, prim_arg_register
827
- from mindspore.ops import signature as sig
828
- from mindspore.common import dtype as mstype
829
- from mindspore.common._decorator import deprecated
830
- from mindspore.ops._primitive_cache import _get_cache_prim
831
- from mindspore.ops.auto_generate.gen_arg_dtype_cast import type_it
832
- from mindspore.ops.auto_generate.gen_arg_handler import *
833
- from mindspore._c_expression import OpDtype
834
- from mindspore.common._stub_tensor import _convert_stub
835
- """
836
-
837
- ops_py_def_header = f"""
838
- \"\"\"Operators definition generated by gen_ops.py, includes functions.\"\"\"
839
-
840
- from .gen_ops_prim import *
841
- from .pyboost_inner_prim import *
842
- from mindspore.ops.operations.manually_defined.ops_def import *
843
- from mindspore.ops._primitive_cache import _get_cache_prim
844
- """
845
-
846
-
847
- def generate_ops_prim_file(work_path, yaml_str, doc_str, file_pre):
848
- py_path = os.path.join(work_path, f'{K.PY_AUTO_GEN_PATH}/{file_pre}_ops_prim.py')
849
- tmp_py_path = os.path.join(work_path, f'{K.PY_AUTO_GEN_PATH}/tmp_{file_pre}_ops_prim.py')
850
- pyboost_import_header = generate_pyboost_import_header(yaml_str)
851
- py_prim = generate_py_primitive(yaml_str, doc_str)
852
- write_file(tmp_py_path, py_licence_str + ops_py_prim_header + pyboost_import_header + py_prim)
853
- check_change_and_replace_file(py_path, tmp_py_path)
854
-
101
+ call_ops_def_cc_generator(work_path, op_protos_with_deprecated)
102
+ call_ops_def_h_generator(work_path, op_protos_with_deprecated)
103
+ call_ops_primitive_h_generator(work_path, op_protos)
104
+ call_lite_ops_h_generator(work_path, op_protos)
105
+ call_lite_ops_cc_generator(work_path, op_protos)
106
+ call_ops_name_h_generator(work_path, op_protos)
855
107
 
856
- def generate_ops_def_file(work_path, yaml_str, doc_str, file_pre):
857
- py_path = os.path.join(work_path, f'{K.PY_AUTO_GEN_PATH}/{file_pre}_ops_def.py')
858
- tmp_py_path = os.path.join(work_path, f'{K.PY_AUTO_GEN_PATH}/tmp_{file_pre}_ops_def.py')
859
- py_func = generate_py_op_func(yaml_str, doc_str)
860
- write_file(tmp_py_path, py_licence_str + ops_py_def_header + py_func)
861
- check_change_and_replace_file(py_path, tmp_py_path)
862
108
 
863
-
864
- def generate_ops_py_files(work_path, yaml_str, doc_str, file_pre):
109
+ def get_tensor_op_protos_with_deprecated(func_protos, op_protos):
865
110
  """
866
- Generate ops python file from yaml.
111
+ Get op_protos with deprecated op_protos from func_protos.
867
112
  """
868
- generate_ops_prim_file(work_path, yaml_str, doc_str, file_pre)
869
- generate_ops_def_file(work_path, yaml_str, doc_str, file_pre)
870
- shutil.copy(os.path.join(work_path, K.PY_OPS_GEN_PATH, 'ops_auto_generate_init.txt'),
871
- os.path.join(work_path, K.PY_AUTO_GEN_PATH, "__init__.py"))
113
+ tensor_op_protos = copy.deepcopy(op_protos)
114
+ for _, item in func_protos.items():
115
+ for func_proto in item:
116
+ op_name = func_proto.op_proto.op_name
117
+ if "deprecated" in func_proto.op_proto.op_name:
118
+ func_proto.op_proto.op_class.name = ''.join(word.capitalize() for word in op_name.split('_'))
119
+ if func_proto.op_proto.op_name[-1] == '_':
120
+ func_proto.op_proto.op_class.name += '_'
121
+ tensor_op_protos.append(func_proto.op_proto)
122
+ return tensor_op_protos
872
123
 
873
124
 
874
- def generate_ops_cc_files(work_path, yaml_str):
875
- """
876
- Generate ops c++ file from yaml.
877
- """
878
- # ops_def
879
- op_cc_path = os.path.join(work_path, K.MS_OP_DEF_AUTO_GENERATE_PATH, 'gen_ops_def.cc')
880
- tmp_op_cc_path = os.path.join(work_path, K.MS_OP_DEF_AUTO_GENERATE_PATH, 'tmp_gen_ops_def.cc')
881
- cc_def_code = generate_cc_opdef(yaml_str)
882
- write_file(tmp_op_cc_path, cc_license_str + cc_def_code)
883
- check_change_and_replace_file(op_cc_path, tmp_op_cc_path)
884
-
885
- # ops_primitive
886
- op_prim_path = os.path.join(work_path, K.MS_OP_DEF_AUTO_GENERATE_PATH, 'gen_ops_primitive.h')
887
- tmp_op_prim_path = os.path.join(work_path, K.MS_OP_DEF_AUTO_GENERATE_PATH, 'tmp_gen_ops_primitive.h')
888
- op_prim_code = generate_op_prim_opdef(yaml_str)
889
- write_file(tmp_op_prim_path, cc_license_str + op_prim_code)
890
- check_change_and_replace_file(op_prim_path, tmp_op_prim_path)
891
-
892
- # lite_h_ops
893
- lite_ops_h_path = os.path.join(work_path, K.MS_OP_DEF_AUTO_GENERATE_PATH, 'gen_lite_ops.h')
894
- tmp_lite_ops_h_path = os.path.join(work_path, K.MS_OP_DEF_AUTO_GENERATE_PATH, 'tmp_gen_lite_ops.h')
895
- lite_ops_h_code, lite_ops_cc_code = generate_lite_ops(yaml_str)
896
- write_file(tmp_lite_ops_h_path, cc_license_str + lite_ops_h_code)
897
- check_change_and_replace_file(lite_ops_h_path, tmp_lite_ops_h_path)
898
-
899
- # lite_cc_ops
900
- lite_ops_cc_path = os.path.join(work_path, K.MS_OP_DEF_AUTO_GENERATE_PATH, 'gen_lite_ops.cc')
901
- tmp_lite_ops_cc_path = os.path.join(work_path, K.MS_OP_DEF_AUTO_GENERATE_PATH, 'tmp_gen_lite_ops.cc')
902
- write_file(tmp_lite_ops_cc_path, cc_license_str + lite_ops_cc_code)
903
- check_change_and_replace_file(lite_ops_cc_path, tmp_lite_ops_cc_path)
904
-
905
- # ops_names
906
- op_name_path = os.path.join(work_path, K.MS_OP_DEF_AUTO_GENERATE_PATH, 'gen_ops_name.h')
907
- tmp_op_name_path = os.path.join(work_path, K.MS_OP_DEF_AUTO_GENERATE_PATH, 'tmp_gen_ops_name.h')
908
- op_name_code = generate_op_name_opdef(yaml_str)
909
- write_file(tmp_op_name_path, cc_license_str + op_name_code)
910
- check_change_and_replace_file(op_name_path, tmp_op_name_path)
911
-
912
-
913
- def generate_op_labels(yaml_data):
914
- """
915
- Generate python labels
916
- """
917
- gen_label_py = f"""op_labels = {{"""
918
- for operator_name, operator_data in yaml_data.items():
919
- labels = operator_data.get('labels')
920
- if labels is not None:
921
- class_name = _get_op_name(operator_name, operator_data)
922
- gen_label_py += f"""
923
- "{class_name}": {{"""
924
- gen_label_py += f""", """.join([f""""{key}": {value}""" for key, value in labels.items()])
925
- gen_label_py += f"""}},"""
926
- gen_label_py += f"""
927
- }}"""
928
- return gen_label_py
929
-
930
-
931
- def generate_op_arg_default_value(yaml_data):
932
- """
933
- Generate python default value.
934
- """
935
- default_py_header = f"""\"\"\"Operator labels and args default value.\"\"\"
936
- from mindspore.common import dtype as mstype\n\n"""
937
-
938
- gen_default_py = default_py_header + f"""op_args_default_value = {{"""
939
- for operator_name, operator_data in yaml_data.items():
940
- arg_default_dict = {}
941
- args = operator_data.get('args')
942
- for arg_name, arg_info in args.items():
943
- arg_default = arg_info.get('default')
944
- if arg_default is not None:
945
- arg_default_dict[arg_name] = arg_default
946
- if arg_default_dict:
947
- class_name = _get_op_name(operator_name, operator_data)
948
- gen_default_py += f"""
949
- "{class_name}": {{"""
950
- gen_default_py += f""", """.join([f""""{key}": {value}""" for key, value in arg_default_dict.items()])
951
- gen_default_py += f"""}},"""
952
- gen_default_py += f"""
953
- }}"""
954
- return gen_default_py
955
-
956
-
957
- def generate_create_instance_helper_file(work_path, yaml_str):
125
+ def generate_create_instance_helper_file(work_path, op_protos_with_deprecated):
958
126
  """
959
127
  Generate C++ helper file from yaml.
960
128
  """
961
- dst_dir = os.path.join(work_path, K.PY_AUTO_GEN_PATH)
962
- op_py_path = os.path.join(dst_dir, 'cpp_create_prim_instance_helper.py')
963
- tmp_op_py_path = os.path.join(dst_dir, 'tmp_cpp_create_prim_instance_helper.py')
964
- py_labels = generate_op_labels(yaml_str)
965
- py_arg_default = generate_op_arg_default_value(yaml_str)
966
- write_file(tmp_op_py_path, py_licence_str + "\n" + py_arg_default + "\n\n" + py_labels + "\n")
967
- check_change_and_replace_file(op_py_path, tmp_op_py_path)
968
-
969
-
970
- def generate_aclnn_reg_code(yaml_data):
971
- """generate aclnn register code"""
972
- current_path = os.path.dirname(os.path.realpath(__file__))
973
- work_path = os.path.join(current_path, '../../../../')
974
- ops_yaml_path = os.path.join(work_path, K.PY_OPS_GEN_PATH, "ops.yaml")
975
- yaml_str = gen_utils.safe_load_yaml(ops_yaml_path)
976
-
977
- reg_code = f"""
978
- #include "{K.MS_OPS_KERNEL_PATH}/ascend/opapi/aclnn_kernel_mod.h"
979
-
980
- namespace mindspore {{
981
- namespace kernel {{
982
- """
983
- for operator_name, operator_data in yaml_data.items():
984
- dispatch = operator_data.get("dispatch")
985
- if not dispatch or not dispatch.get("enable"):
986
- continue
987
- Ascend = dispatch.get("Ascend")
988
- if Ascend is not None: # KernelMod is provided by yaml, don't auto generate it.
989
- continue
990
- _, _, none_tensor_exist = get_dtypes(operator_data)
991
- if none_tensor_exist:
992
- gen_aclnn_kernel(operator_name, yaml_str, auto=True)
993
- continue
994
- class_name = ''.join(word.capitalize() for word in operator_name.split('_'))
995
- op_class = operator_data.get("class")
996
- if op_class and op_class.get("name") is not None:
997
- class_name = op_class.get("name")
998
- inputs_outputs_num = len(operator_data.get("args")) + len(operator_data.get("returns"))
999
- aclnn_name = AclnnUtils.get_aclnn_interface(class_name)
1000
- reg_code += f"""
1001
- MS_ACLNN_COMMON_KERNEL_FACTORY_REG({class_name}, {aclnn_name}, {inputs_outputs_num});"""
1002
- reg_code += f"""
1003
- }} // namespace kernel
1004
- }} // namespace mindspore
1005
- """
1006
- return reg_code
129
+ generator = CppCreatePrimInstanceHelperGenerator()
130
+ generator.generate(work_path, op_protos_with_deprecated)
1007
131
 
1008
132
 
1009
- def generate_aclnn_reg_file(work_path, yaml_str):
133
+ def generate_aclnn_reg_file(work_path, op_protos):
1010
134
  """
1011
135
  Generate nnacl kernelmod register
1012
136
  """
1013
- tmp_register_file = work_path + f'{K.MS_OPS_KERNEL_PATH}/ascend/opapi/tmp_aclnn_kernel_register.cc'
1014
- register_file = work_path + f'{K.MS_OPS_KERNEL_PATH}/ascend/opapi/aclnn_kernel_register_auto.cc'
1015
- reg_code = generate_aclnn_reg_code(yaml_str)
1016
- write_file(tmp_register_file, cc_license_str + reg_code)
1017
- check_change_and_replace_file(register_file, tmp_register_file)
137
+ generator = AclnnKernelRegisterAutoCcGenerator()
138
+ generator.generate(work_path, op_protos)
1018
139
 
1019
140
 
1020
141
  def generate_arg_handler_files(work_path):
@@ -1037,33 +158,36 @@ def generate_arg_handler_files(work_path):
1037
158
  check_change_and_replace_file(dst_arg_dtype_cast_path, tmp_arg_dtype_cast_path)
1038
159
 
1039
160
 
1040
- def get_view_ops(yaml_data):
1041
- """
1042
- Get ops with view: True
1043
- """
1044
- view_ops = []
1045
- for operator_name, operator_data in yaml_data.items():
1046
- class_name = _get_op_name(operator_name, operator_data)
1047
- view = operator_data.get("view")
1048
- if view:
1049
- view_ops.append(class_name + "View")
1050
- return view_ops
161
+ def gen_tensor_func_code(work_path, op_protos, func_protos, alias_api_mapping):
162
+ generator = TensorFuncRegCppGenerator()
163
+ generator.generate(work_path, op_protos, func_protos, alias_api_mapping)
164
+
165
+
166
+ def gen_functional_map_code(work_path, tensor_method_protos, mint_func_protos, alias_api_mapping):
167
+ generator = FunctionalMapCppGenerator()
168
+ generator.generate(work_path, tensor_method_protos, mint_func_protos, alias_api_mapping)
169
+
1051
170
 
171
+ def gen_tensor_docs_code(work_path, tensor_docs_data):
172
+ generator = AddTensorDocsGenerator()
173
+ generator.generate(work_path, tensor_docs_data)
1052
174
 
1053
- def main():
175
+
176
+ def gen_functional_overload_py(work_path, mint_func_protos, function_doc_data, alias_api_mapping):
177
+ generator = FunctionalOverloadPyGenerator()
178
+ generator.generate(work_path, mint_func_protos, function_doc_data, alias_api_mapping)
179
+
180
+
181
+ def main(args):
1054
182
  current_path = os.path.dirname(os.path.realpath(__file__))
1055
183
  work_path = os.path.join(current_path, '../../../../')
1056
184
 
1057
- # merge ops yaml
1058
- ops_yaml_path = os.path.join(work_path, K.PY_OPS_GEN_PATH, 'ops.yaml')
1059
- doc_yaml_path = os.path.join(work_path, K.PY_OPS_GEN_PATH, 'ops_doc.yaml')
185
+ if args.clear_auto_gen:
186
+ delete_auto_gen_files(work_path)
1060
187
 
1061
- ops_yaml_dir_path = os.path.join(work_path, K.MS_YAML_PATH)
1062
- infer_ops_yaml_dir_path = os.path.join(ops_yaml_dir_path, "infer")
1063
- doc_yaml_dir_path = os.path.join(ops_yaml_dir_path, "doc")
1064
- merge_files(ops_yaml_dir_path, ops_yaml_path, '*op.yaml')
1065
- merge_files_append(infer_ops_yaml_dir_path, ops_yaml_path, '*op.yaml')
1066
- merge_files(doc_yaml_dir_path, doc_yaml_path, '*doc.yaml')
188
+ # merge ops yaml
189
+ (doc_yaml_path, ops_yaml_path, deprecated_ops_yaml_path, ops_api_yaml_path,
190
+ tensor_method_doc_yaml_path, mint_func_doc_yaml_path) = merge_ops_yaml(work_path)
1067
191
 
1068
192
  # make auto_generate dir
1069
193
  cc_path = os.path.join(work_path, K.MS_OP_DEF_AUTO_GENERATE_PATH)
@@ -1072,28 +196,159 @@ def main():
1072
196
  # generate arg_handler files
1073
197
  generate_arg_handler_files(work_path)
1074
198
 
1075
- # read ops definition str and doc str
1076
- ops_yaml_str = safe_load_yaml(ops_yaml_path)
1077
- doc_yaml_str = safe_load_yaml(doc_yaml_path)
199
+ # read ops definition str and tensor method doc str
200
+ ops_yaml_dict = safe_load_yaml(ops_yaml_path)
201
+ doc_yaml_dict = safe_load_yaml(doc_yaml_path)
202
+ deprecated_ops_yaml_dict = safe_load_yaml(deprecated_ops_yaml_path)
203
+ ops_api_yaml_dict = safe_load_yaml(ops_api_yaml_path)
204
+ tensor_method_doc_yaml_dict = safe_load_yaml(tensor_method_doc_yaml_path)
205
+ mint_function_doc_yaml_dict = safe_load_yaml(mint_func_doc_yaml_path)
206
+
207
+ op_protos = load_op_protos_from_ops_yaml(ops_yaml_dict)
208
+ deprecated_op_protos = load_deprecated_op_protos_from_ops_yaml(deprecated_ops_yaml_dict)
209
+ tensor_method_protos, mint_func_protos, alias_api_mapping \
210
+ = load_api_protos_from_yaml(ops_api_yaml_dict, op_protos, deprecated_op_protos)
211
+ # for generate tensor method deprecated in graph mode
212
+ op_protos_with_deprecated = get_tensor_op_protos_with_deprecated(tensor_method_protos, op_protos)
1078
213
 
1079
214
  # generate ops python files
1080
- generate_ops_py_files(work_path, ops_yaml_str, doc_yaml_str, "gen")
1081
-
215
+ generate_ops_py_files(work_path, op_protos, doc_yaml_dict, "gen")
1082
216
  # generate ops c++ files
1083
- generate_ops_cc_files(work_path, ops_yaml_str)
217
+ generate_ops_cc_files(work_path, op_protos, op_protos_with_deprecated)
1084
218
  # generate create prim instance helper file
1085
- generate_create_instance_helper_file(work_path, ops_yaml_str)
1086
- # get view extra ops
1087
- extra_ops = get_view_ops(ops_yaml_str)
219
+ generate_create_instance_helper_file(work_path, op_protos_with_deprecated)
1088
220
  # generate pyboost code
1089
- gen_pyboost_code(work_path, ops_yaml_str, doc_yaml_str, extra_ops)
221
+ gen_pyboost_code(work_path, op_protos, doc_yaml_dict, tensor_method_protos, mint_func_protos, alias_api_mapping)
1090
222
  # generate aclnn kernelmod register
1091
- generate_aclnn_reg_file(work_path, ops_yaml_str)
223
+ generate_aclnn_reg_file(work_path, op_protos)
224
+ # generate tensor_py func code
225
+ gen_tensor_func_code(work_path, op_protos, tensor_method_protos, alias_api_mapping)
226
+ # generate functional map code
227
+ gen_functional_map_code(work_path, tensor_method_protos, mint_func_protos, alias_api_mapping)
228
+ # generate _tensor_docs.py that attaches docs to tensor func APIs when import mindspore
229
+ gen_tensor_docs_code(work_path, tensor_method_doc_yaml_dict)
230
+ # generate functional_overload.py which init pybind mint APIs from cpp
231
+ gen_functional_overload_py(work_path, mint_func_protos, mint_function_doc_yaml_dict, alias_api_mapping)
232
+
233
+
234
+ def delete_auto_gen_files(work_path):
235
+ """
236
+ Deletes auto-generated files and folders.
237
+ """
238
+ auto_gen_code_file = get_auto_gen_path_from_gitignore(work_path)
239
+
240
+ for name in auto_gen_code_file:
241
+ # Recursively delete all single-level folder names
242
+ if name.rstrip('/').count('/') == 0:
243
+ for dir_path, dir_names, _ in os.walk(work_path, topdown=False):
244
+ for dirname in dir_names:
245
+ if dirname == name.rstrip('/'):
246
+ folder_path = os.path.join(dir_path, dirname)
247
+ logging.info("Recursively deleting folder: %s", folder_path)
248
+ shutil.rmtree(folder_path)
249
+ continue
250
+
251
+ # Delete all individual files or folders
252
+ tmp_path = os.path.join(work_path, name)
253
+ if os.path.exists(tmp_path):
254
+ if os.path.isdir(tmp_path):
255
+ logging.info("Deleting folder: %s", tmp_path)
256
+ shutil.rmtree(tmp_path)
257
+ elif os.path.isfile(tmp_path):
258
+ logging.info("Deleting file: %s", tmp_path)
259
+ os.remove(tmp_path)
260
+ else:
261
+ logging.info("The path is not exist: %s", tmp_path)
262
+
263
+
264
+ def get_auto_gen_path_from_gitignore(work_path):
265
+ """
266
+ Extracts a list of auto-gen file and folder paths from the "# auto gen code files" section in the .gitignore file.
267
+ """
268
+ file_path = os.path.join(work_path, ".gitignore")
269
+ auto_gen_code_file_started = False
270
+ auto_gen_code_file = []
271
+ with open(file_path, 'r') as f:
272
+ for line in f.readlines():
273
+ if line.strip() == "# auto gen code files":
274
+ auto_gen_code_file_started = True
275
+ continue
276
+ if auto_gen_code_file_started:
277
+ if line.strip() and not line.strip().startswith("#"):
278
+ auto_gen_code_file.append(line.strip())
279
+ else:
280
+ break
281
+ return auto_gen_code_file
282
+
283
+
284
+ def load_op_protos_from_ops_yaml(ops_yaml_data):
285
+ op_protos = []
286
+ for operator_name, operator_data in ops_yaml_data.items():
287
+ op_proto = OpProto.load_from_yaml(operator_name, operator_data)
288
+ op_protos.append(op_proto)
289
+ return op_protos
290
+
291
+
292
+ def load_deprecated_op_protos_from_ops_yaml(ops_yaml_data):
293
+ op_protos = []
294
+ for operator_name, operator_data in ops_yaml_data.items():
295
+ op_proto = OpProto.load_from_yaml(operator_name, operator_data)
296
+ op_proto.op_name = 'deprecated_' + operator_name
297
+ op_protos.append(op_proto)
298
+ return op_protos
299
+
300
+
301
+ def merge_ops_yaml(work_path):
302
+ """
303
+ Merges operator YAML files scattered in different directories into a single file.
304
+
305
+ Args:
306
+ work_path (str): The path to the working directory.
307
+
308
+ Returns:
309
+ tuple: Paths to the merged documentation and operators YAML files.
310
+ """
311
+ ops_yaml_dir_path = os.path.join(work_path, K.MS_OP_DEF_YAML_PATH)
312
+ ops_yaml_path = os.path.join(work_path, K.PY_OPS_GEN_PATH, 'ops.yaml')
313
+ infer_ops_yaml_dir_path = os.path.join(ops_yaml_dir_path, "infer")
314
+ merge_files(ops_yaml_dir_path, ops_yaml_path, '*op.yaml')
315
+ merge_files_append(infer_ops_yaml_dir_path, ops_yaml_path, '*op.yaml')
316
+
317
+ doc_yaml_dir_path = os.path.join(ops_yaml_dir_path, "doc")
318
+ doc_yaml_path = os.path.join(work_path, K.PY_OPS_GEN_PATH, 'ops_doc.yaml')
319
+ merge_files(doc_yaml_dir_path, doc_yaml_path, '*doc.yaml')
320
+
321
+ ops_api_yaml_dir_path = os.path.join(work_path, K.MS_OP_API_YAML_PATH)
322
+ ops_api_yaml_path = os.path.join(work_path, K.PY_OPS_GEN_PATH, 'api_def.yaml')
323
+ merge_files(ops_api_yaml_dir_path, ops_api_yaml_path, '*.yaml')
324
+
325
+ deprecated_ops_yaml_dir_path = os.path.join(work_path, K.MS_OP_DEPRECATED_DEF_YAML_PATH)
326
+ deprecated_ops_yaml_path = os.path.join(work_path, K.PY_OPS_GEN_PATH, 'deprecated_ops.yaml')
327
+ merge_files(deprecated_ops_yaml_dir_path, deprecated_ops_yaml_path, '*_method.yaml')
328
+
329
+ tensor_method_doc_yaml_dir_path = os.path.join(work_path, K.MS_TENSOR_METHOD_DOC_YAML_PATH)
330
+ tensor_method_doc_yaml_path = os.path.join(work_path, K.PY_OPS_GEN_PATH, 'tensor_method_doc.yaml')
331
+ merge_files(tensor_method_doc_yaml_dir_path, tensor_method_doc_yaml_path, '*doc.yaml')
332
+
333
+ mint_func_doc_yaml_dir_path = os.path.join(work_path, K.MS_MINT_FUNC_DOC_YAML_PATH)
334
+ mint_func_doc_yaml_path = os.path.join(work_path, K.PY_OPS_GEN_PATH, 'mint_func_doc.yaml')
335
+ merge_files(mint_func_doc_yaml_dir_path, mint_func_doc_yaml_path, '*doc.yaml')
336
+
337
+ return (doc_yaml_path, ops_yaml_path, deprecated_ops_yaml_path,
338
+ ops_api_yaml_path, tensor_method_doc_yaml_path, mint_func_doc_yaml_path)
339
+
340
+
341
+ def parse_args():
342
+ parser = argparse.ArgumentParser()
343
+ parser.add_argument('--clear_auto_gen', default=False, help='clear all auto gen files')
344
+ return parser.parse_args()
1092
345
 
1093
346
 
1094
347
  if __name__ == "__main__":
1095
348
  try:
1096
- main()
349
+ arguments = parse_args()
350
+ main(arguments)
1097
351
  # pylint: disable=broad-except
1098
352
  except Exception as e:
1099
353
  logging.critical("Auto generate failed, err info: %s", e)
354
+ raise e