mindspore 2.4.10__cp310-cp310-win_amd64.whl → 2.5.0__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (366) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +8 -3
  3. mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
  4. mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
  5. mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
  6. mindspore/_checkparam.py +0 -5
  7. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
  8. mindspore/_extends/parse/compile_config.py +64 -0
  9. mindspore/_extends/parse/deprecated/__init__.py +0 -0
  10. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +375 -0
  11. mindspore/_extends/parse/parser.py +23 -5
  12. mindspore/_extends/parse/standard_method.py +123 -27
  13. mindspore/_extends/pijit/pijit_func_white_list.py +1 -1
  14. mindspore/amp.py +7 -1
  15. mindspore/avcodec-59.dll +0 -0
  16. mindspore/avdevice-59.dll +0 -0
  17. mindspore/avfilter-8.dll +0 -0
  18. mindspore/avformat-59.dll +0 -0
  19. mindspore/avutil-57.dll +0 -0
  20. mindspore/boost/boost_cell_wrapper.py +136 -41
  21. mindspore/common/__init__.py +3 -1
  22. mindspore/common/_register_for_tensor.py +0 -1
  23. mindspore/common/_stub_tensor.py +25 -4
  24. mindspore/common/_tensor_cpp_method.py +17 -0
  25. mindspore/common/_tensor_docs.py +6132 -0
  26. mindspore/common/api.py +98 -21
  27. mindspore/common/dtype.py +34 -34
  28. mindspore/common/dump.py +2 -1
  29. mindspore/common/file_system.py +8 -3
  30. mindspore/common/generator.py +2 -0
  31. mindspore/common/hook_handle.py +3 -1
  32. mindspore/common/initializer.py +3 -4
  33. mindspore/common/lazy_inline.py +8 -2
  34. mindspore/common/mindir_util.py +10 -2
  35. mindspore/common/parameter.py +31 -15
  36. mindspore/common/tensor.py +713 -1337
  37. mindspore/communication/__init__.py +1 -1
  38. mindspore/communication/_comm_helper.py +5 -0
  39. mindspore/communication/comm_func.py +215 -173
  40. mindspore/communication/management.py +23 -20
  41. mindspore/context.py +285 -191
  42. mindspore/dataset/__init__.py +23 -19
  43. mindspore/dataset/callback/ds_callback.py +2 -1
  44. mindspore/dataset/core/config.py +84 -3
  45. mindspore/dataset/engine/cache_admin.py +3 -3
  46. mindspore/dataset/engine/cache_client.py +5 -4
  47. mindspore/dataset/engine/datasets.py +192 -149
  48. mindspore/dataset/engine/datasets_audio.py +14 -0
  49. mindspore/dataset/engine/datasets_standard_format.py +11 -11
  50. mindspore/dataset/engine/datasets_text.py +38 -1
  51. mindspore/dataset/engine/datasets_user_defined.py +100 -66
  52. mindspore/dataset/engine/datasets_vision.py +81 -8
  53. mindspore/dataset/engine/iterators.py +281 -63
  54. mindspore/dataset/engine/obs/util.py +8 -0
  55. mindspore/dataset/engine/queue.py +40 -0
  56. mindspore/dataset/engine/samplers.py +26 -2
  57. mindspore/dataset/engine/serializer_deserializer.py +1 -1
  58. mindspore/dataset/engine/validators.py +43 -11
  59. mindspore/dataset/transforms/py_transforms_util.py +17 -0
  60. mindspore/dataset/transforms/transforms.py +29 -12
  61. mindspore/dataset/vision/validators.py +1 -2
  62. mindspore/device_context/__init__.py +21 -0
  63. mindspore/device_context/ascend/__init__.py +25 -0
  64. mindspore/device_context/ascend/device.py +72 -0
  65. mindspore/device_context/ascend/op_debug.py +94 -0
  66. mindspore/device_context/ascend/op_precision.py +193 -0
  67. mindspore/device_context/ascend/op_tuning.py +127 -0
  68. mindspore/device_context/cpu/__init__.py +25 -0
  69. mindspore/device_context/cpu/device.py +62 -0
  70. mindspore/device_context/cpu/op_tuning.py +43 -0
  71. mindspore/device_context/gpu/__init__.py +21 -0
  72. mindspore/device_context/gpu/device.py +70 -0
  73. mindspore/device_context/gpu/op_precision.py +67 -0
  74. mindspore/device_context/gpu/op_tuning.py +175 -0
  75. mindspore/device_manager.py +134 -0
  76. mindspore/dnnl.dll +0 -0
  77. mindspore/experimental/llm_boost/__init__.py +1 -0
  78. mindspore/experimental/llm_boost/ascend_native/__init__.py +22 -0
  79. mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +211 -0
  80. mindspore/experimental/llm_boost/ascend_native/llm_boost.py +52 -0
  81. mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
  82. mindspore/experimental/llm_boost/atb/llama_boost.py +6 -1
  83. mindspore/experimental/llm_boost/register.py +1 -0
  84. mindspore/experimental/optim/adadelta.py +26 -22
  85. mindspore/experimental/optim/adam.py +3 -0
  86. mindspore/experimental/optim/lr_scheduler.py +33 -24
  87. mindspore/experimental/optim/radam.py +33 -30
  88. mindspore/hal/device.py +28 -0
  89. mindspore/hal/event.py +17 -0
  90. mindspore/hal/memory.py +94 -3
  91. mindspore/hal/stream.py +91 -6
  92. mindspore/include/api/context.h +0 -1
  93. mindspore/jpeg62.dll +0 -0
  94. mindspore/log.py +12 -0
  95. mindspore/mindrecord/__init__.py +1 -1
  96. mindspore/mindrecord/config.py +17 -316
  97. mindspore/mindrecord/filereader.py +1 -9
  98. mindspore/mindrecord/filewriter.py +5 -15
  99. mindspore/mindrecord/mindpage.py +1 -9
  100. mindspore/mindspore_backend.dll +0 -0
  101. mindspore/mindspore_common.dll +0 -0
  102. mindspore/mindspore_core.dll +0 -0
  103. mindspore/mindspore_glog.dll +0 -0
  104. mindspore/mindspore_ops.dll +0 -0
  105. mindspore/mint/__init__.py +824 -218
  106. mindspore/mint/distributed/__init__.py +66 -4
  107. mindspore/mint/distributed/distributed.py +2594 -44
  108. mindspore/mint/linalg/__init__.py +6 -0
  109. mindspore/mint/nn/__init__.py +473 -14
  110. mindspore/mint/nn/functional.py +486 -11
  111. mindspore/mint/nn/layer/__init__.py +17 -4
  112. mindspore/mint/nn/layer/_functions.py +330 -0
  113. mindspore/mint/nn/layer/activation.py +169 -1
  114. mindspore/mint/nn/layer/basic.py +123 -0
  115. mindspore/mint/nn/layer/conv.py +727 -0
  116. mindspore/mint/nn/layer/normalization.py +215 -19
  117. mindspore/mint/nn/layer/padding.py +797 -0
  118. mindspore/mint/nn/layer/pooling.py +170 -0
  119. mindspore/mint/optim/__init__.py +2 -1
  120. mindspore/mint/optim/adam.py +223 -0
  121. mindspore/mint/optim/adamw.py +26 -19
  122. mindspore/mint/special/__init__.py +2 -1
  123. mindspore/multiprocessing/__init__.py +5 -0
  124. mindspore/nn/cell.py +126 -19
  125. mindspore/nn/dynamic_lr.py +2 -1
  126. mindspore/nn/layer/activation.py +6 -6
  127. mindspore/nn/layer/basic.py +35 -25
  128. mindspore/nn/layer/channel_shuffle.py +3 -3
  129. mindspore/nn/layer/embedding.py +3 -3
  130. mindspore/nn/layer/normalization.py +8 -7
  131. mindspore/nn/layer/padding.py +4 -3
  132. mindspore/nn/layer/pooling.py +47 -13
  133. mindspore/nn/layer/rnn_cells.py +1 -1
  134. mindspore/nn/layer/rnns.py +2 -1
  135. mindspore/nn/layer/timedistributed.py +5 -5
  136. mindspore/nn/layer/transformer.py +48 -26
  137. mindspore/nn/learning_rate_schedule.py +5 -3
  138. mindspore/nn/loss/loss.py +31 -36
  139. mindspore/nn/optim/ada_grad.py +1 -0
  140. mindspore/nn/optim/adadelta.py +2 -2
  141. mindspore/nn/optim/adam.py +1 -1
  142. mindspore/nn/optim/lars.py +1 -4
  143. mindspore/nn/optim/optimizer.py +1 -1
  144. mindspore/nn/optim/rprop.py +2 -2
  145. mindspore/nn/optim/thor.py +2 -1
  146. mindspore/nn/utils/init.py +13 -11
  147. mindspore/nn/wrap/cell_wrapper.py +4 -6
  148. mindspore/nn/wrap/loss_scale.py +3 -4
  149. mindspore/numpy/array_creations.py +60 -62
  150. mindspore/numpy/array_ops.py +148 -143
  151. mindspore/numpy/logic_ops.py +41 -42
  152. mindspore/numpy/math_ops.py +361 -359
  153. mindspore/numpy/utils.py +16 -16
  154. mindspore/numpy/utils_const.py +4 -4
  155. mindspore/opencv_core452.dll +0 -0
  156. mindspore/opencv_imgcodecs452.dll +0 -0
  157. mindspore/opencv_imgproc452.dll +0 -0
  158. mindspore/ops/__init__.py +2 -1
  159. mindspore/ops/_grad_experimental/grad_comm_ops.py +94 -13
  160. mindspore/ops/_grad_experimental/grad_debug_ops.py +6 -1
  161. mindspore/ops/_grad_experimental/grad_inner_ops.py +9 -0
  162. mindspore/ops/_grad_experimental/grad_math_ops.py +2 -1
  163. mindspore/ops/_op_impl/cpu/__init__.py +1 -0
  164. mindspore/ops/_op_impl/cpu/raise_op.py +28 -0
  165. mindspore/ops/_vmap/vmap_array_ops.py +20 -19
  166. mindspore/ops/_vmap/vmap_base.py +0 -2
  167. mindspore/ops/_vmap/vmap_grad_nn_ops.py +19 -13
  168. mindspore/ops/_vmap/vmap_math_ops.py +11 -9
  169. mindspore/ops/_vmap/vmap_nn_ops.py +20 -34
  170. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +149 -12
  171. mindspore/ops/auto_generate/gen_arg_handler.py +0 -61
  172. mindspore/ops/auto_generate/gen_extend_func.py +554 -60
  173. mindspore/ops/auto_generate/gen_ops_def.py +1621 -115
  174. mindspore/ops/auto_generate/gen_ops_prim.py +8024 -3409
  175. mindspore/ops/auto_generate/pyboost_inner_prim.py +183 -79
  176. mindspore/ops/composite/base.py +1 -1
  177. mindspore/ops/composite/multitype_ops/_compile_utils.py +229 -30
  178. mindspore/ops/composite/multitype_ops/pow_impl.py +0 -29
  179. mindspore/ops/function/__init__.py +12 -0
  180. mindspore/ops/function/array_func.py +561 -159
  181. mindspore/ops/function/clip_func.py +64 -0
  182. mindspore/ops/function/debug_func.py +28 -20
  183. mindspore/ops/function/image_func.py +1 -1
  184. mindspore/ops/function/linalg_func.py +5 -4
  185. mindspore/ops/function/math_func.py +1659 -290
  186. mindspore/ops/function/nn_func.py +988 -317
  187. mindspore/ops/function/parameter_func.py +3 -56
  188. mindspore/ops/function/random_func.py +243 -33
  189. mindspore/ops/function/sparse_unary_func.py +1 -1
  190. mindspore/ops/functional.py +18 -5
  191. mindspore/ops/functional_overload.py +897 -0
  192. mindspore/ops/operations/__init__.py +3 -2
  193. mindspore/ops/operations/_embedding_cache_ops.py +4 -4
  194. mindspore/ops/operations/_grad_ops.py +2 -34
  195. mindspore/ops/operations/_infer_ops.py +2 -1
  196. mindspore/ops/operations/_inner_ops.py +38 -8
  197. mindspore/ops/operations/array_ops.py +45 -303
  198. mindspore/ops/operations/comm_ops.py +19 -16
  199. mindspore/ops/operations/custom_ops.py +11 -55
  200. mindspore/ops/operations/debug_ops.py +42 -47
  201. mindspore/ops/operations/inner_ops.py +6 -4
  202. mindspore/ops/operations/linalg_ops.py +3 -2
  203. mindspore/ops/operations/manually_defined/ops_def.py +185 -104
  204. mindspore/ops/operations/math_ops.py +11 -216
  205. mindspore/ops/operations/nn_ops.py +146 -308
  206. mindspore/ops/primitive.py +23 -21
  207. mindspore/ops/tensor_method.py +1669 -0
  208. mindspore/ops_generate/aclnn_kernel_register_auto_cc_generator.py +110 -0
  209. mindspore/ops_generate/add_tensor_docs_generator.py +54 -0
  210. mindspore/ops_generate/arg_handler.py +0 -61
  211. mindspore/ops_generate/auto_grad_impl_cc_generator.py +135 -0
  212. mindspore/ops_generate/auto_grad_reg_cc_generator.py +93 -0
  213. mindspore/ops_generate/base_generator.py +11 -0
  214. mindspore/ops_generate/cpp_create_prim_instance_helper_generator.py +108 -0
  215. mindspore/ops_generate/functional_map_cpp_generator.py +491 -0
  216. mindspore/ops_generate/functional_overload_py_generator.py +110 -0
  217. mindspore/ops_generate/functions_cc_generator.py +233 -0
  218. mindspore/ops_generate/gen_aclnn_implement.py +110 -114
  219. mindspore/ops_generate/gen_constants.py +157 -3
  220. mindspore/ops_generate/gen_ops.py +245 -990
  221. mindspore/ops_generate/gen_pyboost_func.py +97 -998
  222. mindspore/ops_generate/gen_utils.py +119 -33
  223. mindspore/ops_generate/lite_ops_cpp_generator.py +155 -0
  224. mindspore/ops_generate/op_api_proto.py +206 -0
  225. mindspore/ops_generate/op_def_py_generator.py +131 -0
  226. mindspore/ops_generate/op_prim_py_generator.py +480 -0
  227. mindspore/ops_generate/op_proto.py +373 -108
  228. mindspore/ops_generate/op_template_parser.py +436 -0
  229. mindspore/ops_generate/ops_def_cc_generator.py +288 -0
  230. mindspore/ops_generate/ops_def_h_generator.py +74 -0
  231. mindspore/ops_generate/ops_name_h_generator.py +68 -0
  232. mindspore/ops_generate/ops_primitive_h_generator.py +81 -0
  233. mindspore/ops_generate/pyboost_functions_cpp_generator.py +370 -0
  234. mindspore/ops_generate/pyboost_functions_h_generator.py +68 -0
  235. mindspore/ops_generate/pyboost_functions_py_generator.py +148 -0
  236. mindspore/ops_generate/pyboost_grad_function_cpp_generator.py +154 -0
  237. mindspore/ops_generate/pyboost_inner_prim_generator.py +131 -0
  238. mindspore/ops_generate/pyboost_native_grad_functions_generator.py +268 -0
  239. mindspore/ops_generate/pyboost_op_cpp_code_generator.py +851 -0
  240. mindspore/ops_generate/pyboost_overload_functions_cpp_generator.py +344 -0
  241. mindspore/ops_generate/pyboost_utils.py +92 -33
  242. mindspore/ops_generate/template.py +294 -44
  243. mindspore/ops_generate/tensor_func_reg_cpp_generator.py +422 -0
  244. mindspore/parallel/__init__.py +3 -3
  245. mindspore/parallel/_auto_parallel_context.py +24 -33
  246. mindspore/parallel/_parallel_serialization.py +13 -2
  247. mindspore/parallel/_utils.py +4 -1
  248. mindspore/parallel/algo_parameter_config.py +1 -1
  249. mindspore/parallel/checkpoint_transform.py +44 -0
  250. mindspore/parallel/cluster/process_entity/_api.py +131 -37
  251. mindspore/parallel/cluster/process_entity/_utils.py +41 -6
  252. mindspore/parallel/cluster/run.py +20 -3
  253. mindspore/parallel/parameter_broadcast.py +1 -1
  254. mindspore/parallel/shard.py +3 -0
  255. mindspore/parallel/transform_safetensors.py +119 -253
  256. mindspore/profiler/__init__.py +17 -4
  257. mindspore/profiler/analysis/__init__.py +0 -0
  258. mindspore/profiler/analysis/parser/__init__.py +0 -0
  259. mindspore/profiler/analysis/parser/ascend_cann_parser.py +166 -0
  260. mindspore/profiler/analysis/parser/base_parser.py +158 -0
  261. mindspore/profiler/analysis/parser/framework_cann_relation_parser.py +45 -0
  262. mindspore/profiler/analysis/parser/ms_framework_parser.py +142 -0
  263. mindspore/profiler/analysis/parser/ms_minddata_parser.py +145 -0
  264. mindspore/profiler/analysis/parser/timeline_assembly_factory/__init__.py +0 -0
  265. mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +261 -0
  266. mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +40 -0
  267. mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +84 -0
  268. mindspore/profiler/analysis/parser/timeline_creator/__init__.py +0 -0
  269. mindspore/profiler/analysis/parser/timeline_creator/base_timeline_creator.py +44 -0
  270. mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +90 -0
  271. mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +76 -0
  272. mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +103 -0
  273. mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +134 -0
  274. mindspore/profiler/analysis/parser/timeline_event/__init__.py +0 -0
  275. mindspore/profiler/analysis/parser/timeline_event/base_event.py +233 -0
  276. mindspore/profiler/analysis/parser/timeline_event/cpu_op_event.py +47 -0
  277. mindspore/profiler/analysis/parser/timeline_event/flow_event.py +36 -0
  278. mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +260 -0
  279. mindspore/profiler/analysis/parser/timeline_event/msprof_event.py +73 -0
  280. mindspore/profiler/analysis/parser/timeline_event/scope_layer_event.py +53 -0
  281. mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +146 -0
  282. mindspore/profiler/analysis/task_manager.py +131 -0
  283. mindspore/profiler/analysis/time_converter.py +84 -0
  284. mindspore/profiler/analysis/viewer/__init__.py +0 -0
  285. mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +333 -0
  286. mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +87 -0
  287. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +252 -0
  288. mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +313 -0
  289. mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +322 -0
  290. mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +265 -0
  291. mindspore/profiler/analysis/viewer/ascend_timeline_viewer.py +58 -0
  292. mindspore/profiler/analysis/viewer/base_viewer.py +26 -0
  293. mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +97 -0
  294. mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +581 -0
  295. mindspore/profiler/analysis/work_flow.py +73 -0
  296. mindspore/profiler/common/ascend_msprof_exporter.py +138 -0
  297. mindspore/profiler/common/command_executor.py +90 -0
  298. mindspore/profiler/common/constant.py +174 -3
  299. mindspore/profiler/common/file_manager.py +208 -0
  300. mindspore/profiler/common/log.py +130 -0
  301. mindspore/profiler/common/msprof_cmd_tool.py +202 -0
  302. mindspore/profiler/common/path_manager.py +371 -0
  303. mindspore/profiler/common/process_bar.py +168 -0
  304. mindspore/profiler/common/process_pool.py +9 -3
  305. mindspore/profiler/common/profiler_context.py +476 -0
  306. mindspore/profiler/common/profiler_info.py +304 -0
  307. mindspore/profiler/common/profiler_output_path.py +284 -0
  308. mindspore/profiler/common/profiler_parameters.py +210 -0
  309. mindspore/profiler/common/profiler_path_manager.py +120 -0
  310. mindspore/profiler/common/record_function.py +76 -0
  311. mindspore/profiler/common/tlv_decoder.py +76 -0
  312. mindspore/profiler/common/util.py +75 -2
  313. mindspore/profiler/dynamic_profiler.py +270 -37
  314. mindspore/profiler/envprofiler.py +138 -0
  315. mindspore/profiler/mstx.py +199 -0
  316. mindspore/profiler/platform/__init__.py +21 -0
  317. mindspore/profiler/platform/base_profiler.py +40 -0
  318. mindspore/profiler/platform/cpu_profiler.py +124 -0
  319. mindspore/profiler/platform/gpu_profiler.py +74 -0
  320. mindspore/profiler/platform/npu_profiler.py +309 -0
  321. mindspore/profiler/profiler.py +580 -93
  322. mindspore/profiler/profiler_action_controller.py +187 -0
  323. mindspore/profiler/profiler_interface.py +114 -0
  324. mindspore/profiler/schedule.py +208 -0
  325. mindspore/rewrite/api/symbol_tree.py +1 -2
  326. mindspore/run_check/_check_version.py +2 -6
  327. mindspore/runtime/__init__.py +37 -0
  328. mindspore/runtime/device.py +27 -0
  329. mindspore/runtime/event.py +209 -0
  330. mindspore/runtime/executor.py +148 -0
  331. mindspore/runtime/memory.py +392 -0
  332. mindspore/runtime/stream.py +460 -0
  333. mindspore/runtime/thread_bind_core.py +401 -0
  334. mindspore/swresample-4.dll +0 -0
  335. mindspore/swscale-6.dll +0 -0
  336. mindspore/tinyxml2.dll +0 -0
  337. mindspore/train/__init__.py +2 -2
  338. mindspore/train/_utils.py +53 -18
  339. mindspore/train/amp.py +8 -4
  340. mindspore/train/callback/_checkpoint.py +32 -18
  341. mindspore/train/callback/_early_stop.py +1 -1
  342. mindspore/train/callback/_flops_collector.py +105 -69
  343. mindspore/train/callback/_history.py +1 -1
  344. mindspore/train/callback/_summary_collector.py +44 -6
  345. mindspore/train/callback/_tft_register.py +31 -10
  346. mindspore/train/dataset_helper.py +11 -11
  347. mindspore/train/metrics/precision.py +4 -5
  348. mindspore/train/mind_ir_pb2.py +167 -46
  349. mindspore/train/model.py +13 -15
  350. mindspore/train/serialization.py +462 -76
  351. mindspore/train/summary/summary_record.py +1 -2
  352. mindspore/train/train_thor/model_thor.py +1 -1
  353. mindspore/turbojpeg.dll +0 -0
  354. mindspore/utils/__init__.py +4 -2
  355. mindspore/utils/dryrun.py +138 -0
  356. mindspore/utils/runtime_execution_order_check.py +550 -0
  357. mindspore/version.py +1 -1
  358. {mindspore-2.4.10.dist-info → mindspore-2.5.0.dist-info}/METADATA +2 -3
  359. {mindspore-2.4.10.dist-info → mindspore-2.5.0.dist-info}/RECORD +362 -238
  360. {mindspore-2.4.10.dist-info → mindspore-2.5.0.dist-info}/entry_points.txt +1 -1
  361. mindspore/common/_tensor_overload.py +0 -139
  362. mindspore/mindspore_np_dtype.dll +0 -0
  363. mindspore/profiler/envprofiling.py +0 -254
  364. mindspore/profiler/profiling.py +0 -1926
  365. {mindspore-2.4.10.dist-info → mindspore-2.5.0.dist-info}/WHEEL +0 -0
  366. {mindspore-2.4.10.dist-info → mindspore-2.5.0.dist-info}/top_level.txt +0 -0
