mindspore 2.5.0__cp39-cp39-win_amd64.whl → 2.6.0rc1__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (491) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  3. mindspore/Newtonsoft.Json.dll +0 -0
  4. mindspore/__init__.py +6 -4
  5. mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
  6. mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
  7. mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
  8. mindspore/_check_jit_forbidden_api.py +3 -0
  9. mindspore/_checkparam.py +3 -33
  10. mindspore/_deprecated/__init__.py +17 -0
  11. mindspore/_deprecated/jit.py +198 -0
  12. mindspore/_extends/builtin_operations.py +1 -1
  13. mindspore/_extends/parse/__init__.py +6 -7
  14. mindspore/_extends/parse/compile_config.py +19 -0
  15. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +22 -3
  16. mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
  17. mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
  18. mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
  19. mindspore/_extends/parse/parser.py +24 -193
  20. mindspore/_extends/parse/resources.py +1 -5
  21. mindspore/_extends/parse/standard_method.py +97 -74
  22. mindspore/_extends/pijit/__init__.py +2 -2
  23. mindspore/_extends/pijit/pijit_func_white_list.py +16 -11
  24. mindspore/_extends/pijit/tensor_func_list.py +27 -0
  25. mindspore/_extends/utils.py +1 -1
  26. mindspore/amp.py +4 -4
  27. mindspore/atlprov.dll +0 -0
  28. mindspore/avcodec-59.dll +0 -0
  29. mindspore/avdevice-59.dll +0 -0
  30. mindspore/avfilter-8.dll +0 -0
  31. mindspore/avformat-59.dll +0 -0
  32. mindspore/avutil-57.dll +0 -0
  33. mindspore/boost/__init__.py +2 -2
  34. mindspore/boost/base.py +3 -7
  35. mindspore/boost/boost_cell_wrapper.py +2 -2
  36. mindspore/c1.dll +0 -0
  37. mindspore/c1xx.dll +0 -0
  38. mindspore/c2.dll +0 -0
  39. mindspore/common/__init__.py +4 -3
  40. mindspore/common/_grad_function.py +56 -0
  41. mindspore/common/_pijit_context.py +14 -5
  42. mindspore/common/_register_for_tensor.py +1 -1
  43. mindspore/common/_stub_tensor.py +5 -10
  44. mindspore/common/_tensor_cpp_method.py +1 -1
  45. mindspore/common/_tensor_docs.py +1915 -3287
  46. mindspore/common/api.py +341 -354
  47. mindspore/common/auto_dynamic_shape.py +41 -44
  48. mindspore/common/dtype.py +5 -2
  49. mindspore/common/dump.py +7 -5
  50. mindspore/common/file_system.py +3 -0
  51. mindspore/common/hook_handle.py +5 -3
  52. mindspore/common/initializer.py +10 -6
  53. mindspore/common/jit_begin_end.py +94 -0
  54. mindspore/common/jit_config.py +6 -1
  55. mindspore/common/jit_context.py +76 -0
  56. mindspore/common/jit_trace.py +378 -0
  57. mindspore/common/lazy_inline.py +2 -2
  58. mindspore/common/mutable.py +5 -4
  59. mindspore/common/parameter.py +106 -39
  60. mindspore/common/seed.py +2 -2
  61. mindspore/common/sparse_tensor.py +23 -17
  62. mindspore/common/tensor.py +297 -714
  63. mindspore/communication/__init__.py +7 -5
  64. mindspore/communication/_comm_helper.py +47 -2
  65. mindspore/communication/comm_func.py +70 -53
  66. mindspore/communication/management.py +83 -17
  67. mindspore/context.py +214 -560
  68. mindspore/dataset/__init__.py +44 -20
  69. mindspore/dataset/audio/__init__.py +2 -8
  70. mindspore/dataset/audio/transforms.py +3 -17
  71. mindspore/dataset/core/config.py +3 -3
  72. mindspore/dataset/engine/cache_client.py +1 -1
  73. mindspore/dataset/engine/datasets.py +102 -120
  74. mindspore/dataset/engine/datasets_audio.py +22 -22
  75. mindspore/dataset/engine/datasets_standard_format.py +43 -24
  76. mindspore/dataset/engine/datasets_text.py +78 -85
  77. mindspore/dataset/engine/datasets_user_defined.py +108 -76
  78. mindspore/dataset/engine/datasets_vision.py +111 -108
  79. mindspore/dataset/engine/iterators.py +5 -3
  80. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
  81. mindspore/dataset/engine/samplers.py +279 -57
  82. mindspore/dataset/engine/serializer_deserializer.py +2 -1
  83. mindspore/dataset/engine/validators.py +10 -0
  84. mindspore/dataset/text/__init__.py +7 -6
  85. mindspore/dataset/text/transforms.py +6 -5
  86. mindspore/dataset/text/utils.py +3 -3
  87. mindspore/dataset/transforms/__init__.py +0 -9
  88. mindspore/dataset/transforms/transforms.py +3 -3
  89. mindspore/dataset/utils/browse_dataset.py +1 -1
  90. mindspore/dataset/vision/__init__.py +2 -9
  91. mindspore/dataset/vision/transforms.py +202 -158
  92. mindspore/dataset/vision/utils.py +7 -5
  93. mindspore/device_context/ascend/op_debug.py +60 -1
  94. mindspore/device_context/ascend/op_tuning.py +0 -4
  95. mindspore/device_manager.py +39 -3
  96. mindspore/dnnl.dll +0 -0
  97. mindspore/dpcmi.dll +0 -0
  98. mindspore/experimental/es/embedding_service.py +35 -27
  99. mindspore/experimental/map_parameter.py +4 -4
  100. mindspore/experimental/optim/adadelta.py +22 -26
  101. mindspore/experimental/optim/adagrad.py +4 -4
  102. mindspore/experimental/optim/adam.py +4 -0
  103. mindspore/experimental/optim/adamax.py +4 -4
  104. mindspore/experimental/optim/adamw.py +4 -0
  105. mindspore/experimental/optim/asgd.py +1 -1
  106. mindspore/experimental/optim/lr_scheduler.py +40 -22
  107. mindspore/experimental/optim/radam.py +5 -5
  108. mindspore/experimental/optim/rprop.py +1 -1
  109. mindspore/experimental/optim/sgd.py +1 -1
  110. mindspore/hal/contiguous_tensors_handle.py +6 -10
  111. mindspore/hal/device.py +55 -81
  112. mindspore/hal/event.py +38 -55
  113. mindspore/hal/memory.py +93 -144
  114. mindspore/hal/stream.py +81 -125
  115. mindspore/include/dataset/constants.h +7 -4
  116. mindspore/include/dataset/execute.h +2 -2
  117. mindspore/jpeg62.dll +0 -0
  118. mindspore/log.py +40 -2
  119. mindspore/mindrecord/__init__.py +20 -7
  120. mindspore/mindspore_backend_common.dll +0 -0
  121. mindspore/mindspore_backend_manager.dll +0 -0
  122. mindspore/mindspore_common.dll +0 -0
  123. mindspore/mindspore_core.dll +0 -0
  124. mindspore/mindspore_dump.dll +0 -0
  125. mindspore/mindspore_frontend.dll +0 -0
  126. mindspore/mindspore_glog.dll +0 -0
  127. mindspore/mindspore_memory_pool.dll +0 -0
  128. mindspore/mindspore_ms_backend.dll +0 -0
  129. mindspore/mindspore_ops.dll +0 -0
  130. mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
  131. mindspore/mindspore_ops_kernel_common.dll +0 -0
  132. mindspore/mindspore_profiler.dll +0 -0
  133. mindspore/mindspore_pyboost.dll +0 -0
  134. mindspore/mindspore_pynative.dll +0 -0
  135. mindspore/mindspore_res_manager.dll +0 -0
  136. mindspore/mindspore_runtime_pipeline.dll +0 -0
  137. mindspore/mint/__init__.py +131 -700
  138. mindspore/mint/distributed/__init__.py +5 -1
  139. mindspore/mint/distributed/distributed.py +194 -109
  140. mindspore/mint/linalg/__init__.py +2 -0
  141. mindspore/mint/nn/__init__.py +280 -18
  142. mindspore/mint/nn/functional.py +282 -64
  143. mindspore/mint/nn/layer/__init__.py +4 -0
  144. mindspore/mint/nn/layer/_functions.py +7 -3
  145. mindspore/mint/nn/layer/activation.py +120 -13
  146. mindspore/mint/nn/layer/conv.py +218 -24
  147. mindspore/mint/nn/layer/normalization.py +15 -16
  148. mindspore/mint/nn/layer/padding.py +1 -1
  149. mindspore/mint/nn/layer/pooling.py +66 -1
  150. mindspore/mint/optim/__init__.py +2 -1
  151. mindspore/mint/optim/sgd.py +171 -0
  152. mindspore/msobj140.dll +0 -0
  153. mindspore/mspdb140.dll +0 -0
  154. mindspore/mspdbcore.dll +0 -0
  155. mindspore/mspdbst.dll +0 -0
  156. mindspore/mspft140.dll +0 -0
  157. mindspore/msvcdis140.dll +0 -0
  158. mindspore/msvcp140_1.dll +0 -0
  159. mindspore/msvcp140_2.dll +0 -0
  160. mindspore/msvcp140_atomic_wait.dll +0 -0
  161. mindspore/msvcp140_codecvt_ids.dll +0 -0
  162. mindspore/nn/__init__.py +4 -1
  163. mindspore/nn/cell.py +1250 -176
  164. mindspore/nn/layer/activation.py +23 -21
  165. mindspore/nn/layer/basic.py +22 -16
  166. mindspore/nn/layer/container.py +1 -1
  167. mindspore/nn/layer/conv.py +22 -17
  168. mindspore/nn/layer/embedding.py +9 -8
  169. mindspore/nn/layer/normalization.py +48 -42
  170. mindspore/nn/layer/pooling.py +75 -31
  171. mindspore/nn/layer/transformer.py +11 -10
  172. mindspore/nn/learning_rate_schedule.py +4 -2
  173. mindspore/nn/loss/loss.py +27 -19
  174. mindspore/nn/optim/ada_grad.py +6 -5
  175. mindspore/nn/optim/adadelta.py +9 -7
  176. mindspore/nn/optim/adafactor.py +1 -1
  177. mindspore/nn/optim/adam.py +16 -12
  178. mindspore/nn/optim/adamax.py +8 -7
  179. mindspore/nn/optim/adasum.py +5 -5
  180. mindspore/nn/optim/asgd.py +1 -1
  181. mindspore/nn/optim/ftrl.py +11 -9
  182. mindspore/nn/optim/lamb.py +1 -1
  183. mindspore/nn/optim/lazyadam.py +12 -10
  184. mindspore/nn/optim/momentum.py +7 -6
  185. mindspore/nn/optim/optimizer.py +2 -2
  186. mindspore/nn/optim/proximal_ada_grad.py +12 -10
  187. mindspore/nn/optim/rmsprop.py +13 -12
  188. mindspore/nn/optim/rprop.py +9 -7
  189. mindspore/nn/optim/sgd.py +9 -6
  190. mindspore/nn/optim/tft_wrapper.py +5 -2
  191. mindspore/nn/probability/bijector/bijector.py +17 -11
  192. mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
  193. mindspore/nn/probability/bijector/invert.py +2 -2
  194. mindspore/nn/probability/bijector/scalar_affine.py +3 -3
  195. mindspore/nn/probability/bijector/softplus.py +3 -2
  196. mindspore/nn/probability/distribution/beta.py +3 -3
  197. mindspore/nn/probability/distribution/categorical.py +1 -1
  198. mindspore/nn/probability/distribution/cauchy.py +4 -2
  199. mindspore/nn/probability/distribution/exponential.py +6 -7
  200. mindspore/nn/probability/distribution/gamma.py +2 -2
  201. mindspore/nn/probability/distribution/gumbel.py +2 -2
  202. mindspore/nn/probability/distribution/half_normal.py +5 -3
  203. mindspore/nn/probability/distribution/logistic.py +5 -3
  204. mindspore/nn/probability/distribution/poisson.py +1 -1
  205. mindspore/nn/probability/distribution/uniform.py +5 -3
  206. mindspore/nn/reinforcement/_tensors_queue.py +1 -1
  207. mindspore/nn/reinforcement/tensor_array.py +1 -1
  208. mindspore/nn/wrap/__init__.py +6 -6
  209. mindspore/nn/wrap/cell_wrapper.py +178 -117
  210. mindspore/nn/wrap/grad_reducer.py +45 -36
  211. mindspore/nn/wrap/loss_scale.py +3 -3
  212. mindspore/numpy/array_creations.py +3 -3
  213. mindspore/numpy/array_ops.py +1 -1
  214. mindspore/numpy/math_ops.py +4 -4
  215. mindspore/numpy/utils.py +1 -2
  216. mindspore/numpy/utils_const.py +1 -2
  217. mindspore/opencv_core452.dll +0 -0
  218. mindspore/opencv_imgcodecs452.dll +0 -0
  219. mindspore/opencv_imgproc452.dll +0 -0
  220. mindspore/ops/__init__.py +3 -2
  221. mindspore/ops/_grad_experimental/grad_comm_ops.py +18 -3
  222. mindspore/ops/_grad_experimental/grad_debug_ops.py +8 -1
  223. mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
  224. mindspore/ops/_register_for_op.py +0 -11
  225. mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
  226. mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -4
  227. mindspore/ops/_vmap/vmap_array_ops.py +7 -6
  228. mindspore/ops/_vmap/vmap_grad_nn_ops.py +2 -1
  229. mindspore/ops/_vmap/vmap_math_ops.py +4 -7
  230. mindspore/ops/_vmap/vmap_nn_ops.py +9 -8
  231. mindspore/ops/auto_generate/__init__.py +4 -3
  232. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +102 -49
  233. mindspore/ops/auto_generate/gen_extend_func.py +281 -135
  234. mindspore/ops/auto_generate/gen_ops_def.py +2574 -2326
  235. mindspore/ops/auto_generate/gen_ops_prim.py +8566 -2755
  236. mindspore/ops/auto_generate/pyboost_inner_prim.py +106 -76
  237. mindspore/ops/composite/__init__.py +2 -1
  238. mindspore/ops/composite/base.py +19 -24
  239. mindspore/ops/composite/math_ops.py +6 -16
  240. mindspore/ops/composite/multitype_ops/__init__.py +5 -2
  241. mindspore/ops/composite/multitype_ops/_compile_utils.py +2 -3
  242. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
  243. mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
  244. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
  245. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
  246. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
  247. mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
  248. mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
  249. mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
  250. mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
  251. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
  252. mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
  253. mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
  254. mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
  255. mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
  256. mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
  257. mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
  258. mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
  259. mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
  260. mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
  261. mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
  262. mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
  263. mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
  264. mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
  265. mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
  266. mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
  267. mindspore/ops/composite/multitype_ops/pow_impl.py +2 -1
  268. mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
  269. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  270. mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
  271. mindspore/ops/function/__init__.py +28 -2
  272. mindspore/ops/function/_add_attr_func.py +58 -0
  273. mindspore/ops/function/array_func.py +1629 -2345
  274. mindspore/ops/function/clip_func.py +38 -45
  275. mindspore/ops/function/debug_func.py +36 -44
  276. mindspore/ops/function/grad/__init__.py +1 -0
  277. mindspore/ops/function/grad/grad_func.py +104 -71
  278. mindspore/ops/function/image_func.py +1 -1
  279. mindspore/ops/function/linalg_func.py +46 -78
  280. mindspore/ops/function/math_func.py +3035 -3705
  281. mindspore/ops/function/nn_func.py +676 -241
  282. mindspore/ops/function/other_func.py +159 -1
  283. mindspore/ops/function/parameter_func.py +17 -30
  284. mindspore/ops/function/random_func.py +204 -361
  285. mindspore/ops/function/reshard_func.py +4 -70
  286. mindspore/ops/function/sparse_func.py +3 -3
  287. mindspore/ops/function/sparse_unary_func.py +5 -5
  288. mindspore/ops/function/spectral_func.py +25 -58
  289. mindspore/ops/function/vmap_func.py +24 -17
  290. mindspore/ops/functional.py +6 -4
  291. mindspore/ops/functional_overload.py +547 -4
  292. mindspore/ops/op_info_register.py +32 -244
  293. mindspore/ops/operations/__init__.py +10 -5
  294. mindspore/ops/operations/_custom_ops_utils.py +247 -0
  295. mindspore/ops/operations/_grad_ops.py +1 -10
  296. mindspore/ops/operations/_inner_ops.py +5 -76
  297. mindspore/ops/operations/_ms_kernel.py +4 -10
  298. mindspore/ops/operations/_rl_inner_ops.py +1 -1
  299. mindspore/ops/operations/_scalar_ops.py +3 -2
  300. mindspore/ops/operations/_sequence_ops.py +1 -1
  301. mindspore/ops/operations/_tensor_array.py +1 -1
  302. mindspore/ops/operations/array_ops.py +37 -22
  303. mindspore/ops/operations/comm_ops.py +150 -107
  304. mindspore/ops/operations/custom_ops.py +221 -23
  305. mindspore/ops/operations/debug_ops.py +115 -16
  306. mindspore/ops/operations/inner_ops.py +1 -1
  307. mindspore/ops/operations/linalg_ops.py +1 -58
  308. mindspore/ops/operations/manually_defined/_inner.py +1 -1
  309. mindspore/ops/operations/manually_defined/ops_def.py +746 -79
  310. mindspore/ops/operations/math_ops.py +21 -18
  311. mindspore/ops/operations/nn_ops.py +65 -191
  312. mindspore/ops/operations/other_ops.py +62 -9
  313. mindspore/ops/operations/random_ops.py +13 -7
  314. mindspore/ops/operations/reshard_ops.py +1 -1
  315. mindspore/ops/operations/sparse_ops.py +2 -2
  316. mindspore/ops/primitive.py +43 -32
  317. mindspore/ops/tensor_method.py +232 -13
  318. mindspore/ops_generate/__init__.py +0 -5
  319. mindspore/ops_generate/aclnn/__init__.py +0 -0
  320. mindspore/ops_generate/{aclnn_kernel_register_auto_cc_generator.py → aclnn/aclnn_kernel_register_auto_cc_generator.py} +43 -18
  321. mindspore/ops_generate/{gen_aclnn_implement.py → aclnn/gen_aclnn_implement.py} +49 -51
  322. mindspore/ops_generate/api/__init__.py +0 -0
  323. mindspore/ops_generate/{add_tensor_docs_generator.py → api/add_tensor_docs_generator.py} +9 -7
  324. mindspore/ops_generate/{cpp_create_prim_instance_helper_generator.py → api/cpp_create_prim_instance_helper_generator.py} +6 -9
  325. mindspore/ops_generate/{functional_map_cpp_generator.py → api/functional_map_cpp_generator.py} +25 -12
  326. mindspore/ops_generate/{functional_overload_py_generator.py → api/functional_overload_py_generator.py} +8 -6
  327. mindspore/ops_generate/{functions_cc_generator.py → api/functions_cc_generator.py} +14 -10
  328. mindspore/ops_generate/api/gen_api.py +103 -0
  329. mindspore/ops_generate/{op_api_proto.py → api/op_api_proto.py} +98 -69
  330. mindspore/ops_generate/{tensor_func_reg_cpp_generator.py → api/tensor_func_reg_cpp_generator.py} +82 -43
  331. mindspore/ops_generate/common/__init__.py +0 -0
  332. mindspore/ops_generate/common/gen_constants.py +91 -0
  333. mindspore/ops_generate/{gen_utils.py → common/gen_utils.py} +72 -19
  334. mindspore/ops_generate/{op_proto.py → common/op_proto.py} +64 -1
  335. mindspore/ops_generate/{template.py → common/template.py} +96 -84
  336. mindspore/ops_generate/gen_ops.py +23 -325
  337. mindspore/ops_generate/op_def/__init__.py +0 -0
  338. mindspore/ops_generate/op_def/gen_op_def.py +90 -0
  339. mindspore/ops_generate/{lite_ops_cpp_generator.py → op_def/lite_ops_cpp_generator.py} +47 -11
  340. mindspore/ops_generate/{ops_def_cc_generator.py → op_def/ops_def_cc_generator.py} +18 -7
  341. mindspore/ops_generate/{ops_def_h_generator.py → op_def/ops_def_h_generator.py} +5 -5
  342. mindspore/ops_generate/{ops_name_h_generator.py → op_def/ops_name_h_generator.py} +30 -15
  343. mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
  344. mindspore/ops_generate/op_def_py/__init__.py +0 -0
  345. mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
  346. mindspore/ops_generate/{op_def_py_generator.py → op_def_py/op_def_py_generator.py} +6 -5
  347. mindspore/ops_generate/{op_prim_py_generator.py → op_def_py/op_prim_py_generator.py} +24 -15
  348. mindspore/ops_generate/pyboost/__init__.py +0 -0
  349. mindspore/ops_generate/{auto_grad_impl_cc_generator.py → pyboost/auto_grad_impl_cc_generator.py} +11 -7
  350. mindspore/ops_generate/{auto_grad_reg_cc_generator.py → pyboost/auto_grad_reg_cc_generator.py} +7 -7
  351. mindspore/ops_generate/{gen_pyboost_func.py → pyboost/gen_pyboost_func.py} +40 -16
  352. mindspore/ops_generate/{op_template_parser.py → pyboost/op_template_parser.py} +105 -24
  353. mindspore/ops_generate/{pyboost_functions_cpp_generator.py → pyboost/pyboost_functions_cpp_generator.py} +55 -18
  354. mindspore/ops_generate/{pyboost_functions_h_generator.py → pyboost/pyboost_functions_h_generator.py} +42 -10
  355. mindspore/ops_generate/{pyboost_functions_py_generator.py → pyboost/pyboost_functions_py_generator.py} +6 -6
  356. mindspore/ops_generate/{pyboost_grad_function_cpp_generator.py → pyboost/pyboost_grad_function_cpp_generator.py} +11 -10
  357. mindspore/ops_generate/{pyboost_inner_prim_generator.py → pyboost/pyboost_inner_prim_generator.py} +8 -7
  358. mindspore/ops_generate/{pyboost_native_grad_functions_generator.py → pyboost/pyboost_native_grad_functions_generator.py} +14 -10
  359. mindspore/ops_generate/{pyboost_op_cpp_code_generator.py → pyboost/pyboost_op_cpp_code_generator.py} +140 -53
  360. mindspore/ops_generate/{pyboost_overload_functions_cpp_generator.py → pyboost/pyboost_overload_functions_cpp_generator.py} +28 -15
  361. mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +88 -4
  362. mindspore/ops_generate/resources/__init__.py +0 -0
  363. mindspore/ops_generate/resources/resource_list.py +30 -0
  364. mindspore/ops_generate/resources/resource_loader.py +36 -0
  365. mindspore/ops_generate/resources/resource_manager.py +64 -0
  366. mindspore/ops_generate/resources/yaml_loader.py +88 -0
  367. mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
  368. mindspore/parallel/__init__.py +6 -2
  369. mindspore/parallel/_auto_parallel_context.py +133 -6
  370. mindspore/parallel/_cell_wrapper.py +130 -15
  371. mindspore/parallel/_parallel_serialization.py +95 -4
  372. mindspore/parallel/_ps_context.py +1 -1
  373. mindspore/parallel/_recovery_context.py +7 -2
  374. mindspore/parallel/_tensor.py +142 -18
  375. mindspore/parallel/_utils.py +198 -25
  376. mindspore/parallel/algo_parameter_config.py +3 -3
  377. mindspore/parallel/auto_parallel.py +732 -0
  378. mindspore/parallel/checkpoint_convert.py +159 -0
  379. mindspore/parallel/checkpoint_transform.py +656 -37
  380. mindspore/parallel/cluster/process_entity/_api.py +151 -19
  381. mindspore/parallel/cluster/run.py +1 -1
  382. mindspore/parallel/function/__init__.py +24 -0
  383. mindspore/parallel/function/reshard_func.py +259 -0
  384. mindspore/parallel/nn/__init__.py +25 -0
  385. mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
  386. mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
  387. mindspore/parallel/parameter_broadcast.py +24 -13
  388. mindspore/parallel/shard.py +137 -61
  389. mindspore/parallel/transform_safetensors.py +287 -95
  390. mindspore/pgodb140.dll +0 -0
  391. mindspore/pgort140.dll +0 -0
  392. mindspore/profiler/__init__.py +9 -5
  393. mindspore/profiler/analysis/parser/ascend_cann_parser.py +6 -2
  394. mindspore/profiler/analysis/parser/ms_framework_parser.py +4 -4
  395. mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +7 -4
  396. mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +22 -0
  397. mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +3 -3
  398. mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +241 -86
  399. mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +41 -2
  400. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +33 -35
  401. mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +7 -0
  402. mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +8 -3
  403. mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +141 -30
  404. mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +5 -6
  405. mindspore/profiler/common/ascend_msprof_exporter.py +5 -4
  406. mindspore/profiler/common/constant.py +12 -0
  407. mindspore/profiler/common/msprof_cmd_tool.py +42 -23
  408. mindspore/profiler/common/path_manager.py +24 -0
  409. mindspore/profiler/common/profiler_context.py +26 -2
  410. mindspore/profiler/common/profiler_meta_data.py +74 -0
  411. mindspore/profiler/common/profiler_parameters.py +59 -18
  412. mindspore/profiler/common/profiler_path_manager.py +66 -7
  413. mindspore/profiler/dynamic_profiler.py +112 -79
  414. mindspore/profiler/envprofiler.py +26 -1
  415. mindspore/profiler/experimental_config.py +197 -0
  416. mindspore/profiler/mstx.py +57 -14
  417. mindspore/profiler/platform/npu_profiler.py +33 -7
  418. mindspore/profiler/profiler.py +541 -45
  419. mindspore/profiler/profiler_action_controller.py +1 -1
  420. mindspore/profiler/profiler_interface.py +4 -0
  421. mindspore/profiler/schedule.py +57 -22
  422. mindspore/rewrite/api/node.py +15 -13
  423. mindspore/rewrite/api/symbol_tree.py +1 -1
  424. mindspore/run_check/_check_version.py +25 -14
  425. mindspore/run_check/run_check.py +1 -1
  426. mindspore/runtime/__init__.py +2 -2
  427. mindspore/runtime/executor.py +40 -11
  428. mindspore/runtime/memory.py +25 -8
  429. mindspore/safeguard/rewrite_obfuscation.py +12 -9
  430. mindspore/swresample-4.dll +0 -0
  431. mindspore/swscale-6.dll +0 -0
  432. mindspore/tbbmalloc.dll +0 -0
  433. mindspore/tinyxml2.dll +0 -0
  434. mindspore/train/__init__.py +8 -8
  435. mindspore/train/_utils.py +35 -7
  436. mindspore/train/amp.py +1 -1
  437. mindspore/train/callback/__init__.py +2 -2
  438. mindspore/train/callback/_callback.py +2 -16
  439. mindspore/train/callback/_checkpoint.py +24 -40
  440. mindspore/train/callback/_cluster_monitor.py +14 -18
  441. mindspore/train/callback/_flops_collector.py +2 -3
  442. mindspore/train/callback/_history.py +7 -4
  443. mindspore/train/callback/_lambda_callback.py +2 -2
  444. mindspore/train/callback/_landscape.py +0 -3
  445. mindspore/train/callback/_loss_monitor.py +2 -1
  446. mindspore/train/callback/_on_request_exit.py +6 -5
  447. mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
  448. mindspore/train/callback/_summary_collector.py +8 -13
  449. mindspore/train/callback/_time_monitor.py +2 -1
  450. mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +179 -103
  451. mindspore/train/data_sink.py +25 -2
  452. mindspore/train/dataset_helper.py +4 -5
  453. mindspore/train/loss_scale_manager.py +8 -7
  454. mindspore/train/metrics/accuracy.py +3 -3
  455. mindspore/train/metrics/confusion_matrix.py +9 -9
  456. mindspore/train/metrics/error.py +3 -3
  457. mindspore/train/metrics/hausdorff_distance.py +4 -4
  458. mindspore/train/metrics/mean_surface_distance.py +3 -3
  459. mindspore/train/metrics/metric.py +0 -12
  460. mindspore/train/metrics/occlusion_sensitivity.py +4 -2
  461. mindspore/train/metrics/precision.py +8 -6
  462. mindspore/train/metrics/recall.py +9 -9
  463. mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
  464. mindspore/train/mind_ir_pb2.py +19 -12
  465. mindspore/train/model.py +176 -103
  466. mindspore/train/serialization.py +246 -988
  467. mindspore/train/summary/_summary_adapter.py +2 -2
  468. mindspore/train/summary/summary_record.py +1 -1
  469. mindspore/turbojpeg.dll +0 -0
  470. mindspore/utils/__init__.py +3 -2
  471. mindspore/utils/dryrun.py +4 -2
  472. mindspore/utils/hooks.py +81 -0
  473. mindspore/utils/utils.py +138 -4
  474. mindspore/vcmeta.dll +0 -0
  475. mindspore/vcruntime140.dll +0 -0
  476. mindspore/vcruntime140_1.dll +0 -0
  477. mindspore/version.py +1 -1
  478. {mindspore-2.5.0.dist-info → mindspore-2.6.0rc1.dist-info}/METADATA +2 -1
  479. {mindspore-2.5.0.dist-info → mindspore-2.6.0rc1.dist-info}/RECORD +483 -438
  480. mindspore/_install_custom.py +0 -43
  481. mindspore/common/_register_for_adapter.py +0 -74
  482. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
  483. mindspore/ops/auto_generate/gen_arg_handler.py +0 -136
  484. mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
  485. mindspore/ops_generate/gen_constants.py +0 -190
  486. mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
  487. mindspore/ops_generate/ops_primitive_h_generator.py +0 -81
  488. /mindspore/ops_generate/{base_generator.py → common/base_generator.py} +0 -0
  489. {mindspore-2.5.0.dist-info → mindspore-2.6.0rc1.dist-info}/WHEEL +0 -0
  490. {mindspore-2.5.0.dist-info → mindspore-2.6.0rc1.dist-info}/entry_points.txt +0 -0
  491. {mindspore-2.5.0.dist-info → mindspore-2.6.0rc1.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Huawei Technologies Co., Ltd
1
+ # Copyright 2023 Huawei Technologies Co., Ltd
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -14,7 +14,8 @@
14
14
  # ============================================================================
15
15
 
16
16
  from mindspore.common._stub_tensor import _convert_stub
17
- from mindspore.ops.auto_generate.gen_arg_handler import *
17
+ from mindspore.ops._utils.arg_handler import *
18
+ from mindspore._c_expression import AdaptiveMaxPool2DPrim_
18
19
  from mindspore._c_expression import ArgMaxWithValuePrim_
19
20
  from mindspore._c_expression import ArgMinWithValuePrim_
20
21
  from mindspore._c_expression import BatchMatMulPrim_
@@ -66,23 +67,34 @@ from mindspore._c_expression import SmoothL1LossPrim_
66
67
  from mindspore._c_expression import SoftmaxPrim_
67
68
  from mindspore._c_expression import SoftShrinkGradPrim_
68
69
  from mindspore._c_expression import SoftShrinkPrim_
70
+ from mindspore._c_expression import SoftMarginLossGradPrim_
71
+ from mindspore._c_expression import SoftMarginLossPrim_
69
72
  from mindspore._c_expression import SplitPrim_
70
73
  from mindspore._c_expression import SqueezePrim_
71
74
  from mindspore._c_expression import StackExtPrim_
72
- from mindspore._c_expression import TrilExtPrim_
73
75
  from mindspore._c_expression import TriuPrim_
74
76
  from mindspore._c_expression import UniqueConsecutivePrim_
75
77
  from mindspore._c_expression import UpsampleTrilinear3DGradPrim_
76
78
  from mindspore._c_expression import UpsampleTrilinear3DPrim_
79
+ from mindspore._c_expression import FusedInferAttentionScorePrim_
77
80
  from mindspore._c_expression import GroupedMatmulPrim_
78
81
  from mindspore._c_expression import QuantBatchMatmulPrim_
79
82
  from mindspore._c_expression import WeightQuantBatchMatmulPrim_
80
83
 
81
84
 
85
+ class _PyboostAdaptiveMaxPool2DPrim(AdaptiveMaxPool2DPrim_):
86
+ def __call__(self, input, output_size):
87
+
88
+ return super().__call__([input, output_size])
89
+
90
+
91
+ adaptive_max_pool2d_impl = _PyboostAdaptiveMaxPool2DPrim()
92
+
93
+
82
94
  class _PyboostArgMaxWithValuePrim(ArgMaxWithValuePrim_):
83
95
  def __call__(self, input, axis, keep_dims):
84
96
 
85
- return _convert_stub(super().__call__([input, axis, keep_dims]))
97
+ return super().__call__([input, axis, keep_dims])
86
98
 
87
99
 
88
100
  argmax_with_value_impl = _PyboostArgMaxWithValuePrim()
@@ -91,7 +103,7 @@ argmax_with_value_impl = _PyboostArgMaxWithValuePrim()
91
103
  class _PyboostArgMinWithValuePrim(ArgMinWithValuePrim_):
92
104
  def __call__(self, input, axis, keep_dims):
93
105
 
94
- return _convert_stub(super().__call__([input, axis, keep_dims]))
106
+ return super().__call__([input, axis, keep_dims])
95
107
 
96
108
 
97
109
  argmin_with_value_impl = _PyboostArgMinWithValuePrim()
@@ -100,7 +112,7 @@ argmin_with_value_impl = _PyboostArgMinWithValuePrim()
100
112
  class _PyboostBatchMatMulPrim(BatchMatMulPrim_):
101
113
  def __call__(self, x, y, transpose_a, transpose_b):
102
114
 
103
- return _convert_stub(super().__call__([x, y, transpose_a, transpose_b]))
115
+ return super().__call__([x, y, transpose_a, transpose_b])
104
116
 
105
117
 
106
118
  batch_mat_mul_impl = _PyboostBatchMatMulPrim()
@@ -109,7 +121,7 @@ batch_mat_mul_impl = _PyboostBatchMatMulPrim()
109
121
  class _PyboostBatchNormGradExtPrim(BatchNormGradExtPrim_):
110
122
  def __call__(self, dout, input, weight, running_mean, running_var, saved_mean, saved_rstd, training, eps, output_mask):
111
123
 
112
- return _convert_stub(super().__call__([dout, input, weight, running_mean, running_var, saved_mean, saved_rstd, training, eps, output_mask]))
124
+ return super().__call__([dout, input, weight, running_mean, running_var, saved_mean, saved_rstd, training, eps, output_mask])
113
125
 
114
126
 
115
127
  batch_norm_grad_ext_impl = _PyboostBatchNormGradExtPrim()
@@ -118,7 +130,7 @@ batch_norm_grad_ext_impl = _PyboostBatchNormGradExtPrim()
118
130
  class _PyboostBinaryCrossEntropyGradPrim(BinaryCrossEntropyGradPrim_):
119
131
  def __call__(self, input, target, grad_output, weight, reduction):
120
132
  converted_reduction = str_to_enum('binary_cross_entropy_grad', 'reduction', reduction)
121
- return _convert_stub(super().__call__([input, target, grad_output, weight, converted_reduction]))
133
+ return super().__call__([input, target, grad_output, weight, converted_reduction])
122
134
 
123
135
 
124
136
  binary_cross_entropy_grad_impl = _PyboostBinaryCrossEntropyGradPrim()
@@ -127,7 +139,7 @@ binary_cross_entropy_grad_impl = _PyboostBinaryCrossEntropyGradPrim()
127
139
  class _PyboostBinaryCrossEntropyPrim(BinaryCrossEntropyPrim_):
128
140
  def __call__(self, input, target, weight, reduction):
129
141
  converted_reduction = str_to_enum('binary_cross_entropy', 'reduction', reduction)
130
- return _convert_stub(super().__call__([input, target, weight, converted_reduction]))
142
+ return super().__call__([input, target, weight, converted_reduction])
131
143
 
132
144
 
133
145
  binary_cross_entropy_impl = _PyboostBinaryCrossEntropyPrim()
@@ -136,7 +148,7 @@ binary_cross_entropy_impl = _PyboostBinaryCrossEntropyPrim()
136
148
  class _PyboostBCEWithLogitsLossPrim(BCEWithLogitsLossPrim_):
137
149
  def __call__(self, input, target, weight, posWeight, reduction):
138
150
  converted_reduction = str_to_enum('binary_cross_entropy_with_logits', 'reduction', reduction)
139
- return _convert_stub(super().__call__([input, target, weight, posWeight, converted_reduction]))
151
+ return super().__call__([input, target, weight, posWeight, converted_reduction])
140
152
 
141
153
 
142
154
  binary_cross_entropy_with_logits_impl = _PyboostBCEWithLogitsLossPrim()
@@ -145,7 +157,7 @@ binary_cross_entropy_with_logits_impl = _PyboostBCEWithLogitsLossPrim()
145
157
  class _PyboostBroadcastToPrim(BroadcastToPrim_):
146
158
  def __call__(self, input, shape):
147
159
 
148
- return _convert_stub(super().__call__([input, shape]))
160
+ return super().__call__([input, shape])
149
161
 
150
162
 
151
163
  broadcast_to_impl = _PyboostBroadcastToPrim()
@@ -154,7 +166,7 @@ broadcast_to_impl = _PyboostBroadcastToPrim()
154
166
  class _PyboostConcatPrim(ConcatPrim_):
155
167
  def __call__(self, tensors, axis):
156
168
 
157
- return _convert_stub(super().__call__([tensors, axis]))
169
+ return super().__call__([tensors, axis])
158
170
 
159
171
 
160
172
  concat_impl = _PyboostConcatPrim()
@@ -163,7 +175,7 @@ concat_impl = _PyboostConcatPrim()
163
175
  class _PyboostCrossPrim(CrossPrim_):
164
176
  def __call__(self, input, other, dim):
165
177
 
166
- return _convert_stub(super().__call__([input, other, dim]))
178
+ return super().__call__([input, other, dim])
167
179
 
168
180
 
169
181
  cross_impl = _PyboostCrossPrim()
@@ -172,7 +184,7 @@ cross_impl = _PyboostCrossPrim()
172
184
  class _PyboostCummaxPrim(CummaxPrim_):
173
185
  def __call__(self, input, axis):
174
186
 
175
- return _convert_stub(super().__call__([input, axis]))
187
+ return super().__call__([input, axis])
176
188
 
177
189
 
178
190
  cummax_impl = _PyboostCummaxPrim()
@@ -181,7 +193,7 @@ cummax_impl = _PyboostCummaxPrim()
181
193
  class _PyboostEluExtPrim(EluExtPrim_):
182
194
  def __call__(self, input, alpha):
183
195
 
184
- return _convert_stub(super().__call__([input, alpha]))
196
+ return super().__call__([input, alpha])
185
197
 
186
198
 
187
199
  elu_ext_impl = _PyboostEluExtPrim()
@@ -190,7 +202,7 @@ elu_ext_impl = _PyboostEluExtPrim()
190
202
  class _PyboostFFNExtPrim(FFNExtPrim_):
191
203
  def __call__(self, x, weight1, weight2, expertTokens, bias1, bias2, scale, offset, deqScale1, deqScale2, antiquant_scale1, antiquant_scale2, antiquant_offset1, antiquant_offset2, activation, inner_precise):
192
204
  converted_activation = str_to_enum('ffn_ext', 'activation', activation)
193
- return _convert_stub(super().__call__([x, weight1, weight2, expertTokens, bias1, bias2, scale, offset, deqScale1, deqScale2, antiquant_scale1, antiquant_scale2, antiquant_offset1, antiquant_offset2, converted_activation, inner_precise]))
205
+ return super().__call__([x, weight1, weight2, expertTokens, bias1, bias2, scale, offset, deqScale1, deqScale2, antiquant_scale1, antiquant_scale2, antiquant_offset1, antiquant_offset2, converted_activation, inner_precise])
194
206
 
195
207
 
196
208
  ffn_ext_impl = _PyboostFFNExtPrim()
@@ -199,7 +211,7 @@ ffn_ext_impl = _PyboostFFNExtPrim()
199
211
  class _PyboostFlashAttentionScoreGradPrim(FlashAttentionScoreGradPrim_):
200
212
  def __call__(self, query, key, value, dy, pse_shift, drop_mask, padding_mask, atten_mask, softmax_max, softmax_sum, softmax_in, attention_in, prefix, actual_seq_qlen, actual_seq_kvlen, head_num, keep_prob, scale_value, pre_tokens, next_tokens, inner_precise, input_layout, sparse_mode):
201
213
  converted_input_layout = str_to_enum('flash_attention_score_grad', 'input_layout', input_layout)
202
- return _convert_stub(super().__call__([query, key, value, dy, pse_shift, drop_mask, padding_mask, atten_mask, softmax_max, softmax_sum, softmax_in, attention_in, prefix, actual_seq_qlen, actual_seq_kvlen, head_num, keep_prob, scale_value, pre_tokens, next_tokens, inner_precise, converted_input_layout, sparse_mode]))
214
+ return super().__call__([query, key, value, dy, pse_shift, drop_mask, padding_mask, atten_mask, softmax_max, softmax_sum, softmax_in, attention_in, prefix, actual_seq_qlen, actual_seq_kvlen, head_num, keep_prob, scale_value, pre_tokens, next_tokens, inner_precise, converted_input_layout, sparse_mode])
203
215
 
204
216
 
205
217
  flash_attention_score_grad_impl = _PyboostFlashAttentionScoreGradPrim()
@@ -208,7 +220,7 @@ flash_attention_score_grad_impl = _PyboostFlashAttentionScoreGradPrim()
208
220
  class _PyboostFlashAttentionScorePrim(FlashAttentionScorePrim_):
209
221
  def __call__(self, query, key, value, real_shift, drop_mask, padding_mask, attn_mask, prefix, actual_seq_qlen, actual_seq_kvlen, head_num, keep_prob, scale_value, pre_tokens, next_tokens, inner_precise, input_layout, sparse_mode):
210
222
  converted_input_layout = str_to_enum('flash_attention_score', 'input_layout', input_layout)
211
- return _convert_stub(super().__call__([query, key, value, real_shift, drop_mask, padding_mask, attn_mask, prefix, actual_seq_qlen, actual_seq_kvlen, head_num, keep_prob, scale_value, pre_tokens, next_tokens, inner_precise, converted_input_layout, sparse_mode]))
223
+ return super().__call__([query, key, value, real_shift, drop_mask, padding_mask, attn_mask, prefix, actual_seq_qlen, actual_seq_kvlen, head_num, keep_prob, scale_value, pre_tokens, next_tokens, inner_precise, converted_input_layout, sparse_mode])
212
224
 
213
225
 
214
226
  flash_attention_score_impl = _PyboostFlashAttentionScorePrim()
@@ -217,7 +229,7 @@ flash_attention_score_impl = _PyboostFlashAttentionScorePrim()
217
229
  class _PyboostGluGradPrim(GluGradPrim_):
218
230
  def __call__(self, grads, x, axis):
219
231
 
220
- return _convert_stub(super().__call__([grads, x, axis]))
232
+ return super().__call__([grads, x, axis])
221
233
 
222
234
 
223
235
  glu_grad_impl = _PyboostGluGradPrim()
@@ -226,7 +238,7 @@ glu_grad_impl = _PyboostGluGradPrim()
226
238
  class _PyboostGLUPrim(GLUPrim_):
227
239
  def __call__(self, x, axis):
228
240
 
229
- return _convert_stub(super().__call__([x, axis]))
241
+ return super().__call__([x, axis])
230
242
 
231
243
 
232
244
  glu_impl = _PyboostGLUPrim()
@@ -236,7 +248,7 @@ class _PyboostGridSampler2DGradPrim(GridSampler2DGradPrim_):
236
248
  def __call__(self, grad, input_x, grid, interpolation_mode, padding_mode, align_corners, output_mask):
237
249
  converted_interpolation_mode = str_to_enum('grid_sampler_2d_grad', 'interpolation_mode', interpolation_mode)
238
250
  converted_padding_mode = str_to_enum('grid_sampler_2d_grad', 'padding_mode', padding_mode)
239
- return _convert_stub(super().__call__([grad, input_x, grid, converted_interpolation_mode, converted_padding_mode, align_corners, output_mask]))
251
+ return super().__call__([grad, input_x, grid, converted_interpolation_mode, converted_padding_mode, align_corners, output_mask])
240
252
 
241
253
 
242
254
  grid_sampler_2d_grad_impl = _PyboostGridSampler2DGradPrim()
@@ -246,7 +258,7 @@ class _PyboostGridSampler2DPrim(GridSampler2DPrim_):
246
258
  def __call__(self, input_x, grid, interpolation_mode, padding_mode, align_corners):
247
259
  converted_interpolation_mode = str_to_enum('grid_sampler_2d', 'interpolation_mode', interpolation_mode)
248
260
  converted_padding_mode = str_to_enum('grid_sampler_2d', 'padding_mode', padding_mode)
249
- return _convert_stub(super().__call__([input_x, grid, converted_interpolation_mode, converted_padding_mode, align_corners]))
261
+ return super().__call__([input_x, grid, converted_interpolation_mode, converted_padding_mode, align_corners])
250
262
 
251
263
 
252
264
  grid_sampler_2d_impl = _PyboostGridSampler2DPrim()
@@ -256,7 +268,7 @@ class _PyboostGridSampler3DGradPrim(GridSampler3DGradPrim_):
256
268
  def __call__(self, grad, input_x, grid, interpolation_mode, padding_mode, align_corners, output_mask):
257
269
  converted_interpolation_mode = str_to_enum('grid_sampler_3d_grad', 'interpolation_mode', interpolation_mode)
258
270
  converted_padding_mode = str_to_enum('grid_sampler_3d_grad', 'padding_mode', padding_mode)
259
- return _convert_stub(super().__call__([grad, input_x, grid, converted_interpolation_mode, converted_padding_mode, align_corners, output_mask]))
271
+ return super().__call__([grad, input_x, grid, converted_interpolation_mode, converted_padding_mode, align_corners, output_mask])
260
272
 
261
273
 
262
274
  grid_sampler_3d_grad_impl = _PyboostGridSampler3DGradPrim()
@@ -266,7 +278,7 @@ class _PyboostGridSampler3DPrim(GridSampler3DPrim_):
266
278
  def __call__(self, input_x, grid, interpolation_mode, padding_mode, align_corners):
267
279
  converted_interpolation_mode = str_to_enum('grid_sampler_3d', 'interpolation_mode', interpolation_mode)
268
280
  converted_padding_mode = str_to_enum('grid_sampler_3d', 'padding_mode', padding_mode)
269
- return _convert_stub(super().__call__([input_x, grid, converted_interpolation_mode, converted_padding_mode, align_corners]))
281
+ return super().__call__([input_x, grid, converted_interpolation_mode, converted_padding_mode, align_corners])
270
282
 
271
283
 
272
284
  grid_sampler_3d_impl = _PyboostGridSampler3DPrim()
@@ -275,7 +287,7 @@ grid_sampler_3d_impl = _PyboostGridSampler3DPrim()
275
287
  class _PyboostHShrinkGradPrim(HShrinkGradPrim_):
276
288
  def __call__(self, gradients, features, lambd):
277
289
 
278
- return _convert_stub(super().__call__([gradients, features, lambd]))
290
+ return super().__call__([gradients, features, lambd])
279
291
 
280
292
 
281
293
  hshrink_grad_impl = _PyboostHShrinkGradPrim()
@@ -284,7 +296,7 @@ hshrink_grad_impl = _PyboostHShrinkGradPrim()
284
296
  class _PyboostHShrinkPrim(HShrinkPrim_):
285
297
  def __call__(self, input, lambd):
286
298
 
287
- return _convert_stub(super().__call__([input, lambd]))
299
+ return super().__call__([input, lambd])
288
300
 
289
301
 
290
302
  hshrink_impl = _PyboostHShrinkPrim()
@@ -293,7 +305,7 @@ hshrink_impl = _PyboostHShrinkPrim()
293
305
  class _PyboostIncreFlashAttentionPrim(IncreFlashAttentionPrim_):
294
306
  def __call__(self, query, key, value, attn_mask, actual_seq_lengths, pse_shift, dequant_scale1, quant_scale1, dequant_scale2, quant_scale2, quant_offset2, antiquant_scale, antiquant_offset, block_table, kv_padding_size, num_heads, input_layout, scale_value, num_key_value_heads, block_size, inner_precise):
295
307
  converted_input_layout = str_to_enum('incre_flash_attention', 'input_layout', input_layout)
296
- return _convert_stub(super().__call__([query, key, value, attn_mask, actual_seq_lengths, pse_shift, dequant_scale1, quant_scale1, dequant_scale2, quant_scale2, quant_offset2, antiquant_scale, antiquant_offset, block_table, kv_padding_size, num_heads, converted_input_layout, scale_value, num_key_value_heads, block_size, inner_precise]))
308
+ return super().__call__([query, key, value, attn_mask, actual_seq_lengths, pse_shift, dequant_scale1, quant_scale1, dequant_scale2, quant_scale2, quant_offset2, antiquant_scale, antiquant_offset, block_table, kv_padding_size, num_heads, converted_input_layout, scale_value, num_key_value_heads, block_size, inner_precise])
297
309
 
298
310
 
299
311
  incre_flash_attention_impl = _PyboostIncreFlashAttentionPrim()
@@ -302,7 +314,7 @@ incre_flash_attention_impl = _PyboostIncreFlashAttentionPrim()
302
314
  class _PyboostIsClosePrim(IsClosePrim_):
303
315
  def __call__(self, input, other, rtol, atol, equal_nan):
304
316
 
305
- return _convert_stub(super().__call__([input, other, rtol, atol, equal_nan]))
317
+ return super().__call__([input, other, rtol, atol, equal_nan])
306
318
 
307
319
 
308
320
  isclose_impl = _PyboostIsClosePrim()
@@ -311,7 +323,7 @@ isclose_impl = _PyboostIsClosePrim()
311
323
  class _PyboostLogSoftmaxGradPrim(LogSoftmaxGradPrim_):
312
324
  def __call__(self, logits, grad, axis):
313
325
 
314
- return _convert_stub(super().__call__([logits, grad, axis]))
326
+ return super().__call__([logits, grad, axis])
315
327
 
316
328
 
317
329
  log_softmax_grad_impl = _PyboostLogSoftmaxGradPrim()
@@ -320,7 +332,7 @@ log_softmax_grad_impl = _PyboostLogSoftmaxGradPrim()
320
332
  class _PyboostLogSoftmaxPrim(LogSoftmaxPrim_):
321
333
  def __call__(self, logits, axis):
322
334
 
323
- return _convert_stub(super().__call__([logits, axis]))
335
+ return super().__call__([logits, axis])
324
336
 
325
337
 
326
338
  log_softmax_impl = _PyboostLogSoftmaxPrim()
@@ -329,7 +341,7 @@ log_softmax_impl = _PyboostLogSoftmaxPrim()
329
341
  class _PyboostMatMulPrim(MatMulPrim_):
330
342
  def __call__(self, input, mat2, transpose_a, transpose_b):
331
343
 
332
- return _convert_stub(super().__call__([input, mat2, transpose_a, transpose_b]))
344
+ return super().__call__([input, mat2, transpose_a, transpose_b])
333
345
 
334
346
 
335
347
  matmul_impl = _PyboostMatMulPrim()
@@ -341,7 +353,7 @@ class _PyboostMaxPoolGradWithIndicesPrim(MaxPoolGradWithIndicesPrim_):
341
353
  converted_strides = to_strides('max_pool_grad_with_indices', 'strides', strides)
342
354
  converted_pads = to_output_padding('max_pool_grad_with_indices', 'pads', pads)
343
355
  converted_dilation = to_dilations('max_pool_grad_with_indices', 'dilation', dilation)
344
- return _convert_stub(super().__call__([x, grad, argmax, converted_kernel_size, converted_strides, converted_pads, converted_dilation, ceil_mode, argmax_type]))
356
+ return super().__call__([x, grad, argmax, converted_kernel_size, converted_strides, converted_pads, converted_dilation, ceil_mode, argmax_type])
345
357
 
346
358
 
347
359
  max_pool_grad_with_indices_impl = _PyboostMaxPoolGradWithIndicesPrim()
@@ -353,7 +365,7 @@ class _PyboostMaxPoolGradWithMaskPrim(MaxPoolGradWithMaskPrim_):
353
365
  converted_strides = to_strides('max_pool_grad_with_mask', 'strides', strides)
354
366
  converted_pads = to_output_padding('max_pool_grad_with_mask', 'pads', pads)
355
367
  converted_dilation = to_dilations('max_pool_grad_with_mask', 'dilation', dilation)
356
- return _convert_stub(super().__call__([x, grad, mask, converted_kernel_size, converted_strides, converted_pads, converted_dilation, ceil_mode, argmax_type]))
368
+ return super().__call__([x, grad, mask, converted_kernel_size, converted_strides, converted_pads, converted_dilation, ceil_mode, argmax_type])
357
369
 
358
370
 
359
371
  max_pool_grad_with_mask_impl = _PyboostMaxPoolGradWithMaskPrim()
@@ -365,7 +377,7 @@ class _PyboostMaxPoolWithIndicesPrim(MaxPoolWithIndicesPrim_):
365
377
  converted_strides = to_strides('max_pool_with_indices', 'strides', strides)
366
378
  converted_pads = to_output_padding('max_pool_with_indices', 'pads', pads)
367
379
  converted_dilation = to_dilations('max_pool_with_indices', 'dilation', dilation)
368
- return _convert_stub(super().__call__([x, converted_kernel_size, converted_strides, converted_pads, converted_dilation, ceil_mode, argmax_type]))
380
+ return super().__call__([x, converted_kernel_size, converted_strides, converted_pads, converted_dilation, ceil_mode, argmax_type])
369
381
 
370
382
 
371
383
  max_pool_with_indices_impl = _PyboostMaxPoolWithIndicesPrim()
@@ -377,7 +389,7 @@ class _PyboostMaxPoolWithMaskPrim(MaxPoolWithMaskPrim_):
377
389
  converted_strides = to_strides('max_pool_with_mask', 'strides', strides)
378
390
  converted_pads = to_output_padding('max_pool_with_mask', 'pads', pads)
379
391
  converted_dilation = to_dilations('max_pool_with_mask', 'dilation', dilation)
380
- return _convert_stub(super().__call__([x, converted_kernel_size, converted_strides, converted_pads, converted_dilation, ceil_mode, argmax_type]))
392
+ return super().__call__([x, converted_kernel_size, converted_strides, converted_pads, converted_dilation, ceil_mode, argmax_type])
381
393
 
382
394
 
383
395
  max_pool_with_mask_impl = _PyboostMaxPoolWithMaskPrim()
@@ -386,7 +398,7 @@ max_pool_with_mask_impl = _PyboostMaxPoolWithMaskPrim()
386
398
  class _PyboostMeshgridPrim(MeshgridPrim_):
387
399
  def __call__(self, inputs, indexing):
388
400
  converted_indexing = str_to_enum('meshgrid', 'indexing', indexing)
389
- return _convert_stub(super().__call__([inputs, converted_indexing]))
401
+ return super().__call__([inputs, converted_indexing])
390
402
 
391
403
 
392
404
  meshgrid_impl = _PyboostMeshgridPrim()
@@ -395,7 +407,7 @@ meshgrid_impl = _PyboostMeshgridPrim()
395
407
  class _PyboostNanToNumPrim(NanToNumPrim_):
396
408
  def __call__(self, input, nan, posinf, neginf):
397
409
 
398
- return _convert_stub(super().__call__([input, nan, posinf, neginf]))
410
+ return super().__call__([input, nan, posinf, neginf])
399
411
 
400
412
 
401
413
  nan_to_num_impl = _PyboostNanToNumPrim()
@@ -404,7 +416,7 @@ nan_to_num_impl = _PyboostNanToNumPrim()
404
416
  class _PyboostNLLLossGradPrim(NLLLossGradPrim_):
405
417
  def __call__(self, logits, loss_grad, labels, weight, total_weight, reduction, ignore_index):
406
418
  converted_reduction = str_to_enum('nllloss_grad', 'reduction', reduction)
407
- return _convert_stub(super().__call__([logits, loss_grad, labels, weight, total_weight, converted_reduction, ignore_index]))
419
+ return super().__call__([logits, loss_grad, labels, weight, total_weight, converted_reduction, ignore_index])
408
420
 
409
421
 
410
422
  nllloss_grad_impl = _PyboostNLLLossGradPrim()
@@ -413,7 +425,7 @@ nllloss_grad_impl = _PyboostNLLLossGradPrim()
413
425
  class _PyboostNLLLossPrim(NLLLossPrim_):
414
426
  def __call__(self, logits, labels, weight, reduction, ignore_index):
415
427
  converted_reduction = str_to_enum('nllloss', 'reduction', reduction)
416
- return _convert_stub(super().__call__([logits, labels, weight, converted_reduction, ignore_index]))
428
+ return super().__call__([logits, labels, weight, converted_reduction, ignore_index])
417
429
 
418
430
 
419
431
  nllloss_impl = _PyboostNLLLossPrim()
@@ -422,7 +434,7 @@ nllloss_impl = _PyboostNLLLossPrim()
422
434
  class _PyboostOneHotExtPrim(OneHotExtPrim_):
423
435
  def __call__(self, tensor, num_classes, on_value, off_value, axis):
424
436
 
425
- return _convert_stub(super().__call__([tensor, num_classes, on_value, off_value, axis]))
437
+ return super().__call__([tensor, num_classes, on_value, off_value, axis])
426
438
 
427
439
 
428
440
  one_hot_ext_impl = _PyboostOneHotExtPrim()
@@ -431,7 +443,7 @@ one_hot_ext_impl = _PyboostOneHotExtPrim()
431
443
  class _PyboostPromptFlashAttentionPrim(PromptFlashAttentionPrim_):
432
444
  def __call__(self, query, key, value, attn_mask, actual_seq_lengths, actual_seq_lengths_kv, pse_shift, deq_scale1, quant_scale1, deq_scale2, quant_scale2, quant_offset2, num_heads, scale_value, pre_tokens, next_tokens, input_layout, num_key_value_heads, sparse_mode, inner_precise):
433
445
  converted_input_layout = str_to_enum('prompt_flash_attention', 'input_layout', input_layout)
434
- return _convert_stub(super().__call__([query, key, value, attn_mask, actual_seq_lengths, actual_seq_lengths_kv, pse_shift, deq_scale1, quant_scale1, deq_scale2, quant_scale2, quant_offset2, num_heads, scale_value, pre_tokens, next_tokens, converted_input_layout, num_key_value_heads, sparse_mode, inner_precise]))
446
+ return super().__call__([query, key, value, attn_mask, actual_seq_lengths, actual_seq_lengths_kv, pse_shift, deq_scale1, quant_scale1, deq_scale2, quant_scale2, quant_offset2, num_heads, scale_value, pre_tokens, next_tokens, converted_input_layout, num_key_value_heads, sparse_mode, inner_precise])
435
447
 
436
448
 
437
449
  prompt_flash_attention_impl = _PyboostPromptFlashAttentionPrim()
@@ -440,7 +452,7 @@ prompt_flash_attention_impl = _PyboostPromptFlashAttentionPrim()
440
452
  class _PyboostReduceAllPrim(ReduceAllPrim_):
441
453
  def __call__(self, input, axis, keep_dims):
442
454
 
443
- return _convert_stub(super().__call__([input, axis, keep_dims]))
455
+ return super().__call__([input, axis, keep_dims])
444
456
 
445
457
 
446
458
  reduce_all_impl = _PyboostReduceAllPrim()
@@ -449,7 +461,7 @@ reduce_all_impl = _PyboostReduceAllPrim()
449
461
  class _PyboostReduceAnyPrim(ReduceAnyPrim_):
450
462
  def __call__(self, x, axis, keep_dims):
451
463
 
452
- return _convert_stub(super().__call__([x, axis, keep_dims]))
464
+ return super().__call__([x, axis, keep_dims])
453
465
 
454
466
 
455
467
  reduce_any_impl = _PyboostReduceAnyPrim()
@@ -458,7 +470,7 @@ reduce_any_impl = _PyboostReduceAnyPrim()
458
470
  class _PyboostReduceMaxPrim(ReduceMaxPrim_):
459
471
  def __call__(self, x, axis, keep_dims):
460
472
 
461
- return _convert_stub(super().__call__([x, axis, keep_dims]))
473
+ return super().__call__([x, axis, keep_dims])
462
474
 
463
475
 
464
476
  reduce_max_impl = _PyboostReduceMaxPrim()
@@ -467,7 +479,7 @@ reduce_max_impl = _PyboostReduceMaxPrim()
467
479
  class _PyboostReduceMinPrim(ReduceMinPrim_):
468
480
  def __call__(self, x, axis, keep_dims):
469
481
 
470
- return _convert_stub(super().__call__([x, axis, keep_dims]))
482
+ return super().__call__([x, axis, keep_dims])
471
483
 
472
484
 
473
485
  reduce_min_impl = _PyboostReduceMinPrim()
@@ -476,7 +488,7 @@ reduce_min_impl = _PyboostReduceMinPrim()
476
488
  class _PyboostReverseV2Prim(ReverseV2Prim_):
477
489
  def __call__(self, input, axis):
478
490
 
479
- return _convert_stub(super().__call__([input, axis]))
491
+ return super().__call__([input, axis])
480
492
 
481
493
 
482
494
  reverse_v2_impl = _PyboostReverseV2Prim()
@@ -485,16 +497,16 @@ reverse_v2_impl = _PyboostReverseV2Prim()
485
497
  class _PyboostRmsNormPrim(RmsNormPrim_):
486
498
  def __call__(self, x, gamma, epsilon):
487
499
 
488
- return _convert_stub(super().__call__([x, gamma, epsilon]))
500
+ return super().__call__([x, gamma, epsilon])
489
501
 
490
502
 
491
503
  rms_norm_impl = _PyboostRmsNormPrim()
492
504
 
493
505
 
494
506
  class _PyboostRollPrim(RollPrim_):
495
- def __call__(self, input, shift, axis):
507
+ def __call__(self, input, shifts, dims):
496
508
 
497
- return _convert_stub(super().__call__([input, shift, axis]))
509
+ return super().__call__([input, shifts, dims])
498
510
 
499
511
 
500
512
  roll_impl = _PyboostRollPrim()
@@ -503,7 +515,7 @@ roll_impl = _PyboostRollPrim()
503
515
  class _PyboostSearchSortedPrim(SearchSortedPrim_):
504
516
  def __call__(self, sorted_sequence, values, sorter, dtype, right):
505
517
 
506
- return _convert_stub(super().__call__([sorted_sequence, values, sorter, dtype, right]))
518
+ return super().__call__([sorted_sequence, values, sorter, dtype, right])
507
519
 
508
520
 
509
521
  searchsorted_impl = _PyboostSearchSortedPrim()
@@ -512,7 +524,7 @@ searchsorted_impl = _PyboostSearchSortedPrim()
512
524
  class _PyboostSmoothL1LossGradPrim(SmoothL1LossGradPrim_):
513
525
  def __call__(self, prediction, target, dout, beta, reduction):
514
526
  converted_reduction = str_to_enum('smooth_l1_loss_grad', 'reduction', reduction)
515
- return _convert_stub(super().__call__([prediction, target, dout, beta, converted_reduction]))
527
+ return super().__call__([prediction, target, dout, beta, converted_reduction])
516
528
 
517
529
 
518
530
  smooth_l1_loss_grad_impl = _PyboostSmoothL1LossGradPrim()
@@ -521,7 +533,7 @@ smooth_l1_loss_grad_impl = _PyboostSmoothL1LossGradPrim()
521
533
  class _PyboostSmoothL1LossPrim(SmoothL1LossPrim_):
522
534
  def __call__(self, prediction, target, beta, reduction):
523
535
  converted_reduction = str_to_enum('smooth_l1_loss', 'reduction', reduction)
524
- return _convert_stub(super().__call__([prediction, target, beta, converted_reduction]))
536
+ return super().__call__([prediction, target, beta, converted_reduction])
525
537
 
526
538
 
527
539
  smooth_l1_loss_impl = _PyboostSmoothL1LossPrim()
@@ -530,7 +542,7 @@ smooth_l1_loss_impl = _PyboostSmoothL1LossPrim()
530
542
  class _PyboostSoftmaxPrim(SoftmaxPrim_):
531
543
  def __call__(self, input, axis):
532
544
 
533
- return _convert_stub(super().__call__([input, axis]))
545
+ return super().__call__([input, axis])
534
546
 
535
547
 
536
548
  softmax_impl = _PyboostSoftmaxPrim()
@@ -539,7 +551,7 @@ softmax_impl = _PyboostSoftmaxPrim()
539
551
  class _PyboostSoftShrinkGradPrim(SoftShrinkGradPrim_):
540
552
  def __call__(self, input_grad, input_x, lambd):
541
553
 
542
- return _convert_stub(super().__call__([input_grad, input_x, lambd]))
554
+ return super().__call__([input_grad, input_x, lambd])
543
555
 
544
556
 
545
557
  softshrink_grad_impl = _PyboostSoftShrinkGradPrim()
@@ -548,16 +560,34 @@ softshrink_grad_impl = _PyboostSoftShrinkGradPrim()
548
560
  class _PyboostSoftShrinkPrim(SoftShrinkPrim_):
549
561
  def __call__(self, input, lambd):
550
562
 
551
- return _convert_stub(super().__call__([input, lambd]))
563
+ return super().__call__([input, lambd])
552
564
 
553
565
 
554
566
  softshrink_impl = _PyboostSoftShrinkPrim()
555
567
 
556
568
 
569
+ class _PyboostSoftMarginLossGradPrim(SoftMarginLossGradPrim_):
570
+ def __call__(self, predict, label, dout, reduction):
571
+ converted_reduction = str_to_enum('soft_margin_loss_grad', 'reduction', reduction)
572
+ return super().__call__([predict, label, dout, converted_reduction])
573
+
574
+
575
+ soft_margin_loss_grad_impl = _PyboostSoftMarginLossGradPrim()
576
+
577
+
578
+ class _PyboostSoftMarginLossPrim(SoftMarginLossPrim_):
579
+ def __call__(self, input, target, reduction):
580
+ converted_reduction = str_to_enum('soft_margin_loss', 'reduction', reduction)
581
+ return super().__call__([input, target, converted_reduction])
582
+
583
+
584
+ soft_margin_loss_impl = _PyboostSoftMarginLossPrim()
585
+
586
+
557
587
  class _PyboostSplitPrim(SplitPrim_):
558
588
  def __call__(self, input_x, axis, output_num):
559
589
 
560
- return _convert_stub(super().__call__([input_x, axis, output_num]))
590
+ return super().__call__([input_x, axis, output_num])
561
591
 
562
592
 
563
593
  split_impl = _PyboostSplitPrim()
@@ -566,7 +596,7 @@ split_impl = _PyboostSplitPrim()
566
596
  class _PyboostSqueezePrim(SqueezePrim_):
567
597
  def __call__(self, input, axis):
568
598
 
569
- return _convert_stub(super().__call__([input, axis]))
599
+ return super().__call__([input, axis])
570
600
 
571
601
 
572
602
  squeeze_impl = _PyboostSqueezePrim()
@@ -575,34 +605,25 @@ squeeze_impl = _PyboostSqueezePrim()
575
605
  class _PyboostStackExtPrim(StackExtPrim_):
576
606
  def __call__(self, tensors, dim):
577
607
 
578
- return _convert_stub(super().__call__([tensors, dim]))
608
+ return super().__call__([tensors, dim])
579
609
 
580
610
 
581
611
  stack_ext_impl = _PyboostStackExtPrim()
582
612
 
583
613
 
584
- class _PyboostTrilExtPrim(TrilExtPrim_):
585
- def __call__(self, input, diagonal):
586
-
587
- return _convert_stub(super().__call__([input, diagonal]))
588
-
589
-
590
- tril_ext_impl = _PyboostTrilExtPrim()
591
-
592
-
593
614
  class _PyboostTriuPrim(TriuPrim_):
594
615
  def __call__(self, input, diagonal):
595
616
 
596
- return _convert_stub(super().__call__([input, diagonal]))
617
+ return super().__call__([input, diagonal])
597
618
 
598
619
 
599
620
  triu_impl = _PyboostTriuPrim()
600
621
 
601
622
 
602
623
  class _PyboostUniqueConsecutivePrim(UniqueConsecutivePrim_):
603
- def __call__(self, input, return_idx, return_counts, axis):
624
+ def __call__(self, input, return_inverse, return_counts, dim):
604
625
 
605
- return _convert_stub(super().__call__([input, return_idx, return_counts, axis]))
626
+ return super().__call__([input, return_inverse, return_counts, dim])
606
627
 
607
628
 
608
629
  unique_consecutive_impl = _PyboostUniqueConsecutivePrim()
@@ -611,7 +632,7 @@ unique_consecutive_impl = _PyboostUniqueConsecutivePrim()
611
632
  class _PyboostUpsampleTrilinear3DGradPrim(UpsampleTrilinear3DGradPrim_):
612
633
  def __call__(self, dy, input_size, output_size, scales, align_corners):
613
634
 
614
- return _convert_stub(super().__call__([dy, input_size, output_size, scales, align_corners]))
635
+ return super().__call__([dy, input_size, output_size, scales, align_corners])
615
636
 
616
637
 
617
638
  upsample_trilinear3d_grad_impl = _PyboostUpsampleTrilinear3DGradPrim()
@@ -620,16 +641,25 @@ upsample_trilinear3d_grad_impl = _PyboostUpsampleTrilinear3DGradPrim()
620
641
  class _PyboostUpsampleTrilinear3DPrim(UpsampleTrilinear3DPrim_):
621
642
  def __call__(self, x, output_size, scales, align_corners):
622
643
 
623
- return _convert_stub(super().__call__([x, output_size, scales, align_corners]))
644
+ return super().__call__([x, output_size, scales, align_corners])
624
645
 
625
646
 
626
647
  upsample_trilinear3d_impl = _PyboostUpsampleTrilinear3DPrim()
627
648
 
628
649
 
650
+ class _PyboostFusedInferAttentionScorePrim(FusedInferAttentionScorePrim_):
651
+ def __call__(self, query, key, value, pse_shift, attn_mask, actual_seq_lengths, actual_seq_lengths_kv, dequant_scale1, quant_scale1, dequant_scale2, quant_scale2, quant_offset2, antiquant_scale, antiquant_offset, block_table, query_padding_size, kv_padding_size, key_antiquant_scale, key_antiquant_offset, value_antiquant_scale, value_antiquant_offset, key_shared_prefix, value_shared_prefix, actual_shared_prefix_len, num_heads, scale_value, pre_tokens, next_tokens, input_layout, num_key_value_heads, sparse_mode, inner_precise, block_size, antiquant_mode, softmax_lse_flag, key_antiquant_mode, value_antiquant_mode):
652
+ converted_input_layout = str_to_enum('fused_infer_attention_score', 'input_layout', input_layout)
653
+ return super().__call__([query, key, value, pse_shift, attn_mask, actual_seq_lengths, actual_seq_lengths_kv, dequant_scale1, quant_scale1, dequant_scale2, quant_scale2, quant_offset2, antiquant_scale, antiquant_offset, block_table, query_padding_size, kv_padding_size, key_antiquant_scale, key_antiquant_offset, value_antiquant_scale, value_antiquant_offset, key_shared_prefix, value_shared_prefix, actual_shared_prefix_len, num_heads, scale_value, pre_tokens, next_tokens, converted_input_layout, num_key_value_heads, sparse_mode, inner_precise, block_size, antiquant_mode, softmax_lse_flag, key_antiquant_mode, value_antiquant_mode])
654
+
655
+
656
+ fused_infer_attention_score_impl = _PyboostFusedInferAttentionScorePrim()
657
+
658
+
629
659
  class _PyboostGroupedMatmulPrim(GroupedMatmulPrim_):
630
- def __call__(self, x, weight, bias, scale, offset, antiquant_scale, antiquant_offset, group_list, split_item, group_type):
660
+ def __call__(self, x, weight, bias, scale, offset, antiquant_scale, antiquant_offset, group_list, split_item, group_type, transpose_a, transpose_b):
631
661
 
632
- return _convert_stub(super().__call__([x, weight, bias, scale, offset, antiquant_scale, antiquant_offset, group_list, split_item, group_type]))
662
+ return super().__call__([x, weight, bias, scale, offset, antiquant_scale, antiquant_offset, group_list, split_item, group_type, transpose_a, transpose_b])
633
663
 
634
664
 
635
665
  grouped_matmul_impl = _PyboostGroupedMatmulPrim()
@@ -638,7 +668,7 @@ grouped_matmul_impl = _PyboostGroupedMatmulPrim()
638
668
  class _PyboostQuantBatchMatmulPrim(QuantBatchMatmulPrim_):
639
669
  def __call__(self, x1, x2, scale, offset, bias, pertokenScaleOptional, transpose_x1, transpose_x2, dtype):
640
670
 
641
- return _convert_stub(super().__call__([x1, x2, scale, offset, bias, pertokenScaleOptional, transpose_x1, transpose_x2, dtype]))
671
+ return super().__call__([x1, x2, scale, offset, bias, pertokenScaleOptional, transpose_x1, transpose_x2, dtype])
642
672
 
643
673
 
644
674
  quant_batch_matmul_impl = _PyboostQuantBatchMatmulPrim()
@@ -647,7 +677,7 @@ quant_batch_matmul_impl = _PyboostQuantBatchMatmulPrim()
647
677
  class _PyboostWeightQuantBatchMatmulPrim(WeightQuantBatchMatmulPrim_):
648
678
  def __call__(self, x, weight, antiquant_scale, antiquant_offset, quant_scale, quant_offset, bias, transpose_x, transpose_weight, antiquant_group_size):
649
679
 
650
- return _convert_stub(super().__call__([x, weight, antiquant_scale, antiquant_offset, quant_scale, quant_offset, bias, transpose_x, transpose_weight, antiquant_group_size]))
680
+ return super().__call__([x, weight, antiquant_scale, antiquant_offset, quant_scale, quant_offset, bias, transpose_x, transpose_weight, antiquant_group_size])
651
681
 
652
682
 
653
683
  weight_quant_batch_matmul_impl = _PyboostWeightQuantBatchMatmulPrim()
@@ -25,7 +25,7 @@ from mindspore.ops.composite.base import GradOperation, _Grad, HyperMap, Map, Mu
25
25
  from mindspore.ops.composite.env_ops import env_get
26
26
  from mindspore.ops.function.clip_func import clip_by_global_norm
27
27
  from mindspore.ops.composite.multitype_ops.add_impl import hyper_add
28
- from mindspore.ops.composite.multitype_ops.ones_like_impl import ones_like
28
+ from mindspore.ops.composite.multitype_ops.ones_like_impl import ones_like, _ones_like_for_grad
29
29
  from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like
30
30
  from mindspore.ops.function.random_func import normal, laplace, uniform, gamma, poisson, multinomial
31
31
  from mindspore.ops.composite.math_ops import matmul, cummin, mm
@@ -46,6 +46,7 @@ __all__ = [
46
46
  'hyper_add',
47
47
  'zeros_like',
48
48
  'ones_like',
49
+ '_ones_like_for_grad',
49
50
  'zip_operation',
50
51
  'normal',
51
52
  'laplace',