mindspore 2.4.10__cp310-cp310-win_amd64.whl → 2.6.0rc1__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (602) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  3. mindspore/Newtonsoft.Json.dll +0 -0
  4. mindspore/__init__.py +13 -6
  5. mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
  6. mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
  7. mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
  8. mindspore/_check_jit_forbidden_api.py +3 -0
  9. mindspore/_checkparam.py +3 -38
  10. mindspore/_deprecated/__init__.py +17 -0
  11. mindspore/_deprecated/jit.py +198 -0
  12. mindspore/_extends/builtin_operations.py +1 -1
  13. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
  14. mindspore/_extends/parse/__init__.py +6 -7
  15. mindspore/_extends/parse/compile_config.py +83 -0
  16. mindspore/_extends/parse/deprecated/__init__.py +0 -0
  17. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +394 -0
  18. mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
  19. mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
  20. mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
  21. mindspore/_extends/parse/parser.py +46 -197
  22. mindspore/_extends/parse/resources.py +1 -5
  23. mindspore/_extends/parse/standard_method.py +217 -98
  24. mindspore/_extends/pijit/__init__.py +2 -2
  25. mindspore/_extends/pijit/pijit_func_white_list.py +17 -12
  26. mindspore/_extends/pijit/tensor_func_list.py +27 -0
  27. mindspore/_extends/utils.py +1 -1
  28. mindspore/amp.py +11 -5
  29. mindspore/atlprov.dll +0 -0
  30. mindspore/avcodec-59.dll +0 -0
  31. mindspore/avdevice-59.dll +0 -0
  32. mindspore/avfilter-8.dll +0 -0
  33. mindspore/avformat-59.dll +0 -0
  34. mindspore/avutil-57.dll +0 -0
  35. mindspore/boost/__init__.py +2 -2
  36. mindspore/boost/base.py +3 -7
  37. mindspore/boost/boost_cell_wrapper.py +138 -43
  38. mindspore/c1.dll +0 -0
  39. mindspore/c1xx.dll +0 -0
  40. mindspore/c2.dll +0 -0
  41. mindspore/common/__init__.py +6 -3
  42. mindspore/common/_grad_function.py +56 -0
  43. mindspore/common/_pijit_context.py +14 -5
  44. mindspore/common/_register_for_tensor.py +1 -2
  45. mindspore/common/_stub_tensor.py +30 -14
  46. mindspore/common/_tensor_cpp_method.py +17 -0
  47. mindspore/common/_tensor_docs.py +4760 -0
  48. mindspore/common/api.py +435 -371
  49. mindspore/common/auto_dynamic_shape.py +41 -44
  50. mindspore/common/dtype.py +39 -36
  51. mindspore/common/dump.py +9 -6
  52. mindspore/common/file_system.py +9 -1
  53. mindspore/common/generator.py +2 -0
  54. mindspore/common/hook_handle.py +6 -2
  55. mindspore/common/initializer.py +13 -10
  56. mindspore/common/jit_begin_end.py +94 -0
  57. mindspore/common/jit_config.py +6 -1
  58. mindspore/common/jit_context.py +76 -0
  59. mindspore/common/jit_trace.py +378 -0
  60. mindspore/common/lazy_inline.py +9 -3
  61. mindspore/common/mindir_util.py +10 -2
  62. mindspore/common/mutable.py +5 -4
  63. mindspore/common/parameter.py +135 -52
  64. mindspore/common/seed.py +2 -2
  65. mindspore/common/sparse_tensor.py +23 -17
  66. mindspore/common/tensor.py +951 -1992
  67. mindspore/communication/__init__.py +7 -5
  68. mindspore/communication/_comm_helper.py +52 -2
  69. mindspore/communication/comm_func.py +240 -181
  70. mindspore/communication/management.py +95 -26
  71. mindspore/context.py +314 -566
  72. mindspore/dataset/__init__.py +65 -37
  73. mindspore/dataset/audio/__init__.py +2 -8
  74. mindspore/dataset/audio/transforms.py +3 -17
  75. mindspore/dataset/callback/ds_callback.py +2 -1
  76. mindspore/dataset/core/config.py +87 -6
  77. mindspore/dataset/engine/cache_admin.py +3 -3
  78. mindspore/dataset/engine/cache_client.py +6 -5
  79. mindspore/dataset/engine/datasets.py +292 -267
  80. mindspore/dataset/engine/datasets_audio.py +22 -8
  81. mindspore/dataset/engine/datasets_standard_format.py +46 -27
  82. mindspore/dataset/engine/datasets_text.py +78 -48
  83. mindspore/dataset/engine/datasets_user_defined.py +182 -116
  84. mindspore/dataset/engine/datasets_vision.py +120 -44
  85. mindspore/dataset/engine/iterators.py +283 -63
  86. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
  87. mindspore/dataset/engine/obs/util.py +8 -0
  88. mindspore/dataset/engine/queue.py +40 -0
  89. mindspore/dataset/engine/samplers.py +289 -43
  90. mindspore/dataset/engine/serializer_deserializer.py +3 -2
  91. mindspore/dataset/engine/validators.py +53 -11
  92. mindspore/dataset/text/__init__.py +7 -6
  93. mindspore/dataset/text/transforms.py +6 -5
  94. mindspore/dataset/text/utils.py +3 -3
  95. mindspore/dataset/transforms/__init__.py +0 -9
  96. mindspore/dataset/transforms/py_transforms_util.py +17 -0
  97. mindspore/dataset/transforms/transforms.py +31 -14
  98. mindspore/dataset/utils/browse_dataset.py +1 -1
  99. mindspore/dataset/vision/__init__.py +2 -9
  100. mindspore/dataset/vision/transforms.py +202 -158
  101. mindspore/dataset/vision/utils.py +7 -5
  102. mindspore/dataset/vision/validators.py +1 -2
  103. mindspore/device_context/__init__.py +21 -0
  104. mindspore/device_context/ascend/__init__.py +25 -0
  105. mindspore/device_context/ascend/device.py +72 -0
  106. mindspore/device_context/ascend/op_debug.py +153 -0
  107. mindspore/device_context/ascend/op_precision.py +193 -0
  108. mindspore/device_context/ascend/op_tuning.py +123 -0
  109. mindspore/{ops_generate/gen_constants.py → device_context/cpu/__init__.py} +6 -17
  110. mindspore/device_context/cpu/device.py +62 -0
  111. mindspore/device_context/cpu/op_tuning.py +43 -0
  112. mindspore/device_context/gpu/__init__.py +21 -0
  113. mindspore/device_context/gpu/device.py +70 -0
  114. mindspore/device_context/gpu/op_precision.py +67 -0
  115. mindspore/device_context/gpu/op_tuning.py +175 -0
  116. mindspore/device_manager.py +170 -0
  117. mindspore/dnnl.dll +0 -0
  118. mindspore/dpcmi.dll +0 -0
  119. mindspore/experimental/es/embedding_service.py +35 -27
  120. mindspore/experimental/llm_boost/__init__.py +1 -0
  121. mindspore/experimental/llm_boost/ascend_native/__init__.py +22 -0
  122. mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +211 -0
  123. mindspore/experimental/llm_boost/ascend_native/llm_boost.py +52 -0
  124. mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
  125. mindspore/experimental/llm_boost/atb/llama_boost.py +6 -1
  126. mindspore/experimental/llm_boost/register.py +1 -0
  127. mindspore/experimental/map_parameter.py +4 -4
  128. mindspore/experimental/optim/adadelta.py +6 -6
  129. mindspore/experimental/optim/adagrad.py +4 -4
  130. mindspore/experimental/optim/adam.py +7 -0
  131. mindspore/experimental/optim/adamax.py +4 -4
  132. mindspore/experimental/optim/adamw.py +4 -0
  133. mindspore/experimental/optim/asgd.py +1 -1
  134. mindspore/experimental/optim/lr_scheduler.py +73 -46
  135. mindspore/experimental/optim/radam.py +34 -31
  136. mindspore/experimental/optim/rprop.py +1 -1
  137. mindspore/experimental/optim/sgd.py +1 -1
  138. mindspore/hal/contiguous_tensors_handle.py +6 -10
  139. mindspore/hal/device.py +55 -53
  140. mindspore/hal/event.py +52 -52
  141. mindspore/hal/memory.py +157 -117
  142. mindspore/hal/stream.py +150 -109
  143. mindspore/include/api/context.h +0 -1
  144. mindspore/include/dataset/constants.h +7 -4
  145. mindspore/include/dataset/execute.h +2 -2
  146. mindspore/jpeg62.dll +0 -0
  147. mindspore/log.py +50 -0
  148. mindspore/mindrecord/__init__.py +21 -8
  149. mindspore/mindrecord/config.py +17 -316
  150. mindspore/mindrecord/filereader.py +1 -9
  151. mindspore/mindrecord/filewriter.py +5 -15
  152. mindspore/mindrecord/mindpage.py +1 -9
  153. mindspore/mindspore_backend_common.dll +0 -0
  154. mindspore/mindspore_backend_manager.dll +0 -0
  155. mindspore/mindspore_common.dll +0 -0
  156. mindspore/mindspore_core.dll +0 -0
  157. mindspore/mindspore_dump.dll +0 -0
  158. mindspore/mindspore_frontend.dll +0 -0
  159. mindspore/mindspore_glog.dll +0 -0
  160. mindspore/mindspore_memory_pool.dll +0 -0
  161. mindspore/mindspore_ms_backend.dll +0 -0
  162. mindspore/mindspore_ops.dll +0 -0
  163. mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
  164. mindspore/mindspore_ops_kernel_common.dll +0 -0
  165. mindspore/mindspore_profiler.dll +0 -0
  166. mindspore/mindspore_pyboost.dll +0 -0
  167. mindspore/mindspore_pynative.dll +0 -0
  168. mindspore/mindspore_res_manager.dll +0 -0
  169. mindspore/mindspore_runtime_pipeline.dll +0 -0
  170. mindspore/mint/__init__.py +796 -759
  171. mindspore/mint/distributed/__init__.py +70 -4
  172. mindspore/mint/distributed/distributed.py +2679 -44
  173. mindspore/mint/linalg/__init__.py +8 -0
  174. mindspore/mint/nn/__init__.py +743 -22
  175. mindspore/mint/nn/functional.py +716 -23
  176. mindspore/mint/nn/layer/__init__.py +21 -4
  177. mindspore/mint/nn/layer/_functions.py +334 -0
  178. mindspore/mint/nn/layer/activation.py +276 -1
  179. mindspore/mint/nn/layer/basic.py +123 -0
  180. mindspore/mint/nn/layer/conv.py +921 -0
  181. mindspore/mint/nn/layer/normalization.py +223 -28
  182. mindspore/mint/nn/layer/padding.py +797 -0
  183. mindspore/mint/nn/layer/pooling.py +235 -0
  184. mindspore/mint/optim/__init__.py +3 -1
  185. mindspore/mint/optim/adam.py +223 -0
  186. mindspore/mint/optim/adamw.py +26 -19
  187. mindspore/mint/optim/sgd.py +171 -0
  188. mindspore/mint/special/__init__.py +2 -1
  189. mindspore/msobj140.dll +0 -0
  190. mindspore/mspdb140.dll +0 -0
  191. mindspore/mspdbcore.dll +0 -0
  192. mindspore/mspdbst.dll +0 -0
  193. mindspore/mspft140.dll +0 -0
  194. mindspore/msvcdis140.dll +0 -0
  195. mindspore/msvcp140_1.dll +0 -0
  196. mindspore/msvcp140_2.dll +0 -0
  197. mindspore/msvcp140_atomic_wait.dll +0 -0
  198. mindspore/msvcp140_codecvt_ids.dll +0 -0
  199. mindspore/multiprocessing/__init__.py +5 -0
  200. mindspore/nn/__init__.py +4 -1
  201. mindspore/nn/cell.py +1370 -189
  202. mindspore/nn/dynamic_lr.py +2 -1
  203. mindspore/nn/layer/activation.py +29 -27
  204. mindspore/nn/layer/basic.py +51 -35
  205. mindspore/nn/layer/channel_shuffle.py +3 -3
  206. mindspore/nn/layer/container.py +1 -1
  207. mindspore/nn/layer/conv.py +22 -17
  208. mindspore/nn/layer/embedding.py +12 -11
  209. mindspore/nn/layer/normalization.py +56 -49
  210. mindspore/nn/layer/padding.py +4 -3
  211. mindspore/nn/layer/pooling.py +120 -42
  212. mindspore/nn/layer/rnn_cells.py +1 -1
  213. mindspore/nn/layer/rnns.py +2 -1
  214. mindspore/nn/layer/timedistributed.py +5 -5
  215. mindspore/nn/layer/transformer.py +59 -36
  216. mindspore/nn/learning_rate_schedule.py +8 -4
  217. mindspore/nn/loss/loss.py +58 -55
  218. mindspore/nn/optim/ada_grad.py +7 -5
  219. mindspore/nn/optim/adadelta.py +11 -9
  220. mindspore/nn/optim/adafactor.py +1 -1
  221. mindspore/nn/optim/adam.py +17 -13
  222. mindspore/nn/optim/adamax.py +8 -7
  223. mindspore/nn/optim/adasum.py +5 -5
  224. mindspore/nn/optim/asgd.py +1 -1
  225. mindspore/nn/optim/ftrl.py +11 -9
  226. mindspore/nn/optim/lamb.py +1 -1
  227. mindspore/nn/optim/lars.py +1 -4
  228. mindspore/nn/optim/lazyadam.py +12 -10
  229. mindspore/nn/optim/momentum.py +7 -6
  230. mindspore/nn/optim/optimizer.py +3 -3
  231. mindspore/nn/optim/proximal_ada_grad.py +12 -10
  232. mindspore/nn/optim/rmsprop.py +13 -12
  233. mindspore/nn/optim/rprop.py +11 -9
  234. mindspore/nn/optim/sgd.py +9 -6
  235. mindspore/nn/optim/tft_wrapper.py +5 -2
  236. mindspore/nn/optim/thor.py +2 -1
  237. mindspore/nn/probability/bijector/bijector.py +17 -11
  238. mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
  239. mindspore/nn/probability/bijector/invert.py +2 -2
  240. mindspore/nn/probability/bijector/scalar_affine.py +3 -3
  241. mindspore/nn/probability/bijector/softplus.py +3 -2
  242. mindspore/nn/probability/distribution/beta.py +3 -3
  243. mindspore/nn/probability/distribution/categorical.py +1 -1
  244. mindspore/nn/probability/distribution/cauchy.py +4 -2
  245. mindspore/nn/probability/distribution/exponential.py +6 -7
  246. mindspore/nn/probability/distribution/gamma.py +2 -2
  247. mindspore/nn/probability/distribution/gumbel.py +2 -2
  248. mindspore/nn/probability/distribution/half_normal.py +5 -3
  249. mindspore/nn/probability/distribution/logistic.py +5 -3
  250. mindspore/nn/probability/distribution/poisson.py +1 -1
  251. mindspore/nn/probability/distribution/uniform.py +5 -3
  252. mindspore/nn/reinforcement/_tensors_queue.py +1 -1
  253. mindspore/nn/reinforcement/tensor_array.py +1 -1
  254. mindspore/nn/utils/init.py +13 -11
  255. mindspore/nn/wrap/__init__.py +6 -6
  256. mindspore/nn/wrap/cell_wrapper.py +181 -122
  257. mindspore/nn/wrap/grad_reducer.py +45 -36
  258. mindspore/nn/wrap/loss_scale.py +6 -7
  259. mindspore/numpy/array_creations.py +63 -65
  260. mindspore/numpy/array_ops.py +149 -144
  261. mindspore/numpy/logic_ops.py +41 -42
  262. mindspore/numpy/math_ops.py +365 -363
  263. mindspore/numpy/utils.py +17 -18
  264. mindspore/numpy/utils_const.py +5 -6
  265. mindspore/opencv_core452.dll +0 -0
  266. mindspore/opencv_imgcodecs452.dll +0 -0
  267. mindspore/opencv_imgproc452.dll +0 -0
  268. mindspore/ops/__init__.py +5 -3
  269. mindspore/ops/_grad_experimental/grad_comm_ops.py +112 -16
  270. mindspore/ops/_grad_experimental/grad_debug_ops.py +14 -2
  271. mindspore/ops/_grad_experimental/grad_inner_ops.py +9 -0
  272. mindspore/ops/_grad_experimental/grad_math_ops.py +2 -1
  273. mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
  274. mindspore/ops/_op_impl/cpu/__init__.py +1 -0
  275. mindspore/ops/_op_impl/cpu/raise_op.py +28 -0
  276. mindspore/ops/_register_for_op.py +0 -11
  277. mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
  278. mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -65
  279. mindspore/ops/_vmap/vmap_array_ops.py +27 -25
  280. mindspore/ops/_vmap/vmap_base.py +0 -2
  281. mindspore/ops/_vmap/vmap_grad_nn_ops.py +21 -14
  282. mindspore/ops/_vmap/vmap_math_ops.py +15 -16
  283. mindspore/ops/_vmap/vmap_nn_ops.py +29 -42
  284. mindspore/ops/auto_generate/__init__.py +4 -3
  285. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +236 -46
  286. mindspore/ops/auto_generate/gen_extend_func.py +764 -124
  287. mindspore/ops/auto_generate/gen_ops_def.py +4018 -2264
  288. mindspore/ops/auto_generate/gen_ops_prim.py +15463 -5037
  289. mindspore/ops/auto_generate/pyboost_inner_prim.py +221 -87
  290. mindspore/ops/composite/__init__.py +2 -1
  291. mindspore/ops/composite/base.py +20 -25
  292. mindspore/ops/composite/math_ops.py +6 -16
  293. mindspore/ops/composite/multitype_ops/__init__.py +5 -2
  294. mindspore/ops/composite/multitype_ops/_compile_utils.py +228 -30
  295. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
  296. mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
  297. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
  298. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
  299. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
  300. mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
  301. mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
  302. mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
  303. mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
  304. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
  305. mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
  306. mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
  307. mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
  308. mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
  309. mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
  310. mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
  311. mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
  312. mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
  313. mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
  314. mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
  315. mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
  316. mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
  317. mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
  318. mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
  319. mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
  320. mindspore/ops/composite/multitype_ops/pow_impl.py +2 -30
  321. mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
  322. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  323. mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
  324. mindspore/ops/function/__init__.py +40 -2
  325. mindspore/ops/function/_add_attr_func.py +58 -0
  326. mindspore/ops/function/array_func.py +2089 -2403
  327. mindspore/ops/function/clip_func.py +80 -23
  328. mindspore/ops/function/debug_func.py +57 -57
  329. mindspore/ops/function/grad/__init__.py +1 -0
  330. mindspore/ops/function/grad/grad_func.py +104 -71
  331. mindspore/ops/function/image_func.py +2 -2
  332. mindspore/ops/function/linalg_func.py +47 -78
  333. mindspore/ops/function/math_func.py +4501 -3802
  334. mindspore/ops/function/nn_func.py +1726 -620
  335. mindspore/ops/function/other_func.py +159 -1
  336. mindspore/ops/function/parameter_func.py +18 -84
  337. mindspore/ops/function/random_func.py +440 -387
  338. mindspore/ops/function/reshard_func.py +4 -70
  339. mindspore/ops/function/sparse_func.py +3 -3
  340. mindspore/ops/function/sparse_unary_func.py +6 -6
  341. mindspore/ops/function/spectral_func.py +25 -58
  342. mindspore/ops/function/vmap_func.py +24 -17
  343. mindspore/ops/functional.py +22 -7
  344. mindspore/ops/functional_overload.py +1440 -0
  345. mindspore/ops/op_info_register.py +32 -244
  346. mindspore/ops/operations/__init__.py +13 -7
  347. mindspore/ops/operations/_custom_ops_utils.py +247 -0
  348. mindspore/ops/operations/_embedding_cache_ops.py +4 -4
  349. mindspore/ops/operations/_grad_ops.py +2 -43
  350. mindspore/ops/operations/_infer_ops.py +2 -1
  351. mindspore/ops/operations/_inner_ops.py +43 -84
  352. mindspore/ops/operations/_ms_kernel.py +4 -10
  353. mindspore/ops/operations/_rl_inner_ops.py +1 -1
  354. mindspore/ops/operations/_scalar_ops.py +3 -2
  355. mindspore/ops/operations/_sequence_ops.py +1 -1
  356. mindspore/ops/operations/_tensor_array.py +1 -1
  357. mindspore/ops/operations/array_ops.py +81 -324
  358. mindspore/ops/operations/comm_ops.py +154 -108
  359. mindspore/ops/operations/custom_ops.py +232 -78
  360. mindspore/ops/operations/debug_ops.py +153 -59
  361. mindspore/ops/operations/inner_ops.py +7 -5
  362. mindspore/ops/operations/linalg_ops.py +1 -57
  363. mindspore/ops/operations/manually_defined/_inner.py +1 -1
  364. mindspore/ops/operations/manually_defined/ops_def.py +928 -180
  365. mindspore/ops/operations/math_ops.py +32 -234
  366. mindspore/ops/operations/nn_ops.py +210 -498
  367. mindspore/ops/operations/other_ops.py +62 -9
  368. mindspore/ops/operations/random_ops.py +13 -7
  369. mindspore/ops/operations/reshard_ops.py +1 -1
  370. mindspore/ops/operations/sparse_ops.py +2 -2
  371. mindspore/ops/primitive.py +66 -53
  372. mindspore/ops/tensor_method.py +1888 -0
  373. mindspore/ops_generate/__init__.py +0 -5
  374. mindspore/ops_generate/aclnn/__init__.py +0 -0
  375. mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +135 -0
  376. mindspore/ops_generate/aclnn/gen_aclnn_implement.py +257 -0
  377. mindspore/ops_generate/api/__init__.py +0 -0
  378. mindspore/ops_generate/api/add_tensor_docs_generator.py +56 -0
  379. mindspore/ops_generate/api/cpp_create_prim_instance_helper_generator.py +105 -0
  380. mindspore/ops_generate/api/functional_map_cpp_generator.py +504 -0
  381. mindspore/ops_generate/api/functional_overload_py_generator.py +112 -0
  382. mindspore/ops_generate/api/functions_cc_generator.py +237 -0
  383. mindspore/ops_generate/api/gen_api.py +103 -0
  384. mindspore/ops_generate/api/op_api_proto.py +235 -0
  385. mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +461 -0
  386. mindspore/ops_generate/common/__init__.py +0 -0
  387. mindspore/ops_generate/common/base_generator.py +11 -0
  388. mindspore/ops_generate/common/gen_constants.py +91 -0
  389. mindspore/ops_generate/common/gen_utils.py +348 -0
  390. mindspore/ops_generate/common/op_proto.py +473 -0
  391. mindspore/ops_generate/common/template.py +523 -0
  392. mindspore/ops_generate/gen_ops.py +22 -1069
  393. mindspore/ops_generate/op_def/__init__.py +0 -0
  394. mindspore/ops_generate/op_def/gen_op_def.py +90 -0
  395. mindspore/ops_generate/op_def/lite_ops_cpp_generator.py +191 -0
  396. mindspore/ops_generate/op_def/ops_def_cc_generator.py +299 -0
  397. mindspore/ops_generate/op_def/ops_def_h_generator.py +74 -0
  398. mindspore/ops_generate/op_def/ops_name_h_generator.py +83 -0
  399. mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
  400. mindspore/ops_generate/op_def_py/__init__.py +0 -0
  401. mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
  402. mindspore/ops_generate/op_def_py/op_def_py_generator.py +132 -0
  403. mindspore/ops_generate/op_def_py/op_prim_py_generator.py +489 -0
  404. mindspore/ops_generate/pyboost/__init__.py +0 -0
  405. mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +139 -0
  406. mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +93 -0
  407. mindspore/ops_generate/pyboost/gen_pyboost_func.py +175 -0
  408. mindspore/ops_generate/pyboost/op_template_parser.py +517 -0
  409. mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +407 -0
  410. mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +100 -0
  411. mindspore/ops_generate/pyboost/pyboost_functions_py_generator.py +148 -0
  412. mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +155 -0
  413. mindspore/ops_generate/pyboost/pyboost_inner_prim_generator.py +132 -0
  414. mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +272 -0
  415. mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +938 -0
  416. mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +357 -0
  417. mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +179 -36
  418. mindspore/ops_generate/resources/__init__.py +0 -0
  419. mindspore/ops_generate/resources/resource_list.py +30 -0
  420. mindspore/ops_generate/resources/resource_loader.py +36 -0
  421. mindspore/ops_generate/resources/resource_manager.py +64 -0
  422. mindspore/ops_generate/resources/yaml_loader.py +88 -0
  423. mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
  424. mindspore/parallel/__init__.py +7 -3
  425. mindspore/parallel/_auto_parallel_context.py +152 -34
  426. mindspore/parallel/_cell_wrapper.py +130 -15
  427. mindspore/parallel/_parallel_serialization.py +107 -5
  428. mindspore/parallel/_ps_context.py +1 -1
  429. mindspore/parallel/_recovery_context.py +7 -2
  430. mindspore/parallel/_tensor.py +142 -18
  431. mindspore/parallel/_utils.py +199 -23
  432. mindspore/parallel/algo_parameter_config.py +4 -4
  433. mindspore/parallel/auto_parallel.py +732 -0
  434. mindspore/parallel/checkpoint_convert.py +159 -0
  435. mindspore/parallel/checkpoint_transform.py +698 -35
  436. mindspore/parallel/cluster/process_entity/_api.py +276 -50
  437. mindspore/parallel/cluster/process_entity/_utils.py +41 -6
  438. mindspore/parallel/cluster/run.py +21 -4
  439. mindspore/parallel/function/__init__.py +24 -0
  440. mindspore/parallel/function/reshard_func.py +259 -0
  441. mindspore/parallel/nn/__init__.py +25 -0
  442. mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
  443. mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
  444. mindspore/parallel/parameter_broadcast.py +25 -14
  445. mindspore/parallel/shard.py +137 -58
  446. mindspore/parallel/transform_safetensors.py +363 -305
  447. mindspore/pgodb140.dll +0 -0
  448. mindspore/pgort140.dll +0 -0
  449. mindspore/profiler/__init__.py +22 -5
  450. mindspore/profiler/analysis/__init__.py +0 -0
  451. mindspore/profiler/analysis/parser/__init__.py +0 -0
  452. mindspore/profiler/analysis/parser/ascend_cann_parser.py +170 -0
  453. mindspore/profiler/analysis/parser/base_parser.py +158 -0
  454. mindspore/profiler/analysis/parser/framework_cann_relation_parser.py +45 -0
  455. mindspore/profiler/analysis/parser/ms_framework_parser.py +142 -0
  456. mindspore/profiler/analysis/parser/ms_minddata_parser.py +145 -0
  457. mindspore/profiler/analysis/parser/timeline_assembly_factory/__init__.py +0 -0
  458. mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +264 -0
  459. mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +40 -0
  460. mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +106 -0
  461. mindspore/profiler/analysis/parser/timeline_creator/__init__.py +0 -0
  462. mindspore/profiler/analysis/parser/timeline_creator/base_timeline_creator.py +44 -0
  463. mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +90 -0
  464. mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +76 -0
  465. mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +103 -0
  466. mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +134 -0
  467. mindspore/profiler/analysis/parser/timeline_event/__init__.py +0 -0
  468. mindspore/profiler/analysis/parser/timeline_event/base_event.py +233 -0
  469. mindspore/profiler/analysis/parser/timeline_event/cpu_op_event.py +47 -0
  470. mindspore/profiler/analysis/parser/timeline_event/flow_event.py +36 -0
  471. mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +415 -0
  472. mindspore/profiler/analysis/parser/timeline_event/msprof_event.py +73 -0
  473. mindspore/profiler/analysis/parser/timeline_event/scope_layer_event.py +53 -0
  474. mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +146 -0
  475. mindspore/profiler/analysis/task_manager.py +131 -0
  476. mindspore/profiler/analysis/time_converter.py +84 -0
  477. mindspore/profiler/analysis/viewer/__init__.py +0 -0
  478. mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +372 -0
  479. mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +87 -0
  480. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +250 -0
  481. mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +320 -0
  482. mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +327 -0
  483. mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +376 -0
  484. mindspore/profiler/analysis/viewer/ascend_timeline_viewer.py +58 -0
  485. mindspore/profiler/analysis/viewer/base_viewer.py +26 -0
  486. mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +96 -0
  487. mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +581 -0
  488. mindspore/profiler/analysis/work_flow.py +73 -0
  489. mindspore/profiler/common/ascend_msprof_exporter.py +139 -0
  490. mindspore/profiler/common/command_executor.py +90 -0
  491. mindspore/profiler/common/constant.py +186 -3
  492. mindspore/profiler/common/file_manager.py +208 -0
  493. mindspore/profiler/common/log.py +130 -0
  494. mindspore/profiler/common/msprof_cmd_tool.py +221 -0
  495. mindspore/profiler/common/path_manager.py +395 -0
  496. mindspore/profiler/common/process_bar.py +168 -0
  497. mindspore/profiler/common/process_pool.py +9 -3
  498. mindspore/profiler/common/profiler_context.py +500 -0
  499. mindspore/profiler/common/profiler_info.py +304 -0
  500. mindspore/profiler/common/profiler_meta_data.py +74 -0
  501. mindspore/profiler/common/profiler_output_path.py +284 -0
  502. mindspore/profiler/common/profiler_parameters.py +251 -0
  503. mindspore/profiler/common/profiler_path_manager.py +179 -0
  504. mindspore/profiler/common/record_function.py +76 -0
  505. mindspore/profiler/common/tlv_decoder.py +76 -0
  506. mindspore/profiler/common/util.py +75 -2
  507. mindspore/profiler/dynamic_profiler.py +341 -75
  508. mindspore/profiler/envprofiler.py +163 -0
  509. mindspore/profiler/experimental_config.py +197 -0
  510. mindspore/profiler/mstx.py +242 -0
  511. mindspore/profiler/platform/__init__.py +21 -0
  512. mindspore/profiler/platform/base_profiler.py +40 -0
  513. mindspore/profiler/platform/cpu_profiler.py +124 -0
  514. mindspore/profiler/platform/gpu_profiler.py +74 -0
  515. mindspore/profiler/platform/npu_profiler.py +335 -0
  516. mindspore/profiler/profiler.py +1073 -90
  517. mindspore/profiler/profiler_action_controller.py +187 -0
  518. mindspore/profiler/profiler_interface.py +118 -0
  519. mindspore/profiler/schedule.py +243 -0
  520. mindspore/rewrite/api/node.py +15 -13
  521. mindspore/rewrite/api/symbol_tree.py +2 -3
  522. mindspore/run_check/_check_version.py +27 -20
  523. mindspore/run_check/run_check.py +1 -1
  524. mindspore/runtime/__init__.py +37 -0
  525. mindspore/runtime/device.py +27 -0
  526. mindspore/runtime/event.py +209 -0
  527. mindspore/runtime/executor.py +177 -0
  528. mindspore/runtime/memory.py +409 -0
  529. mindspore/runtime/stream.py +460 -0
  530. mindspore/runtime/thread_bind_core.py +401 -0
  531. mindspore/safeguard/rewrite_obfuscation.py +12 -9
  532. mindspore/swresample-4.dll +0 -0
  533. mindspore/swscale-6.dll +0 -0
  534. mindspore/tbbmalloc.dll +0 -0
  535. mindspore/tinyxml2.dll +0 -0
  536. mindspore/train/__init__.py +8 -8
  537. mindspore/train/_utils.py +88 -25
  538. mindspore/train/amp.py +9 -5
  539. mindspore/train/callback/__init__.py +2 -2
  540. mindspore/train/callback/_callback.py +2 -16
  541. mindspore/train/callback/_checkpoint.py +53 -55
  542. mindspore/train/callback/_cluster_monitor.py +14 -18
  543. mindspore/train/callback/_early_stop.py +1 -1
  544. mindspore/train/callback/_flops_collector.py +103 -68
  545. mindspore/train/callback/_history.py +8 -5
  546. mindspore/train/callback/_lambda_callback.py +2 -2
  547. mindspore/train/callback/_landscape.py +0 -3
  548. mindspore/train/callback/_loss_monitor.py +2 -1
  549. mindspore/train/callback/_on_request_exit.py +6 -5
  550. mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
  551. mindspore/train/callback/_summary_collector.py +52 -19
  552. mindspore/train/callback/_time_monitor.py +2 -1
  553. mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +204 -107
  554. mindspore/train/data_sink.py +25 -2
  555. mindspore/train/dataset_helper.py +15 -16
  556. mindspore/train/loss_scale_manager.py +8 -7
  557. mindspore/train/metrics/accuracy.py +3 -3
  558. mindspore/train/metrics/confusion_matrix.py +9 -9
  559. mindspore/train/metrics/error.py +3 -3
  560. mindspore/train/metrics/hausdorff_distance.py +4 -4
  561. mindspore/train/metrics/mean_surface_distance.py +3 -3
  562. mindspore/train/metrics/metric.py +0 -12
  563. mindspore/train/metrics/occlusion_sensitivity.py +4 -2
  564. mindspore/train/metrics/precision.py +11 -10
  565. mindspore/train/metrics/recall.py +9 -9
  566. mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
  567. mindspore/train/mind_ir_pb2.py +174 -46
  568. mindspore/train/model.py +184 -113
  569. mindspore/train/serialization.py +622 -978
  570. mindspore/train/summary/_summary_adapter.py +2 -2
  571. mindspore/train/summary/summary_record.py +2 -3
  572. mindspore/train/train_thor/model_thor.py +1 -1
  573. mindspore/turbojpeg.dll +0 -0
  574. mindspore/utils/__init__.py +6 -3
  575. mindspore/utils/dryrun.py +140 -0
  576. mindspore/utils/hooks.py +81 -0
  577. mindspore/utils/runtime_execution_order_check.py +550 -0
  578. mindspore/utils/utils.py +138 -4
  579. mindspore/vcmeta.dll +0 -0
  580. mindspore/vcruntime140.dll +0 -0
  581. mindspore/vcruntime140_1.dll +0 -0
  582. mindspore/version.py +1 -1
  583. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/METADATA +3 -3
  584. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/RECORD +587 -418
  585. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/entry_points.txt +1 -1
  586. mindspore/_install_custom.py +0 -43
  587. mindspore/common/_register_for_adapter.py +0 -74
  588. mindspore/common/_tensor_overload.py +0 -139
  589. mindspore/mindspore_np_dtype.dll +0 -0
  590. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
  591. mindspore/ops/auto_generate/gen_arg_handler.py +0 -197
  592. mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
  593. mindspore/ops_generate/gen_aclnn_implement.py +0 -263
  594. mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
  595. mindspore/ops_generate/gen_pyboost_func.py +0 -1052
  596. mindspore/ops_generate/gen_utils.py +0 -209
  597. mindspore/ops_generate/op_proto.py +0 -145
  598. mindspore/ops_generate/template.py +0 -261
  599. mindspore/profiler/envprofiling.py +0 -254
  600. mindspore/profiler/profiling.py +0 -1926
  601. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/WHEEL +0 -0
  602. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,155 @@