@@ -18,43 +18,13 @@ Generate operator utils function
18
18
  import os
19
19
  import glob
20
20
  import hashlib
21
+ import pathlib
22
+ import re
21
23
  import stat
24
+ import logging
22
25
  import yaml
23
26
 
24
27
 
25
- py_licence_str = f"""# Copyright 2023 Huawei Technologies Co., Ltd
26
- #
27
- # Licensed under the Apache License, Version 2.0 (the "License");
28
- # you may not use this file except in compliance with the License.
29
- # You may obtain a copy of the License at
30
- #
31
- # http://www.apache.org/licenses/LICENSE-2.0
32
- #
33
- # Unless required by applicable law or agreed to in writing, software
34
- # distributed under the License is distributed on an "AS IS" BASIS,
35
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
36
- # See the License for the specific language governing permissions and
37
- # limitations under the License.
38
- # ============================================================================
39
- """
40
-
41
- cc_license_str = f"""/**
42
- * Copyright 2023 Huawei Technologies Co., Ltd
43
- *
44
- * Licensed under the Apache License, Version 2.0 (the "License");
45
- * you may not use this file except in compliance with the License.
46
- * You may obtain a copy of the License at
47
- *
48
- * http://www.apache.org/licenses/LICENSE-2.0
49
- *
50
- * Unless required by applicable law or agreed to in writing, software
51
- * distributed under the License is distributed on an "AS IS" BASIS,
52
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
53
- * See the License for the specific language governing permissions and
54
- * limitations under the License.
55
- */"""
56
-
57
-
58
28
  def convert_dtype_str(dtype_str):
