mindspore 2.4.10__cp310-cp310-win_amd64.whl → 2.6.0rc1__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (602) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  3. mindspore/Newtonsoft.Json.dll +0 -0
  4. mindspore/__init__.py +13 -6
  5. mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
  6. mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
  7. mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
  8. mindspore/_check_jit_forbidden_api.py +3 -0
  9. mindspore/_checkparam.py +3 -38
  10. mindspore/_deprecated/__init__.py +17 -0
  11. mindspore/_deprecated/jit.py +198 -0
  12. mindspore/_extends/builtin_operations.py +1 -1
  13. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
  14. mindspore/_extends/parse/__init__.py +6 -7
  15. mindspore/_extends/parse/compile_config.py +83 -0
  16. mindspore/_extends/parse/deprecated/__init__.py +0 -0
  17. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +394 -0
  18. mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
  19. mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
  20. mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
  21. mindspore/_extends/parse/parser.py +46 -197
  22. mindspore/_extends/parse/resources.py +1 -5
  23. mindspore/_extends/parse/standard_method.py +217 -98
  24. mindspore/_extends/pijit/__init__.py +2 -2
  25. mindspore/_extends/pijit/pijit_func_white_list.py +17 -12
  26. mindspore/_extends/pijit/tensor_func_list.py +27 -0
  27. mindspore/_extends/utils.py +1 -1
  28. mindspore/amp.py +11 -5
  29. mindspore/atlprov.dll +0 -0
  30. mindspore/avcodec-59.dll +0 -0
  31. mindspore/avdevice-59.dll +0 -0
  32. mindspore/avfilter-8.dll +0 -0
  33. mindspore/avformat-59.dll +0 -0
  34. mindspore/avutil-57.dll +0 -0
  35. mindspore/boost/__init__.py +2 -2
  36. mindspore/boost/base.py +3 -7
  37. mindspore/boost/boost_cell_wrapper.py +138 -43
  38. mindspore/c1.dll +0 -0
  39. mindspore/c1xx.dll +0 -0
  40. mindspore/c2.dll +0 -0
  41. mindspore/common/__init__.py +6 -3
  42. mindspore/common/_grad_function.py +56 -0
  43. mindspore/common/_pijit_context.py +14 -5
  44. mindspore/common/_register_for_tensor.py +1 -2
  45. mindspore/common/_stub_tensor.py +30 -14
  46. mindspore/common/_tensor_cpp_method.py +17 -0
  47. mindspore/common/_tensor_docs.py +4760 -0
  48. mindspore/common/api.py +435 -371
  49. mindspore/common/auto_dynamic_shape.py +41 -44
  50. mindspore/common/dtype.py +39 -36
  51. mindspore/common/dump.py +9 -6
  52. mindspore/common/file_system.py +9 -1
  53. mindspore/common/generator.py +2 -0
  54. mindspore/common/hook_handle.py +6 -2
  55. mindspore/common/initializer.py +13 -10
  56. mindspore/common/jit_begin_end.py +94 -0
  57. mindspore/common/jit_config.py +6 -1
  58. mindspore/common/jit_context.py +76 -0
  59. mindspore/common/jit_trace.py +378 -0
  60. mindspore/common/lazy_inline.py +9 -3
  61. mindspore/common/mindir_util.py +10 -2
  62. mindspore/common/mutable.py +5 -4
  63. mindspore/common/parameter.py +135 -52
  64. mindspore/common/seed.py +2 -2
  65. mindspore/common/sparse_tensor.py +23 -17
  66. mindspore/common/tensor.py +951 -1992
  67. mindspore/communication/__init__.py +7 -5
  68. mindspore/communication/_comm_helper.py +52 -2
  69. mindspore/communication/comm_func.py +240 -181
  70. mindspore/communication/management.py +95 -26
  71. mindspore/context.py +314 -566
  72. mindspore/dataset/__init__.py +65 -37
  73. mindspore/dataset/audio/__init__.py +2 -8
  74. mindspore/dataset/audio/transforms.py +3 -17
  75. mindspore/dataset/callback/ds_callback.py +2 -1
  76. mindspore/dataset/core/config.py +87 -6
  77. mindspore/dataset/engine/cache_admin.py +3 -3
  78. mindspore/dataset/engine/cache_client.py +6 -5
  79. mindspore/dataset/engine/datasets.py +292 -267
  80. mindspore/dataset/engine/datasets_audio.py +22 -8
  81. mindspore/dataset/engine/datasets_standard_format.py +46 -27
  82. mindspore/dataset/engine/datasets_text.py +78 -48
  83. mindspore/dataset/engine/datasets_user_defined.py +182 -116
  84. mindspore/dataset/engine/datasets_vision.py +120 -44
  85. mindspore/dataset/engine/iterators.py +283 -63
  86. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
  87. mindspore/dataset/engine/obs/util.py +8 -0
  88. mindspore/dataset/engine/queue.py +40 -0
  89. mindspore/dataset/engine/samplers.py +289 -43
  90. mindspore/dataset/engine/serializer_deserializer.py +3 -2
  91. mindspore/dataset/engine/validators.py +53 -11
  92. mindspore/dataset/text/__init__.py +7 -6
  93. mindspore/dataset/text/transforms.py +6 -5
  94. mindspore/dataset/text/utils.py +3 -3
  95. mindspore/dataset/transforms/__init__.py +0 -9
  96. mindspore/dataset/transforms/py_transforms_util.py +17 -0
  97. mindspore/dataset/transforms/transforms.py +31 -14
  98. mindspore/dataset/utils/browse_dataset.py +1 -1
  99. mindspore/dataset/vision/__init__.py +2 -9
  100. mindspore/dataset/vision/transforms.py +202 -158
  101. mindspore/dataset/vision/utils.py +7 -5
  102. mindspore/dataset/vision/validators.py +1 -2
  103. mindspore/device_context/__init__.py +21 -0
  104. mindspore/device_context/ascend/__init__.py +25 -0
  105. mindspore/device_context/ascend/device.py +72 -0
  106. mindspore/device_context/ascend/op_debug.py +153 -0
  107. mindspore/device_context/ascend/op_precision.py +193 -0
  108. mindspore/device_context/ascend/op_tuning.py +123 -0
  109. mindspore/{ops_generate/gen_constants.py → device_context/cpu/__init__.py} +6 -17
  110. mindspore/device_context/cpu/device.py +62 -0
  111. mindspore/device_context/cpu/op_tuning.py +43 -0
  112. mindspore/device_context/gpu/__init__.py +21 -0
  113. mindspore/device_context/gpu/device.py +70 -0
  114. mindspore/device_context/gpu/op_precision.py +67 -0
  115. mindspore/device_context/gpu/op_tuning.py +175 -0
  116. mindspore/device_manager.py +170 -0
  117. mindspore/dnnl.dll +0 -0
  118. mindspore/dpcmi.dll +0 -0
  119. mindspore/experimental/es/embedding_service.py +35 -27
  120. mindspore/experimental/llm_boost/__init__.py +1 -0
  121. mindspore/experimental/llm_boost/ascend_native/__init__.py +22 -0
  122. mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +211 -0
  123. mindspore/experimental/llm_boost/ascend_native/llm_boost.py +52 -0
  124. mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
  125. mindspore/experimental/llm_boost/atb/llama_boost.py +6 -1
  126. mindspore/experimental/llm_boost/register.py +1 -0
  127. mindspore/experimental/map_parameter.py +4 -4
  128. mindspore/experimental/optim/adadelta.py +6 -6
  129. mindspore/experimental/optim/adagrad.py +4 -4
  130. mindspore/experimental/optim/adam.py +7 -0
  131. mindspore/experimental/optim/adamax.py +4 -4
  132. mindspore/experimental/optim/adamw.py +4 -0
  133. mindspore/experimental/optim/asgd.py +1 -1
  134. mindspore/experimental/optim/lr_scheduler.py +73 -46
  135. mindspore/experimental/optim/radam.py +34 -31
  136. mindspore/experimental/optim/rprop.py +1 -1
  137. mindspore/experimental/optim/sgd.py +1 -1
  138. mindspore/hal/contiguous_tensors_handle.py +6 -10
  139. mindspore/hal/device.py +55 -53
  140. mindspore/hal/event.py +52 -52
  141. mindspore/hal/memory.py +157 -117
  142. mindspore/hal/stream.py +150 -109
  143. mindspore/include/api/context.h +0 -1
  144. mindspore/include/dataset/constants.h +7 -4
  145. mindspore/include/dataset/execute.h +2 -2
  146. mindspore/jpeg62.dll +0 -0
  147. mindspore/log.py +50 -0
  148. mindspore/mindrecord/__init__.py +21 -8
  149. mindspore/mindrecord/config.py +17 -316
  150. mindspore/mindrecord/filereader.py +1 -9
  151. mindspore/mindrecord/filewriter.py +5 -15
  152. mindspore/mindrecord/mindpage.py +1 -9
  153. mindspore/mindspore_backend_common.dll +0 -0
  154. mindspore/mindspore_backend_manager.dll +0 -0
  155. mindspore/mindspore_common.dll +0 -0
  156. mindspore/mindspore_core.dll +0 -0
  157. mindspore/mindspore_dump.dll +0 -0
  158. mindspore/mindspore_frontend.dll +0 -0
  159. mindspore/mindspore_glog.dll +0 -0
  160. mindspore/mindspore_memory_pool.dll +0 -0
  161. mindspore/mindspore_ms_backend.dll +0 -0
  162. mindspore/mindspore_ops.dll +0 -0
  163. mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
  164. mindspore/mindspore_ops_kernel_common.dll +0 -0
  165. mindspore/mindspore_profiler.dll +0 -0
  166. mindspore/mindspore_pyboost.dll +0 -0
  167. mindspore/mindspore_pynative.dll +0 -0
  168. mindspore/mindspore_res_manager.dll +0 -0
  169. mindspore/mindspore_runtime_pipeline.dll +0 -0
  170. mindspore/mint/__init__.py +796 -759
  171. mindspore/mint/distributed/__init__.py +70 -4
  172. mindspore/mint/distributed/distributed.py +2679 -44
  173. mindspore/mint/linalg/__init__.py +8 -0
  174. mindspore/mint/nn/__init__.py +743 -22
  175. mindspore/mint/nn/functional.py +716 -23
  176. mindspore/mint/nn/layer/__init__.py +21 -4
  177. mindspore/mint/nn/layer/_functions.py +334 -0
  178. mindspore/mint/nn/layer/activation.py +276 -1
  179. mindspore/mint/nn/layer/basic.py +123 -0
  180. mindspore/mint/nn/layer/conv.py +921 -0
  181. mindspore/mint/nn/layer/normalization.py +223 -28
  182. mindspore/mint/nn/layer/padding.py +797 -0
  183. mindspore/mint/nn/layer/pooling.py +235 -0
  184. mindspore/mint/optim/__init__.py +3 -1
  185. mindspore/mint/optim/adam.py +223 -0
  186. mindspore/mint/optim/adamw.py +26 -19
  187. mindspore/mint/optim/sgd.py +171 -0
  188. mindspore/mint/special/__init__.py +2 -1
  189. mindspore/msobj140.dll +0 -0
  190. mindspore/mspdb140.dll +0 -0
  191. mindspore/mspdbcore.dll +0 -0
  192. mindspore/mspdbst.dll +0 -0
  193. mindspore/mspft140.dll +0 -0
  194. mindspore/msvcdis140.dll +0 -0
  195. mindspore/msvcp140_1.dll +0 -0
  196. mindspore/msvcp140_2.dll +0 -0
  197. mindspore/msvcp140_atomic_wait.dll +0 -0
  198. mindspore/msvcp140_codecvt_ids.dll +0 -0
  199. mindspore/multiprocessing/__init__.py +5 -0
  200. mindspore/nn/__init__.py +4 -1
  201. mindspore/nn/cell.py +1370 -189
  202. mindspore/nn/dynamic_lr.py +2 -1
  203. mindspore/nn/layer/activation.py +29 -27
  204. mindspore/nn/layer/basic.py +51 -35
  205. mindspore/nn/layer/channel_shuffle.py +3 -3
  206. mindspore/nn/layer/container.py +1 -1
  207. mindspore/nn/layer/conv.py +22 -17
  208. mindspore/nn/layer/embedding.py +12 -11
  209. mindspore/nn/layer/normalization.py +56 -49
  210. mindspore/nn/layer/padding.py +4 -3
  211. mindspore/nn/layer/pooling.py +120 -42
  212. mindspore/nn/layer/rnn_cells.py +1 -1
  213. mindspore/nn/layer/rnns.py +2 -1
  214. mindspore/nn/layer/timedistributed.py +5 -5
  215. mindspore/nn/layer/transformer.py +59 -36
  216. mindspore/nn/learning_rate_schedule.py +8 -4
  217. mindspore/nn/loss/loss.py +58 -55
  218. mindspore/nn/optim/ada_grad.py +7 -5
  219. mindspore/nn/optim/adadelta.py +11 -9
  220. mindspore/nn/optim/adafactor.py +1 -1
  221. mindspore/nn/optim/adam.py +17 -13
  222. mindspore/nn/optim/adamax.py +8 -7
  223. mindspore/nn/optim/adasum.py +5 -5
  224. mindspore/nn/optim/asgd.py +1 -1
  225. mindspore/nn/optim/ftrl.py +11 -9
  226. mindspore/nn/optim/lamb.py +1 -1
  227. mindspore/nn/optim/lars.py +1 -4
  228. mindspore/nn/optim/lazyadam.py +12 -10
  229. mindspore/nn/optim/momentum.py +7 -6
  230. mindspore/nn/optim/optimizer.py +3 -3
  231. mindspore/nn/optim/proximal_ada_grad.py +12 -10
  232. mindspore/nn/optim/rmsprop.py +13 -12
  233. mindspore/nn/optim/rprop.py +11 -9
  234. mindspore/nn/optim/sgd.py +9 -6
  235. mindspore/nn/optim/tft_wrapper.py +5 -2
  236. mindspore/nn/optim/thor.py +2 -1
  237. mindspore/nn/probability/bijector/bijector.py +17 -11
  238. mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
  239. mindspore/nn/probability/bijector/invert.py +2 -2
  240. mindspore/nn/probability/bijector/scalar_affine.py +3 -3
  241. mindspore/nn/probability/bijector/softplus.py +3 -2
  242. mindspore/nn/probability/distribution/beta.py +3 -3
  243. mindspore/nn/probability/distribution/categorical.py +1 -1
  244. mindspore/nn/probability/distribution/cauchy.py +4 -2
  245. mindspore/nn/probability/distribution/exponential.py +6 -7
  246. mindspore/nn/probability/distribution/gamma.py +2 -2
  247. mindspore/nn/probability/distribution/gumbel.py +2 -2
  248. mindspore/nn/probability/distribution/half_normal.py +5 -3
  249. mindspore/nn/probability/distribution/logistic.py +5 -3
  250. mindspore/nn/probability/distribution/poisson.py +1 -1
  251. mindspore/nn/probability/distribution/uniform.py +5 -3
  252. mindspore/nn/reinforcement/_tensors_queue.py +1 -1
  253. mindspore/nn/reinforcement/tensor_array.py +1 -1
  254. mindspore/nn/utils/init.py +13 -11
  255. mindspore/nn/wrap/__init__.py +6 -6
  256. mindspore/nn/wrap/cell_wrapper.py +181 -122
  257. mindspore/nn/wrap/grad_reducer.py +45 -36
  258. mindspore/nn/wrap/loss_scale.py +6 -7
  259. mindspore/numpy/array_creations.py +63 -65
  260. mindspore/numpy/array_ops.py +149 -144
  261. mindspore/numpy/logic_ops.py +41 -42
  262. mindspore/numpy/math_ops.py +365 -363
  263. mindspore/numpy/utils.py +17 -18
  264. mindspore/numpy/utils_const.py +5 -6
  265. mindspore/opencv_core452.dll +0 -0
  266. mindspore/opencv_imgcodecs452.dll +0 -0
  267. mindspore/opencv_imgproc452.dll +0 -0
  268. mindspore/ops/__init__.py +5 -3
  269. mindspore/ops/_grad_experimental/grad_comm_ops.py +112 -16
  270. mindspore/ops/_grad_experimental/grad_debug_ops.py +14 -2
  271. mindspore/ops/_grad_experimental/grad_inner_ops.py +9 -0
  272. mindspore/ops/_grad_experimental/grad_math_ops.py +2 -1
  273. mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
  274. mindspore/ops/_op_impl/cpu/__init__.py +1 -0
  275. mindspore/ops/_op_impl/cpu/raise_op.py +28 -0
  276. mindspore/ops/_register_for_op.py +0 -11
  277. mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
  278. mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -65
  279. mindspore/ops/_vmap/vmap_array_ops.py +27 -25
  280. mindspore/ops/_vmap/vmap_base.py +0 -2
  281. mindspore/ops/_vmap/vmap_grad_nn_ops.py +21 -14
  282. mindspore/ops/_vmap/vmap_math_ops.py +15 -16
  283. mindspore/ops/_vmap/vmap_nn_ops.py +29 -42
  284. mindspore/ops/auto_generate/__init__.py +4 -3
  285. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +236 -46
  286. mindspore/ops/auto_generate/gen_extend_func.py +764 -124
  287. mindspore/ops/auto_generate/gen_ops_def.py +4018 -2264
  288. mindspore/ops/auto_generate/gen_ops_prim.py +15463 -5037
  289. mindspore/ops/auto_generate/pyboost_inner_prim.py +221 -87
  290. mindspore/ops/composite/__init__.py +2 -1
  291. mindspore/ops/composite/base.py +20 -25
  292. mindspore/ops/composite/math_ops.py +6 -16
  293. mindspore/ops/composite/multitype_ops/__init__.py +5 -2
  294. mindspore/ops/composite/multitype_ops/_compile_utils.py +228 -30
  295. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
  296. mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
  297. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
  298. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
  299. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
  300. mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
  301. mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
  302. mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
  303. mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
  304. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
  305. mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
  306. mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
  307. mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
  308. mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
  309. mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
  310. mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
  311. mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
  312. mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
  313. mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
  314. mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
  315. mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
  316. mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
  317. mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
  318. mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
  319. mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
  320. mindspore/ops/composite/multitype_ops/pow_impl.py +2 -30
  321. mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
  322. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  323. mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
  324. mindspore/ops/function/__init__.py +40 -2
  325. mindspore/ops/function/_add_attr_func.py +58 -0
  326. mindspore/ops/function/array_func.py +2089 -2403
  327. mindspore/ops/function/clip_func.py +80 -23
  328. mindspore/ops/function/debug_func.py +57 -57
  329. mindspore/ops/function/grad/__init__.py +1 -0
  330. mindspore/ops/function/grad/grad_func.py +104 -71
  331. mindspore/ops/function/image_func.py +2 -2
  332. mindspore/ops/function/linalg_func.py +47 -78
  333. mindspore/ops/function/math_func.py +4501 -3802
  334. mindspore/ops/function/nn_func.py +1726 -620
  335. mindspore/ops/function/other_func.py +159 -1
  336. mindspore/ops/function/parameter_func.py +18 -84
  337. mindspore/ops/function/random_func.py +440 -387
  338. mindspore/ops/function/reshard_func.py +4 -70
  339. mindspore/ops/function/sparse_func.py +3 -3
  340. mindspore/ops/function/sparse_unary_func.py +6 -6
  341. mindspore/ops/function/spectral_func.py +25 -58
  342. mindspore/ops/function/vmap_func.py +24 -17
  343. mindspore/ops/functional.py +22 -7
  344. mindspore/ops/functional_overload.py +1440 -0
  345. mindspore/ops/op_info_register.py +32 -244
  346. mindspore/ops/operations/__init__.py +13 -7
  347. mindspore/ops/operations/_custom_ops_utils.py +247 -0
  348. mindspore/ops/operations/_embedding_cache_ops.py +4 -4
  349. mindspore/ops/operations/_grad_ops.py +2 -43
  350. mindspore/ops/operations/_infer_ops.py +2 -1
  351. mindspore/ops/operations/_inner_ops.py +43 -84
  352. mindspore/ops/operations/_ms_kernel.py +4 -10
  353. mindspore/ops/operations/_rl_inner_ops.py +1 -1
  354. mindspore/ops/operations/_scalar_ops.py +3 -2
  355. mindspore/ops/operations/_sequence_ops.py +1 -1
  356. mindspore/ops/operations/_tensor_array.py +1 -1
  357. mindspore/ops/operations/array_ops.py +81 -324
  358. mindspore/ops/operations/comm_ops.py +154 -108
  359. mindspore/ops/operations/custom_ops.py +232 -78
  360. mindspore/ops/operations/debug_ops.py +153 -59
  361. mindspore/ops/operations/inner_ops.py +7 -5
  362. mindspore/ops/operations/linalg_ops.py +1 -57
  363. mindspore/ops/operations/manually_defined/_inner.py +1 -1
  364. mindspore/ops/operations/manually_defined/ops_def.py +928 -180
  365. mindspore/ops/operations/math_ops.py +32 -234
  366. mindspore/ops/operations/nn_ops.py +210 -498
  367. mindspore/ops/operations/other_ops.py +62 -9
  368. mindspore/ops/operations/random_ops.py +13 -7
  369. mindspore/ops/operations/reshard_ops.py +1 -1
  370. mindspore/ops/operations/sparse_ops.py +2 -2
  371. mindspore/ops/primitive.py +66 -53
  372. mindspore/ops/tensor_method.py +1888 -0
  373. mindspore/ops_generate/__init__.py +0 -5
  374. mindspore/ops_generate/aclnn/__init__.py +0 -0
  375. mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +135 -0
  376. mindspore/ops_generate/aclnn/gen_aclnn_implement.py +257 -0
  377. mindspore/ops_generate/api/__init__.py +0 -0
  378. mindspore/ops_generate/api/add_tensor_docs_generator.py +56 -0
  379. mindspore/ops_generate/api/cpp_create_prim_instance_helper_generator.py +105 -0
  380. mindspore/ops_generate/api/functional_map_cpp_generator.py +504 -0
  381. mindspore/ops_generate/api/functional_overload_py_generator.py +112 -0
  382. mindspore/ops_generate/api/functions_cc_generator.py +237 -0
  383. mindspore/ops_generate/api/gen_api.py +103 -0
  384. mindspore/ops_generate/api/op_api_proto.py +235 -0
  385. mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +461 -0
  386. mindspore/ops_generate/common/__init__.py +0 -0
  387. mindspore/ops_generate/common/base_generator.py +11 -0
  388. mindspore/ops_generate/common/gen_constants.py +91 -0
  389. mindspore/ops_generate/common/gen_utils.py +348 -0
  390. mindspore/ops_generate/common/op_proto.py +473 -0
  391. mindspore/ops_generate/common/template.py +523 -0
  392. mindspore/ops_generate/gen_ops.py +22 -1069
  393. mindspore/ops_generate/op_def/__init__.py +0 -0
  394. mindspore/ops_generate/op_def/gen_op_def.py +90 -0
  395. mindspore/ops_generate/op_def/lite_ops_cpp_generator.py +191 -0
  396. mindspore/ops_generate/op_def/ops_def_cc_generator.py +299 -0
  397. mindspore/ops_generate/op_def/ops_def_h_generator.py +74 -0
  398. mindspore/ops_generate/op_def/ops_name_h_generator.py +83 -0
  399. mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
  400. mindspore/ops_generate/op_def_py/__init__.py +0 -0
  401. mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
  402. mindspore/ops_generate/op_def_py/op_def_py_generator.py +132 -0
  403. mindspore/ops_generate/op_def_py/op_prim_py_generator.py +489 -0
  404. mindspore/ops_generate/pyboost/__init__.py +0 -0
  405. mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +139 -0
  406. mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +93 -0
  407. mindspore/ops_generate/pyboost/gen_pyboost_func.py +175 -0
  408. mindspore/ops_generate/pyboost/op_template_parser.py +517 -0
  409. mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +407 -0
  410. mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +100 -0
  411. mindspore/ops_generate/pyboost/pyboost_functions_py_generator.py +148 -0
  412. mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +155 -0
  413. mindspore/ops_generate/pyboost/pyboost_inner_prim_generator.py +132 -0
  414. mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +272 -0
  415. mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +938 -0
  416. mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +357 -0
  417. mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +179 -36
  418. mindspore/ops_generate/resources/__init__.py +0 -0
  419. mindspore/ops_generate/resources/resource_list.py +30 -0
  420. mindspore/ops_generate/resources/resource_loader.py +36 -0
  421. mindspore/ops_generate/resources/resource_manager.py +64 -0
  422. mindspore/ops_generate/resources/yaml_loader.py +88 -0
  423. mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
  424. mindspore/parallel/__init__.py +7 -3
  425. mindspore/parallel/_auto_parallel_context.py +152 -34
  426. mindspore/parallel/_cell_wrapper.py +130 -15
  427. mindspore/parallel/_parallel_serialization.py +107 -5
  428. mindspore/parallel/_ps_context.py +1 -1
  429. mindspore/parallel/_recovery_context.py +7 -2
  430. mindspore/parallel/_tensor.py +142 -18
  431. mindspore/parallel/_utils.py +199 -23
  432. mindspore/parallel/algo_parameter_config.py +4 -4
  433. mindspore/parallel/auto_parallel.py +732 -0
  434. mindspore/parallel/checkpoint_convert.py +159 -0
  435. mindspore/parallel/checkpoint_transform.py +698 -35
  436. mindspore/parallel/cluster/process_entity/_api.py +276 -50
  437. mindspore/parallel/cluster/process_entity/_utils.py +41 -6
  438. mindspore/parallel/cluster/run.py +21 -4
  439. mindspore/parallel/function/__init__.py +24 -0
  440. mindspore/parallel/function/reshard_func.py +259 -0
  441. mindspore/parallel/nn/__init__.py +25 -0
  442. mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
  443. mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
  444. mindspore/parallel/parameter_broadcast.py +25 -14
  445. mindspore/parallel/shard.py +137 -58
  446. mindspore/parallel/transform_safetensors.py +363 -305
  447. mindspore/pgodb140.dll +0 -0
  448. mindspore/pgort140.dll +0 -0
  449. mindspore/profiler/__init__.py +22 -5
  450. mindspore/profiler/analysis/__init__.py +0 -0
  451. mindspore/profiler/analysis/parser/__init__.py +0 -0
  452. mindspore/profiler/analysis/parser/ascend_cann_parser.py +170 -0
  453. mindspore/profiler/analysis/parser/base_parser.py +158 -0
  454. mindspore/profiler/analysis/parser/framework_cann_relation_parser.py +45 -0
  455. mindspore/profiler/analysis/parser/ms_framework_parser.py +142 -0
  456. mindspore/profiler/analysis/parser/ms_minddata_parser.py +145 -0
  457. mindspore/profiler/analysis/parser/timeline_assembly_factory/__init__.py +0 -0
  458. mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +264 -0
  459. mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +40 -0
  460. mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +106 -0
  461. mindspore/profiler/analysis/parser/timeline_creator/__init__.py +0 -0
  462. mindspore/profiler/analysis/parser/timeline_creator/base_timeline_creator.py +44 -0
  463. mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +90 -0
  464. mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +76 -0
  465. mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +103 -0
  466. mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +134 -0
  467. mindspore/profiler/analysis/parser/timeline_event/__init__.py +0 -0
  468. mindspore/profiler/analysis/parser/timeline_event/base_event.py +233 -0
  469. mindspore/profiler/analysis/parser/timeline_event/cpu_op_event.py +47 -0
  470. mindspore/profiler/analysis/parser/timeline_event/flow_event.py +36 -0
  471. mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +415 -0
  472. mindspore/profiler/analysis/parser/timeline_event/msprof_event.py +73 -0
  473. mindspore/profiler/analysis/parser/timeline_event/scope_layer_event.py +53 -0
  474. mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +146 -0
  475. mindspore/profiler/analysis/task_manager.py +131 -0
  476. mindspore/profiler/analysis/time_converter.py +84 -0
  477. mindspore/profiler/analysis/viewer/__init__.py +0 -0
  478. mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +372 -0
  479. mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +87 -0
  480. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +250 -0
  481. mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +320 -0
  482. mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +327 -0
  483. mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +376 -0
  484. mindspore/profiler/analysis/viewer/ascend_timeline_viewer.py +58 -0
  485. mindspore/profiler/analysis/viewer/base_viewer.py +26 -0
  486. mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +96 -0
  487. mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +581 -0
  488. mindspore/profiler/analysis/work_flow.py +73 -0
  489. mindspore/profiler/common/ascend_msprof_exporter.py +139 -0
  490. mindspore/profiler/common/command_executor.py +90 -0
  491. mindspore/profiler/common/constant.py +186 -3
  492. mindspore/profiler/common/file_manager.py +208 -0
  493. mindspore/profiler/common/log.py +130 -0
  494. mindspore/profiler/common/msprof_cmd_tool.py +221 -0
  495. mindspore/profiler/common/path_manager.py +395 -0
  496. mindspore/profiler/common/process_bar.py +168 -0
  497. mindspore/profiler/common/process_pool.py +9 -3
  498. mindspore/profiler/common/profiler_context.py +500 -0
  499. mindspore/profiler/common/profiler_info.py +304 -0
  500. mindspore/profiler/common/profiler_meta_data.py +74 -0
  501. mindspore/profiler/common/profiler_output_path.py +284 -0
  502. mindspore/profiler/common/profiler_parameters.py +251 -0
  503. mindspore/profiler/common/profiler_path_manager.py +179 -0
  504. mindspore/profiler/common/record_function.py +76 -0
  505. mindspore/profiler/common/tlv_decoder.py +76 -0
  506. mindspore/profiler/common/util.py +75 -2
  507. mindspore/profiler/dynamic_profiler.py +341 -75
  508. mindspore/profiler/envprofiler.py +163 -0
  509. mindspore/profiler/experimental_config.py +197 -0
  510. mindspore/profiler/mstx.py +242 -0
  511. mindspore/profiler/platform/__init__.py +21 -0
  512. mindspore/profiler/platform/base_profiler.py +40 -0
  513. mindspore/profiler/platform/cpu_profiler.py +124 -0
  514. mindspore/profiler/platform/gpu_profiler.py +74 -0
  515. mindspore/profiler/platform/npu_profiler.py +335 -0
  516. mindspore/profiler/profiler.py +1073 -90
  517. mindspore/profiler/profiler_action_controller.py +187 -0
  518. mindspore/profiler/profiler_interface.py +118 -0
  519. mindspore/profiler/schedule.py +243 -0
  520. mindspore/rewrite/api/node.py +15 -13
  521. mindspore/rewrite/api/symbol_tree.py +2 -3
  522. mindspore/run_check/_check_version.py +27 -20
  523. mindspore/run_check/run_check.py +1 -1
  524. mindspore/runtime/__init__.py +37 -0
  525. mindspore/runtime/device.py +27 -0
  526. mindspore/runtime/event.py +209 -0
  527. mindspore/runtime/executor.py +177 -0
  528. mindspore/runtime/memory.py +409 -0
  529. mindspore/runtime/stream.py +460 -0
  530. mindspore/runtime/thread_bind_core.py +401 -0
  531. mindspore/safeguard/rewrite_obfuscation.py +12 -9
  532. mindspore/swresample-4.dll +0 -0
  533. mindspore/swscale-6.dll +0 -0
  534. mindspore/tbbmalloc.dll +0 -0
  535. mindspore/tinyxml2.dll +0 -0
  536. mindspore/train/__init__.py +8 -8
  537. mindspore/train/_utils.py +88 -25
  538. mindspore/train/amp.py +9 -5
  539. mindspore/train/callback/__init__.py +2 -2
  540. mindspore/train/callback/_callback.py +2 -16
  541. mindspore/train/callback/_checkpoint.py +53 -55
  542. mindspore/train/callback/_cluster_monitor.py +14 -18
  543. mindspore/train/callback/_early_stop.py +1 -1
  544. mindspore/train/callback/_flops_collector.py +103 -68
  545. mindspore/train/callback/_history.py +8 -5
  546. mindspore/train/callback/_lambda_callback.py +2 -2
  547. mindspore/train/callback/_landscape.py +0 -3
  548. mindspore/train/callback/_loss_monitor.py +2 -1
  549. mindspore/train/callback/_on_request_exit.py +6 -5
  550. mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
  551. mindspore/train/callback/_summary_collector.py +52 -19
  552. mindspore/train/callback/_time_monitor.py +2 -1
  553. mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +204 -107
  554. mindspore/train/data_sink.py +25 -2
  555. mindspore/train/dataset_helper.py +15 -16
  556. mindspore/train/loss_scale_manager.py +8 -7
  557. mindspore/train/metrics/accuracy.py +3 -3
  558. mindspore/train/metrics/confusion_matrix.py +9 -9
  559. mindspore/train/metrics/error.py +3 -3
  560. mindspore/train/metrics/hausdorff_distance.py +4 -4
  561. mindspore/train/metrics/mean_surface_distance.py +3 -3
  562. mindspore/train/metrics/metric.py +0 -12
  563. mindspore/train/metrics/occlusion_sensitivity.py +4 -2
  564. mindspore/train/metrics/precision.py +11 -10
  565. mindspore/train/metrics/recall.py +9 -9
  566. mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
  567. mindspore/train/mind_ir_pb2.py +174 -46
  568. mindspore/train/model.py +184 -113
  569. mindspore/train/serialization.py +622 -978
  570. mindspore/train/summary/_summary_adapter.py +2 -2
  571. mindspore/train/summary/summary_record.py +2 -3
  572. mindspore/train/train_thor/model_thor.py +1 -1
  573. mindspore/turbojpeg.dll +0 -0
  574. mindspore/utils/__init__.py +6 -3
  575. mindspore/utils/dryrun.py +140 -0
  576. mindspore/utils/hooks.py +81 -0
  577. mindspore/utils/runtime_execution_order_check.py +550 -0
  578. mindspore/utils/utils.py +138 -4
  579. mindspore/vcmeta.dll +0 -0
  580. mindspore/vcruntime140.dll +0 -0
  581. mindspore/vcruntime140_1.dll +0 -0
  582. mindspore/version.py +1 -1
  583. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/METADATA +3 -3
  584. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/RECORD +587 -418
  585. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/entry_points.txt +1 -1
  586. mindspore/_install_custom.py +0 -43
  587. mindspore/common/_register_for_adapter.py +0 -74
  588. mindspore/common/_tensor_overload.py +0 -139
  589. mindspore/mindspore_np_dtype.dll +0 -0
  590. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
  591. mindspore/ops/auto_generate/gen_arg_handler.py +0 -197
  592. mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
  593. mindspore/ops_generate/gen_aclnn_implement.py +0 -263
  594. mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
  595. mindspore/ops_generate/gen_pyboost_func.py +0 -1052
  596. mindspore/ops_generate/gen_utils.py +0 -209
  597. mindspore/ops_generate/op_proto.py +0 -145
  598. mindspore/ops_generate/template.py +0 -261
  599. mindspore/profiler/envprofiling.py +0 -254
  600. mindspore/profiler/profiling.py +0 -1926
  601. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/WHEEL +0 -0
  602. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,473 @@
