mindspore 2.4.10__cp39-cp39-win_amd64.whl → 2.6.0__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (579) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +13 -6
  3. mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
  4. mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
  5. mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
  6. mindspore/_check_jit_forbidden_api.py +3 -0
  7. mindspore/_checkparam.py +3 -38
  8. mindspore/_deprecated/__init__.py +17 -0
  9. mindspore/_deprecated/jit.py +198 -0
  10. mindspore/_extends/builtin_operations.py +1 -1
  11. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
  12. mindspore/_extends/parse/__init__.py +6 -7
  13. mindspore/_extends/parse/compile_config.py +83 -0
  14. mindspore/_extends/parse/deprecated/__init__.py +0 -0
  15. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +394 -0
  16. mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
  17. mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
  18. mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
  19. mindspore/_extends/parse/parser.py +47 -198
  20. mindspore/_extends/parse/resources.py +1 -5
  21. mindspore/_extends/parse/standard_method.py +229 -99
  22. mindspore/_extends/pijit/__init__.py +2 -2
  23. mindspore/_extends/pijit/pijit_func_white_list.py +17 -12
  24. mindspore/_extends/pijit/tensor_func_list.py +27 -0
  25. mindspore/_extends/utils.py +1 -1
  26. mindspore/amp.py +11 -5
  27. mindspore/avcodec-59.dll +0 -0
  28. mindspore/avdevice-59.dll +0 -0
  29. mindspore/avfilter-8.dll +0 -0
  30. mindspore/avformat-59.dll +0 -0
  31. mindspore/avutil-57.dll +0 -0
  32. mindspore/boost/__init__.py +2 -2
  33. mindspore/boost/base.py +3 -7
  34. mindspore/boost/boost_cell_wrapper.py +138 -43
  35. mindspore/common/__init__.py +6 -3
  36. mindspore/common/_grad_function.py +56 -0
  37. mindspore/common/_pijit_context.py +14 -5
  38. mindspore/common/_register_for_tensor.py +1 -2
  39. mindspore/common/_stub_tensor.py +30 -14
  40. mindspore/common/_tensor_cpp_method.py +17 -0
  41. mindspore/common/_tensor_docs.py +4760 -0
  42. mindspore/common/api.py +480 -372
  43. mindspore/common/auto_dynamic_shape.py +41 -44
  44. mindspore/common/dtype.py +39 -36
  45. mindspore/common/dump.py +9 -6
  46. mindspore/common/file_system.py +9 -1
  47. mindspore/common/generator.py +5 -0
  48. mindspore/common/hook_handle.py +6 -2
  49. mindspore/common/initializer.py +13 -10
  50. mindspore/common/jit_begin_end.py +94 -0
  51. mindspore/common/jit_config.py +6 -1
  52. mindspore/common/jit_context.py +76 -0
  53. mindspore/common/jit_trace.py +378 -0
  54. mindspore/common/lazy_inline.py +9 -3
  55. mindspore/common/mindir_util.py +10 -2
  56. mindspore/common/mutable.py +5 -4
  57. mindspore/common/parameter.py +135 -52
  58. mindspore/common/seed.py +2 -2
  59. mindspore/common/sparse_tensor.py +23 -17
  60. mindspore/common/tensor.py +975 -1981
  61. mindspore/communication/__init__.py +7 -5
  62. mindspore/communication/_comm_helper.py +52 -2
  63. mindspore/communication/comm_func.py +240 -181
  64. mindspore/communication/management.py +95 -26
  65. mindspore/context.py +324 -573
  66. mindspore/dataset/__init__.py +65 -37
  67. mindspore/dataset/audio/__init__.py +2 -8
  68. mindspore/dataset/audio/transforms.py +3 -17
  69. mindspore/dataset/callback/ds_callback.py +2 -1
  70. mindspore/dataset/core/config.py +87 -6
  71. mindspore/dataset/engine/cache_admin.py +3 -3
  72. mindspore/dataset/engine/cache_client.py +6 -5
  73. mindspore/dataset/engine/datasets.py +292 -267
  74. mindspore/dataset/engine/datasets_audio.py +22 -8
  75. mindspore/dataset/engine/datasets_standard_format.py +46 -27
  76. mindspore/dataset/engine/datasets_text.py +78 -48
  77. mindspore/dataset/engine/datasets_user_defined.py +183 -117
  78. mindspore/dataset/engine/datasets_vision.py +120 -44
  79. mindspore/dataset/engine/iterators.py +283 -63
  80. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
  81. mindspore/dataset/engine/obs/util.py +8 -0
  82. mindspore/dataset/engine/queue.py +40 -0
  83. mindspore/dataset/engine/samplers.py +289 -43
  84. mindspore/dataset/engine/serializer_deserializer.py +3 -2
  85. mindspore/dataset/engine/validators.py +53 -11
  86. mindspore/dataset/text/__init__.py +7 -6
  87. mindspore/dataset/text/transforms.py +6 -5
  88. mindspore/dataset/text/utils.py +3 -3
  89. mindspore/dataset/transforms/__init__.py +0 -9
  90. mindspore/dataset/transforms/py_transforms_util.py +17 -0
  91. mindspore/dataset/transforms/transforms.py +31 -14
  92. mindspore/dataset/utils/browse_dataset.py +1 -1
  93. mindspore/dataset/vision/__init__.py +2 -9
  94. mindspore/dataset/vision/transforms.py +202 -158
  95. mindspore/dataset/vision/utils.py +7 -5
  96. mindspore/dataset/vision/validators.py +1 -2
  97. mindspore/device_context/__init__.py +21 -0
  98. mindspore/device_context/ascend/__init__.py +25 -0
  99. mindspore/device_context/ascend/device.py +72 -0
  100. mindspore/device_context/ascend/op_debug.py +153 -0
  101. mindspore/device_context/ascend/op_precision.py +193 -0
  102. mindspore/device_context/ascend/op_tuning.py +123 -0
  103. mindspore/{ops_generate/gen_constants.py → device_context/cpu/__init__.py} +6 -17
  104. mindspore/device_context/cpu/device.py +62 -0
  105. mindspore/device_context/cpu/op_tuning.py +43 -0
  106. mindspore/device_context/gpu/__init__.py +21 -0
  107. mindspore/device_context/gpu/device.py +70 -0
  108. mindspore/device_context/gpu/op_precision.py +67 -0
  109. mindspore/device_context/gpu/op_tuning.py +175 -0
  110. mindspore/device_manager.py +170 -0
  111. mindspore/dnnl.dll +0 -0
  112. mindspore/experimental/es/embedding_service.py +35 -27
  113. mindspore/experimental/llm_boost/__init__.py +1 -0
  114. mindspore/experimental/llm_boost/ascend_native/__init__.py +22 -0
  115. mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +209 -0
  116. mindspore/experimental/llm_boost/ascend_native/llm_boost.py +52 -0
  117. mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
  118. mindspore/experimental/llm_boost/atb/llama_boost.py +6 -1
  119. mindspore/experimental/llm_boost/register.py +1 -0
  120. mindspore/experimental/map_parameter.py +4 -4
  121. mindspore/experimental/optim/adadelta.py +6 -6
  122. mindspore/experimental/optim/adagrad.py +4 -4
  123. mindspore/experimental/optim/adam.py +7 -0
  124. mindspore/experimental/optim/adamax.py +4 -4
  125. mindspore/experimental/optim/adamw.py +4 -0
  126. mindspore/experimental/optim/asgd.py +1 -1
  127. mindspore/experimental/optim/lr_scheduler.py +73 -46
  128. mindspore/experimental/optim/radam.py +34 -31
  129. mindspore/experimental/optim/rprop.py +1 -1
  130. mindspore/experimental/optim/sgd.py +1 -1
  131. mindspore/hal/contiguous_tensors_handle.py +6 -10
  132. mindspore/hal/device.py +55 -53
  133. mindspore/hal/event.py +52 -52
  134. mindspore/hal/memory.py +179 -120
  135. mindspore/hal/stream.py +150 -109
  136. mindspore/include/api/context.h +0 -1
  137. mindspore/include/dataset/constants.h +7 -4
  138. mindspore/include/dataset/execute.h +2 -2
  139. mindspore/jpeg62.dll +0 -0
  140. mindspore/log.py +50 -0
  141. mindspore/mindrecord/__init__.py +21 -8
  142. mindspore/mindrecord/config.py +17 -316
  143. mindspore/mindrecord/filereader.py +1 -9
  144. mindspore/mindrecord/filewriter.py +5 -15
  145. mindspore/mindrecord/mindpage.py +1 -9
  146. mindspore/mindspore_backend_common.dll +0 -0
  147. mindspore/mindspore_backend_manager.dll +0 -0
  148. mindspore/mindspore_common.dll +0 -0
  149. mindspore/mindspore_core.dll +0 -0
  150. mindspore/mindspore_dump.dll +0 -0
  151. mindspore/mindspore_frontend.dll +0 -0
  152. mindspore/mindspore_glog.dll +0 -0
  153. mindspore/mindspore_memory_pool.dll +0 -0
  154. mindspore/mindspore_ms_backend.dll +0 -0
  155. mindspore/mindspore_ops.dll +0 -0
  156. mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
  157. mindspore/mindspore_ops_kernel_common.dll +0 -0
  158. mindspore/mindspore_profiler.dll +0 -0
  159. mindspore/mindspore_pyboost.dll +0 -0
  160. mindspore/mindspore_pynative.dll +0 -0
  161. mindspore/mindspore_res_manager.dll +0 -0
  162. mindspore/mindspore_runtime_pipeline.dll +0 -0
  163. mindspore/mint/__init__.py +798 -761
  164. mindspore/mint/distributed/__init__.py +70 -4
  165. mindspore/mint/distributed/distributed.py +2679 -44
  166. mindspore/mint/linalg/__init__.py +8 -0
  167. mindspore/mint/nn/__init__.py +743 -22
  168. mindspore/mint/nn/functional.py +716 -23
  169. mindspore/mint/nn/layer/__init__.py +21 -4
  170. mindspore/mint/nn/layer/_functions.py +334 -0
  171. mindspore/mint/nn/layer/activation.py +276 -1
  172. mindspore/mint/nn/layer/basic.py +123 -0
  173. mindspore/mint/nn/layer/conv.py +933 -0
  174. mindspore/mint/nn/layer/normalization.py +223 -28
  175. mindspore/mint/nn/layer/padding.py +797 -0
  176. mindspore/mint/nn/layer/pooling.py +235 -0
  177. mindspore/mint/optim/__init__.py +3 -1
  178. mindspore/mint/optim/adam.py +223 -0
  179. mindspore/mint/optim/adamw.py +26 -19
  180. mindspore/mint/optim/sgd.py +171 -0
  181. mindspore/mint/special/__init__.py +2 -1
  182. mindspore/multiprocessing/__init__.py +5 -0
  183. mindspore/nn/__init__.py +4 -1
  184. mindspore/nn/cell.py +1373 -192
  185. mindspore/nn/dynamic_lr.py +2 -1
  186. mindspore/nn/layer/activation.py +29 -27
  187. mindspore/nn/layer/basic.py +51 -35
  188. mindspore/nn/layer/channel_shuffle.py +3 -3
  189. mindspore/nn/layer/container.py +1 -1
  190. mindspore/nn/layer/conv.py +53 -42
  191. mindspore/nn/layer/embedding.py +12 -11
  192. mindspore/nn/layer/normalization.py +56 -49
  193. mindspore/nn/layer/padding.py +4 -3
  194. mindspore/nn/layer/pooling.py +120 -42
  195. mindspore/nn/layer/rnn_cells.py +1 -1
  196. mindspore/nn/layer/rnns.py +2 -1
  197. mindspore/nn/layer/timedistributed.py +5 -5
  198. mindspore/nn/layer/transformer.py +59 -36
  199. mindspore/nn/learning_rate_schedule.py +8 -4
  200. mindspore/nn/loss/loss.py +58 -55
  201. mindspore/nn/optim/ada_grad.py +7 -5
  202. mindspore/nn/optim/adadelta.py +11 -9
  203. mindspore/nn/optim/adafactor.py +1 -1
  204. mindspore/nn/optim/adam.py +19 -15
  205. mindspore/nn/optim/adamax.py +8 -7
  206. mindspore/nn/optim/adasum.py +5 -5
  207. mindspore/nn/optim/asgd.py +3 -1
  208. mindspore/nn/optim/ftrl.py +11 -9
  209. mindspore/nn/optim/lamb.py +1 -1
  210. mindspore/nn/optim/lars.py +1 -4
  211. mindspore/nn/optim/lazyadam.py +12 -10
  212. mindspore/nn/optim/momentum.py +7 -6
  213. mindspore/nn/optim/optimizer.py +3 -3
  214. mindspore/nn/optim/proximal_ada_grad.py +12 -10
  215. mindspore/nn/optim/rmsprop.py +13 -12
  216. mindspore/nn/optim/rprop.py +11 -9
  217. mindspore/nn/optim/sgd.py +9 -6
  218. mindspore/nn/optim/tft_wrapper.py +5 -2
  219. mindspore/nn/optim/thor.py +2 -1
  220. mindspore/nn/probability/bijector/bijector.py +17 -11
  221. mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
  222. mindspore/nn/probability/bijector/invert.py +2 -2
  223. mindspore/nn/probability/bijector/scalar_affine.py +3 -3
  224. mindspore/nn/probability/bijector/softplus.py +3 -2
  225. mindspore/nn/probability/distribution/beta.py +3 -3
  226. mindspore/nn/probability/distribution/categorical.py +1 -1
  227. mindspore/nn/probability/distribution/cauchy.py +4 -2
  228. mindspore/nn/probability/distribution/exponential.py +6 -7
  229. mindspore/nn/probability/distribution/gamma.py +2 -2
  230. mindspore/nn/probability/distribution/gumbel.py +2 -2
  231. mindspore/nn/probability/distribution/half_normal.py +5 -3
  232. mindspore/nn/probability/distribution/logistic.py +5 -3
  233. mindspore/nn/probability/distribution/poisson.py +1 -1
  234. mindspore/nn/probability/distribution/uniform.py +5 -3
  235. mindspore/nn/reinforcement/_tensors_queue.py +1 -1
  236. mindspore/nn/reinforcement/tensor_array.py +1 -1
  237. mindspore/nn/utils/init.py +13 -11
  238. mindspore/nn/wrap/__init__.py +6 -6
  239. mindspore/nn/wrap/cell_wrapper.py +181 -122
  240. mindspore/nn/wrap/grad_reducer.py +45 -36
  241. mindspore/nn/wrap/loss_scale.py +6 -7
  242. mindspore/numpy/array_creations.py +63 -65
  243. mindspore/numpy/array_ops.py +149 -144
  244. mindspore/numpy/logic_ops.py +41 -42
  245. mindspore/numpy/math_ops.py +361 -359
  246. mindspore/numpy/utils.py +17 -18
  247. mindspore/numpy/utils_const.py +5 -6
  248. mindspore/opencv_core452.dll +0 -0
  249. mindspore/opencv_imgcodecs452.dll +0 -0
  250. mindspore/opencv_imgproc452.dll +0 -0
  251. mindspore/ops/__init__.py +5 -3
  252. mindspore/ops/_grad_experimental/grad_comm_ops.py +112 -16
  253. mindspore/ops/_grad_experimental/grad_debug_ops.py +14 -2
  254. mindspore/ops/_grad_experimental/grad_inner_ops.py +9 -0
  255. mindspore/ops/_grad_experimental/grad_math_ops.py +2 -1
  256. mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
  257. mindspore/ops/_op_impl/cpu/__init__.py +1 -0
  258. mindspore/ops/_op_impl/cpu/raise_op.py +28 -0
  259. mindspore/ops/_register_for_op.py +0 -11
  260. mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
  261. mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -65
  262. mindspore/ops/_vmap/vmap_array_ops.py +52 -25
  263. mindspore/ops/_vmap/vmap_base.py +0 -2
  264. mindspore/ops/_vmap/vmap_grad_nn_ops.py +21 -14
  265. mindspore/ops/_vmap/vmap_math_ops.py +15 -16
  266. mindspore/ops/_vmap/vmap_nn_ops.py +29 -42
  267. mindspore/ops/auto_generate/__init__.py +4 -3
  268. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +258 -46
  269. mindspore/ops/auto_generate/gen_extend_func.py +757 -185
  270. mindspore/ops/auto_generate/gen_ops_def.py +4197 -2243
  271. mindspore/ops/auto_generate/gen_ops_prim.py +16976 -6055
  272. mindspore/ops/auto_generate/pyboost_inner_prim.py +221 -87
  273. mindspore/ops/composite/__init__.py +2 -1
  274. mindspore/ops/composite/base.py +20 -25
  275. mindspore/ops/composite/math_ops.py +6 -16
  276. mindspore/ops/composite/multitype_ops/__init__.py +5 -2
  277. mindspore/ops/composite/multitype_ops/_compile_utils.py +228 -30
  278. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
  279. mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
  280. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
  281. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
  282. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
  283. mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
  284. mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
  285. mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
  286. mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
  287. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
  288. mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
  289. mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
  290. mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
  291. mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
  292. mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
  293. mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
  294. mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
  295. mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
  296. mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
  297. mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
  298. mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
  299. mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
  300. mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
  301. mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
  302. mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
  303. mindspore/ops/composite/multitype_ops/pow_impl.py +2 -30
  304. mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
  305. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  306. mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
  307. mindspore/ops/function/__init__.py +40 -2
  308. mindspore/ops/function/_add_attr_func.py +58 -0
  309. mindspore/ops/function/array_func.py +2089 -2403
  310. mindspore/ops/function/clip_func.py +80 -23
  311. mindspore/ops/function/debug_func.py +57 -57
  312. mindspore/ops/function/grad/__init__.py +1 -0
  313. mindspore/ops/function/grad/grad_func.py +104 -71
  314. mindspore/ops/function/image_func.py +2 -2
  315. mindspore/ops/function/linalg_func.py +47 -78
  316. mindspore/ops/function/math_func.py +4351 -3813
  317. mindspore/ops/function/nn_func.py +1712 -637
  318. mindspore/ops/function/other_func.py +159 -1
  319. mindspore/ops/function/parameter_func.py +18 -84
  320. mindspore/ops/function/random_func.py +452 -387
  321. mindspore/ops/function/reshard_func.py +4 -70
  322. mindspore/ops/function/sparse_func.py +3 -3
  323. mindspore/ops/function/sparse_unary_func.py +6 -6
  324. mindspore/ops/function/spectral_func.py +25 -58
  325. mindspore/ops/function/vmap_func.py +26 -18
  326. mindspore/ops/functional.py +23 -7
  327. mindspore/ops/functional_overload.py +1548 -0
  328. mindspore/ops/op_info_register.py +32 -244
  329. mindspore/ops/operations/__init__.py +23 -15
  330. mindspore/ops/operations/_custom_ops_utils.py +235 -0
  331. mindspore/ops/operations/_embedding_cache_ops.py +4 -4
  332. mindspore/ops/operations/_grad_ops.py +2 -43
  333. mindspore/ops/operations/_infer_ops.py +2 -1
  334. mindspore/ops/operations/_inner_ops.py +43 -84
  335. mindspore/ops/operations/_ms_kernel.py +4 -10
  336. mindspore/ops/operations/_rl_inner_ops.py +1 -1
  337. mindspore/ops/operations/_scalar_ops.py +3 -2
  338. mindspore/ops/operations/_sequence_ops.py +1 -1
  339. mindspore/ops/operations/_tensor_array.py +1 -1
  340. mindspore/ops/operations/array_ops.py +81 -324
  341. mindspore/ops/operations/comm_ops.py +154 -108
  342. mindspore/ops/operations/custom_ops.py +298 -87
  343. mindspore/ops/operations/debug_ops.py +157 -59
  344. mindspore/ops/operations/inner_ops.py +7 -5
  345. mindspore/ops/operations/linalg_ops.py +1 -57
  346. mindspore/ops/operations/manually_defined/_inner.py +1 -1
  347. mindspore/ops/operations/manually_defined/ops_def.py +928 -180
  348. mindspore/ops/operations/math_ops.py +32 -234
  349. mindspore/ops/operations/nn_ops.py +212 -531
  350. mindspore/ops/operations/other_ops.py +62 -9
  351. mindspore/ops/operations/random_ops.py +13 -7
  352. mindspore/ops/operations/reshard_ops.py +1 -1
  353. mindspore/ops/operations/sparse_ops.py +2 -2
  354. mindspore/ops/primitive.py +66 -53
  355. mindspore/ops/tensor_method.py +1895 -0
  356. mindspore/ops_generate/__init__.py +0 -5
  357. mindspore/ops_generate/aclnn/__init__.py +0 -0
  358. mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +135 -0
  359. mindspore/ops_generate/aclnn/gen_aclnn_implement.py +257 -0
  360. mindspore/ops_generate/api/__init__.py +0 -0
  361. mindspore/ops_generate/api/add_tensor_docs_generator.py +56 -0
  362. mindspore/ops_generate/api/cpp_create_prim_instance_helper_generator.py +105 -0
  363. mindspore/ops_generate/api/functional_map_cpp_generator.py +504 -0
  364. mindspore/ops_generate/api/functional_overload_py_generator.py +112 -0
  365. mindspore/ops_generate/api/functions_cc_generator.py +237 -0
  366. mindspore/ops_generate/api/gen_api.py +103 -0
  367. mindspore/ops_generate/api/op_api_proto.py +235 -0
  368. mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +461 -0
  369. mindspore/ops_generate/common/__init__.py +0 -0
  370. mindspore/ops_generate/common/base_generator.py +11 -0
  371. mindspore/ops_generate/common/gen_constants.py +91 -0
  372. mindspore/ops_generate/common/gen_utils.py +348 -0
  373. mindspore/ops_generate/common/op_proto.py +473 -0
  374. mindspore/ops_generate/common/template.py +523 -0
  375. mindspore/ops_generate/gen_ops.py +22 -1069
  376. mindspore/ops_generate/op_def/__init__.py +0 -0
  377. mindspore/ops_generate/op_def/gen_op_def.py +90 -0
  378. mindspore/ops_generate/op_def/lite_ops_cpp_generator.py +191 -0
  379. mindspore/ops_generate/op_def/ops_def_cc_generator.py +296 -0
  380. mindspore/ops_generate/op_def/ops_def_h_generator.py +74 -0
  381. mindspore/ops_generate/op_def/ops_name_h_generator.py +83 -0
  382. mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
  383. mindspore/ops_generate/op_def_py/__init__.py +0 -0
  384. mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
  385. mindspore/ops_generate/op_def_py/op_def_py_generator.py +132 -0
  386. mindspore/ops_generate/op_def_py/op_prim_py_generator.py +489 -0
  387. mindspore/ops_generate/pyboost/__init__.py +0 -0
  388. mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +139 -0
  389. mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +93 -0
  390. mindspore/ops_generate/pyboost/gen_pyboost_func.py +175 -0
  391. mindspore/ops_generate/pyboost/op_template_parser.py +517 -0
  392. mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +407 -0
  393. mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +100 -0
  394. mindspore/ops_generate/pyboost/pyboost_functions_py_generator.py +148 -0
  395. mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +155 -0
  396. mindspore/ops_generate/pyboost/pyboost_inner_prim_generator.py +132 -0
  397. mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +272 -0
  398. mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +938 -0
  399. mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +357 -0
  400. mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +179 -36
  401. mindspore/ops_generate/resources/__init__.py +0 -0
  402. mindspore/ops_generate/resources/resource_list.py +30 -0
  403. mindspore/ops_generate/resources/resource_loader.py +36 -0
  404. mindspore/ops_generate/resources/resource_manager.py +64 -0
  405. mindspore/ops_generate/resources/yaml_loader.py +88 -0
  406. mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
  407. mindspore/parallel/__init__.py +7 -3
  408. mindspore/parallel/_auto_parallel_context.py +159 -40
  409. mindspore/parallel/_cell_wrapper.py +132 -15
  410. mindspore/parallel/_parallel_serialization.py +107 -5
  411. mindspore/parallel/_ps_context.py +1 -1
  412. mindspore/parallel/_recovery_context.py +7 -2
  413. mindspore/parallel/_tensor.py +142 -18
  414. mindspore/parallel/_utils.py +199 -23
  415. mindspore/parallel/algo_parameter_config.py +4 -4
  416. mindspore/parallel/auto_parallel.py +732 -0
  417. mindspore/parallel/checkpoint_convert.py +159 -0
  418. mindspore/parallel/checkpoint_transform.py +700 -35
  419. mindspore/parallel/cluster/process_entity/_api.py +276 -50
  420. mindspore/parallel/cluster/process_entity/_utils.py +41 -6
  421. mindspore/parallel/cluster/run.py +21 -4
  422. mindspore/parallel/function/__init__.py +24 -0
  423. mindspore/parallel/function/reshard_func.py +258 -0
  424. mindspore/parallel/nn/__init__.py +25 -0
  425. mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
  426. mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
  427. mindspore/parallel/parameter_broadcast.py +25 -14
  428. mindspore/parallel/shard.py +137 -59
  429. mindspore/parallel/transform_safetensors.py +364 -305
  430. mindspore/profiler/__init__.py +22 -5
  431. mindspore/profiler/analysis/__init__.py +0 -0
  432. mindspore/profiler/analysis/parser/__init__.py +0 -0
  433. mindspore/profiler/analysis/parser/ascend_cann_parser.py +170 -0
  434. mindspore/profiler/analysis/parser/base_parser.py +158 -0
  435. mindspore/profiler/analysis/parser/framework_cann_relation_parser.py +45 -0
  436. mindspore/profiler/analysis/parser/ms_framework_parser.py +142 -0
  437. mindspore/profiler/analysis/parser/ms_minddata_parser.py +145 -0
  438. mindspore/profiler/analysis/parser/timeline_assembly_factory/__init__.py +0 -0
  439. mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +264 -0
  440. mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +40 -0
  441. mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +109 -0
  442. mindspore/profiler/analysis/parser/timeline_creator/__init__.py +0 -0
  443. mindspore/profiler/analysis/parser/timeline_creator/base_timeline_creator.py +44 -0
  444. mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +90 -0
  445. mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +76 -0
  446. mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +103 -0
  447. mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +134 -0
  448. mindspore/profiler/analysis/parser/timeline_event/__init__.py +0 -0
  449. mindspore/profiler/analysis/parser/timeline_event/base_event.py +233 -0
  450. mindspore/profiler/analysis/parser/timeline_event/cpu_op_event.py +47 -0
  451. mindspore/profiler/analysis/parser/timeline_event/flow_event.py +36 -0
  452. mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +415 -0
  453. mindspore/profiler/analysis/parser/timeline_event/msprof_event.py +73 -0
  454. mindspore/profiler/analysis/parser/timeline_event/scope_layer_event.py +53 -0
  455. mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +146 -0
  456. mindspore/profiler/analysis/task_manager.py +131 -0
  457. mindspore/profiler/analysis/time_converter.py +84 -0
  458. mindspore/profiler/analysis/viewer/__init__.py +0 -0
  459. mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +372 -0
  460. mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +87 -0
  461. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +250 -0
  462. mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +320 -0
  463. mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +327 -0
  464. mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +376 -0
  465. mindspore/profiler/analysis/viewer/ascend_timeline_viewer.py +58 -0
  466. mindspore/profiler/analysis/viewer/base_viewer.py +26 -0
  467. mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +96 -0
  468. mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +581 -0
  469. mindspore/profiler/analysis/work_flow.py +73 -0
  470. mindspore/profiler/common/ascend_msprof_exporter.py +139 -0
  471. mindspore/profiler/common/command_executor.py +90 -0
  472. mindspore/profiler/common/constant.py +186 -3
  473. mindspore/profiler/common/file_manager.py +208 -0
  474. mindspore/profiler/common/log.py +130 -0
  475. mindspore/profiler/common/msprof_cmd_tool.py +221 -0
  476. mindspore/profiler/common/path_manager.py +395 -0
  477. mindspore/profiler/common/process_bar.py +168 -0
  478. mindspore/profiler/common/process_pool.py +9 -3
  479. mindspore/profiler/common/profiler_context.py +500 -0
  480. mindspore/profiler/common/profiler_info.py +304 -0
  481. mindspore/profiler/common/profiler_meta_data.py +74 -0
  482. mindspore/profiler/common/profiler_output_path.py +284 -0
  483. mindspore/profiler/common/profiler_parameters.py +251 -0
  484. mindspore/profiler/common/profiler_path_manager.py +179 -0
  485. mindspore/profiler/common/record_function.py +76 -0
  486. mindspore/profiler/common/tlv_decoder.py +76 -0
  487. mindspore/profiler/common/util.py +75 -2
  488. mindspore/profiler/dynamic_profiler.py +341 -75
  489. mindspore/profiler/envprofiler.py +163 -0
  490. mindspore/profiler/experimental_config.py +197 -0
  491. mindspore/profiler/mstx.py +242 -0
  492. mindspore/profiler/platform/__init__.py +21 -0
  493. mindspore/profiler/platform/base_profiler.py +40 -0
  494. mindspore/profiler/platform/cpu_profiler.py +124 -0
  495. mindspore/profiler/platform/gpu_profiler.py +74 -0
  496. mindspore/profiler/platform/npu_profiler.py +335 -0
  497. mindspore/profiler/profiler.py +1073 -90
  498. mindspore/profiler/profiler_action_controller.py +187 -0
  499. mindspore/profiler/profiler_interface.py +118 -0
  500. mindspore/profiler/schedule.py +243 -0
  501. mindspore/rewrite/api/node.py +15 -13
  502. mindspore/rewrite/api/symbol_tree.py +2 -3
  503. mindspore/run_check/_check_version.py +27 -20
  504. mindspore/run_check/run_check.py +1 -1
  505. mindspore/runtime/__init__.py +37 -0
  506. mindspore/runtime/device.py +27 -0
  507. mindspore/runtime/event.py +209 -0
  508. mindspore/runtime/executor.py +177 -0
  509. mindspore/runtime/memory.py +416 -0
  510. mindspore/runtime/stream.py +460 -0
  511. mindspore/runtime/thread_bind_core.py +401 -0
  512. mindspore/safeguard/rewrite_obfuscation.py +12 -9
  513. mindspore/swresample-4.dll +0 -0
  514. mindspore/swscale-6.dll +0 -0
  515. mindspore/tinyxml2.dll +0 -0
  516. mindspore/train/__init__.py +8 -8
  517. mindspore/train/_utils.py +96 -27
  518. mindspore/train/amp.py +9 -5
  519. mindspore/train/callback/__init__.py +2 -2
  520. mindspore/train/callback/_callback.py +2 -16
  521. mindspore/train/callback/_checkpoint.py +53 -55
  522. mindspore/train/callback/_cluster_monitor.py +14 -18
  523. mindspore/train/callback/_early_stop.py +1 -1
  524. mindspore/train/callback/_flops_collector.py +103 -68
  525. mindspore/train/callback/_history.py +8 -5
  526. mindspore/train/callback/_lambda_callback.py +2 -2
  527. mindspore/train/callback/_landscape.py +0 -3
  528. mindspore/train/callback/_loss_monitor.py +2 -1
  529. mindspore/train/callback/_on_request_exit.py +6 -5
  530. mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
  531. mindspore/train/callback/_summary_collector.py +52 -19
  532. mindspore/train/callback/_time_monitor.py +2 -1
  533. mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +228 -108
  534. mindspore/train/data_sink.py +25 -2
  535. mindspore/train/dataset_helper.py +15 -16
  536. mindspore/train/loss_scale_manager.py +8 -7
  537. mindspore/train/metrics/accuracy.py +3 -3
  538. mindspore/train/metrics/confusion_matrix.py +9 -9
  539. mindspore/train/metrics/error.py +3 -3
  540. mindspore/train/metrics/hausdorff_distance.py +4 -4
  541. mindspore/train/metrics/mean_surface_distance.py +3 -3
  542. mindspore/train/metrics/metric.py +0 -12
  543. mindspore/train/metrics/occlusion_sensitivity.py +4 -2
  544. mindspore/train/metrics/precision.py +11 -10
  545. mindspore/train/metrics/recall.py +9 -9
  546. mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
  547. mindspore/train/mind_ir_pb2.py +174 -46
  548. mindspore/train/model.py +269 -136
  549. mindspore/train/serialization.py +622 -978
  550. mindspore/train/summary/_summary_adapter.py +2 -2
  551. mindspore/train/summary/summary_record.py +2 -3
  552. mindspore/train/train_thor/model_thor.py +1 -1
  553. mindspore/turbojpeg.dll +0 -0
  554. mindspore/utils/__init__.py +6 -3
  555. mindspore/utils/dryrun.py +140 -0
  556. mindspore/utils/hooks.py +81 -0
  557. mindspore/utils/runtime_execution_order_check.py +552 -0
  558. mindspore/utils/utils.py +138 -4
  559. mindspore/version.py +1 -1
  560. {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/METADATA +3 -3
  561. {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/RECORD +564 -395
  562. {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/entry_points.txt +1 -1
  563. mindspore/_install_custom.py +0 -43
  564. mindspore/common/_register_for_adapter.py +0 -74
  565. mindspore/common/_tensor_overload.py +0 -139
  566. mindspore/mindspore_np_dtype.dll +0 -0
  567. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
  568. mindspore/ops/auto_generate/gen_arg_handler.py +0 -197
  569. mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
  570. mindspore/ops_generate/gen_aclnn_implement.py +0 -263
  571. mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
  572. mindspore/ops_generate/gen_pyboost_func.py +0 -1052
  573. mindspore/ops_generate/gen_utils.py +0 -209
  574. mindspore/ops_generate/op_proto.py +0 -145
  575. mindspore/ops_generate/template.py +0 -261
  576. mindspore/profiler/envprofiling.py +0 -254
  577. mindspore/profiler/profiling.py +0 -1926
  578. {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/WHEEL +0 -0
  579. {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,473 @@
1
+ # Copyright 2024 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+
16
+ """Op Proto module for defining operator prototypes and their arguments."""
17
+
18
+ import os
19
+ from typing import Dict
20
+
21
+ from resources.resource_loader import ResourceLoader
22
+ from resources.resource_list import ResourceType
23
+
24
+ from . import gen_constants as K
25
+ from .gen_utils import safe_load_yaml_from_dir
26
+
27
+
28
+ class OpArg:
29
+ """
30
+ Represents an argument of an operator.
31
+
32
+ Attributes:
33
+ arg_name (str): The name of the argument.
34
+ arg_dtype (str): The data type of the argument.
35
+ type_cast (list): A list of type casts applicable to the argument.
36
+ is_type_id (bool): Indicates if the argument is a type identifier.
37
+ as_init_arg (bool): Indicates if the argument is an initialization argument.
38
+ default: The default value of the argument.
39
+ inplace (str): The name of the inplace tensor if applicable.
40
+ is_prim_init (bool): Indicates if the argument is a primitive initialization argument.
41
+ arg_handler (str): A handler for the argument, if applicable.
42
+ """
43
+
44
+ def __init__(self, arg_name, arg_dtype, type_cast, is_type_id=False, as_init_arg=False, default=-1, inplace='',
45
+ is_prim_init=False, arg_handler=''):
46
+ self.arg_name = arg_name
47
+ self.arg_dtype = arg_dtype
48
+ self.type_cast = type_cast
49
+ self.is_type_id = is_type_id
50
+ self.as_init_arg = as_init_arg
51
+ self.default = default
52
+ self.inplace = inplace
53
+ self.is_prim_init = is_prim_init
54
+ self.arg_handler = arg_handler
55
+
56
+
57
+ class OpArgsSignature:
58
+ """
59
+ Represents the signature of operator arguments.
60
+
61
+ Attributes:
62
+ rw_write (list): Arguments that are written to.
63
+ rw_read (list): Arguments that are read from.
64
+ rw_ref (list): Arguments that are passed by reference.
65
+ dtype_group (list): Grouping of data types for the arguments.
66
+ """
67
+
68
+ def __init__(self, rw_write=None, rw_read=None, rw_ref=None, dtype_group=None):
69
+ self.rw_write = rw_write
70
+ self.rw_read = rw_read
71
+ self.rw_ref = rw_ref
72
+ self.dtype_group = dtype_group
73
+
74
+
75
+ class OpFunction:
76
+ """
77
+ Represents the function associated with an operator.
78
+
79
+ Attributes:
80
+ disable (bool): Indicates if the function is disabled.
81
+ name (str): The name of the function.
82
+ """
83
+
84
+ def __init__(self, disable=False, name=''):
85
+ self.disable = disable
86
+ self.name = name
87
+
88
+
89
+ class OpClass:
90
+ """
91
+ Represents a class associated with an operator.
92
+
93
+ Attributes:
94
+ disable (bool): Indicates if the class is disabled.
95
+ name (str): The name of the class.
96
+ """
97
+
98
+ def __init__(self, disable=False, name=''):
99
+ self.disable = disable
100
+ self.name = name
101
+
102
+
103
+ class OpDispatch:
104
+ """
105
+ Represents the dispatch information for an operator.
106
+
107
+ Attributes:
108
+ enable (bool): Indicates if the dispatch is enabled.
109
+ is_comm_op (bool): Indicates if the dispatch is communication operator or not.
110
+ ascend (str): The dispatch type for the Ascend device.
111
+ cpu (str): The dispatch type for the CPU.
112
+ gpu (str): The dispatch type for the GPU.
113
+ """
114
+
115
+ def __init__(self, enable=False, is_comm_op=False, ascend='default', cpu='default', gpu='default'):
116
+ self.enable = enable
117
+ self.is_comm_op = is_comm_op
118
+ self.ascend = ascend
119
+ self.cpu = cpu
120
+ self.gpu = gpu
121
+
122
+
123
+ class OpProto:
124
+ """
125
+ Defines a prototype for an operator in MindSpore.
126
+
127
+ This class is used to parse the operator definition from a YAML file and to generate
128
+ the necessary primitive and PyBoost functions.
129
+
130
+ Attributes:
131
+ op_name (str): The name of the operator.
132
+ op_args (list): A list of arguments for the operator.
133
+ op_function (OpFunction): The function associated with the operator.
134
+ op_class (OpClass): The class associated with the operator.
135
+ op_dispatch (OpDispatch): The dispatch information for the operator.
136
+ op_args_signature (OpArgsSignature): The signature of the operator's arguments.
137
+ op_returns (list): A list of return values for the operator.
138
+ op_view (bool): Indicates if the operator is a view operator.
139
+ """
140
+
141
+ def __init__(self,
142
+ op_name,
143
+ op_args,
144
+ op_function,
145
+ op_class,
146
+ op_dispatch,
147
+ op_args_signature,
148
+ op_returns,
149
+ op_view=False,
150
+ op_graph_view=False,
151
+ op_inplace=False,
152
+ op_labels=None,
153
+ op_deprecated=None,
154
+ bprop_expander=True):
155
+ self.op_name = op_name
156
+ self.op_args = op_args
157
+ self.op_function = op_function
158
+ self.op_class = op_class
159
+ self.op_dispatch = op_dispatch
160
+ self.op_args_signature = op_args_signature
161
+ self.op_returns = op_returns
162
+ self.op_view = op_view
163
+ self.op_graph_view = op_graph_view
164
+ self.op_inplace = op_inplace
165
+ self.op_labels = op_labels
166
+ self.op_deprecated = op_deprecated
167
+ self.bprop_expander = bprop_expander
168
+
169
+ @staticmethod
170
+ def load_from_yaml(op_name, op_data):
171
+ """
172
+ Loads an operator prototype from YAML data.
173
+
174
+ Args:
175
+ op_name (str): The name of the operation.
176
+ op_data (dict): A dictionary containing the operation data.
177
+
178
+ Returns:
179
+ OpProto: An instance of OpProto representing the operator.
180
+ """
181
+ # check op keys
182
+ check_validation(op_name, op_data)
183
+ # get op args
184
+ op_args = get_op_args(op_name, op_data)
185
+ # get op return args
186
+ op_returns = get_op_returns(op_name, op_data)
187
+ # get op args signature
188
+ op_args_signature = get_op_args_signature(op_name, op_data)
189
+ # get op class
190
+ op_class = get_op_class(op_name, op_data)
191
+ # get op function
192
+ op_function = get_op_function(op_name, op_data)
193
+ # get op dispatch
194
+ op_dispatch = get_op_dispatch(op_name, op_data)
195
+ # get op view
196
+ op_view = op_data.get('view', False)
197
+ if not isinstance(op_view, bool):
198
+ raise TypeError(
199
+ f'The view value should be bool, but get {type(op_view)}, op name is {op_name}.')
200
+ # get op graph view
201
+ op_graph_view = op_data.get('graph_view', False)
202
+ if not isinstance(op_graph_view, bool):
203
+ raise TypeError(
204
+ f'The graph view value should be bool, but get {type(op_graph_view)}, op name is {op_name}.')
205
+ op_inplace = is_inplace_op(op_returns)
206
+ # get op labels
207
+ op_labels = op_data.get('labels', None)
208
+ # get op deprecated
209
+ op_deprecated = op_data.get('deprecated', None)
210
+ bprop_expander = op_data.get('bprop_expander', True)
211
+ op_proto = OpProto(op_name=op_name, op_args=op_args, op_returns=op_returns, op_function=op_function,
212
+ op_class=op_class, op_dispatch=op_dispatch, op_args_signature=op_args_signature,
213
+ op_view=op_view, op_graph_view=op_graph_view, op_inplace=op_inplace, op_labels=op_labels,
214
+ op_deprecated=op_deprecated, bprop_expander=bprop_expander)
215
+ return op_proto
216
+
217
+
218
+ class OpProtoLoader(ResourceLoader):
219
+ """
220
+ OpProtoLoader is a class for loading operator prototypes from YAML data.
221
+ """
222
+ def __init__(self):
223
+ ops_yaml_path = os.path.join(K.WORK_DIR, K.MS_OP_DEF_YAML_PATH)
224
+ infer_ops_yaml_path = os.path.join(ops_yaml_path, 'infer')
225
+ self.yaml_paths = [ops_yaml_path, infer_ops_yaml_path]
226
+ self.type = ResourceType.OP_PROTO
227
+ self.is_deprecated = False
228
+ self.func_op = False
229
+
230
+ def load(self) -> Dict[ResourceType, object]:
231
+ """
232
+ Load OpProto.
233
+
234
+ Returns:
235
+ Dict[ResourceType, object]: The resource type and the OpProto.
236
+ """
237
+ yaml_dict = {}
238
+ for yaml_path in self.yaml_paths:
239
+ yaml_dict.update(safe_load_yaml_from_dir(yaml_path))
240
+ op_protos = []
241
+ for op_name, op_data in yaml_dict.items():
242
+ op_proto = OpProto.load_from_yaml(op_name, op_data)
243
+ if self.is_deprecated:
244
+ op_proto.op_name = 'deprecated_' + op_name
245
+ op_proto.func_op = self.func_op
246
+ op_protos.append(op_proto)
247
+ return {self.type: op_protos}
248
+
249
+
250
+ class DeprecatedOpProtoLoader(OpProtoLoader):
251
+ """
252
+ DeprecatedOpProtoLoader is a class for loading deprecated operator prototypes from YAML data.
253
+ """
254
+ def __init__(self):
255
+ super().__init__()
256
+ self.yaml_paths = [os.path.join(K.WORK_DIR, K.MS_OP_DEPRECATED_DEF_YAML_PATH)]
257
+ self.type = ResourceType.DEPRECATED_OP_PROTO
258
+ self.is_deprecated = True
259
+ self.func_op = True
260
+
261
+
262
+ class FuncOpProtoLoader(OpProtoLoader):
263
+ """
264
+ FuncOpProtoLoader is a class for loading func_op operator prototypes from YAML data.
265
+ """
266
+ def __init__(self):
267
+ super().__init__()
268
+ self.yaml_paths = [os.path.join(K.WORK_DIR, K.MS_OP_DEF_FUNC_OP_YAML_PATH)]
269
+ self.type = ResourceType.FUNC_OP_PROTO
270
+ self.is_deprecated = False
271
+ self.func_op = True
272
+
273
+
274
+ def get_op_args_signature(op_name, op_data):
275
+ """
276
+ Retrieves the argument signature from the operation data.
277
+
278
+ Args:
279
+ op_data (dict): A dictionary containing the operation data.
280
+
281
+ Returns:
282
+ OpArgsSignature: An instance of OpArgsSignature containing the argument signature.
283
+ """
284
+ op_args_signature = op_data.get('args_signature', None)
285
+ if op_args_signature is not None:
286
+ args_signature_keys = op_args_signature.keys()
287
+ check_op_yaml_keys(op_name, set(args_signature_keys), K.ARG_SIGNATURE_KEYS)
288
+ rw_write = op_args_signature.get('rw_write', None)
289
+ rw_read = op_args_signature.get('rw_read', None)
290
+ rw_ref = op_args_signature.get('rw_ref', None)
291
+ dtype_group = op_args_signature.get('dtype_group', None)
292
+ return OpArgsSignature(rw_write, rw_read, rw_ref, dtype_group)
293
+ return None
294
+
295
+
296
+ def check_validation(op_name: str, op_data: dict):
297
+ """
298
+ Validates the operator data to ensure it contains necessary keys.
299
+
300
+ Args:
301
+ op_data (dict): The operator data to validate.
302
+
303
+ Raises:
304
+ TypeError: If the required keys 'args' or 'returns' are missing.
305
+ """
306
+ # check keys
307
+ check_op_yaml_keys(op_name, set(op_data.keys()), K.OP_KEYS)
308
+
309
+ # Those keys must in yaml
310
+ if 'args' not in op_data.keys():
311
+ raise TypeError(f"Op define miss key 'args', op name is {op_name}")
312
+ if 'returns' not in op_data.keys():
313
+ raise TypeError(f"Op define miss key 'returns', op name is {op_name}")
314
+
315
+
316
+ def get_op_args(op_name, op_data):
317
+ """
318
+ Retrieves the arguments for the operator from the operation data.
319
+
320
+ Args:
321
+ op_data (dict): A dictionary containing the operation data.
322
+
323
+ Returns:
324
+ list: A list of OpArg instances representing the arguments of the operator.
325
+ """
326
+ args_dict = op_data.get('args')
327
+ op_args = []
328
+ for arg_name in args_dict.keys():
329
+ arg_keys = args_dict[arg_name].keys()
330
+ check_op_yaml_keys(op_name, set(arg_keys), K.ARG_KEYS)
331
+ arg_dtype = args_dict[arg_name]['dtype']
332
+ if arg_dtype == 'TypeId':
333
+ arg_dtype = 'int'
334
+ default = None
335
+ as_init_arg = False
336
+ is_type_id = False
337
+ prim_init = False
338
+ type_cast = []
339
+ if 'default' in args_dict[arg_name]:
340
+ default = args_dict[arg_name]['default']
341
+ as_init_arg = True
342
+ # 当op_args任意一个参数有prim_init,该op就要在pyboost_inner_prim.py生成
343
+ if 'prim_init' in args_dict[arg_name] and args_dict[arg_name]['prim_init'] is True:
344
+ prim_init = True
345
+ if 'type_cast' in args_dict[arg_name]:
346
+ type_cast = [cast_type.strip() for cast_type in args_dict[arg_name]['type_cast'].split(',')]
347
+ arg_handler_key = 'arg_handler'
348
+ arg_handler = args_dict[arg_name].get(arg_handler_key, '')
349
+ if arg_handler_key in args_dict[arg_name] and args_dict[arg_name][arg_handler_key] == 'dtype_to_type_id':
350
+ is_type_id = True
351
+ op_arg = OpArg(arg_name, arg_dtype, type_cast, is_type_id, as_init_arg, default,
352
+ is_prim_init=prim_init, arg_handler=arg_handler)
353
+ op_args.append(op_arg)
354
+ return op_args
355
+
356
+
357
+ def get_op_returns(op_name, op_data):
358
+ """
359
+ Retrieves the return values for the operator from the operation data.
360
+
361
+ Args:
362
+ op_data (dict): A dictionary containing the operation data.
363
+
364
+ Returns:
365
+ list: A list of OpArg instances representing the return values of the operator.
366
+ """
367
+ op_return_args = []
368
+ return_dict = op_data['returns']
369
+ for return_name in return_dict.keys():
370
+ return_keys = return_dict[return_name].keys()
371
+ check_op_yaml_keys(op_name, set(return_keys), K.RETURN_KEYS)
372
+ inplace = ''
373
+ if 'inplace' in return_dict[return_name]:
374
+ inplace = return_dict[return_name]['inplace']
375
+ if 'dtype' not in return_dict[return_name]:
376
+ raise TypeError("op return args need key 'dtype'")
377
+ dtype = return_dict[return_name]['dtype']
378
+ arg = OpArg(return_name, dtype, type_cast=[], inplace=inplace)
379
+ op_return_args.append(arg)
380
+ return op_return_args
381
+
382
+
383
+ def get_op_dispatch(op_name, op_data):
384
+ """
385
+ Retrieves the dispatch information for the operator from the operation data.
386
+
387
+ Args:
388
+ op_data (dict): A dictionary containing the operation data.
389
+
390
+ Returns:
391
+ OpDispatch: An instance of OpDispatch containing the dispatch information.
392
+ """
393
+ op_dispatch = op_data.get('dispatch', {})
394
+ dispatch_keys = op_dispatch.keys()
395
+ check_op_yaml_keys(op_name, set(dispatch_keys), K.DISPATCH_KEYS)
396
+ if not op_dispatch:
397
+ return None
398
+ enable = op_dispatch.get('enable', False)
399
+ if not isinstance(enable, bool):
400
+ raise TypeError(
401
+ f'The dispatch enable value should be bool, but get {type(enable)}, op name is {op_name}.')
402
+ is_comm_op = op_dispatch.get('is_comm_op', False)
403
+ ascend = op_dispatch.get('Ascend', 'default')
404
+ cpu = op_dispatch.get('CPU', 'default')
405
+ gpu = op_dispatch.get('GPU', 'default')
406
+ return OpDispatch(enable, is_comm_op, ascend, cpu, gpu)
407
+
408
+
409
+ def get_op_class(op_name, op_data) -> OpClass:
410
+ """
411
+ Retrieves the class information for the operator from the operation data.
412
+
413
+ Args:
414
+ op_name (str): The name of the operation.
415
+ op_data (dict): A dictionary containing the operation data.
416
+
417
+ Returns:
418
+ OpClass: An instance of OpClass containing the class information for the operator.
419
+ """
420
+ op_class = op_data.get('class', {})
421
+ class_keys = op_class.keys()
422
+ check_op_yaml_keys(op_name, set(class_keys), K.CLASS_KEYS)
423
+ is_disable = op_class.get('disable', False)
424
+ if not isinstance(is_disable, bool):
425
+ raise TypeError(
426
+ f'The class disable value should be bool, but get {type(is_disable)}, op name is {op_name}.')
427
+ class_name = op_class.get('name', convert_python_func_name_to_c(op_name))
428
+ return OpClass(disable=is_disable, name=class_name)
429
+
430
+
431
+ def get_op_function(op_name, op_data) -> OpFunction:
432
+ """
433
+ Retrieves the function information for the operator from the operation data.
434
+
435
+ Args:
436
+ op_name (str): default operation function name.
437
+ op_data (dict): A dictionary containing the operation data.
438
+
439
+ Returns:
440
+ OpFunction: An instance of OpFunction containing the function information for the operator.
441
+ """
442
+ op_function = op_data.get('function', {})
443
+ function_keys = op_function.keys()
444
+ check_op_yaml_keys(op_name, set(function_keys), K.FUNCTION_KEYS)
445
+ is_disable = op_function.get('disable', False)
446
+ if not isinstance(is_disable, bool):
447
+ raise TypeError(
448
+ f'The function disable value should be bool, but get {type(is_disable)}, op name is {op_name}.')
449
+ function_name = op_function.get('name', op_name)
450
+ return OpFunction(disable=is_disable, name=function_name)
451
+
452
+
453
+ def convert_python_func_name_to_c(func_name: str) -> str:
454
+ return ''.join(word.capitalize() for word in func_name.split('_'))
455
+
456
+
457
+ def check_op_yaml_keys(op_name: str, input_keys: set, compare_keys: set):
458
+ diff_keys = input_keys - compare_keys
459
+ if diff_keys:
460
+ raise TypeError(
461
+ f'The definition of keys in yaml has faults, op name is {op_name}, wrong keys are {diff_keys}.')
462
+
463
+
464
+ def is_inplace_op(args):
465
+ """
466
+ is inplace op
467
+ :param args:
468
+ :return: bool
469
+ """
470
+ for arg in args:
471
+ if arg.inplace:
472
+ return True
473
+ return False