mindspore 2.4.1__cp311-cp311-win_amd64.whl → 2.5.0__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (395) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  3. mindspore/Newtonsoft.Json.dll +0 -0
  4. mindspore/__init__.py +8 -3
  5. mindspore/_c_dataengine.cp311-win_amd64.pyd +0 -0
  6. mindspore/_c_expression.cp311-win_amd64.pyd +0 -0
  7. mindspore/_c_mindrecord.cp311-win_amd64.pyd +0 -0
  8. mindspore/_checkparam.py +0 -5
  9. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
  10. mindspore/_extends/parse/compile_config.py +64 -0
  11. mindspore/_extends/parse/deprecated/__init__.py +0 -0
  12. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +375 -0
  13. mindspore/_extends/parse/parser.py +23 -5
  14. mindspore/_extends/parse/standard_method.py +123 -27
  15. mindspore/_extends/pijit/pijit_func_white_list.py +1 -1
  16. mindspore/amp.py +7 -1
  17. mindspore/atlprov.dll +0 -0
  18. mindspore/avcodec-59.dll +0 -0
  19. mindspore/avdevice-59.dll +0 -0
  20. mindspore/avfilter-8.dll +0 -0
  21. mindspore/avformat-59.dll +0 -0
  22. mindspore/avutil-57.dll +0 -0
  23. mindspore/boost/boost_cell_wrapper.py +136 -41
  24. mindspore/c1.dll +0 -0
  25. mindspore/c1xx.dll +0 -0
  26. mindspore/c2.dll +0 -0
  27. mindspore/common/__init__.py +3 -1
  28. mindspore/common/_register_for_tensor.py +0 -1
  29. mindspore/common/_stub_tensor.py +25 -4
  30. mindspore/common/_tensor_cpp_method.py +17 -0
  31. mindspore/common/_tensor_docs.py +6132 -0
  32. mindspore/common/api.py +99 -25
  33. mindspore/common/dtype.py +34 -34
  34. mindspore/common/dump.py +2 -1
  35. mindspore/common/file_system.py +8 -1
  36. mindspore/common/generator.py +2 -0
  37. mindspore/common/hook_handle.py +3 -1
  38. mindspore/common/initializer.py +3 -4
  39. mindspore/common/lazy_inline.py +8 -2
  40. mindspore/common/mindir_util.py +10 -2
  41. mindspore/common/parameter.py +30 -27
  42. mindspore/common/tensor.py +713 -1337
  43. mindspore/communication/__init__.py +1 -1
  44. mindspore/communication/_comm_helper.py +10 -0
  45. mindspore/communication/comm_func.py +215 -173
  46. mindspore/communication/management.py +23 -20
  47. mindspore/context.py +292 -193
  48. mindspore/dataset/__init__.py +23 -19
  49. mindspore/dataset/callback/ds_callback.py +2 -1
  50. mindspore/dataset/core/config.py +84 -3
  51. mindspore/dataset/engine/cache_admin.py +3 -3
  52. mindspore/dataset/engine/cache_client.py +5 -4
  53. mindspore/dataset/engine/datasets.py +192 -149
  54. mindspore/dataset/engine/datasets_audio.py +14 -0
  55. mindspore/dataset/engine/datasets_standard_format.py +28 -11
  56. mindspore/dataset/engine/datasets_text.py +38 -1
  57. mindspore/dataset/engine/datasets_user_defined.py +125 -65
  58. mindspore/dataset/engine/datasets_vision.py +81 -8
  59. mindspore/dataset/engine/iterators.py +281 -63
  60. mindspore/dataset/engine/obs/util.py +8 -0
  61. mindspore/dataset/engine/queue.py +40 -0
  62. mindspore/dataset/engine/samplers.py +26 -2
  63. mindspore/dataset/engine/serializer_deserializer.py +1 -1
  64. mindspore/dataset/engine/validators.py +43 -11
  65. mindspore/dataset/transforms/py_transforms_util.py +17 -0
  66. mindspore/dataset/transforms/transforms.py +29 -12
  67. mindspore/dataset/vision/validators.py +1 -2
  68. mindspore/device_context/__init__.py +21 -0
  69. mindspore/device_context/ascend/__init__.py +25 -0
  70. mindspore/device_context/ascend/device.py +72 -0
  71. mindspore/device_context/ascend/op_debug.py +94 -0
  72. mindspore/device_context/ascend/op_precision.py +193 -0
  73. mindspore/device_context/ascend/op_tuning.py +127 -0
  74. mindspore/device_context/cpu/__init__.py +25 -0
  75. mindspore/device_context/cpu/device.py +62 -0
  76. mindspore/device_context/cpu/op_tuning.py +43 -0
  77. mindspore/device_context/gpu/__init__.py +21 -0
  78. mindspore/device_context/gpu/device.py +70 -0
  79. mindspore/device_context/gpu/op_precision.py +67 -0
  80. mindspore/device_context/gpu/op_tuning.py +175 -0
  81. mindspore/device_manager.py +134 -0
  82. mindspore/dnnl.dll +0 -0
  83. mindspore/dpcmi.dll +0 -0
  84. mindspore/experimental/llm_boost/__init__.py +3 -2
  85. mindspore/experimental/llm_boost/ascend_native/__init__.py +22 -0
  86. mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +211 -0
  87. mindspore/experimental/llm_boost/ascend_native/llm_boost.py +52 -0
  88. mindspore/experimental/llm_boost/atb/boost_base.py +239 -64
  89. mindspore/experimental/llm_boost/atb/llama_boost.py +52 -30
  90. mindspore/experimental/llm_boost/atb/qwen_boost.py +47 -24
  91. mindspore/experimental/llm_boost/register.py +1 -0
  92. mindspore/experimental/optim/adadelta.py +26 -22
  93. mindspore/experimental/optim/adam.py +3 -0
  94. mindspore/experimental/optim/lr_scheduler.py +33 -24
  95. mindspore/experimental/optim/radam.py +33 -30
  96. mindspore/hal/device.py +28 -0
  97. mindspore/hal/event.py +17 -0
  98. mindspore/hal/memory.py +94 -3
  99. mindspore/hal/stream.py +91 -6
  100. mindspore/include/api/context.h +1 -2
  101. mindspore/include/dataset/constants.h +2 -2
  102. mindspore/jpeg62.dll +0 -0
  103. mindspore/log.py +12 -0
  104. mindspore/mindrecord/__init__.py +1 -1
  105. mindspore/mindrecord/config.py +17 -316
  106. mindspore/mindrecord/filereader.py +1 -9
  107. mindspore/mindrecord/filewriter.py +5 -15
  108. mindspore/mindrecord/mindpage.py +1 -9
  109. mindspore/mindspore_backend.dll +0 -0
  110. mindspore/mindspore_common.dll +0 -0
  111. mindspore/mindspore_core.dll +0 -0
  112. mindspore/mindspore_glog.dll +0 -0
  113. mindspore/mindspore_ops.dll +0 -0
  114. mindspore/mint/__init__.py +824 -218
  115. mindspore/mint/distributed/__init__.py +66 -4
  116. mindspore/mint/distributed/distributed.py +2594 -44
  117. mindspore/mint/linalg/__init__.py +6 -0
  118. mindspore/mint/nn/__init__.py +473 -14
  119. mindspore/mint/nn/functional.py +486 -11
  120. mindspore/mint/nn/layer/__init__.py +17 -4
  121. mindspore/mint/nn/layer/_functions.py +330 -0
  122. mindspore/mint/nn/layer/activation.py +169 -1
  123. mindspore/mint/nn/layer/basic.py +123 -0
  124. mindspore/mint/nn/layer/conv.py +727 -0
  125. mindspore/mint/nn/layer/normalization.py +215 -19
  126. mindspore/mint/nn/layer/padding.py +797 -0
  127. mindspore/mint/nn/layer/pooling.py +170 -0
  128. mindspore/mint/optim/__init__.py +2 -1
  129. mindspore/mint/optim/adam.py +223 -0
  130. mindspore/mint/optim/adamw.py +26 -19
  131. mindspore/mint/special/__init__.py +2 -1
  132. mindspore/msobj140.dll +0 -0
  133. mindspore/mspdb140.dll +0 -0
  134. mindspore/mspdbcore.dll +0 -0
  135. mindspore/mspdbst.dll +0 -0
  136. mindspore/mspft140.dll +0 -0
  137. mindspore/msvcdis140.dll +0 -0
  138. mindspore/msvcp140_1.dll +0 -0
  139. mindspore/msvcp140_2.dll +0 -0
  140. mindspore/msvcp140_atomic_wait.dll +0 -0
  141. mindspore/msvcp140_codecvt_ids.dll +0 -0
  142. mindspore/multiprocessing/__init__.py +5 -0
  143. mindspore/nn/__init__.py +2 -0
  144. mindspore/nn/cell.py +142 -21
  145. mindspore/nn/dynamic_lr.py +2 -1
  146. mindspore/nn/layer/activation.py +6 -6
  147. mindspore/nn/layer/basic.py +35 -25
  148. mindspore/nn/layer/channel_shuffle.py +3 -3
  149. mindspore/nn/layer/conv.py +3 -0
  150. mindspore/nn/layer/embedding.py +3 -3
  151. mindspore/nn/layer/normalization.py +8 -7
  152. mindspore/nn/layer/padding.py +4 -3
  153. mindspore/nn/layer/pooling.py +55 -23
  154. mindspore/nn/layer/rnn_cells.py +1 -1
  155. mindspore/nn/layer/rnns.py +2 -1
  156. mindspore/nn/layer/timedistributed.py +5 -5
  157. mindspore/nn/layer/transformer.py +48 -26
  158. mindspore/nn/learning_rate_schedule.py +5 -3
  159. mindspore/nn/loss/loss.py +31 -36
  160. mindspore/nn/optim/ada_grad.py +1 -0
  161. mindspore/nn/optim/adadelta.py +2 -2
  162. mindspore/nn/optim/adam.py +1 -1
  163. mindspore/nn/optim/lars.py +1 -4
  164. mindspore/nn/optim/optimizer.py +1 -1
  165. mindspore/nn/optim/rprop.py +2 -2
  166. mindspore/nn/optim/thor.py +2 -1
  167. mindspore/nn/utils/__init__.py +22 -0
  168. mindspore/nn/utils/init.py +73 -0
  169. mindspore/nn/wrap/cell_wrapper.py +4 -6
  170. mindspore/nn/wrap/loss_scale.py +3 -4
  171. mindspore/numpy/array_creations.py +60 -62
  172. mindspore/numpy/array_ops.py +148 -143
  173. mindspore/numpy/logic_ops.py +41 -42
  174. mindspore/numpy/math_ops.py +361 -359
  175. mindspore/numpy/utils.py +16 -16
  176. mindspore/numpy/utils_const.py +4 -4
  177. mindspore/opencv_core452.dll +0 -0
  178. mindspore/opencv_imgcodecs452.dll +0 -0
  179. mindspore/opencv_imgproc452.dll +0 -0
  180. mindspore/ops/__init__.py +2 -1
  181. mindspore/ops/_grad_experimental/grad_comm_ops.py +107 -8
  182. mindspore/ops/_grad_experimental/grad_debug_ops.py +6 -1
  183. mindspore/ops/_grad_experimental/grad_inner_ops.py +9 -0
  184. mindspore/ops/_grad_experimental/grad_math_ops.py +2 -1
  185. mindspore/ops/_op_impl/cpu/__init__.py +1 -0
  186. mindspore/ops/_op_impl/cpu/raise_op.py +28 -0
  187. mindspore/ops/_vmap/vmap_array_ops.py +20 -19
  188. mindspore/ops/_vmap/vmap_base.py +0 -2
  189. mindspore/ops/_vmap/vmap_grad_nn_ops.py +19 -13
  190. mindspore/ops/_vmap/vmap_math_ops.py +11 -9
  191. mindspore/ops/_vmap/vmap_nn_ops.py +20 -34
  192. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +149 -12
  193. mindspore/ops/auto_generate/gen_arg_handler.py +0 -61
  194. mindspore/ops/auto_generate/gen_extend_func.py +554 -60
  195. mindspore/ops/auto_generate/gen_ops_def.py +1621 -115
  196. mindspore/ops/auto_generate/gen_ops_prim.py +8027 -3411
  197. mindspore/ops/auto_generate/pyboost_inner_prim.py +183 -79
  198. mindspore/ops/composite/base.py +1 -1
  199. mindspore/ops/composite/multitype_ops/_compile_utils.py +229 -30
  200. mindspore/ops/composite/multitype_ops/pow_impl.py +0 -29
  201. mindspore/ops/function/__init__.py +12 -0
  202. mindspore/ops/function/array_func.py +561 -159
  203. mindspore/ops/function/clip_func.py +64 -0
  204. mindspore/ops/function/debug_func.py +28 -20
  205. mindspore/ops/function/image_func.py +1 -1
  206. mindspore/ops/function/linalg_func.py +5 -4
  207. mindspore/ops/function/math_func.py +1664 -294
  208. mindspore/ops/function/nn_func.py +988 -317
  209. mindspore/ops/function/parameter_func.py +3 -56
  210. mindspore/ops/function/random_func.py +243 -33
  211. mindspore/ops/function/sparse_unary_func.py +1 -1
  212. mindspore/ops/functional.py +18 -5
  213. mindspore/ops/functional_overload.py +897 -0
  214. mindspore/ops/operations/__init__.py +3 -2
  215. mindspore/ops/operations/_embedding_cache_ops.py +4 -4
  216. mindspore/ops/operations/_grad_ops.py +2 -34
  217. mindspore/ops/operations/_infer_ops.py +2 -1
  218. mindspore/ops/operations/_inner_ops.py +38 -8
  219. mindspore/ops/operations/array_ops.py +45 -303
  220. mindspore/ops/operations/comm_ops.py +23 -17
  221. mindspore/ops/operations/custom_ops.py +7 -49
  222. mindspore/ops/operations/debug_ops.py +42 -47
  223. mindspore/ops/operations/inner_ops.py +6 -4
  224. mindspore/ops/operations/linalg_ops.py +3 -2
  225. mindspore/ops/operations/manually_defined/ops_def.py +185 -104
  226. mindspore/ops/operations/math_ops.py +11 -216
  227. mindspore/ops/operations/nn_ops.py +153 -310
  228. mindspore/ops/primitive.py +23 -21
  229. mindspore/ops/tensor_method.py +1669 -0
  230. mindspore/ops_generate/aclnn_kernel_register_auto_cc_generator.py +110 -0
  231. mindspore/ops_generate/add_tensor_docs_generator.py +54 -0
  232. mindspore/ops_generate/arg_handler.py +0 -61
  233. mindspore/ops_generate/auto_grad_impl_cc_generator.py +135 -0
  234. mindspore/ops_generate/auto_grad_reg_cc_generator.py +93 -0
  235. mindspore/ops_generate/base_generator.py +11 -0
  236. mindspore/ops_generate/cpp_create_prim_instance_helper_generator.py +108 -0
  237. mindspore/ops_generate/functional_map_cpp_generator.py +491 -0
  238. mindspore/ops_generate/functional_overload_py_generator.py +110 -0
  239. mindspore/ops_generate/functions_cc_generator.py +233 -0
  240. mindspore/ops_generate/gen_aclnn_implement.py +110 -114
  241. mindspore/ops_generate/gen_constants.py +157 -3
  242. mindspore/ops_generate/gen_ops.py +245 -990
  243. mindspore/ops_generate/gen_pyboost_func.py +97 -998
  244. mindspore/ops_generate/gen_utils.py +119 -33
  245. mindspore/ops_generate/lite_ops_cpp_generator.py +155 -0
  246. mindspore/ops_generate/op_api_proto.py +206 -0
  247. mindspore/ops_generate/op_def_py_generator.py +131 -0
  248. mindspore/ops_generate/op_prim_py_generator.py +480 -0
  249. mindspore/ops_generate/op_proto.py +373 -108
  250. mindspore/ops_generate/op_template_parser.py +436 -0
  251. mindspore/ops_generate/ops_def_cc_generator.py +288 -0
  252. mindspore/ops_generate/ops_def_h_generator.py +74 -0
  253. mindspore/ops_generate/ops_name_h_generator.py +68 -0
  254. mindspore/ops_generate/ops_primitive_h_generator.py +81 -0
  255. mindspore/ops_generate/pyboost_functions_cpp_generator.py +370 -0
  256. mindspore/ops_generate/pyboost_functions_h_generator.py +68 -0
  257. mindspore/ops_generate/pyboost_functions_py_generator.py +148 -0
  258. mindspore/ops_generate/pyboost_grad_function_cpp_generator.py +154 -0
  259. mindspore/ops_generate/pyboost_inner_prim_generator.py +131 -0
  260. mindspore/ops_generate/pyboost_native_grad_functions_generator.py +268 -0
  261. mindspore/ops_generate/pyboost_op_cpp_code_generator.py +851 -0
  262. mindspore/ops_generate/pyboost_overload_functions_cpp_generator.py +344 -0
  263. mindspore/ops_generate/pyboost_utils.py +92 -33
  264. mindspore/ops_generate/template.py +294 -44
  265. mindspore/ops_generate/tensor_func_reg_cpp_generator.py +422 -0
  266. mindspore/parallel/__init__.py +3 -3
  267. mindspore/parallel/_auto_parallel_context.py +44 -34
  268. mindspore/parallel/_cell_wrapper.py +22 -3
  269. mindspore/parallel/_parallel_serialization.py +13 -2
  270. mindspore/parallel/_utils.py +4 -2
  271. mindspore/parallel/algo_parameter_config.py +1 -1
  272. mindspore/parallel/checkpoint_transform.py +44 -0
  273. mindspore/parallel/cluster/process_entity/_api.py +131 -37
  274. mindspore/parallel/cluster/process_entity/_utils.py +41 -6
  275. mindspore/parallel/cluster/run.py +20 -3
  276. mindspore/parallel/parameter_broadcast.py +1 -1
  277. mindspore/parallel/shard.py +3 -0
  278. mindspore/parallel/transform_safetensors.py +119 -253
  279. mindspore/pgodb140.dll +0 -0
  280. mindspore/pgort140.dll +0 -0
  281. mindspore/profiler/__init__.py +17 -4
  282. mindspore/profiler/analysis/__init__.py +0 -0
  283. mindspore/profiler/analysis/parser/__init__.py +0 -0
  284. mindspore/profiler/analysis/parser/ascend_cann_parser.py +166 -0
  285. mindspore/profiler/analysis/parser/base_parser.py +158 -0
  286. mindspore/profiler/analysis/parser/framework_cann_relation_parser.py +45 -0
  287. mindspore/profiler/analysis/parser/ms_framework_parser.py +142 -0
  288. mindspore/profiler/analysis/parser/ms_minddata_parser.py +145 -0
  289. mindspore/profiler/analysis/parser/timeline_assembly_factory/__init__.py +0 -0
  290. mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +261 -0
  291. mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +40 -0
  292. mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +84 -0
  293. mindspore/profiler/analysis/parser/timeline_creator/__init__.py +0 -0
  294. mindspore/profiler/analysis/parser/timeline_creator/base_timeline_creator.py +44 -0
  295. mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +90 -0
  296. mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +76 -0
  297. mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +103 -0
  298. mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +134 -0
  299. mindspore/profiler/analysis/parser/timeline_event/__init__.py +0 -0
  300. mindspore/profiler/analysis/parser/timeline_event/base_event.py +233 -0
  301. mindspore/profiler/analysis/parser/timeline_event/cpu_op_event.py +47 -0
  302. mindspore/profiler/analysis/parser/timeline_event/flow_event.py +36 -0
  303. mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +260 -0
  304. mindspore/profiler/analysis/parser/timeline_event/msprof_event.py +73 -0
  305. mindspore/profiler/analysis/parser/timeline_event/scope_layer_event.py +53 -0
  306. mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +146 -0
  307. mindspore/profiler/analysis/task_manager.py +131 -0
  308. mindspore/profiler/analysis/time_converter.py +84 -0
  309. mindspore/profiler/analysis/viewer/__init__.py +0 -0
  310. mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +333 -0
  311. mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +87 -0
  312. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +252 -0
  313. mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +313 -0
  314. mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +322 -0
  315. mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +265 -0
  316. mindspore/profiler/analysis/viewer/ascend_timeline_viewer.py +58 -0
  317. mindspore/profiler/analysis/viewer/base_viewer.py +26 -0
  318. mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +97 -0
  319. mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +581 -0
  320. mindspore/profiler/analysis/work_flow.py +73 -0
  321. mindspore/profiler/common/ascend_msprof_exporter.py +138 -0
  322. mindspore/profiler/common/command_executor.py +90 -0
  323. mindspore/profiler/common/constant.py +174 -3
  324. mindspore/profiler/common/file_manager.py +208 -0
  325. mindspore/profiler/common/log.py +130 -0
  326. mindspore/profiler/common/msprof_cmd_tool.py +202 -0
  327. mindspore/profiler/common/path_manager.py +371 -0
  328. mindspore/profiler/common/process_bar.py +168 -0
  329. mindspore/profiler/common/process_pool.py +9 -3
  330. mindspore/profiler/common/profiler_context.py +476 -0
  331. mindspore/profiler/common/profiler_info.py +304 -0
  332. mindspore/profiler/common/profiler_output_path.py +284 -0
  333. mindspore/profiler/common/profiler_parameters.py +210 -0
  334. mindspore/profiler/common/profiler_path_manager.py +120 -0
  335. mindspore/profiler/common/record_function.py +76 -0
  336. mindspore/profiler/common/tlv_decoder.py +76 -0
  337. mindspore/profiler/common/util.py +75 -2
  338. mindspore/profiler/dynamic_profiler.py +270 -37
  339. mindspore/profiler/envprofiler.py +138 -0
  340. mindspore/profiler/mstx.py +199 -0
  341. mindspore/profiler/platform/__init__.py +21 -0
  342. mindspore/profiler/platform/base_profiler.py +40 -0
  343. mindspore/profiler/platform/cpu_profiler.py +124 -0
  344. mindspore/profiler/platform/gpu_profiler.py +74 -0
  345. mindspore/profiler/platform/npu_profiler.py +309 -0
  346. mindspore/profiler/profiler.py +580 -93
  347. mindspore/profiler/profiler_action_controller.py +187 -0
  348. mindspore/profiler/profiler_interface.py +114 -0
  349. mindspore/profiler/schedule.py +208 -0
  350. mindspore/rewrite/api/symbol_tree.py +1 -2
  351. mindspore/run_check/_check_version.py +18 -13
  352. mindspore/runtime/__init__.py +37 -0
  353. mindspore/runtime/device.py +27 -0
  354. mindspore/runtime/event.py +209 -0
  355. mindspore/runtime/executor.py +148 -0
  356. mindspore/runtime/memory.py +392 -0
  357. mindspore/runtime/stream.py +460 -0
  358. mindspore/runtime/thread_bind_core.py +401 -0
  359. mindspore/swresample-4.dll +0 -0
  360. mindspore/swscale-6.dll +0 -0
  361. mindspore/tbbmalloc.dll +0 -0
  362. mindspore/tinyxml2.dll +0 -0
  363. mindspore/train/__init__.py +2 -2
  364. mindspore/train/_utils.py +53 -18
  365. mindspore/train/amp.py +8 -4
  366. mindspore/train/callback/_checkpoint.py +32 -18
  367. mindspore/train/callback/_early_stop.py +1 -1
  368. mindspore/train/callback/_flops_collector.py +105 -69
  369. mindspore/train/callback/_history.py +1 -1
  370. mindspore/train/callback/_summary_collector.py +44 -6
  371. mindspore/train/callback/_tft_register.py +37 -15
  372. mindspore/train/dataset_helper.py +11 -11
  373. mindspore/train/metrics/precision.py +4 -5
  374. mindspore/train/mind_ir_pb2.py +167 -46
  375. mindspore/train/model.py +13 -14
  376. mindspore/train/serialization.py +461 -72
  377. mindspore/train/summary/summary_record.py +1 -2
  378. mindspore/train/train_thor/model_thor.py +1 -1
  379. mindspore/turbojpeg.dll +0 -0
  380. mindspore/utils/__init__.py +4 -2
  381. mindspore/utils/dryrun.py +138 -0
  382. mindspore/utils/runtime_execution_order_check.py +550 -0
  383. mindspore/vcmeta.dll +0 -0
  384. mindspore/vcruntime140.dll +0 -0
  385. mindspore/vcruntime140_1.dll +0 -0
  386. mindspore/version.py +1 -1
  387. {mindspore-2.4.1.dist-info → mindspore-2.5.0.dist-info}/METADATA +3 -4
  388. {mindspore-2.4.1.dist-info → mindspore-2.5.0.dist-info}/RECORD +391 -265
  389. {mindspore-2.4.1.dist-info → mindspore-2.5.0.dist-info}/entry_points.txt +1 -1
  390. mindspore/common/_tensor_overload.py +0 -139
  391. mindspore/mindspore_np_dtype.dll +0 -0
  392. mindspore/profiler/envprofiling.py +0 -254
  393. mindspore/profiler/profiling.py +0 -1926
  394. {mindspore-2.4.1.dist-info → mindspore-2.5.0.dist-info}/WHEEL +0 -0
  395. {mindspore-2.4.1.dist-info → mindspore-2.5.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,370 @@
1
+ # Copyright 2024 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+ """
16
+ This module defines the PyboostFunctionsGenerator class for generating C++ functions for PyBoost operations.
17
+
18
+ The generator processes operator prototypes and constructs the necessary function definitions, including
19
+ conversions for optional parameters and tensor arguments. It generates the registration code and includes
20
+ the necessary header files for the generated functions.
21
+ """
22
+
23
+ import os
24
+
25
+ import pyboost_utils
26
+ from pyboost_utils import get_convert_type_str, is_optional_param, is_op_multi_output
27
+ import template
28
+ from template import Template
29
+ import gen_constants as K
30
+ from gen_utils import save_file
31
+ from op_proto import OpProto
32
+ from op_template_parser import OpTemplateParser
33
+ from base_generator import BaseGenerator
34
+
35
+
36
+ class PyboostFunctionsGenerator(BaseGenerator):
37
+ """
38
+ Generates PyBoost functions based on operator prototypes.
39
+
40
+ This class processes operator prototypes (`op_protos`) to create the necessary C++ function definitions for
41
+ PyBoost operations. It constructs function bodies, handles optional value conversions, and generates
42
+ registration code and header inclusions.
43
+ """
44
+
45
+ def __init__(self):
46
+ """Initializes the PyboostFunctionsGenerator with the necessary templates."""
47
+ self.pyboost_func_include_header_template = Template(
48
+ f'#include "{K.MS_COMMON_PYBOOST_KERNEL_PATH}/auto_generate/${{operator_name}}.h"\n'
49
+ )
50
+ self.convert_optional_to_value_template = Template(
51
+ "auto ${output} = PyNativeAlgo::PyBoost::OptionalToValue(${input});\n"
52
+ )
53
+ self.convert_to_tensor_template = Template(
54
+ 'auto ${output} = PyNativeAlgo::Common::ConvertStubNodeToTensor(${input}, ${need_contiguous}, '
55
+ 'op_run_info->requires_grad);\n'
56
+ )
57
+ self.convert_to_tensor_list_template = Template(
58
+ 'auto ${output} = PyNativeAlgo::Common::ConvertStubNodeToValueTuple(${input}, ${need_contiguous}, '
59
+ 'op_run_info->requires_grad);\n'
60
+ )
61
+ self.convert_template = Template("auto $arg_name = converter.${convert_func}(args, $arg_index);\n")
62
+ self.PYBOOST_FUNCTION_TEMPLATE = template.PYBOOST_FUNCTION_TEMPLATE
63
+ self.PYBOOST_COMM_FUNCTION_TEMPLATE = template.PYBOOST_COMM_FUNCTION_TEMPLATE
64
+ self.REGISTER_DEFINE_TEMPLATE = template.REGISTER_DEFINE_TEMPLATE
65
+ self.REGISTER_TEMPLATE = template.REGISTER_TEMPLATE
66
+ self.PYBOOST_HEADER_TEMPLATE = template.PYBOOST_FUNCTIONS_CC_TEMPLATE
67
+ self.TENSOR_FUNC_CLASS_REG = template.TENSOR_FUNC_CLASS_REG
68
+
69
+ def generate(self, work_path, op_protos, tensor_func_protos_data):
70
+ """
71
+ Generates the C++ PyBoost functions and writes them to the specified files.
72
+
73
+ This method processes a list of operator prototypes (`op_protos`), extracting necessary information
74
+ such as operator names, arguments, and conversion types. It constructs the function definitions, includes,
75
+ and registration code. The generated content is saved to the specified path as a C++ source file.
76
+
77
+ Args:
78
+ work_path (str): The file path where the generated files will be saved.
79
+ op_protos (list): A list of operator prototypes containing information about the operators to be processed.
80
+ tensor_func_protos_data(dict): A dict of tensor prototypes containing device-related information.
81
+
82
+ Returns:
83
+ None
84
+ """
85
+ pyboost_func_str = ''
86
+ pyboost_func_pybind_def = ''
87
+ pyboost_func_include_headers_str = ''
88
+ for op_proto in op_protos:
89
+ if op_proto.op_dispatch is None or not op_proto.op_dispatch.enable:
90
+ continue
91
+ op_parser = OpTemplateParser(op_proto)
92
+ op_pyboost_func_name = op_parser.get_pyboost_func_name()
93
+ op_def_name_str = op_parser.get_op_def_name_str()
94
+ type_num, same_type = op_parser.gen_signature_same_type_table()
95
+ parser_body_str = self._generate_parser_func(op_proto)
96
+ op_args_str = [op_arg.arg_name for op_arg in op_proto.op_args]
97
+ convert_stub_str = self._get_convert_stub_str(op_proto)
98
+ optional_to_value_str = self._get_optional_to_value_str(op_proto)
99
+ call_args_str = self._get_call_args_str(op_proto)
100
+ grad_args_str = self._get_grad_args_str(op_proto)
101
+ cast_args_str = self._get_cast_to_value_str(op_proto)
102
+ view_arg_str = self._get_first_str(op_proto.op_view, grad_args_str)
103
+ view_arg_str = ", " + view_arg_str if view_arg_str else ''
104
+ multi_ouptut_str = 'Multi' if is_op_multi_output(op_proto.op_returns) else ''
105
+ # communication operators have different func template
106
+ function_tpl = self.PYBOOST_COMM_FUNCTION_TEMPLATE \
107
+ if op_proto.op_dispatch.is_comm_op else self.PYBOOST_FUNCTION_TEMPLATE
108
+ pyboost_func_str += function_tpl.replace(func_name=op_pyboost_func_name,
109
+ op_def_name=op_def_name_str,
110
+ type_num=type_num,
111
+ same_type=same_type,
112
+ parser_body=parser_body_str,
113
+ op_name=op_proto.op_class.name,
114
+ class_name=op_proto.op_class.name,
115
+ op_args=op_args_str,
116
+ convert_stub=convert_stub_str,
117
+ optional_to_value=optional_to_value_str,
118
+ call_args=call_args_str,
119
+ grad_args=grad_args_str,
120
+ cast_args=cast_args_str,
121
+ view_arg=view_arg_str,
122
+ is_multi=multi_ouptut_str,
123
+ operator_name=op_proto.op_name)
124
+ pyboost_func_str = pyboost_func_str + template.NEW_LINE + template.NEW_LINE
125
+ pyboost_op_name = op_parser.get_pyboost_name()
126
+ pyboost_func_name = op_parser.get_pyboost_func_name()
127
+ pyboost_func_pybind_def += self.REGISTER_DEFINE_TEMPLATE.replace(
128
+ pyboost_op_name=pyboost_op_name,
129
+ pyboost_cfunc_name=pyboost_func_name,
130
+ class_name=op_proto.op_class.name)
131
+ pyboost_func_include_headers_str += self.pyboost_func_include_header_template.replace(
132
+ operator_name=op_proto.op_name)
133
+ register_func_str = self.REGISTER_TEMPLATE.replace(register_func=pyboost_func_pybind_def)
134
+ function_class_register = self._get_function_class_register(op_protos)
135
+ pyboost_func_file = self.PYBOOST_HEADER_TEMPLATE.replace(include_op_header=pyboost_func_include_headers_str,
136
+ function_body=pyboost_func_str,
137
+ register_function_body=register_func_str,
138
+ function_class_register=function_class_register)
139
+ save_path = os.path.join(work_path, K.PIPELINE_PYBOOST_FUNC_GEN_PATH)
140
+ file_name = "pyboost_functions.cc"
141
+ save_file(save_path, file_name, pyboost_func_file)
142
+
143
+
144
+ def _get_cast_args_with_type_str(self, op_proto, cast_args_str):
145
+ args_with_type = []
146
+ for op_arg, cast_args_name in zip(op_proto.op_args, cast_args_str):
147
+ input_dtype = get_input_dtype(op_arg.arg_dtype, is_optional_param(op_arg))
148
+ args_with_type.append("const " + input_dtype + " &" + cast_args_name)
149
+ return list(args_with_type)
150
+
151
+
152
+ def _get_function_class_register(self, op_protos) -> str:
153
+ """
154
+ Generates a function class registration string for tensor functions.
155
+
156
+ Args:
157
+ op_protos (list): A list of tensor op prototypes.
158
+
159
+ Returns:
160
+ str: A concatenated string representing the registration information for tensor
161
+ function classes.
162
+ """
163
+ function_class_register = ''
164
+ for op_proto in op_protos:
165
+ if op_proto.op_dispatch is None or not op_proto.op_dispatch.enable:
166
+ continue
167
+ class_name, op_name = op_proto.op_class.name, op_proto.op_name
168
+ function_class_register += self.TENSOR_FUNC_CLASS_REG.replace(class_name=class_name,
169
+ op_name=op_name)
170
+ return function_class_register
171
+
172
+ def _generate_parser_func(self, op_proto: OpProto) -> str:
173
+ """
174
+ Generates the parsing function for the operator's arguments.
175
+
176
+ This method constructs the code for converting each argument in the operator prototype to its appropriate
177
+ type, handling optional parameters as necessary.
178
+
179
+ Args:
180
+ op_proto (OpProto): The operator prototype containing the argument information.
181
+
182
+ Returns:
183
+ str: The generated parsing function code as a string.
184
+ """
185
+ parser_func_str = ''
186
+ for index, op_arg in enumerate(op_proto.op_args):
187
+ is_optional = is_optional_param(op_arg)
188
+ if op_arg.is_type_id:
189
+ convert_type_str = get_convert_type_str('type', is_optional)
190
+ else:
191
+ convert_type_str = get_convert_type_str(op_arg.arg_dtype, is_optional)
192
+ parser_func_str += self.convert_template.replace(arg_name=op_arg.arg_name, convert_func=convert_type_str,
193
+ arg_index=pyboost_utils.get_index(index))
194
+ return parser_func_str
195
+
196
+ def _get_convert_stub_str(self, op_proto: OpProto):
197
+ """
198
+ Generates the conversion stub code for the operator's arguments.
199
+
200
+ This method creates code for converting operator arguments to tensor format, depending on whether they
201
+ are view operations or standard tensor operations.
202
+
203
+ Args:
204
+ op_proto (OpProto): The operator prototype containing the argument information.
205
+
206
+ Returns:
207
+ str: The generated conversion stub code as a string.
208
+ """
209
+ convert_stub_str = ''
210
+ need_contiguous = 'true'
211
+ if op_proto.op_view:
212
+ # View/ACLNN op does not need to convert to contiguous tensor.
213
+ need_contiguous = 'false'
214
+ for op_arg in op_proto.op_args:
215
+ if pyboost_utils.is_tensor(op_arg):
216
+ convert_stub_output_name = op_arg.arg_name + '_optional' if is_optional_param(op_arg) \
217
+ else op_arg.arg_name + "_tensor"
218
+ convert_stub_str += self.convert_to_tensor_template.replace(input=op_arg.arg_name,
219
+ output=convert_stub_output_name,
220
+ need_contiguous=need_contiguous)
221
+ elif pyboost_utils.is_tensor_list(op_arg):
222
+ # To adapt the cases where TensorList is optional.
223
+ convert_stub_output_name = op_arg.arg_name + '_optional' if is_optional_param(op_arg) \
224
+ else op_arg.arg_name + "_tensor_list"
225
+ convert_stub_str += self.convert_to_tensor_list_template.replace(input=op_arg.arg_name,
226
+ output=convert_stub_output_name,
227
+ need_contiguous=need_contiguous)
228
+ return convert_stub_str
229
+
230
+ def _get_optional_to_value_str(self, op_proto: OpProto):
231
+ """
232
+ Generates the code for converting optional arguments to their corresponding values.
233
+
234
+ This method constructs code to handle optional arguments and converts them to their actual values,
235
+ ensuring proper handling for tensors and lists.
236
+
237
+ Args:
238
+ op_proto (OpProto): The operator prototype containing the argument information.
239
+
240
+ Returns:
241
+ str: The generated code for converting optional arguments to values as a string.
242
+ """
243
+ optional_to_value_str = ''
244
+ for op_arg in op_proto.op_args:
245
+ if is_optional_param(op_arg):
246
+ if pyboost_utils.is_tensor(op_arg) or pyboost_utils.is_tensor_list(op_arg):
247
+ convert_stub_output_name = op_arg.arg_name + '_optional'
248
+ cast_output = 'cast_' + convert_stub_output_name
249
+ convert_optional_to_value_name = op_arg.arg_name + '_value'
250
+ optional_to_value_str += \
251
+ self.convert_optional_to_value_template.replace(input=cast_output,
252
+ output=convert_optional_to_value_name)
253
+ else:
254
+ call_arg = op_arg.arg_name
255
+ convert_optional_to_value_name = op_arg.arg_name + '_value'
256
+ optional_to_value_str += \
257
+ self.convert_optional_to_value_template.replace(input=call_arg,
258
+ output=convert_optional_to_value_name)
259
+ return optional_to_value_str
260
+
261
+ def _get_call_args_str(self, op_proto: OpProto):
262
+ """
263
+ Generates the list of call arguments for the operator.
264
+
265
+ This method constructs a list of argument names for the function call, adapting the names for
266
+ optional tensors and tensor lists as needed.
267
+
268
+ Args:
269
+ op_proto (OpProto): The operator prototype containing the argument information.
270
+
271
+ Returns:
272
+ list: A list of formatted argument names for the function call.
273
+ """
274
+ call_args_str = []
275
+ for op_arg in op_proto.op_args:
276
+ if pyboost_utils.is_tensor(op_arg):
277
+ convert_stub_output_name = op_arg.arg_name + '_optional' if is_optional_param(op_arg) \
278
+ else op_arg.arg_name + "_tensor"
279
+ call_arg = convert_stub_output_name
280
+ elif pyboost_utils.is_tensor_list(op_arg):
281
+ convert_stub_output_name = op_arg.arg_name + '_optional' if is_optional_param(op_arg) \
282
+ else op_arg.arg_name + "_tensor_list"
283
+ call_arg = convert_stub_output_name
284
+ else:
285
+ call_arg = op_arg.arg_name
286
+ call_args_str.append(call_arg)
287
+ return call_args_str
288
+
289
+ def _get_grad_args_str(self, op_proto: OpProto):
290
+ """
291
+ Generates the list of gradient arguments for the operator.
292
+
293
+ This method constructs a list of argument names used for computing gradients, adapting for
294
+ optional tensors and tensor lists as necessary.
295
+
296
+ Args:
297
+ op_proto (OpProto): The operator prototype containing the argument information.
298
+
299
+ Returns:
300
+ list: A list of formatted gradient argument names.
301
+ """
302
+ grad_args_str = []
303
+ for op_arg in op_proto.op_args:
304
+ if pyboost_utils.is_tensor(op_arg):
305
+ grad_arg = op_arg.arg_name + "_value" if is_optional_param(op_arg) else \
306
+ f"cast_" + op_arg.arg_name + "_tensor"
307
+ elif pyboost_utils.is_tensor_list(op_arg):
308
+ if is_optional_param(op_arg):
309
+ # To adapt the cases where TensorList is optional.
310
+ convert_optional_to_value_name = op_arg.arg_name + "_value"
311
+ grad_arg = convert_optional_to_value_name
312
+ else:
313
+ convert_stub_output_name = op_arg.arg_name + "_tensor_list"
314
+ grad_arg = "cast_" + convert_stub_output_name
315
+ else:
316
+ grad_arg = "cast_" + op_arg.arg_name
317
+ if is_optional_param(op_arg):
318
+ convert_optional_to_value_name = op_arg.arg_name + "_value"
319
+ grad_arg = convert_optional_to_value_name
320
+ grad_args_str.append(grad_arg)
321
+ return grad_args_str
322
+
323
+ def _get_cast_to_value_str(self, op_proto: OpProto):
324
+ """
325
+ Generates the list of cast arguments for the operator.
326
+
327
+ This method constructs a list of argument names that need to be cast to their corresponding types.
328
+
329
+ Args:
330
+ op_proto (OpProto): The operator prototype containing the argument information.
331
+
332
+ Returns:
333
+ list: A list of formatted cast argument names.
334
+ """
335
+ cast_args_str = []
336
+ for op_arg in op_proto.op_args:
337
+ cast_str = 'cast_'
338
+ if pyboost_utils.is_tensor(op_arg):
339
+ convert_stub_output_name = op_arg.arg_name + '_optional' if is_optional_param(op_arg) \
340
+ else op_arg.arg_name + "_tensor"
341
+ cast_arg = cast_str + convert_stub_output_name
342
+ elif pyboost_utils.is_tensor_list(op_arg):
343
+ # To adapt the cases where TensorList is optional.
344
+ convert_stub_output_name = op_arg.arg_name + '_optional' if is_optional_param(op_arg) \
345
+ else op_arg.arg_name + "_tensor_list"
346
+ cast_arg = cast_str + convert_stub_output_name
347
+ else:
348
+ cast_arg = cast_str + op_arg.arg_name
349
+ cast_args_str.append(cast_arg)
350
+ return cast_args_str
351
+
352
+ def _get_first_str(self, is_view_or_inplace: bool, grad_args: list):
353
+ """
354
+ Generates the view base str of arguments for the operator.
355
+
356
+ This method constructs a list of argument names that need to be cast to their corresponding types.
357
+
358
+ Args:
359
+ is_view_or_inplace (bool): Whether the op is view op or inplace op.
360
+ grad_args (list): grad args
361
+
362
+ Returns:
363
+ str: Formatted view or inplace first argument names.
364
+ """
365
+ arg_str = ''
366
+ for i, grad_arg in enumerate(grad_args):
367
+ if is_view_or_inplace and i == 0:
368
+ arg_str = grad_arg
369
+ break
370
+ return arg_str
@@ -0,0 +1,68 @@
1
+ # Copyright 2024 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+ """
16
+ This module defines the `PyboostFunctionsHeaderGenerator` class, which is responsible for generating
17
+ the header file (`pyboost_functions.h`) for Pyboost function declarations.
18
+
19
+ The class uses templates and operation prototypes to create function declarations based on the
20
+ operation's primitive and arguments. The generated file is saved to the specified path.
21
+ """
22
+
23
+ import os
24
+
25
+ import template
26
+
27
+ from template import Template
28
+ import gen_constants as K
29
+ from gen_utils import save_file
30
+ from op_template_parser import OpTemplateParser
31
+ from base_generator import BaseGenerator
32
+
33
+
34
+ class PyboostFunctionsHeaderGenerator(BaseGenerator):
35
+ """
36
+ A class to generate the `pyboost_functions.h` header file, which contains Pyboost function declarations.
37
+ """
38
+
39
+ def __init__(self):
40
+ """Initializes the PyboostFunctionsHeaderGenerator with the necessary templates."""
41
+ self.PYBOOST_FUNCTION_HEADER_TEMPLATE = template.PYBOOST_FUNCTION_HEADER_TEMPLATE
42
+
43
+ self.pyboost_func_template = Template(
44
+ 'py::object ME_EXPORT ${func_name}_Base(const PrimitivePtr &prim, const py::list &args);'
45
+ )
46
+
47
+ def generate(self, work_path, op_protos):
48
+ """
49
+ Generates the Pyboost function header file (`pyboost_functions.h`).
50
+
51
+ Args:
52
+ work_path (str): The directory where the generated file will be saved.
53
+ op_protos (list): A list of operation prototypes to parse and convert into Pyboost function declarations.
54
+
55
+ Returns:
56
+ None: The method writes the generated header file to the specified directory.
57
+ """
58
+ func_list = []
59
+ for op_proto in op_protos:
60
+ if op_proto.op_dispatch is None or not op_proto.op_dispatch.enable:
61
+ continue
62
+ op_parser = OpTemplateParser(op_proto)
63
+ op_pyboost_func_name = op_parser.get_pyboost_func_name()
64
+ func_list.append(self.pyboost_func_template.replace(func_name=op_pyboost_func_name))
65
+ pyboost_func_h_str = self.PYBOOST_FUNCTION_HEADER_TEMPLATE.replace(prim_func_list=func_list)
66
+ save_path = os.path.join(work_path, K.PIPELINE_PYBOOST_FUNC_GEN_PATH)
67
+ file_name = "pyboost_functions.h"
68
+ save_file(save_path, file_name, pyboost_func_h_str)
@@ -0,0 +1,148 @@
1
+ # Copyright 2024 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+ """
16
+ This module defines the PyboostFunctionsPyGenerator class for generating Python bindings for PyBoost functions.
17
+
18
+ The PyboostFunctionsPyGenerator class processes operator prototypes and generates Python functions
19
+ that correspond to the PyBoost operations defined in the operator prototypes. It handles the necessary
20
+ argument processing and includes appropriate documentation descriptions.
21
+ """
22
+
23
+ import os
24
+
25
+ import template
26
+ import gen_constants as K
27
+ from gen_utils import save_file
28
+ from op_proto import OpProto
29
+ from base_generator import BaseGenerator
30
+
31
+
32
+ class PyboostFunctionsPyGenerator(BaseGenerator):
33
+ """
34
+ Generates Python bindings for PyBoost functions.
35
+
36
+ This class is responsible for creating Python function definitions that correspond to the PyBoost
37
+ operations defined in operator prototypes. It generates a Python file that includes necessary function
38
+ definitions and their descriptions.
39
+ """
40
+
41
+ def __init__(self):
42
+ """Initializes the PyboostFunctionsPyGenerator with required templates."""
43
+ self.IMPORT_PYBOOST_FUNC_HEADER = template.IMPORT_PYBOOST_FUNC_HEADER
44
+ self.PYBOOST_PY_FUNC_TEMPLATE = template.PYBOOST_PY_FUNC_TEMPLATE
45
+
46
+ def generate(self, work_path, op_protos, doc_data):
47
+ """
48
+ Generates the Python file containing PyBoost function definitions.
49
+
50
+ This method processes the provided operator prototypes (`op_protos`), generates Python function
51
+ definitions for each operator that meets the specified conditions, and saves the generated content
52
+ to a Python file.
53
+
54
+ Args:
55
+ work_path (str): The directory path where the generated Python file will be saved.
56
+ op_protos (list): A list of operator prototypes containing information about the operators.
57
+ doc_data (dict): A dictionary containing documentation data for the operators.
58
+
59
+ Returns:
60
+ None
61
+ """
62
+ gen_py = ''
63
+ op_desc_dict = self._get_op_description_dict(doc_data)
64
+ for op_proto in op_protos:
65
+ # check if the operator is in pyboost scenario
66
+ if op_proto.op_dispatch is None or not op_proto.op_dispatch.enable:
67
+ continue
68
+ if op_proto.op_function.disable:
69
+ continue
70
+ if not op_proto.op_function.name.endswith("_ext") and not op_proto.op_name.endswith("_ext"):
71
+ continue
72
+
73
+ description = op_desc_dict.get(op_proto.op_name)
74
+ func_args, input_args = self._process_args(op_proto.op_args)
75
+ func_name, func_impl_name = self._get_func_impl_name(op_proto)
76
+ gen_py += self.PYBOOST_PY_FUNC_TEMPLATE.replace(func_name=func_name,
77
+ description=description,
78
+ func_args=func_args,
79
+ input_args=input_args,
80
+ func_impl_name=func_impl_name)
81
+ py_header = K.PY_LICENSE + self.IMPORT_PYBOOST_FUNC_HEADER
82
+ save_file(os.path.join(work_path, K.PY_AUTO_GEN_PATH), "gen_extend_func.py", py_header + gen_py)
83
+
84
+ def _get_op_description_dict(self, doc_yaml_data):
85
+ """
86
+ Constructs a dictionary mapping operator names to their descriptions.
87
+
88
+ Args:
89
+ doc_yaml_data (dict): A dictionary containing YAML data for operator documentation.
90
+
91
+ Returns:
92
+ dict: A dictionary mapping operator names to their descriptions.
93
+ """
94
+ op_description_dict = {}
95
+ for operator_name, operator_desc in doc_yaml_data.items():
96
+ desc = operator_desc.get("description")
97
+ op_description_dict[operator_name] = desc
98
+ return op_description_dict
99
+
100
+ def _process_args(self, op_args):
101
+ """
102
+ Processes the operator arguments to generate function argument strings.
103
+
104
+ Args:
105
+ op_args (list): A list of operator arguments to be processed.
106
+
107
+ Returns:
108
+ tuple: A tuple containing:
109
+ - func_args (list): A list of formatted function argument strings.
110
+ - input_args (list): A list of corresponding input argument names.
111
+ """
112
+ func_args = []
113
+ input_args = []
114
+ for op_arg in op_args:
115
+ arg_handler = op_arg.arg_handler
116
+ arg_name = op_arg.arg_name
117
+ input_arg = arg_name
118
+ if arg_handler != '' and arg_handler != 'dtype_to_type_id':
119
+ input_arg = 'converted_' + arg_name
120
+ input_args.append(input_arg)
121
+ default_value = op_arg.default
122
+ if default_value is not None:
123
+ default_value = '=' + str(default_value)
124
+ func_args.append(arg_name + default_value)
125
+ else:
126
+ func_args.append(arg_name)
127
+ return func_args, input_args
128
+
129
+ def _get_func_impl_name(self, op_proto: OpProto):
130
+ """
131
+ Retrieves the implementation function name based on the operator prototype.
132
+
133
+ Args:
134
+ op_proto (OpProto): The operator prototype containing function name information.
135
+
136
+ Returns:
137
+ tuple: A tuple containing:
138
+ - func_name (str): The name of the function.
139
+ - func_impl_name (str): The implementation name of the function.
140
+ """
141
+ func_name = op_proto.op_name if op_proto.op_function.name == '' \
142
+ else op_proto.op_function.name
143
+ if func_name.endswith("_ext"):
144
+ func_name = func_name[:-4]
145
+ func_impl_name = func_name
146
+ if func_name.endswith("_"):
147
+ func_impl_name = func_name[:-1]
148
+ return func_name, func_impl_name