1
+ # Copyright 2024 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+ """
16
+ This module defines the PyboostGradFunctionsGenerator class, which is responsible for generating PyBoost gradient
17
+ functions and saving the corresponding C++ source files. The generator parses operator prototypes and constructs
18
+ function definitions, includes necessary headers, and generates the registration code for the PyBoost functions.
19
+ """
20
+
21
+
22
+ import os
23
+
24
+ import common.template as template
25
+ from common.template import Template
26
+ from common.gen_utils import save_file
27
+ import common.gen_constants as K
28
+ from common.op_proto import OpProto
29
+ from common.base_generator import BaseGenerator
30
+ from pyboost import pyboost_utils
31
+
32
+ from .op_template_parser import OpTemplateParser
33
+
34
+
35
+ class PyboostGradFunctionsGenerator(BaseGenerator):
36
+ """
37
+ PyboostGradFunctionsGenerator generates PyBoost gradient functions based on operator prototypes (instances of
38
+ OpProto).
39
+
40
+ This class processes operator prototypes (`op_protos`) to generate PyBoost functions. It defines the function
41
+ body, handles value conversion, creates contiguous tensor values, and generates the necessary header includes and
42
+ registration code. The generated content is then saved to a specified location in the file system.
43
+ """
44
+
45
+ def __init__(self):
46
+ super().__init__()
47
+ self.pyboost_func_include_header_template = Template(
48
+ f'#include "{K.MS_PYBOOST_BASE_PATH}/auto_generate/${{operator_name}}.h"\n')
49
+ self.GEN_OPS_DEF_HEADER_TEMPLATE = template.GEN_OPS_DEF_HEADER_TEMPLATE
50
+ self.contiguous_template = Template(
51
+ "convert_$arg_name = runtime::ValueConverter::ContiguousTensorValue($device_target, convert_$arg_name);\n")
52
+
53
+ def generate(self, work_path, op_protos):
54
+ """
55
+ Generates the PyBoost gradient functions and writes them to the appropriate files.
56
+
57
+ This method processes a list of operator prototypes (`op_protos`), extracting necessary information such as
58
+ operator names, arguments, and conversion types. It uses this data to construct function bodies, includes,
59
+ and registration code. The generated content is saved to a specified path as a C++ source file.
60
+
61
+ Args:
62
+ work_path (str): The file path where the generated files will be saved.
63
+ op_protos (list): A list of operator prototypes containing information about the operators to be processed.
64
+
65
+ Returns:
66
+ None
67
+ """
68
+ pyboost_func_str = ''
69
+ pyboost_func_reg_def = ''
70
+ pyboost_func_include_headers_str = ''
71
+ for op_proto in op_protos:
72
+ if (op_proto.op_dispatch is None) or (not op_proto.op_dispatch.enable):
73
+ continue
74
+ op_parser = OpTemplateParser(op_proto)
75
+ op_pyboost_func_name = op_parser.get_pyboost_func_name()
76
+ operator_name = op_proto.op_name
77
+ op_name_str = op_proto.op_class.name
78
+ op_args_str = [op_arg.arg_name for op_arg in op_proto.op_args]
79
+ convert_value_type_str = self._convert_value_type(op_proto)
80
+ device_target = "op_runner_info->device_target"
81
+ convert_value_type_str += self._contiguous_tensor_value(op_proto, device_target)
82
+
83
+ call_args_str = []
84
+ for op_arg in op_proto.op_args:
85
+ call_arg = 'convert_' + op_arg.arg_name
86
+ call_args_str.append(call_arg)
87
+ pyboost_func_str += template.PYBOOST_GRAD_FUNCTION_TEMPLATE.replace(
88
+ func_name=op_pyboost_func_name,
89
+ op_name=op_name_str,
90
+ op_args=op_args_str,
91
+ convert_body=convert_value_type_str,
92
+ call_args=call_args_str)
93
+ pyboost_func_str = pyboost_func_str + template.NEW_LINE
94
+ pyboost_func_reg_def += template.REGISTER_PYBOOST_GRAD_DEFINE_TEMPLATE.replace(
95
+ pyboost_op_name=op_proto.op_class.name,
96
+ pyboost_cfunc_name=op_pyboost_func_name)
97
+ pyboost_func_include_headers_str += self.pyboost_func_include_header_template.replace(
98
+ operator_name=operator_name)
99
+
100
+ register_func_str = template.REGISTER_PYBOOST_GRAD_TEMPLATE.replace(register_func=pyboost_func_reg_def)
101
+ pyboost_func_file = \
102
+ template.PYBOOST_GRAD_HEADER_TEMPLATE.replace(include_op_header=pyboost_func_include_headers_str,
103
+ function_body=pyboost_func_str,
104
+ register_function_body=register_func_str)
105
+ save_path = os.path.join(work_path, K.PYBOOST_GRAD_FUNC_GEN_PATH)
106
+ file_name = "pyboost_grad_functions.cc"
107
+ save_file(save_path, file_name, pyboost_func_file)
108
+
109
+ def _convert_value_type(self, op_proto: OpProto) -> str:
110
+ """
111
+ Generates the code for converting the operator's input values to the required types.
112
+
113
+ This method iterates over the operator's arguments, checks if they are optional, and generates the appropriate
114
+ conversion code based on the argument's data type.
115
+
116
+ Args:
117
+ op_proto (OpProto): The operator prototype containing information about the operator's arguments.
118
+
119
+ Returns:
120
+ str: A string containing the code for converting the input values to the required types.
121
+ """
122
+ convert_template = Template(
123
+ "auto convert_$arg_name = ValueConverter::${convert_func}(op_runner_info->inputs[$arg_index]);\n")
124
+ parser_func_str = ''
125
+ for index, arg in enumerate(op_proto.op_args):
126
+ is_optional = pyboost_utils.is_optional_param(arg)
127
+ convert_type_str = pyboost_utils.get_value_convert_type_str(arg.arg_dtype, is_optional)
128
+ parser_func_str += convert_template.replace(arg_name=arg.arg_name, convert_func=convert_type_str,
129
+ arg_index=pyboost_utils.get_index(index))
130
+ return parser_func_str
131
+
132
+ def _contiguous_tensor_value(self, op_proto: OpProto, device_target: str) -> str:
133
+ """
134
+ Generates the code for converting tensors to contiguous format if required.
135
+
136
+ This method checks the data type of the operator's arguments and generates code for converting tensors
137
+ to contiguous format, which is necessary for certain types of tensors.
138
+
139
+ Args:
140
+ op_proto (OpProto): The operator prototype containing information about the operator's arguments.
141
+ device_target (str): The device target string used in the conversion code.
142
+
143
+ Returns:
144
+ str: A string containing the code for converting tensors to contiguous format.
145
+ If the operator is a view operation, an empty string is returned.
146
+ """
147
+ if op_proto.op_view:
148
+ return ''
149
+ contiguous_func_str = ''
150
+ need_contiguous_dtype = {'tensor', 'tuple[tensor]'}
151
+ for arg in op_proto.op_args:
152
+ if arg.arg_dtype not in need_contiguous_dtype:
153
+ continue
154
+ contiguous_func_str += self.contiguous_template.replace(arg_name=arg.arg_name, device_target=device_target)
155
+ return contiguous_func_str
@@ -0,0 +1,132 @@
1
+ # Copyright 2024 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+ """
16
+ This module defines the PyboostInnerPrimGenerator class, which is responsible for generating Python primitive
17
+ wrappers for Pyboost operations. The generator constructs Python function definitions based on operator prototypes,
18
+ generates necessary import statements, and writes the generated content into Python source files.
19
+
20
+ The primary functionality is to take operator prototypes, extract relevant fields, and create Python function wrappers
21
+ that can be used to call the Pyboost primitive implementations.
22
+ """
23
+
24
+ import os
25
+ import common.template as template
26
+ import common.gen_constants as K
27
+ from common.gen_utils import save_file
28
+
29
+ from common.op_proto import OpProto
30
+ from common.base_generator import BaseGenerator
31
+
32
+ from .op_template_parser import OpTemplateParser
33
+
34
+
35
+ class PyboostInnerPrimGenerator(BaseGenerator):
36
+ """
37
+ PyboostInnerPrimGenerator is responsible for generating Python primitive wrappers for Pyboost operators.
38
+
39
+ This class processes operator prototypes (`op_protos`) to generate Python function implementations. It handles the
40
+ inclusion of necessary headers, processes operator arguments, and creates Python functions that wrap Pyboost
41
+ primitives.
42
+
43
+ Attributes:
44
+ IMPORT_PYBOOST_PRIM_HEADER (Template): Template for importing Pyboost primitive headers.
45
+ PYBOOST_PY_FUNC_IMPORT_HEADER (Template): Template for importing Python functions related to Pyboost.
46
+ PYTHON_PRIM_TEMPLATE (Template): Template for generating Python primitive functions.
47
+ """
48
+
49
+ def __init__(self):
50
+ """
51
+ Initializes the PyboostInnerPrimGenerator class.
52
+
53
+ This constructor sets up the required templates for generating import headers, Python function imports,
54
+ and Python primitive function wrappers.
55
+ """
56
+ self.IMPORT_PYBOOST_PRIM_HEADER = template.IMPORT_PYBOOST_PRIM_HEADER
57
+ self.PYBOOST_PY_FUNC_IMPORT_HEADER = template.PYBOOST_PY_FUNC_IMPORT_HEADEAR
58
+ self.PYTHON_PRIM_TEMPLATE = template.PYTHON_PRIM_TEMPLATE
59
+
60
+ def generate(self, work_path, op_protos):
61
+ """
62
+ Generates Python wrappers for Pyboost primitives and writes them to a Python source file.
63
+
64
+ This method processes a list of operator prototypes (`op_protos`), extracting necessary information such as
65
+ class names, arguments, and handlers. It constructs Python function wrappers for the Pyboost primitives and
66
+ generates the required import statements. The generated Python code is saved to a specified path.
67
+
68
+ Args:
69
+ work_path (str): The file path where the generated Python file will be saved.
70
+ op_protos (list): A list of operator prototypes containing information about the operators to be processed.
71
+
72
+ Returns:
73
+ None
74
+ """
75
+ gen_py = ''
76
+ gen_header = template.PY_LICENSE_STR + self.IMPORT_PYBOOST_PRIM_HEADER
77
+ for op_proto in op_protos:
78
+ # only process pyboost enabled scenario
79
+ if op_proto.op_dispatch is None:
80
+ continue
81
+ if not op_proto.op_dispatch.enable:
82
+ continue
83
+ op_parser = OpTemplateParser(op_proto)
84
+ if not op_parser.has_prim_init():
85
+ continue
86
+
87
+ gen_header += self.PYBOOST_PY_FUNC_IMPORT_HEADER.replace(class_name=op_proto.op_class.name)
88
+ input_args, process_func, processed_args = self._get_fields_for_prim_tpl(op_proto)
89
+ gen_py += self.PYTHON_PRIM_TEMPLATE.replace(class_name=op_proto.op_class.name,
90
+ input_args=input_args,
91
+ process_func=process_func,
92
+ func_impl_name=op_proto.op_name,
93
+ processed_args=processed_args)
94
+
95
+ save_file(os.path.join(work_path, K.PY_AUTO_GEN_PATH), "pyboost_inner_prim.py", gen_header + gen_py)
96
+
97
+ def _get_fields_for_prim_tpl(self, op_proto: OpProto):
98
+ """
99
+ Extracts the necessary fields for the primitive template from the operator prototype.
100
+
101
+ This method processes the arguments of the operator prototype and generates the input arguments, the function
102
+ that handles argument processing, and the processed arguments list, which will be used in the final Python
103
+ function definition.
104
+
105
+ Args:
106
+ op_proto (OpProto): The operator prototype from which the argument data will be extracted.
107
+
108
+ Returns:
109
+ tuple: A tuple containing three return values required for the primitive template to be generated:
110
+ - input_args (list): List of input argument names for the Python function.
111
+ - process_func (str): String representing the argument processing logic for the function.
112
+ - processed_args (list): List of processed argument names used in the function call.
113
+ """
114
+ args = op_proto.op_args
115
+ operator_name = op_proto.op_name
116
+
117
+ input_args = []
118
+ process_func = ''
119
+ processed_args = []
120
+
121
+ for arg in args:
122
+ arg_name = arg.arg_name
123
+ arg_handler = arg.arg_handler
124
+ processed_arg = arg_name
125
+ if arg_handler != '' and arg_handler != 'dtype_to_type_id':
126
+ process_func += \
127
+ f"""converted_{arg_name} = {arg_handler}('{operator_name}', '{arg_name}', {arg_name})\n"""
128
+ processed_arg = 'converted_' + arg_name
129
+ input_args.append(arg_name)
130
+ processed_args.append(processed_arg)
131
+
132
+ return input_args, process_func, processed_args
@@ -0,0 +1,272 @@
1
+ # Copyright 2024 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+ """
16
+ This module defines the PyboostGradFunctionsCppGenerator and PyboostGradFunctionsHeaderGenerator classes,
17
+ which are responsible for generating C++ gradient function implementations and headers for PyBoost operations.
18
+
19
+ The PyboostGradFunctionsCppGenerator generates the actual function definitions for the gradient functions,
20
+ while the PyboostGradFunctionsHeaderGenerator creates the corresponding function declarations in header files.
21
+ """
22
+
23
+ import os
24
+
25
+ from pyboost import pyboost_utils
26
+ from pyboost.pyboost_utils import is_optional_param
27
+ import common.template as template
28
+ from common.template import Template
29
+ import common.gen_constants as K
30
+ from common.gen_utils import save_file
31
+ from common.op_proto import OpProto
32
+ from common.base_generator import BaseGenerator
33
+
34
+
35
+ class PyboostGradFunctionsCppGenerator(BaseGenerator):
36
+ """
37
+ PyboostGradFunctionsCppGenerator generates C++ implementations for PyBoost gradient functions.
38
+
39
+ This class processes operator prototypes (`op_protos`) to create C++ function definitions that
40
+ wrap PyBoost gradient functionality. It constructs function bodies, handles value conversion,
41
+ and generates necessary include statements for each operator.
42
+
43
+ Attributes:
44
+ PYBOOST_NATIVE_GRAD_FUNCTION_TEMPLATE (Template): Template for generating native gradient function definitions.
45
+ PYBOOST_NATIVE_GRAD_FUNCTIONS_TEMPLATE (Template): Template for generating the overall gradient functions file.
46
+ native_function_multi_output_template (Template): Template for handling multiple output functions.
47
+ native_function_single_output_template (str): Template for handling single output functions.
48
+ native_include_header_template (Template): Template for generating include headers for each operator.
49
+ convert_template (Template): Template for converting argument values to native types.
50
+ """
51
+
52
+ def __init__(self):
53
+ self.PYBOOST_NATIVE_GRAD_FUNCTION_TEMPLATE = template.PYBOOST_NATIVE_GRAD_FUNCTION_TEMPLATE
54
+ self.PYBOOST_NATIVE_GRAD_FUNCTIONS_TEMPLATE = template.PYBOOST_NATIVE_GRAD_FUNCTIONS_TEMPLATE
55
+ self.native_function_multi_output_template = template.MULTI_OUTPUT_TEMPLATE
56
+ self.native_function_single_output_template = "const auto &output_value = op->outputs()[0];\n"
57
+ self.native_include_header_template = Template(
58
+ f'#include "{K.MS_PYBOOST_BASE_PATH}/auto_generate/${{operator_name}}.h"\n')
59
+ self.convert_template = Template(
60
+ "auto convert_$arg_name = runtime::ValueConverter::${convert_func}(ConvertNode2Value($arg_name));\n")
61
+
62
+ def generate(self, work_path, op_protos):
63
+ """
64
+ Generates C++ gradient function implementations and writes them to a source file.
65
+
66
+ This method processes a list of operator prototypes (`op_protos`), extracting necessary information such as
67
+ operator names, arguments, and conversion types. It constructs C++ function bodies and generates the required
68
+ include statements. The generated content is saved to a specified path.
69
+
70
+ Args:
71
+ work_path (str): The file path where the generated C++ file will be saved.
72
+ op_protos (list): A list of operator prototypes containing information about the operators to be processed.
73
+
74
+ Returns:
75
+ None
76
+ """
77
+ pyboost_func_str = ''
78
+ pyboost_func_include_headers_str = ''
79
+ ops_inc_head_set = set()
80
+ for op_proto in op_protos:
81
+ if op_proto.op_dispatch is None:
82
+ continue
83
+ if not op_proto.op_dispatch.enable:
84
+ continue
85
+
86
+ op_args_str = [op_arg.arg_name for op_arg in op_proto.op_args]
87
+ convert_value_type_str = self._convert_native_value_type(op_proto)
88
+ convert_value_type_str += self._contiguous_tensor_value(op_proto, "device_target_")
89
+ call_args_str = self._get_call_args(op_proto)
90
+ call_args_with_type = self._get_call_args_with_type(op_proto)
91
+
92
+ first_var_name = op_proto.op_args[0].arg_name
93
+ output_expr = self._get_output_expr(op_proto)
94
+
95
+ pyboost_func_str += \
96
+ template.PYBOOST_NATIVE_GRAD_FUNCTION_TEMPLATE.replace(func_name=op_proto.op_class.name,
97
+ op_name=op_proto.op_class.name,
98
+ op_args=op_args_str,
99
+ convert_body=convert_value_type_str,
100
+ call_args=call_args_str,
101
+ call_args_with_type=call_args_with_type,
102
+ first_var_name=first_var_name,
103
+ output_expr=output_expr)
104
+ pyboost_func_str = pyboost_func_str + template.NEW_LINE
105
+ pyboost_func_include_headers_str += (
106
+ self.native_include_header_template.replace(operator_name=op_proto.op_name))
107
+ ops_inc_head_set.add(
108
+ template.OP_DEF_INC_HEAD_TEMPLATE.replace(prefix_char=op_proto.op_class.name[0].lower()))
109
+ native_grad_func_file = \
110
+ self.PYBOOST_NATIVE_GRAD_FUNCTIONS_TEMPLATE.replace(include_op_header=pyboost_func_include_headers_str,
111
+ function_body=pyboost_func_str,
112
+ ops_inc=list(sorted(ops_inc_head_set)))
113
+ save_file(os.path.join(work_path, K.PYBOOST_NATIVE_GRAD_FUNC_GEN_PATH),
114
+ "pyboost_native_grad_functions.cc", native_grad_func_file)
115
+
116
+ def _convert_native_value_type(self, op_proto: OpProto) -> str:
117
+ """
118
+ Generates native value conversion functions for operator arguments.
119
+
120
+ This method processes each argument of the operator prototype and constructs conversion statements
121
+ based on the argument's data type.
122
+
123
+ Args:
124
+ op_proto (OpProto): The operator prototype from which the argument data will be extracted.
125
+
126
+ Returns:
127
+ str: A string containing the conversion statements for the operator's arguments.
128
+ """
129
+ parser_func_str = ''
130
+ for op_arg in op_proto.op_args:
131
+ is_optional = is_optional_param(op_arg)
132
+ convert_type_str = pyboost_utils.get_value_convert_type_str(op_arg.arg_dtype, is_optional)
133
+ parser_func_str += self.convert_template.replace(arg_name=op_arg.arg_name, convert_func=convert_type_str)
134
+ return parser_func_str
135
+
136
+ def _contiguous_tensor_value(self, op_proto: OpProto, device_target: str) -> str:
137
+ """
138
+ Generates contiguous tensor value conversion functions if applicable.
139
+
140
+ This method constructs conversion statements for tensors that need to be contiguous. If the operator is a view
141
+ operation, no conversion is performed.
142
+
143
+ Args:
144
+ op_proto (OpProto): The operator prototype that contains the argument data.
145
+ device_target (str): The device target to be used in the conversion statements.
146
+
147
+ Returns:
148
+ str: A string containing the contiguous tensor conversion statements.
149
+ """
150
+ if op_proto.op_view:
151
+ return ''
152
+ contiguous_template = Template(
153
+ "convert_$arg_name = runtime::ValueConverter::ContiguousTensorValue($device_target, convert_$arg_name);\n")
154
+ contiguous_func_str = ''
155
+ need_contiguous_dtype = {'tensor', 'tuple[tensor]'}
156
+ for op_arg in op_proto.op_args:
157
+ if op_arg.arg_dtype not in need_contiguous_dtype:
158
+ continue
159
+ contiguous_func_str += contiguous_template.replace(arg_name=op_arg.arg_name, device_target=device_target)
160
+ return contiguous_func_str
161
+
162
+ def _get_output_expr(self, op_proto: OpProto):
163
+ """
164
+ Determines the output expression based on the operator prototype.
165
+
166
+ This method checks if the operator produces multiple outputs and returns the corresponding output expression.
167
+
168
+ Args:
169
+ op_proto (OpProto): The operator prototype to evaluate.
170
+
171
+ Returns:
172
+ str: The output expression used in the function implementation.
173
+ """
174
+ output_expr = self.native_function_single_output_template
175
+ if pyboost_utils.is_op_multi_output(op_proto.op_returns):
176
+ output_expr = self.native_function_multi_output_template
177
+ return output_expr
178
+
179
+ def _get_call_args(self, op_proto: OpProto):
180
+ """
181
+ Generates the list of call arguments for the operator function.
182
+
183
+ This method constructs a list of argument names prefixed with 'convert_' for use in the function call.
184
+
185
+ Args:
186
+ op_proto (OpProto): The operator prototype containing the argument information.
187
+
188
+ Returns:
189
+ list: A list of formatted argument names to be used in the function call.
190
+ """
191
+ call_args_str = []
192
+ for op_arg in op_proto.op_args:
193
+ call_arg = 'convert_' + op_arg.arg_name
194
+ call_args_str.append(call_arg)
195
+ return call_args_str
196
+
197
+ def _get_call_args_with_type(self, op_proto: OpProto):
198
+ """
199
+ Generates the list of call arguments with type information.
200
+
201
+ This method constructs a list of argument declarations with the appropriate type for the function definition.
202
+
203
+ Args:
204
+ op_proto (OpProto): The operator prototype containing the argument information.
205
+
206
+ Returns:
207
+ list: A list of argument declarations with types for the function definition.
208
+ """
209
+ call_args_with_type = []
210
+ for op_arg in op_proto.op_args:
211
+ call_args_with_type.append('const NodePtr &' + op_arg.arg_name)
212
+ return call_args_with_type
213
+
214
+
215
+ class PyboostGradFunctionsHeaderGenerator(BaseGenerator):
216
+ """
217
+ PyboostGradFunctionsHeaderGenerator generates C++ header declarations for PyBoost gradient functions.
218
+
219
+ This class processes operator prototypes to create function declarations for the grad functions in a header file.
220
+ """
221
+
222
+ def __init__(self):
223
+ self.native_function_header_template = Template("static NodePtr $func_name(${call_args_with_type});\n")
224
+
225
+ def generate(self, work_path, op_protos):
226
+ """
227
+ Generates C++ header declarations for gradient functions and writes them to a header file.
228
+
229
+ This method processes a list of operator prototypes (`op_protos`), extracting necessary information
230
+ and constructing the function declarations. The generated content is saved to a specified path.
231
+
232
+ Args:
233
+ work_path (str): The file path where the generated header file will be saved.
234
+ op_protos (list): A list of operator prototypes containing information about the operators to be processed.
235
+
236
+ Returns:
237
+ None
238
+ """
239
+ native_function_headers_str = ''
240
+ for op_proto in op_protos:
241
+ if op_proto.op_dispatch is None:
242
+ continue
243
+ if not op_proto.op_dispatch.enable:
244
+ continue
245
+
246
+ call_args_with_type = self._get_call_args_with_type(op_proto)
247
+
248
+ func_header = self.native_function_header_template.replace(func_name=op_proto.op_class.name,
249
+ call_args_with_type=call_args_with_type)
250
+ native_function_headers_str += func_header
251
+ native_grad_func_header_file = template.PYBOOST_NATIVE_GRAD_FUNCTIONS_HEADER_TEMPLATE.replace(
252
+ native_grad_func_def=native_function_headers_str)
253
+
254
+ save_file(os.path.join(work_path, K.PYBOOST_NATIVE_GRAD_FUNC_GEN_PATH),
255
+ "pyboost_native_grad_functions.h", native_grad_func_header_file)
256
+
257
+ def _get_call_args_with_type(self, op_proto: OpProto):
258
+ """
259
+ Generates the list of call arguments with type information for the header declarations.
260
+
261
+ This method constructs a list of argument declarations with the appropriate type for the function declaration.
262
+
263
+ Args:
264
+ op_proto (OpProto): The operator prototype containing the argument information.
265
+
266
+ Returns:
267
+ list: A list of argument declarations with types for the function declaration.
268
+ """
269
+ call_args_with_type = []
270
+ for op_arg in op_proto.op_args:
271
+ call_args_with_type.append('const NodePtr &' + op_arg.arg_name)
272
+ return call_args_with_type