59
29
  """
60
30
  Convert dtype str to expression in ops file
@@ -207,3 +177,119 @@ def write_file(path, data):
207
177
  fd = os.open(path, flags, mode)
208
178
  with os.fdopen(fd, "w") as f:
209
179
  f.write(data)
180
+
181
+
182
+ def save_file(save_path, file_name, content):
183
+ pathlib.Path(save_path).mkdir(parents=True, exist_ok=True)
184
+ dst_file_path = os.path.join(save_path, file_name)
185
+ tmp_file_path = os.path.join(save_path, f"tmp_{file_name}")
186
+ write_file(tmp_file_path, content)
187
+ check_change_and_replace_file(dst_file_path, tmp_file_path)
188
+
189
+
190
+ def normalize_func_description_format(description):
191
+ """
192
+ Process description.
193
+ """
194
+ if not description:
195
+ return description
196
+ lines = description.split("\n")
197
+ if len(lines) == 1:
198
+ return description
199
+
200
+ # Add line indentation to other lines after the first line
201
+ for i in range(1, len(lines)):
202
+ indent = " " if lines[i] else ""
203
+ lines[i] = indent + lines[i]
204
+
205
+ # Remove trailing blank lines
206
+ lines = lines if lines[-1] != "" else lines[:-1]
207
+ description = "\n".join(lines)
208
+ return description
209
+
210
+
211
+ def get_op_description(operator_name, doc_dict):
212
+ """
213
+ Generate ops api description.
214
+ """
215
+ op_description = f" r\"\"\"\n" \
216
+ f" \n" \
217
+ f" \"\"\"\n"
218
+ if doc_dict is None:
219
+ logging.info("Description is None, op_name: %s", operator_name)
220
+ return op_description
221
+
222
+ description = doc_dict.get(operator_name)
223
+ if description is None:
224
+ logging.info("Description is None, op_name: %s", operator_name)
225
+ return op_description
226
+
227
+ description = description.get("description")
228
+ if description is None:
229
+ logging.info("Description is None, op_name: %s", operator_name)
230
+ return op_description
231
+
232
+ op_description = f" r\"\"\"\n" \
233
+ f" {normalize_func_description_format(description)}\n" \
234
+ f" \"\"\"\n"
235
+ return op_description
236
+
237
+
238
+ def get_same_dtype_groups(args_signature, args_name):
239
+ """
240
+ Get same dtype groups
241
+ """
242
+ same_dtype_groups = {}
243
+ dtype_count = 0
244
+
245
+ if not args_signature:
246
+ return same_dtype_groups, dtype_count
247
+
248
+ dtype_group = args_signature.dtype_group
249
+ if not args_signature.dtype_group:
250
+ return same_dtype_groups, dtype_count
251
+
252
+ args_list = []
253
+ match = re.findall(r'\((.*?)\)', dtype_group)
254
+ for item in match:
255
+ args_list.append(item.replace(' ', '').split(","))
256
+ for arg_name in args_name:
257
+ if arg_name in same_dtype_groups:
258
+ continue
259
+ is_match = False
260
+ for group in args_list:
261
+ if arg_name in group:
262
+ is_match = True
263
+ for item in group:
264
+ same_dtype_groups[item] = dtype_count
265
+ break
266
+ if not is_match:
267
+ same_dtype_groups[arg_name] = dtype_count
268
+ dtype_count = dtype_count + 1
269
+ return same_dtype_groups, dtype_count
270
+
271
+
272
+ def init_args_signature_rw(args_signature):
273
+ """
274
+ Extracts read, write, and reference argument lists from signature data.
275
+
276
+ Args:
277
+ args_signature (object): Contains 'rw_write', 'rw_read', 'rw_ref' attributes as comma-separated strings.
278
+
279
+ Returns:
280
+ tuple: Lists of read-only, reference, and write-only argument names.
281
+ """
282
+ write_list = []
283
+ read_list = []
284
+ ref_list = []
285
+ if args_signature:
286
+ if args_signature.rw_write:
287
+ write_list.extend(args_signature.rw_write.replace(' ', '').split(","))
288
+
289
+ if args_signature.rw_read:
290
+ read_list.extend(args_signature.rw_read.replace(' ', '').split(","))
291
+
292
+ if args_signature.rw_ref:
293
+ ref_list.extend(args_signature.rw_ref.replace(' ', '').split(","))
294
+
295
+ return read_list, ref_list, write_list
@@ -0,0 +1,155 @@
1
+ # Copyright 2024 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+ """
16
+ Generates C++ header and source files for lite operations based on YAML configurations.
17
+ """
18
+
19
+ import os
20
+
21
+ import gen_constants as K
22
+ import gen_utils
23
+ import pyboost_utils
24
+
25
+ # refactored
26
+ import template
27
+
28
+ from base_generator import BaseGenerator
29
+
30
+
31
+ class LiteOpsHGenerator(BaseGenerator):
32
+ """
33
+ This class is responsible for generating the header file for lite operations.
34
+ """
35
+
36
+ def __init__(self):
37
+ """
38
+ Initializes the generator with the necessary templates for generating C++ header files.
39
+ """
40
+ self.lite_ops_h_template = template.Template(K.LITE_OPS_H)
41
+ self.lite_ops_class_template = template.op_cc_template
42
+ self.arg_prim_init_template = template.Template("\n"
43
+ " void set_${arg_name}(const ${dtype} &${arg_name});\n"
44
+ " ${dtype} get_${arg_name}() const;")
45
+
46
+ def generate(self, work_path, op_protos):
47
+ """
48
+ Generates the header file content for lite operations and saves it to the specified path.
49
+
50
+ Args:
51
+ work_path (str): The directory where the generated files will be saved.
52
+ op_protos (list): A list of operator prototypes containing information about the operators.
53
+
54
+ Returns:
55
+ None
56
+
57
+ """
58
+ lite_ops_h_code_list = []
59
+ for op_proto in op_protos:
60
+ op_name = pyboost_utils.get_op_name(op_proto.op_name, op_proto.op_class.name)
61
+ op_args = op_proto.op_args
62
+ arg_prim_init_str = ""
63
+ for op_arg in op_args:
64
+ if not op_arg.is_prim_init:
65
+ continue
66
+
67
+ arg_name = op_arg.arg_name
68
+ dtype = trans_dtype_for_lite(op_arg.arg_dtype)
69
+ arg_prim_init_str += self.arg_prim_init_template.replace(arg_name=arg_name, dtype=dtype)
70
+
71
+ temp = self.lite_ops_class_template.replace(op_name=op_name, arg_prim_init_list=arg_prim_init_str)
72
+ lite_ops_h_code_list.append(temp)
73
+
74
+ lite_ops_h = self.lite_ops_h_template.replace(auto_gen_path=K.OP_DEF_AUTO_GENERATE_PATH,
75
+ ops_namespace_body=lite_ops_h_code_list)
76
+
77
+ res_str = template.CC_LICENSE_STR + lite_ops_h
78
+ save_path = os.path.join(work_path, K.MS_OP_DEF_AUTO_GENERATE_PATH)
79
+ file_name = "gen_lite_ops.h"
80
+ gen_utils.save_file(save_path, file_name, res_str)
81
+
82
+
83
+ class LiteOpsCcGenerator(BaseGenerator):
84
+ """
85
+ This class is responsible for generating the source file for lite operations.
86
+ """
87
+
88
+ def __init__(self):
89
+ """
90
+ Initializes the generator with the necessary templates for generating C++ source files.
91
+ """
92
+ self.lite_ops_cc_template = template.Template(K.LITE_OPS_CC)
93
+ self.op_template = template.op_template
94
+ self.register_primitive_c_template = template.Template("REGISTER_PRIMITIVE_C(kName${op_name}, ${op_name});\n"
95
+ "MIND_API_OPERATOR_IMPL(${op_name}, BaseOperator);\n\n")
96
+
97
+ def generate(self, work_path, op_protos):
98
+ """
99
+ Generates the source file content for lite operations and saves it to the specified path.
100
+
101
+ Args:
102
+ work_path (str): The directory where the generated files will be saved.
103
+ op_protos (list): A list of operation prototypes to generate content for.
104
+
105
+ Returns:
106
+ None
107
+ """
108
+ lite_ops_cc_gen_list = []
109
+ for op_proto in op_protos:
110
+ arg_prim_init_str = ""
111
+ op_name = pyboost_utils.get_op_name(op_proto.op_name, op_proto.op_class.name)
112
+ op_args = op_proto.op_args
113
+ for op_arg in op_args:
114
+ if not op_arg.is_prim_init:
115
+ continue
116
+
117
+ arg_name = op_arg.arg_name
118
+ dtype = trans_dtype_for_lite(op_arg.arg_dtype)
119
+ arg_prim_init_str += self.op_template.replace(op_name=op_name, arg_name=arg_name, dtype=dtype)
120
+
121
+ self.register_primitive_c_template.replace(op_name=op_name)
122
+ lite_ops_cc_gen_list.append(arg_prim_init_str + self.register_primitive_c_template.replace(op_name=op_name))
123
+
124
+ lite_ops_cc = self.lite_ops_cc_template.replace(auto_gen_path=K.OP_DEF_AUTO_GENERATE_PATH,
125
+ ops_namespace_body=lite_ops_cc_gen_list)
126
+
127
+ res_str = template.CC_LICENSE_STR + lite_ops_cc
128
+ save_path = os.path.join(work_path, K.MS_OP_DEF_AUTO_GENERATE_PATH)
129
+ file_name = "gen_lite_ops.cc"
130
+ gen_utils.save_file(save_path, file_name, res_str)
131
+
132
+
133
+ def trans_dtype_for_lite(dtype):
134
+ """
135
+ Translate the data type for lite usage based on the argument information.
136
+
137
+ Args:
138
+ dtype (str): The original data type as a string.
139
+
140
+ Returns:
141
+ str: The translated data type suitable for lite usage.
142
+ """
143
+ type_mappings = {
144
+ "str": "std::string",
145
+ "tuple[str]": "std::vector<std::string>",
146
+ "list[str]": "std::vector<std::string>",
147
+ "tuple[int]": "std::vector<int64_t>",
148
+ "list[int]": "std::vector<int64_t>",
149
+ "tuple[float]": "std::vector<float>",
150
+ "list[float]": "std::vector<float>",
151
+ "tuple[bool]": "std::vector<bool>",
152
+ "list[bool]": "std::vector<bool>",
153
+ "int": "int64_t"
154
+ }
155
+ return type_mappings.get(dtype, dtype)
@@ -0,0 +1,206 @@
1
+ # Copyright 2024 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+
16
+ """Tensor Func Proto module for defining tensor_py function prototypes and their arguments."""
17
+ import ast
18
+ import os
19
+ from collections import defaultdict
20
+ import gen_constants as K
21
+
22
+
23
+ class OpApiProto:
24
+ """
25
+ Represents a tensor function prototype with associated function name, operation prototype, and target devices.
26
+ """
27
+
28
+ def __init__(self,
29
+ func_name,
30
+ op_proto,
31
+ py_method,
32
+ kw_only_args,
33
+ varargs,
34
+ ascend,
35
+ gpu,
36
+ cpu):
37
+ self.func_name = func_name
38
+ self.op_proto = op_proto
39
+ self.py_method = py_method
40
+ self.kw_only_args = kw_only_args
41
+ self.varargs = varargs
42
+ self.ascend = ascend
43
+ self.gpu = gpu
44
+ self.cpu = cpu
45
+
46
+
47
+ def get_tensor_method_ast_dict():
48
+ """
49
+ Generates a dictionary mapping function names to their Abstract Syntax Tree (AST) nodes
50
+ for all functions defined in the 'tensor_method.py' file.
51
+ """
52
+ tensor_method_ast_dict = dict()
53
+ tensor_method_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), '../ops/tensor_method.py')
54
+ with open(tensor_method_file, "r", encoding="utf-8") as file:
55
+ tree = ast.parse(file.read(), filename=tensor_method_file)
56
+ for node in ast.walk(tree):
57
+ if isinstance(node, ast.FunctionDef):
58
+ tensor_method_ast_dict[node.name] = node
59
+ return tensor_method_ast_dict
60
+
61
+
62
+ def load_api_protos_from_yaml(tensor_func_yaml_data, op_protos, deprecated_op_protos):
63
+ """
64
+ Loads tensor function prototypes from YAML data and returns them as a dictionary.
65
+ """
66
+ op_protos_dict = {}
67
+ for op_proto in op_protos:
68
+ op_protos_dict[op_proto.op_name] = op_proto
69
+ for deprecated_op_proto in deprecated_op_protos:
70
+ op_protos_dict[deprecated_op_proto.op_name] = deprecated_op_proto
71
+ tensor_method_protos = defaultdict(list)
72
+ mint_func_protos = defaultdict(list)
73
+ alias_api_mapping = defaultdict(list)
74
+ tensor_method_def_ast_dict = get_tensor_method_ast_dict()
75
+ for func_name, tensor_func_data in tensor_func_yaml_data.items():
76
+ func_data_list = [tensor_func_data] if isinstance(tensor_func_data, dict) else tensor_func_data
77
+ for func_data in func_data_list:
78
+ func_keys = func_data.keys()
79
+ check_op_api_yaml_keys(func_name, set(func_keys), K.TENSOR_FUNC_KEYS)
80
+ if 'alias' in func_data:
81
+ alias_api_mapping[func_data['alias']].append(func_name)
82
+ continue
83
+ op_name = _get_op_name_from_op_yaml(func_name, func_data)
84
+ op_proto = op_protos_dict.get(op_name, None)
85
+ if op_proto is None:
86
+ raise TypeError(
87
+ f"For generating tensor functions, op_proto should not be empty. Func name is {func_name}")
88
+ py_method = func_data.get('py_method', '')
89
+ if py_method == '':
90
+ raise TypeError(
91
+ f'For generating tensor functions, py method should not be empty. Func name is {func_name}')
92
+ if py_method not in tensor_method_def_ast_dict:
93
+ raise TypeError(f"{py_method} is not defined in tensor_method.py.")
94
+ kw_only_args = func_data.get('kwonlyargs', None)
95
+ if kw_only_args:
96
+ kw_only_args = [item.strip() for item in kw_only_args.split(',')]
97
+ check_kwonlyargs(func_data, kw_only_args, op_name, op_proto, py_method, tensor_method_def_ast_dict)
98
+ varargs = func_data.get('varargs', None)
99
+ if varargs:
100
+ varargs = [item.strip() for item in varargs.split(',')]
101
+ check_varargs(varargs, op_name)
102
+ ascend = func_data.get('Ascend', 'aclnn')
103
+ gpu = func_data.get('GPU', 'aclnn')
104
+ cpu = func_data.get('CPU', 'aclnn')
105
+ interface = func_data.get('interface')
106
+ if interface is None:
107
+ raise ValueError(
108
+ f"For generating tensor or functional interfaces, field interface must exist. "
109
+ f"Op name is {func_name}")
110
+
111
+ interface = ', '.join(part.strip() for part in interface.split(','))
112
+
113
+ if interface not in {'tensor', 'function', 'tensor, function', 'function, tensor'}:
114
+ raise ValueError(
115
+ f"The value of field 'interface' must be one of 'tensor', 'function', "
116
+ f"'tensor, function', or 'function, tensor'. File name is {func_name}.yaml"
117
+ )
118
+
119
+ proto = OpApiProto(func_name=func_name, op_proto=op_proto, py_method=py_method,
120
+ kw_only_args=kw_only_args, varargs=varargs, ascend=ascend, gpu=gpu, cpu=cpu)
121
+
122
+ if 'tensor' in interface:
123
+ tensor_method_protos[func_name].append(proto)
124
+ if 'function' in interface:
125
+ mint_func_protos[func_name].append(proto)
126
+
127
+ return tensor_method_protos, mint_func_protos, alias_api_mapping
128
+
129
+
130
+ def check_kwonlyargs(func_data, kw_only_args, op_name, op_proto, py_method, tensor_method_def_ast_dict):
131
+ """
132
+ Verifies that the keyword-only arguments (kwonlyargs) specified in the YAML definition
133
+ match the order and names of the keyword-only arguments in the Python method definition.
134
+ """
135
+ op_args = op_proto.op_args
136
+ kw_args_start_idx = len(op_args) - len(kw_only_args)
137
+ node = tensor_method_def_ast_dict[py_method]
138
+ tensor_method_kwonlyargs = [arg.arg for arg in node.args.kwonlyargs]
139
+ for idx, kw_arg in enumerate(kw_only_args):
140
+ kw_args_idx = kw_args_start_idx + idx
141
+ if kw_args_idx > len(op_args) or kw_arg != op_args[kw_args_idx].arg_name:
142
+ op_kw_args = [op_arg.arg_name for op_arg in op_args]
143
+ op_yaml = func_data.get('op_yaml')
144
+ raise TypeError(
145
+ f"For generating tensor functions from {op_name}.yaml, "
146
+ f"the order of kwonlyargs should be consistent with the definition in the {op_yaml}. "
147
+ f"Expect kwonlyarg: {op_kw_args[kw_args_start_idx:]}, current kwonlyarg: {kw_only_args}.")
148
+ if tensor_method_kwonlyargs != kw_only_args:
149
+ raise TypeError(f"The order of kwonlyargs in {py_method} should be consistent with the definition. "
150
+ f"Expect kwonlyarg: {kw_only_args}, current kwonlyarg: {tensor_method_kwonlyargs}.")
151
+
152
+
153
+ def check_varargs(varargs, op_name):
154
+ if len(varargs) != 1:
155
+ raise ValueError(
156
+ f'There must be only one variable argument. But got {len(varargs)} in {op_name}')
157
+
158
+
159
+ def _get_op_name_from_op_yaml(func_name: str, func_data: dict) -> str:
160
+ """Extracts the operation name from the given YAML function data."""
161
+ op_yaml = func_data.get('op_yaml', '')
162
+ if op_yaml == '':
163
+ raise TypeError(f'For generating tensor functions, op yaml should not be empty, func name is {func_name}')
164
+ if 'deprecated' in op_yaml:
165
+ op_name = op_yaml.replace('/', '_').replace('_method.yaml', '')
166
+ else:
167
+ op_name = op_yaml.replace('_op.yaml', '')
168
+ if op_name == '':
169
+ raise TypeError(f'For generating tensor functions, op name should not be empty, func name is {func_name}')
170
+ return op_name
171
+
172
+
173
+ def check_op_api_yaml_keys(func_name: str, input_keys: set, compare_keys: set):
174
+ diff_keys = input_keys - compare_keys
175
+ if diff_keys:
176
+ raise TypeError(
177
+ f'The definition of keys in yaml has faults, func name is {func_name}, wrong keys are {diff_keys}.')
178
+
179
+
180
+ def categorize_func_data(func_protos_data):
181
+ """
182
+ Categorizes function prototypes into single, overloaded function prototypes.
183
+
184
+ Args:
185
+ func_protos_data (dict): Dictionary where keys are function API names and values are lists of
186
+ function prototypes associated with each API.
187
+
188
+ Returns:
189
+ tuple:
190
+ - single_op_func_data (dict): Function prototypes for operations with a single definition.
191
+ - overload_op_func_data (dict): Function prototypes for operations with overloaded definitions.
192
+ """
193
+ single_op_func_data = {}
194
+ overload_op_func_data = {}
195
+ all_op_func_data = {}
196
+ for func_api_name, func_protos in func_protos_data.items():
197
+ if len(func_protos) == 1:
198
+ func_name = func_protos[0].func_name
199
+ if func_name not in single_op_func_data:
200
+ single_op_func_data[func_name] = func_protos[0]
201
+ all_op_func_data[func_name] = func_protos
202
+ elif len(func_protos) > 1:
203
+ overload_op_func_data[func_api_name] = func_protos
204
+ all_op_func_data[func_api_name] = func_protos
205
+
206
+ return all_op_func_data, single_op_func_data, overload_op_func_data
@@ -0,0 +1,131 @@
1
+ # Copyright 2024 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+ """
16
+ Generate Python operator definitions.
17
+ """
18
+
19
+ import os
20
+
21
+ import gen_constants as K
22
+ import gen_utils
23
+
24
+ # refactored
25
+ import template
26
+
27
+ from base_generator import BaseGenerator
28
+
29
+
30
+ class OpDefPyGenerator(BaseGenerator):
31
+ """
32
+ This class is responsible for generating Python operator definitions based on provided
33
+ operation prototypes and documentation strings. It generates the code for the operator
34
+ functions that can be used in Python scripts to interact with the underlying operations.
35
+ """
36
+
37
+ def __init__(self):
38
+ """
39
+ Initializes the generator with the template for primitive class definitions.
40
+ """
41
+ self.op_prim_class_define_template = template.OP_PRIM_CLASS_DEFINE_TEMPLATE
42
+
43
+ def generate(self, work_path, op_protos, doc_dict, file_pre):
44
+ """
45
+ Generates Python code for operator definitions and saves it to a file.
46
+
47
+ Args:
48
+ work_path (str): The base directory where the generated files will be saved.
49
+ op_protos (list): A list of operation prototypes to generate Python code for.
50
+ doc_dict (dict): A dictionary containing documentation strings for the operators.
51
+ file_pre (str): The prefix for the generated Python files.
52
+
53
+ Returns:
54
+ None
55
+
56
+ The generated Python code includes function definitions for each operator, using
57
+ the provided operation prototypes and documentation. It saves the code in a file
58
+ with the given prefix in the specified work path.
59
+ """
60
+ gen_py = "\n"
61
+ for op_proto in op_protos:
62
+ if op_proto.op_function.disable:
63
+ continue
64
+
65
+ class_name = op_proto.op_class.name
66
+ func_name = op_proto.op_function.name
67
+ op_args = op_proto.op_args
68
+ func_args, prim_call_args, prim_init_args = self.get_op_args(op_args)
69
+
70
+ func_code = "\n"
71
+ description = gen_utils.get_op_description(op_proto.op_name, doc_dict)
72
+ func_formal_param = ", ".join(arg_name for arg_name in func_args)
73
+ op_prim_input_args = ", ".join(arg_name for arg_name in prim_call_args)
74
+ if prim_init_args:
75
+ if op_proto.op_dispatch and op_proto.op_dispatch.enable:
76
+ func_impl_input_args = ", ".join(op_args.arg_name for op_args in op_args)
77
+ func_code += f"def {func_name}({func_formal_param}):\n"
78
+ func_code += f"{description}"
79
+ func_code += f" return {op_proto.op_name}_impl({func_impl_input_args})\n"
80
+ else:
81
+ cache_prim_input_args = ", ".join(arg_name for arg_name in prim_init_args)
82
+ func_code += f"def {func_name}({func_formal_param}):\n"
83
+ func_code += f"{description}"
84
+ func_code += f" {op_proto.op_name}_op = _get_cache_prim({class_name})({cache_prim_input_args})\n"
85
+ func_code += f" return {op_proto.op_name}_op({op_prim_input_args})\n"
86
+ else:
87
+ if op_proto.op_class and op_proto.op_class.disable:
88
+ gen_py += f"{op_proto.op_name}_op={class_name}()\n"
89
+ func_code += f"def {func_name}({func_formal_param}):\n"
90
+ func_code += f"{description}"
91
+ func_code += f" return {op_proto.op_name}_op({op_prim_input_args})\n"
92
+
93
+ gen_py += func_code
94
+ gen_py += "\n"
95
+
96
+ res_str = template.PY_LICENCE_STR + template.OPS_PY_DEF_HEADER + gen_py[:-len(template.NEW_LINE)]
97
+ save_path = os.path.join(work_path, K.PY_AUTO_GEN_PATH)
98
+ file_name = f"{file_pre}_ops_def.py"
99
+ gen_utils.save_file(save_path, file_name, res_str)
100
+
101
+ def get_op_args(self, op_args):
102
+ """
103
+ Processes the list of OpArg objects to categorize them into function arguments,
104
+ primitive initialization arguments, and primitive call arguments.
105
+
106
+ Args:
107
+ op_args (list): A list of OpArg objects representing the arguments of an operator.
108
+
109
+ Returns:
110
+ tuple: A tuple containing three lists:
111
+ - func_args (list): Names of the function arguments.
112
+ - prim_call_args (list): Names of the primitive call arguments.
113
+ - prim_init_args (list): Names of the primitive initialization arguments.
114
+ """
115
+ func_args = []
116
+ prim_init_args = []
117
+ prim_call_args = []
118
+ for op_arg in op_args:
119
+ # step1: Process function args.
120
+ if op_arg.default is None:
121
+ func_args.append(f"""{op_arg.arg_name}""")
122
+ else:
123
+ func_args.append(f"""{op_arg.arg_name}={op_arg.default}""")
124
+
125
+ # step2: Process primitive object init args.
126
+ if op_arg.is_prim_init:
127
+ prim_init_args.append(op_arg.arg_name)
128
+ # step3: Process primitive object call args.
129
+ else:
130
+ prim_call_args.append(op_arg.arg_name)
131
+ return func_args, prim_call_args, prim_init_args