1
+ # Copyright 2024 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+
16
+ """Op Proto module for defining operator prototypes and their arguments."""
17
+
18
+ import os
19
+ from typing import Dict
20
+
21
+ from resources.resource_loader import ResourceLoader
22
+ from resources.resource_list import ResourceType
23
+
24
+ from . import gen_constants as K
25
+ from .gen_utils import safe_load_yaml_from_dir
26
+
27
+
28
+ class OpArg:
29
+ """
30
+ Represents an argument of an operator.
31
+
32
+ Attributes:
33
+ arg_name (str): The name of the argument.
34
+ arg_dtype (str): The data type of the argument.
35
+ type_cast (list): A list of type casts applicable to the argument.
36
+ is_type_id (bool): Indicates if the argument is a type identifier.
37
+ as_init_arg (bool): Indicates if the argument is an initialization argument.
38
+ default: The default value of the argument.
39
+ inplace (str): The name of the inplace tensor if applicable.
40
+ is_prim_init (bool): Indicates if the argument is a primitive initialization argument.
41
+ arg_handler (str): A handler for the argument, if applicable.
42
+ """
43
+
44
+ def __init__(self, arg_name, arg_dtype, type_cast, is_type_id=False, as_init_arg=False, default=-1, inplace='',
45
+ is_prim_init=False, arg_handler=''):
46
+ self.arg_name = arg_name
47
+ self.arg_dtype = arg_dtype
48
+ self.type_cast = type_cast
49
+ self.is_type_id = is_type_id
50
+ self.as_init_arg = as_init_arg
51
+ self.default = default
52
+ self.inplace = inplace
53
+ self.is_prim_init = is_prim_init
54
+ self.arg_handler = arg_handler
55
+
56
+
57
+ class OpArgsSignature:
58
+ """
59
+ Represents the signature of operator arguments.
60
+
61
+ Attributes:
62
+ rw_write (list): Arguments that are written to.
63
+ rw_read (list): Arguments that are read from.
64
+ rw_ref (list): Arguments that are passed by reference.
65
+ dtype_group (list): Grouping of data types for the arguments.
66
+ """
67
+
68
+ def __init__(self, rw_write=None, rw_read=None, rw_ref=None, dtype_group=None):
69
+ self.rw_write = rw_write
70
+ self.rw_read = rw_read
71
+ self.rw_ref = rw_ref
72
+ self.dtype_group = dtype_group
73
+
74
+
75
+ class OpFunction:
76
+ """
77
+ Represents the function associated with an operator.
78
+
79
+ Attributes:
80
+ disable (bool): Indicates if the function is disabled.
81
+ name (str): The name of the function.
82
+ """
83
+
84
+ def __init__(self, disable=False, name=''):
85
+ self.disable = disable
86
+ self.name = name
87
+
88
+
89
+ class OpClass:
90
+ """
91
+ Represents a class associated with an operator.
92
+
93
+ Attributes:
94
+ disable (bool): Indicates if the class is disabled.
95
+ name (str): The name of the class.
96
+ """
97
+
98
+ def __init__(self, disable=False, name=''):
99
+ self.disable = disable
100
+ self.name = name
101
+
102
+
103
+ class OpDispatch:
104
+ """
105
+ Represents the dispatch information for an operator.
106
+
107
+ Attributes:
108
+ enable (bool): Indicates if the dispatch is enabled.
109
+ is_comm_op (bool): Indicates if the dispatch is communication operator or not.
110
+ ascend (str): The dispatch type for the Ascend device.
111
+ cpu (str): The dispatch type for the CPU.
112
+ gpu (str): The dispatch type for the GPU.
113
+ """
114
+
115
+ def __init__(self, enable=False, is_comm_op=False, ascend='default', cpu='default', gpu='default'):
116
+ self.enable = enable
117
+ self.is_comm_op = is_comm_op
118
+ self.ascend = ascend
119
+ self.cpu = cpu
120
+ self.gpu = gpu
121
+
122
+
123
+ class OpProto:
124
+ """
125
+ Defines a prototype for an operator in MindSpore.
126
+
127
+ This class is used to parse the operator definition from a YAML file and to generate
128
+ the necessary primitive and PyBoost functions.
129
+
130
+ Attributes:
131
+ op_name (str): The name of the operator.
132
+ op_args (list): A list of arguments for the operator.
133
+ op_function (OpFunction): The function associated with the operator.
134
+ op_class (OpClass): The class associated with the operator.
135
+ op_dispatch (OpDispatch): The dispatch information for the operator.
136
+ op_args_signature (OpArgsSignature): The signature of the operator's arguments.
137
+ op_returns (list): A list of return values for the operator.
138
+ op_view (bool): Indicates if the operator is a view operator.
139
+ """
140
+
141
+ def __init__(self,
142
+ op_name,
143
+ op_args,
144
+ op_function,
145
+ op_class,
146
+ op_dispatch,
147
+ op_args_signature,
148
+ op_returns,
149
+ op_view=False,
150
+ op_graph_view=False,
151
+ op_inplace=False,
152
+ op_labels=None,
153
+ op_deprecated=None,
154
+ bprop_expander=True):
155
+ self.op_name = op_name
156
+ self.op_args = op_args
157
+ self.op_function = op_function
158
+ self.op_class = op_class
159
+ self.op_dispatch = op_dispatch
160
+ self.op_args_signature = op_args_signature
161
+ self.op_returns = op_returns
162
+ self.op_view = op_view
163
+ self.op_graph_view = op_graph_view
164
+ self.op_inplace = op_inplace
165
+ self.op_labels = op_labels
166
+ self.op_deprecated = op_deprecated
167
+ self.bprop_expander = bprop_expander
168
+
169
+ @staticmethod
170
+ def load_from_yaml(op_name, op_data):
171
+ """
172
+ Loads an operator prototype from YAML data.
173
+
174
+ Args:
175
+ op_name (str): The name of the operation.
176
+ op_data (dict): A dictionary containing the operation data.
177
+
178
+ Returns:
179
+ OpProto: An instance of OpProto representing the operator.
180
+ """
181
+ # check op keys
182
+ check_validation(op_name, op_data)
183
+ # get op args
184
+ op_args = get_op_args(op_name, op_data)
185
+ # get op return args
186
+ op_returns = get_op_returns(op_name, op_data)
187
+ # get op args signature
188
+ op_args_signature = get_op_args_signature(op_name, op_data)
189
+ # get op class
190
+ op_class = get_op_class(op_name, op_data)
191
+ # get op function
192
+ op_function = get_op_function(op_name, op_data)
193
+ # get op dispatch
194
+ op_dispatch = get_op_dispatch(op_name, op_data)
195
+ # get op view
196
+ op_view = op_data.get('view', False)
197
+ if not isinstance(op_view, bool):
198
+ raise TypeError(
199
+ f'The view value should be bool, but get {type(op_view)}, op name is {op_name}.')
200
+ # get op graph view
201
+ op_graph_view = op_data.get('graph_view', False)
202
+ if not isinstance(op_graph_view, bool):
203
+ raise TypeError(
204
+ f'The graph view value should be bool, but get {type(op_graph_view)}, op name is {op_name}.')
205
+ op_inplace = is_inplace_op(op_returns)
206
+ # get op labels
207
+ op_labels = op_data.get('labels', None)
208
+ # get op deprecated
209
+ op_deprecated = op_data.get('deprecated', None)
210
+ bprop_expander = op_data.get('bprop_expander', True)
211
+ op_proto = OpProto(op_name=op_name, op_args=op_args, op_returns=op_returns, op_function=op_function,
212
+ op_class=op_class, op_dispatch=op_dispatch, op_args_signature=op_args_signature,
213
+ op_view=op_view, op_graph_view=op_graph_view, op_inplace=op_inplace, op_labels=op_labels,
214
+ op_deprecated=op_deprecated, bprop_expander=bprop_expander)
215
+ return op_proto
216
+
217
+
218
+ class OpProtoLoader(ResourceLoader):
219
+ """
220
+ OpProtoLoader is a class for loading operator prototypes from YAML data.
221
+ """
222
+ def __init__(self):
223
+ ops_yaml_path = os.path.join(K.WORK_DIR, K.MS_OP_DEF_YAML_PATH)
224
+ infer_ops_yaml_path = os.path.join(ops_yaml_path, 'infer')
225
+ self.yaml_paths = [ops_yaml_path, infer_ops_yaml_path]
226
+ self.type = ResourceType.OP_PROTO
227
+ self.is_deprecated = False
228
+ self.func_op = False
229
+
230
+ def load(self) -> Dict[ResourceType, object]:
231
+ """
232
+ Load OpProto.
233
+
234
+ Returns:
235
+ Dict[ResourceType, object]: The resource type and the OpProto.
236
+ """
237
+ yaml_dict = {}
238
+ for yaml_path in self.yaml_paths:
239
+ yaml_dict.update(safe_load_yaml_from_dir(yaml_path))
240
+ op_protos = []
241
+ for op_name, op_data in yaml_dict.items():
242
+ op_proto = OpProto.load_from_yaml(op_name, op_data)
243
+ if self.is_deprecated:
244
+ op_proto.op_name = 'deprecated_' + op_name
245
+ op_proto.func_op = self.func_op
246
+ op_protos.append(op_proto)
247
+ return {self.type: op_protos}
248
+
249
+
250
+ class DeprecatedOpProtoLoader(OpProtoLoader):
251
+ """
252
+ DeprecatedOpProtoLoader is a class for loading deprecated operator prototypes from YAML data.
253
+ """
254
+ def __init__(self):
255
+ super().__init__()
256
+ self.yaml_paths = [os.path.join(K.WORK_DIR, K.MS_OP_DEPRECATED_DEF_YAML_PATH)]
257
+ self.type = ResourceType.DEPRECATED_OP_PROTO
258
+ self.is_deprecated = True
259
+ self.func_op = True
260
+
261
+
262
+ class FuncOpProtoLoader(OpProtoLoader):
263
+ """
264
+ FuncOpProtoLoader is a class for loading func_op operator prototypes from YAML data.
265
+ """
266
+ def __init__(self):
267
+ super().__init__()
268
+ self.yaml_paths = [os.path.join(K.WORK_DIR, K.MS_OP_DEF_FUNC_OP_YAML_PATH)]
269
+ self.type = ResourceType.FUNC_OP_PROTO
270
+ self.is_deprecated = False
271
+ self.func_op = True
272
+
273
+
274
+ def get_op_args_signature(op_name, op_data):
275
+ """
276
+ Retrieves the argument signature from the operation data.
277
+
278
+ Args:
279
+ op_data (dict): A dictionary containing the operation data.
280
+
281
+ Returns:
282
+ OpArgsSignature: An instance of OpArgsSignature containing the argument signature.
283
+ """
284
+ op_args_signature = op_data.get('args_signature', None)
285
+ if op_args_signature is not None:
286
+ args_signature_keys = op_args_signature.keys()
287
+ check_op_yaml_keys(op_name, set(args_signature_keys), K.ARG_SIGNATURE_KEYS)
288
+ rw_write = op_args_signature.get('rw_write', None)
289
+ rw_read = op_args_signature.get('rw_read', None)
290
+ rw_ref = op_args_signature.get('rw_ref', None)
291
+ dtype_group = op_args_signature.get('dtype_group', None)
292
+ return OpArgsSignature(rw_write, rw_read, rw_ref, dtype_group)
293
+ return None
294
+
295
+
296
+ def check_validation(op_name: str, op_data: dict):
297
+ """
298
+ Validates the operator data to ensure it contains necessary keys.
299
+
300
+ Args:
301
+ op_data (dict): The operator data to validate.
302
+
303
+ Raises:
304
+ TypeError: If the required keys 'args' or 'returns' are missing.
305
+ """
306
+ # check keys
307
+ check_op_yaml_keys(op_name, set(op_data.keys()), K.OP_KEYS)
308
+
309
+ # Those keys must in yaml
310
+ if 'args' not in op_data.keys():
311
+ raise TypeError(f"Op define miss key 'args', op name is {op_name}")
312
+ if 'returns' not in op_data.keys():
313
+ raise TypeError(f"Op define miss key 'returns', op name is {op_name}")
314
+
315
+
316
+ def get_op_args(op_name, op_data):
317
+ """
318
+ Retrieves the arguments for the operator from the operation data.
319
+
320
+ Args:
321
+ op_data (dict): A dictionary containing the operation data.
322
+
323
+ Returns:
324
+ list: A list of OpArg instances representing the arguments of the operator.
325
+ """
326
+ args_dict = op_data.get('args')
327
+ op_args = []
328
+ for arg_name in args_dict.keys():
329
+ arg_keys = args_dict[arg_name].keys()
330
+ check_op_yaml_keys(op_name, set(arg_keys), K.ARG_KEYS)
331
+ arg_dtype = args_dict[arg_name]['dtype']
332
+ if arg_dtype == 'TypeId':
333
+ arg_dtype = 'int'
334
+ default = None
335
+ as_init_arg = False
336
+ is_type_id = False
337
+ prim_init = False
338
+ type_cast = []
339
+ if 'default' in args_dict[arg_name]:
340
+ default = args_dict[arg_name]['default']
341
+ as_init_arg = True
342
+ # 当op_args任意一个参数有prim_init,该op就要在pyboost_inner_prim.py生成
343
+ if 'prim_init' in args_dict[arg_name] and args_dict[arg_name]['prim_init'] is True:
344
+ prim_init = True
345
+ if 'type_cast' in args_dict[arg_name]:
346
+ type_cast = [cast_type.strip() for cast_type in args_dict[arg_name]['type_cast'].split(',')]
347
+ arg_handler_key = 'arg_handler'
348
+ arg_handler = args_dict[arg_name].get(arg_handler_key, '')
349
+ if arg_handler_key in args_dict[arg_name] and args_dict[arg_name][arg_handler_key] == 'dtype_to_type_id':
350
+ is_type_id = True
351
+ op_arg = OpArg(arg_name, arg_dtype, type_cast, is_type_id, as_init_arg, default,
352
+ is_prim_init=prim_init, arg_handler=arg_handler)
353
+ op_args.append(op_arg)
354
+ return op_args
355
+
356
+
357
+ def get_op_returns(op_name, op_data):
358
+ """
359
+ Retrieves the return values for the operator from the operation data.
360
+
361
+ Args:
362
+ op_data (dict): A dictionary containing the operation data.
363
+
364
+ Returns:
365
+ list: A list of OpArg instances representing the return values of the operator.
366
+ """
367
+ op_return_args = []
368
+ return_dict = op_data['returns']
369
+ for return_name in return_dict.keys():
370
+ return_keys = return_dict[return_name].keys()
371
+ check_op_yaml_keys(op_name, set(return_keys), K.RETURN_KEYS)
372
+ inplace = ''
373
+ if 'inplace' in return_dict[return_name]:
374
+ inplace = return_dict[return_name]['inplace']
375
+ if 'dtype' not in return_dict[return_name]:
376
+ raise TypeError("op return args need key 'dtype'")
377
+ dtype = return_dict[return_name]['dtype']
378
+ arg = OpArg(return_name, dtype, type_cast=[], inplace=inplace)
379
+ op_return_args.append(arg)
380
+ return op_return_args
381
+
382
+
383
+ def get_op_dispatch(op_name, op_data):
384
+ """
385
+ Retrieves the dispatch information for the operator from the operation data.
386
+
387
+ Args:
388
+ op_data (dict): A dictionary containing the operation data.
389
+
390
+ Returns:
391
+ OpDispatch: An instance of OpDispatch containing the dispatch information.
392
+ """
393
+ op_dispatch = op_data.get('dispatch', {})
394
+ dispatch_keys = op_dispatch.keys()
395
+ check_op_yaml_keys(op_name, set(dispatch_keys), K.DISPATCH_KEYS)
396
+ if not op_dispatch:
397
+ return None
398
+ enable = op_dispatch.get('enable', False)
399
+ if not isinstance(enable, bool):
400
+ raise TypeError(
401
+ f'The dispatch enable value should be bool, but get {type(enable)}, op name is {op_name}.')
402
+ is_comm_op = op_dispatch.get('is_comm_op', False)
403
+ ascend = op_dispatch.get('Ascend', 'default')
404
+ cpu = op_dispatch.get('CPU', 'default')
405
+ gpu = op_dispatch.get('GPU', 'default')
406
+ return OpDispatch(enable, is_comm_op, ascend, cpu, gpu)
407
+
408
+
409
+ def get_op_class(op_name, op_data) -> OpClass:
410
+ """
411
+ Retrieves the class information for the operator from the operation data.
412
+
413
+ Args:
414
+ op_name (str): The name of the operation.
415
+ op_data (dict): A dictionary containing the operation data.
416
+
417
+ Returns:
418
+ OpClass: An instance of OpClass containing the class information for the operator.
419
+ """
420
+ op_class = op_data.get('class', {})
421
+ class_keys = op_class.keys()
422
+ check_op_yaml_keys(op_name, set(class_keys), K.CLASS_KEYS)
423
+ is_disable = op_class.get('disable', False)
424
+ if not isinstance(is_disable, bool):
425
+ raise TypeError(
426
+ f'The class disable value should be bool, but get {type(is_disable)}, op name is {op_name}.')
427
+ class_name = op_class.get('name', convert_python_func_name_to_c(op_name))
428
+ return OpClass(disable=is_disable, name=class_name)
429
+
430
+
431
+ def get_op_function(op_name, op_data) -> OpFunction:
432
+ """
433
+ Retrieves the function information for the operator from the operation data.
434
+
435
+ Args:
436
+ op_name (str): default operation function name.
437
+ op_data (dict): A dictionary containing the operation data.
438
+
439
+ Returns:
440
+ OpFunction: An instance of OpFunction containing the function information for the operator.
441
+ """
442
+ op_function = op_data.get('function', {})
443
+ function_keys = op_function.keys()
444
+ check_op_yaml_keys(op_name, set(function_keys), K.FUNCTION_KEYS)
445
+ is_disable = op_function.get('disable', False)
446
+ if not isinstance(is_disable, bool):
447
+ raise TypeError(
448
+ f'The function disable value should be bool, but get {type(is_disable)}, op name is {op_name}.')
449
+ function_name = op_function.get('name', op_name)
450
+ return OpFunction(disable=is_disable, name=function_name)
451
+
452
+
453
+ def convert_python_func_name_to_c(func_name: str) -> str:
454
+ return ''.join(word.capitalize() for word in func_name.split('_'))
455
+
456
+
457
+ def check_op_yaml_keys(op_name: str, input_keys: set, compare_keys: set):
458
+ diff_keys = input_keys - compare_keys
459
+ if diff_keys:
460
+ raise TypeError(
461
+ f'The definition of keys in yaml has faults, op name is {op_name}, wrong keys are {diff_keys}.')
462
+
463
+
464
+ def is_inplace_op(args):
465
+ """
466
+ is inplace op
467
+ :param args:
468
+ :return: bool
469
+ """
470
+ for arg in args:
471
+ if arg.inplace:
472
+ return True
473
+ return False