mindspore 2.4.10__cp39-cp39-win_amd64.whl → 2.6.0__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (579) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +13 -6
  3. mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
  4. mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
  5. mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
  6. mindspore/_check_jit_forbidden_api.py +3 -0
  7. mindspore/_checkparam.py +3 -38
  8. mindspore/_deprecated/__init__.py +17 -0
  9. mindspore/_deprecated/jit.py +198 -0
  10. mindspore/_extends/builtin_operations.py +1 -1
  11. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
  12. mindspore/_extends/parse/__init__.py +6 -7
  13. mindspore/_extends/parse/compile_config.py +83 -0
  14. mindspore/_extends/parse/deprecated/__init__.py +0 -0
  15. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +394 -0
  16. mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
  17. mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
  18. mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
  19. mindspore/_extends/parse/parser.py +47 -198
  20. mindspore/_extends/parse/resources.py +1 -5
  21. mindspore/_extends/parse/standard_method.py +229 -99
  22. mindspore/_extends/pijit/__init__.py +2 -2
  23. mindspore/_extends/pijit/pijit_func_white_list.py +17 -12
  24. mindspore/_extends/pijit/tensor_func_list.py +27 -0
  25. mindspore/_extends/utils.py +1 -1
  26. mindspore/amp.py +11 -5
  27. mindspore/avcodec-59.dll +0 -0
  28. mindspore/avdevice-59.dll +0 -0
  29. mindspore/avfilter-8.dll +0 -0
  30. mindspore/avformat-59.dll +0 -0
  31. mindspore/avutil-57.dll +0 -0
  32. mindspore/boost/__init__.py +2 -2
  33. mindspore/boost/base.py +3 -7
  34. mindspore/boost/boost_cell_wrapper.py +138 -43
  35. mindspore/common/__init__.py +6 -3
  36. mindspore/common/_grad_function.py +56 -0
  37. mindspore/common/_pijit_context.py +14 -5
  38. mindspore/common/_register_for_tensor.py +1 -2
  39. mindspore/common/_stub_tensor.py +30 -14
  40. mindspore/common/_tensor_cpp_method.py +17 -0
  41. mindspore/common/_tensor_docs.py +4760 -0
  42. mindspore/common/api.py +480 -372
  43. mindspore/common/auto_dynamic_shape.py +41 -44
  44. mindspore/common/dtype.py +39 -36
  45. mindspore/common/dump.py +9 -6
  46. mindspore/common/file_system.py +9 -1
  47. mindspore/common/generator.py +5 -0
  48. mindspore/common/hook_handle.py +6 -2
  49. mindspore/common/initializer.py +13 -10
  50. mindspore/common/jit_begin_end.py +94 -0
  51. mindspore/common/jit_config.py +6 -1
  52. mindspore/common/jit_context.py +76 -0
  53. mindspore/common/jit_trace.py +378 -0
  54. mindspore/common/lazy_inline.py +9 -3
  55. mindspore/common/mindir_util.py +10 -2
  56. mindspore/common/mutable.py +5 -4
  57. mindspore/common/parameter.py +135 -52
  58. mindspore/common/seed.py +2 -2
  59. mindspore/common/sparse_tensor.py +23 -17
  60. mindspore/common/tensor.py +975 -1981
  61. mindspore/communication/__init__.py +7 -5
  62. mindspore/communication/_comm_helper.py +52 -2
  63. mindspore/communication/comm_func.py +240 -181
  64. mindspore/communication/management.py +95 -26
  65. mindspore/context.py +324 -573
  66. mindspore/dataset/__init__.py +65 -37
  67. mindspore/dataset/audio/__init__.py +2 -8
  68. mindspore/dataset/audio/transforms.py +3 -17
  69. mindspore/dataset/callback/ds_callback.py +2 -1
  70. mindspore/dataset/core/config.py +87 -6
  71. mindspore/dataset/engine/cache_admin.py +3 -3
  72. mindspore/dataset/engine/cache_client.py +6 -5
  73. mindspore/dataset/engine/datasets.py +292 -267
  74. mindspore/dataset/engine/datasets_audio.py +22 -8
  75. mindspore/dataset/engine/datasets_standard_format.py +46 -27
  76. mindspore/dataset/engine/datasets_text.py +78 -48
  77. mindspore/dataset/engine/datasets_user_defined.py +183 -117
  78. mindspore/dataset/engine/datasets_vision.py +120 -44
  79. mindspore/dataset/engine/iterators.py +283 -63
  80. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
  81. mindspore/dataset/engine/obs/util.py +8 -0
  82. mindspore/dataset/engine/queue.py +40 -0
  83. mindspore/dataset/engine/samplers.py +289 -43
  84. mindspore/dataset/engine/serializer_deserializer.py +3 -2
  85. mindspore/dataset/engine/validators.py +53 -11
  86. mindspore/dataset/text/__init__.py +7 -6
  87. mindspore/dataset/text/transforms.py +6 -5
  88. mindspore/dataset/text/utils.py +3 -3
  89. mindspore/dataset/transforms/__init__.py +0 -9
  90. mindspore/dataset/transforms/py_transforms_util.py +17 -0
  91. mindspore/dataset/transforms/transforms.py +31 -14
  92. mindspore/dataset/utils/browse_dataset.py +1 -1
  93. mindspore/dataset/vision/__init__.py +2 -9
  94. mindspore/dataset/vision/transforms.py +202 -158
  95. mindspore/dataset/vision/utils.py +7 -5
  96. mindspore/dataset/vision/validators.py +1 -2
  97. mindspore/device_context/__init__.py +21 -0
  98. mindspore/device_context/ascend/__init__.py +25 -0
  99. mindspore/device_context/ascend/device.py +72 -0
  100. mindspore/device_context/ascend/op_debug.py +153 -0
  101. mindspore/device_context/ascend/op_precision.py +193 -0
  102. mindspore/device_context/ascend/op_tuning.py +123 -0
  103. mindspore/{ops_generate/gen_constants.py → device_context/cpu/__init__.py} +6 -17
  104. mindspore/device_context/cpu/device.py +62 -0
  105. mindspore/device_context/cpu/op_tuning.py +43 -0
  106. mindspore/device_context/gpu/__init__.py +21 -0
  107. mindspore/device_context/gpu/device.py +70 -0
  108. mindspore/device_context/gpu/op_precision.py +67 -0
  109. mindspore/device_context/gpu/op_tuning.py +175 -0
  110. mindspore/device_manager.py +170 -0
  111. mindspore/dnnl.dll +0 -0
  112. mindspore/experimental/es/embedding_service.py +35 -27
  113. mindspore/experimental/llm_boost/__init__.py +1 -0
  114. mindspore/experimental/llm_boost/ascend_native/__init__.py +22 -0
  115. mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +209 -0
  116. mindspore/experimental/llm_boost/ascend_native/llm_boost.py +52 -0
  117. mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
  118. mindspore/experimental/llm_boost/atb/llama_boost.py +6 -1
  119. mindspore/experimental/llm_boost/register.py +1 -0
  120. mindspore/experimental/map_parameter.py +4 -4
  121. mindspore/experimental/optim/adadelta.py +6 -6
  122. mindspore/experimental/optim/adagrad.py +4 -4
  123. mindspore/experimental/optim/adam.py +7 -0
  124. mindspore/experimental/optim/adamax.py +4 -4
  125. mindspore/experimental/optim/adamw.py +4 -0
  126. mindspore/experimental/optim/asgd.py +1 -1
  127. mindspore/experimental/optim/lr_scheduler.py +73 -46
  128. mindspore/experimental/optim/radam.py +34 -31
  129. mindspore/experimental/optim/rprop.py +1 -1
  130. mindspore/experimental/optim/sgd.py +1 -1
  131. mindspore/hal/contiguous_tensors_handle.py +6 -10
  132. mindspore/hal/device.py +55 -53
  133. mindspore/hal/event.py +52 -52
  134. mindspore/hal/memory.py +179 -120
  135. mindspore/hal/stream.py +150 -109
  136. mindspore/include/api/context.h +0 -1
  137. mindspore/include/dataset/constants.h +7 -4
  138. mindspore/include/dataset/execute.h +2 -2
  139. mindspore/jpeg62.dll +0 -0
  140. mindspore/log.py +50 -0
  141. mindspore/mindrecord/__init__.py +21 -8
  142. mindspore/mindrecord/config.py +17 -316
  143. mindspore/mindrecord/filereader.py +1 -9
  144. mindspore/mindrecord/filewriter.py +5 -15
  145. mindspore/mindrecord/mindpage.py +1 -9
  146. mindspore/mindspore_backend_common.dll +0 -0
  147. mindspore/mindspore_backend_manager.dll +0 -0
  148. mindspore/mindspore_common.dll +0 -0
  149. mindspore/mindspore_core.dll +0 -0
  150. mindspore/mindspore_dump.dll +0 -0
  151. mindspore/mindspore_frontend.dll +0 -0
  152. mindspore/mindspore_glog.dll +0 -0
  153. mindspore/mindspore_memory_pool.dll +0 -0
  154. mindspore/mindspore_ms_backend.dll +0 -0
  155. mindspore/mindspore_ops.dll +0 -0
  156. mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
  157. mindspore/mindspore_ops_kernel_common.dll +0 -0
  158. mindspore/mindspore_profiler.dll +0 -0
  159. mindspore/mindspore_pyboost.dll +0 -0
  160. mindspore/mindspore_pynative.dll +0 -0
  161. mindspore/mindspore_res_manager.dll +0 -0
  162. mindspore/mindspore_runtime_pipeline.dll +0 -0
  163. mindspore/mint/__init__.py +798 -761
  164. mindspore/mint/distributed/__init__.py +70 -4
  165. mindspore/mint/distributed/distributed.py +2679 -44
  166. mindspore/mint/linalg/__init__.py +8 -0
  167. mindspore/mint/nn/__init__.py +743 -22
  168. mindspore/mint/nn/functional.py +716 -23
  169. mindspore/mint/nn/layer/__init__.py +21 -4
  170. mindspore/mint/nn/layer/_functions.py +334 -0
  171. mindspore/mint/nn/layer/activation.py +276 -1
  172. mindspore/mint/nn/layer/basic.py +123 -0
  173. mindspore/mint/nn/layer/conv.py +933 -0
  174. mindspore/mint/nn/layer/normalization.py +223 -28
  175. mindspore/mint/nn/layer/padding.py +797 -0
  176. mindspore/mint/nn/layer/pooling.py +235 -0
  177. mindspore/mint/optim/__init__.py +3 -1
  178. mindspore/mint/optim/adam.py +223 -0
  179. mindspore/mint/optim/adamw.py +26 -19
  180. mindspore/mint/optim/sgd.py +171 -0
  181. mindspore/mint/special/__init__.py +2 -1
  182. mindspore/multiprocessing/__init__.py +5 -0
  183. mindspore/nn/__init__.py +4 -1
  184. mindspore/nn/cell.py +1373 -192
  185. mindspore/nn/dynamic_lr.py +2 -1
  186. mindspore/nn/layer/activation.py +29 -27
  187. mindspore/nn/layer/basic.py +51 -35
  188. mindspore/nn/layer/channel_shuffle.py +3 -3
  189. mindspore/nn/layer/container.py +1 -1
  190. mindspore/nn/layer/conv.py +53 -42
  191. mindspore/nn/layer/embedding.py +12 -11
  192. mindspore/nn/layer/normalization.py +56 -49
  193. mindspore/nn/layer/padding.py +4 -3
  194. mindspore/nn/layer/pooling.py +120 -42
  195. mindspore/nn/layer/rnn_cells.py +1 -1
  196. mindspore/nn/layer/rnns.py +2 -1
  197. mindspore/nn/layer/timedistributed.py +5 -5
  198. mindspore/nn/layer/transformer.py +59 -36
  199. mindspore/nn/learning_rate_schedule.py +8 -4
  200. mindspore/nn/loss/loss.py +58 -55
  201. mindspore/nn/optim/ada_grad.py +7 -5
  202. mindspore/nn/optim/adadelta.py +11 -9
  203. mindspore/nn/optim/adafactor.py +1 -1
  204. mindspore/nn/optim/adam.py +19 -15
  205. mindspore/nn/optim/adamax.py +8 -7
  206. mindspore/nn/optim/adasum.py +5 -5
  207. mindspore/nn/optim/asgd.py +3 -1
  208. mindspore/nn/optim/ftrl.py +11 -9
  209. mindspore/nn/optim/lamb.py +1 -1
  210. mindspore/nn/optim/lars.py +1 -4
  211. mindspore/nn/optim/lazyadam.py +12 -10
  212. mindspore/nn/optim/momentum.py +7 -6
  213. mindspore/nn/optim/optimizer.py +3 -3
  214. mindspore/nn/optim/proximal_ada_grad.py +12 -10
  215. mindspore/nn/optim/rmsprop.py +13 -12
  216. mindspore/nn/optim/rprop.py +11 -9
  217. mindspore/nn/optim/sgd.py +9 -6
  218. mindspore/nn/optim/tft_wrapper.py +5 -2
  219. mindspore/nn/optim/thor.py +2 -1
  220. mindspore/nn/probability/bijector/bijector.py +17 -11
  221. mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
  222. mindspore/nn/probability/bijector/invert.py +2 -2
  223. mindspore/nn/probability/bijector/scalar_affine.py +3 -3
  224. mindspore/nn/probability/bijector/softplus.py +3 -2
  225. mindspore/nn/probability/distribution/beta.py +3 -3
  226. mindspore/nn/probability/distribution/categorical.py +1 -1
  227. mindspore/nn/probability/distribution/cauchy.py +4 -2
  228. mindspore/nn/probability/distribution/exponential.py +6 -7
  229. mindspore/nn/probability/distribution/gamma.py +2 -2
  230. mindspore/nn/probability/distribution/gumbel.py +2 -2
  231. mindspore/nn/probability/distribution/half_normal.py +5 -3
  232. mindspore/nn/probability/distribution/logistic.py +5 -3
  233. mindspore/nn/probability/distribution/poisson.py +1 -1
  234. mindspore/nn/probability/distribution/uniform.py +5 -3
  235. mindspore/nn/reinforcement/_tensors_queue.py +1 -1
  236. mindspore/nn/reinforcement/tensor_array.py +1 -1
  237. mindspore/nn/utils/init.py +13 -11
  238. mindspore/nn/wrap/__init__.py +6 -6
  239. mindspore/nn/wrap/cell_wrapper.py +181 -122
  240. mindspore/nn/wrap/grad_reducer.py +45 -36
  241. mindspore/nn/wrap/loss_scale.py +6 -7
  242. mindspore/numpy/array_creations.py +63 -65
  243. mindspore/numpy/array_ops.py +149 -144
  244. mindspore/numpy/logic_ops.py +41 -42
  245. mindspore/numpy/math_ops.py +361 -359
  246. mindspore/numpy/utils.py +17 -18
  247. mindspore/numpy/utils_const.py +5 -6
  248. mindspore/opencv_core452.dll +0 -0
  249. mindspore/opencv_imgcodecs452.dll +0 -0
  250. mindspore/opencv_imgproc452.dll +0 -0
  251. mindspore/ops/__init__.py +5 -3
  252. mindspore/ops/_grad_experimental/grad_comm_ops.py +112 -16
  253. mindspore/ops/_grad_experimental/grad_debug_ops.py +14 -2
  254. mindspore/ops/_grad_experimental/grad_inner_ops.py +9 -0
  255. mindspore/ops/_grad_experimental/grad_math_ops.py +2 -1
  256. mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
  257. mindspore/ops/_op_impl/cpu/__init__.py +1 -0
  258. mindspore/ops/_op_impl/cpu/raise_op.py +28 -0
  259. mindspore/ops/_register_for_op.py +0 -11
  260. mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
  261. mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -65
  262. mindspore/ops/_vmap/vmap_array_ops.py +52 -25
  263. mindspore/ops/_vmap/vmap_base.py +0 -2
  264. mindspore/ops/_vmap/vmap_grad_nn_ops.py +21 -14
  265. mindspore/ops/_vmap/vmap_math_ops.py +15 -16
  266. mindspore/ops/_vmap/vmap_nn_ops.py +29 -42
  267. mindspore/ops/auto_generate/__init__.py +4 -3
  268. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +258 -46
  269. mindspore/ops/auto_generate/gen_extend_func.py +757 -185
  270. mindspore/ops/auto_generate/gen_ops_def.py +4197 -2243
  271. mindspore/ops/auto_generate/gen_ops_prim.py +16976 -6055
  272. mindspore/ops/auto_generate/pyboost_inner_prim.py +221 -87
  273. mindspore/ops/composite/__init__.py +2 -1
  274. mindspore/ops/composite/base.py +20 -25
  275. mindspore/ops/composite/math_ops.py +6 -16
  276. mindspore/ops/composite/multitype_ops/__init__.py +5 -2
  277. mindspore/ops/composite/multitype_ops/_compile_utils.py +228 -30
  278. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
  279. mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
  280. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
  281. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
  282. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
  283. mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
  284. mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
  285. mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
  286. mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
  287. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
  288. mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
  289. mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
  290. mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
  291. mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
  292. mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
  293. mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
  294. mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
  295. mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
  296. mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
  297. mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
  298. mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
  299. mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
  300. mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
  301. mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
  302. mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
  303. mindspore/ops/composite/multitype_ops/pow_impl.py +2 -30
  304. mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
  305. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  306. mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
  307. mindspore/ops/function/__init__.py +40 -2
  308. mindspore/ops/function/_add_attr_func.py +58 -0
  309. mindspore/ops/function/array_func.py +2089 -2403
  310. mindspore/ops/function/clip_func.py +80 -23
  311. mindspore/ops/function/debug_func.py +57 -57
  312. mindspore/ops/function/grad/__init__.py +1 -0
  313. mindspore/ops/function/grad/grad_func.py +104 -71
  314. mindspore/ops/function/image_func.py +2 -2
  315. mindspore/ops/function/linalg_func.py +47 -78
  316. mindspore/ops/function/math_func.py +4351 -3813
  317. mindspore/ops/function/nn_func.py +1712 -637
  318. mindspore/ops/function/other_func.py +159 -1
  319. mindspore/ops/function/parameter_func.py +18 -84
  320. mindspore/ops/function/random_func.py +452 -387
  321. mindspore/ops/function/reshard_func.py +4 -70
  322. mindspore/ops/function/sparse_func.py +3 -3
  323. mindspore/ops/function/sparse_unary_func.py +6 -6
  324. mindspore/ops/function/spectral_func.py +25 -58
  325. mindspore/ops/function/vmap_func.py +26 -18
  326. mindspore/ops/functional.py +23 -7
  327. mindspore/ops/functional_overload.py +1548 -0
  328. mindspore/ops/op_info_register.py +32 -244
  329. mindspore/ops/operations/__init__.py +23 -15
  330. mindspore/ops/operations/_custom_ops_utils.py +235 -0
  331. mindspore/ops/operations/_embedding_cache_ops.py +4 -4
  332. mindspore/ops/operations/_grad_ops.py +2 -43
  333. mindspore/ops/operations/_infer_ops.py +2 -1
  334. mindspore/ops/operations/_inner_ops.py +43 -84
  335. mindspore/ops/operations/_ms_kernel.py +4 -10
  336. mindspore/ops/operations/_rl_inner_ops.py +1 -1
  337. mindspore/ops/operations/_scalar_ops.py +3 -2
  338. mindspore/ops/operations/_sequence_ops.py +1 -1
  339. mindspore/ops/operations/_tensor_array.py +1 -1
  340. mindspore/ops/operations/array_ops.py +81 -324
  341. mindspore/ops/operations/comm_ops.py +154 -108
  342. mindspore/ops/operations/custom_ops.py +298 -87
  343. mindspore/ops/operations/debug_ops.py +157 -59
  344. mindspore/ops/operations/inner_ops.py +7 -5
  345. mindspore/ops/operations/linalg_ops.py +1 -57
  346. mindspore/ops/operations/manually_defined/_inner.py +1 -1
  347. mindspore/ops/operations/manually_defined/ops_def.py +928 -180
  348. mindspore/ops/operations/math_ops.py +32 -234
  349. mindspore/ops/operations/nn_ops.py +212 -531
  350. mindspore/ops/operations/other_ops.py +62 -9
  351. mindspore/ops/operations/random_ops.py +13 -7
  352. mindspore/ops/operations/reshard_ops.py +1 -1
  353. mindspore/ops/operations/sparse_ops.py +2 -2
  354. mindspore/ops/primitive.py +66 -53
  355. mindspore/ops/tensor_method.py +1895 -0
  356. mindspore/ops_generate/__init__.py +0 -5
  357. mindspore/ops_generate/aclnn/__init__.py +0 -0
  358. mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +135 -0
  359. mindspore/ops_generate/aclnn/gen_aclnn_implement.py +257 -0
  360. mindspore/ops_generate/api/__init__.py +0 -0
  361. mindspore/ops_generate/api/add_tensor_docs_generator.py +56 -0
  362. mindspore/ops_generate/api/cpp_create_prim_instance_helper_generator.py +105 -0
  363. mindspore/ops_generate/api/functional_map_cpp_generator.py +504 -0
  364. mindspore/ops_generate/api/functional_overload_py_generator.py +112 -0
  365. mindspore/ops_generate/api/functions_cc_generator.py +237 -0
  366. mindspore/ops_generate/api/gen_api.py +103 -0
  367. mindspore/ops_generate/api/op_api_proto.py +235 -0
  368. mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +461 -0
  369. mindspore/ops_generate/common/__init__.py +0 -0
  370. mindspore/ops_generate/common/base_generator.py +11 -0
  371. mindspore/ops_generate/common/gen_constants.py +91 -0
  372. mindspore/ops_generate/common/gen_utils.py +348 -0
  373. mindspore/ops_generate/common/op_proto.py +473 -0
  374. mindspore/ops_generate/common/template.py +523 -0
  375. mindspore/ops_generate/gen_ops.py +22 -1069
  376. mindspore/ops_generate/op_def/__init__.py +0 -0
  377. mindspore/ops_generate/op_def/gen_op_def.py +90 -0
  378. mindspore/ops_generate/op_def/lite_ops_cpp_generator.py +191 -0
  379. mindspore/ops_generate/op_def/ops_def_cc_generator.py +296 -0
  380. mindspore/ops_generate/op_def/ops_def_h_generator.py +74 -0
  381. mindspore/ops_generate/op_def/ops_name_h_generator.py +83 -0
  382. mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
  383. mindspore/ops_generate/op_def_py/__init__.py +0 -0
  384. mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
  385. mindspore/ops_generate/op_def_py/op_def_py_generator.py +132 -0
  386. mindspore/ops_generate/op_def_py/op_prim_py_generator.py +489 -0
  387. mindspore/ops_generate/pyboost/__init__.py +0 -0
  388. mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +139 -0
  389. mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +93 -0
  390. mindspore/ops_generate/pyboost/gen_pyboost_func.py +175 -0
  391. mindspore/ops_generate/pyboost/op_template_parser.py +517 -0
  392. mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +407 -0
  393. mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +100 -0
  394. mindspore/ops_generate/pyboost/pyboost_functions_py_generator.py +148 -0
  395. mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +155 -0
  396. mindspore/ops_generate/pyboost/pyboost_inner_prim_generator.py +132 -0
  397. mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +272 -0
  398. mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +938 -0
  399. mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +357 -0
  400. mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +179 -36
  401. mindspore/ops_generate/resources/__init__.py +0 -0
  402. mindspore/ops_generate/resources/resource_list.py +30 -0
  403. mindspore/ops_generate/resources/resource_loader.py +36 -0
  404. mindspore/ops_generate/resources/resource_manager.py +64 -0
  405. mindspore/ops_generate/resources/yaml_loader.py +88 -0
  406. mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
  407. mindspore/parallel/__init__.py +7 -3
  408. mindspore/parallel/_auto_parallel_context.py +159 -40
  409. mindspore/parallel/_cell_wrapper.py +132 -15
  410. mindspore/parallel/_parallel_serialization.py +107 -5
  411. mindspore/parallel/_ps_context.py +1 -1
  412. mindspore/parallel/_recovery_context.py +7 -2
  413. mindspore/parallel/_tensor.py +142 -18
  414. mindspore/parallel/_utils.py +199 -23
  415. mindspore/parallel/algo_parameter_config.py +4 -4
  416. mindspore/parallel/auto_parallel.py +732 -0
  417. mindspore/parallel/checkpoint_convert.py +159 -0
  418. mindspore/parallel/checkpoint_transform.py +700 -35
  419. mindspore/parallel/cluster/process_entity/_api.py +276 -50
  420. mindspore/parallel/cluster/process_entity/_utils.py +41 -6
  421. mindspore/parallel/cluster/run.py +21 -4
  422. mindspore/parallel/function/__init__.py +24 -0
  423. mindspore/parallel/function/reshard_func.py +258 -0
  424. mindspore/parallel/nn/__init__.py +25 -0
  425. mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
  426. mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
  427. mindspore/parallel/parameter_broadcast.py +25 -14
  428. mindspore/parallel/shard.py +137 -59
  429. mindspore/parallel/transform_safetensors.py +364 -305
  430. mindspore/profiler/__init__.py +22 -5
  431. mindspore/profiler/analysis/__init__.py +0 -0
  432. mindspore/profiler/analysis/parser/__init__.py +0 -0
  433. mindspore/profiler/analysis/parser/ascend_cann_parser.py +170 -0
  434. mindspore/profiler/analysis/parser/base_parser.py +158 -0
  435. mindspore/profiler/analysis/parser/framework_cann_relation_parser.py +45 -0
  436. mindspore/profiler/analysis/parser/ms_framework_parser.py +142 -0
  437. mindspore/profiler/analysis/parser/ms_minddata_parser.py +145 -0
  438. mindspore/profiler/analysis/parser/timeline_assembly_factory/__init__.py +0 -0
  439. mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +264 -0
  440. mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +40 -0
  441. mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +109 -0
  442. mindspore/profiler/analysis/parser/timeline_creator/__init__.py +0 -0
  443. mindspore/profiler/analysis/parser/timeline_creator/base_timeline_creator.py +44 -0
  444. mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +90 -0
  445. mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +76 -0
  446. mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +103 -0
  447. mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +134 -0
  448. mindspore/profiler/analysis/parser/timeline_event/__init__.py +0 -0
  449. mindspore/profiler/analysis/parser/timeline_event/base_event.py +233 -0
  450. mindspore/profiler/analysis/parser/timeline_event/cpu_op_event.py +47 -0
  451. mindspore/profiler/analysis/parser/timeline_event/flow_event.py +36 -0
  452. mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +415 -0
  453. mindspore/profiler/analysis/parser/timeline_event/msprof_event.py +73 -0
  454. mindspore/profiler/analysis/parser/timeline_event/scope_layer_event.py +53 -0
  455. mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +146 -0
  456. mindspore/profiler/analysis/task_manager.py +131 -0
  457. mindspore/profiler/analysis/time_converter.py +84 -0
  458. mindspore/profiler/analysis/viewer/__init__.py +0 -0
  459. mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +372 -0
  460. mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +87 -0
  461. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +250 -0
  462. mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +320 -0
  463. mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +327 -0
  464. mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +376 -0
  465. mindspore/profiler/analysis/viewer/ascend_timeline_viewer.py +58 -0
  466. mindspore/profiler/analysis/viewer/base_viewer.py +26 -0
  467. mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +96 -0
  468. mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +581 -0
  469. mindspore/profiler/analysis/work_flow.py +73 -0
  470. mindspore/profiler/common/ascend_msprof_exporter.py +139 -0
  471. mindspore/profiler/common/command_executor.py +90 -0
  472. mindspore/profiler/common/constant.py +186 -3
  473. mindspore/profiler/common/file_manager.py +208 -0
  474. mindspore/profiler/common/log.py +130 -0
  475. mindspore/profiler/common/msprof_cmd_tool.py +221 -0
  476. mindspore/profiler/common/path_manager.py +395 -0
  477. mindspore/profiler/common/process_bar.py +168 -0
  478. mindspore/profiler/common/process_pool.py +9 -3
  479. mindspore/profiler/common/profiler_context.py +500 -0
  480. mindspore/profiler/common/profiler_info.py +304 -0
  481. mindspore/profiler/common/profiler_meta_data.py +74 -0
  482. mindspore/profiler/common/profiler_output_path.py +284 -0
  483. mindspore/profiler/common/profiler_parameters.py +251 -0
  484. mindspore/profiler/common/profiler_path_manager.py +179 -0
  485. mindspore/profiler/common/record_function.py +76 -0
  486. mindspore/profiler/common/tlv_decoder.py +76 -0
  487. mindspore/profiler/common/util.py +75 -2
  488. mindspore/profiler/dynamic_profiler.py +341 -75
  489. mindspore/profiler/envprofiler.py +163 -0
  490. mindspore/profiler/experimental_config.py +197 -0
  491. mindspore/profiler/mstx.py +242 -0
  492. mindspore/profiler/platform/__init__.py +21 -0
  493. mindspore/profiler/platform/base_profiler.py +40 -0
  494. mindspore/profiler/platform/cpu_profiler.py +124 -0
  495. mindspore/profiler/platform/gpu_profiler.py +74 -0
  496. mindspore/profiler/platform/npu_profiler.py +335 -0
  497. mindspore/profiler/profiler.py +1073 -90
  498. mindspore/profiler/profiler_action_controller.py +187 -0
  499. mindspore/profiler/profiler_interface.py +118 -0
  500. mindspore/profiler/schedule.py +243 -0
  501. mindspore/rewrite/api/node.py +15 -13
  502. mindspore/rewrite/api/symbol_tree.py +2 -3
  503. mindspore/run_check/_check_version.py +27 -20
  504. mindspore/run_check/run_check.py +1 -1
  505. mindspore/runtime/__init__.py +37 -0
  506. mindspore/runtime/device.py +27 -0
  507. mindspore/runtime/event.py +209 -0
  508. mindspore/runtime/executor.py +177 -0
  509. mindspore/runtime/memory.py +416 -0
  510. mindspore/runtime/stream.py +460 -0
  511. mindspore/runtime/thread_bind_core.py +401 -0
  512. mindspore/safeguard/rewrite_obfuscation.py +12 -9
  513. mindspore/swresample-4.dll +0 -0
  514. mindspore/swscale-6.dll +0 -0
  515. mindspore/tinyxml2.dll +0 -0
  516. mindspore/train/__init__.py +8 -8
  517. mindspore/train/_utils.py +96 -27
  518. mindspore/train/amp.py +9 -5
  519. mindspore/train/callback/__init__.py +2 -2
  520. mindspore/train/callback/_callback.py +2 -16
  521. mindspore/train/callback/_checkpoint.py +53 -55
  522. mindspore/train/callback/_cluster_monitor.py +14 -18
  523. mindspore/train/callback/_early_stop.py +1 -1
  524. mindspore/train/callback/_flops_collector.py +103 -68
  525. mindspore/train/callback/_history.py +8 -5
  526. mindspore/train/callback/_lambda_callback.py +2 -2
  527. mindspore/train/callback/_landscape.py +0 -3
  528. mindspore/train/callback/_loss_monitor.py +2 -1
  529. mindspore/train/callback/_on_request_exit.py +6 -5
  530. mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
  531. mindspore/train/callback/_summary_collector.py +52 -19
  532. mindspore/train/callback/_time_monitor.py +2 -1
  533. mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +228 -108
  534. mindspore/train/data_sink.py +25 -2
  535. mindspore/train/dataset_helper.py +15 -16
  536. mindspore/train/loss_scale_manager.py +8 -7
  537. mindspore/train/metrics/accuracy.py +3 -3
  538. mindspore/train/metrics/confusion_matrix.py +9 -9
  539. mindspore/train/metrics/error.py +3 -3
  540. mindspore/train/metrics/hausdorff_distance.py +4 -4
  541. mindspore/train/metrics/mean_surface_distance.py +3 -3
  542. mindspore/train/metrics/metric.py +0 -12
  543. mindspore/train/metrics/occlusion_sensitivity.py +4 -2
  544. mindspore/train/metrics/precision.py +11 -10
  545. mindspore/train/metrics/recall.py +9 -9
  546. mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
  547. mindspore/train/mind_ir_pb2.py +174 -46
  548. mindspore/train/model.py +269 -136
  549. mindspore/train/serialization.py +622 -978
  550. mindspore/train/summary/_summary_adapter.py +2 -2
  551. mindspore/train/summary/summary_record.py +2 -3
  552. mindspore/train/train_thor/model_thor.py +1 -1
  553. mindspore/turbojpeg.dll +0 -0
  554. mindspore/utils/__init__.py +6 -3
  555. mindspore/utils/dryrun.py +140 -0
  556. mindspore/utils/hooks.py +81 -0
  557. mindspore/utils/runtime_execution_order_check.py +552 -0
  558. mindspore/utils/utils.py +138 -4
  559. mindspore/version.py +1 -1
  560. {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/METADATA +3 -3
  561. {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/RECORD +564 -395
  562. {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/entry_points.txt +1 -1
  563. mindspore/_install_custom.py +0 -43
  564. mindspore/common/_register_for_adapter.py +0 -74
  565. mindspore/common/_tensor_overload.py +0 -139
  566. mindspore/mindspore_np_dtype.dll +0 -0
  567. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
  568. mindspore/ops/auto_generate/gen_arg_handler.py +0 -197
  569. mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
  570. mindspore/ops_generate/gen_aclnn_implement.py +0 -263
  571. mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
  572. mindspore/ops_generate/gen_pyboost_func.py +0 -1052
  573. mindspore/ops_generate/gen_utils.py +0 -209
  574. mindspore/ops_generate/op_proto.py +0 -145
  575. mindspore/ops_generate/template.py +0 -261
  576. mindspore/profiler/envprofiling.py +0 -254
  577. mindspore/profiler/profiling.py +0 -1926
  578. {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/WHEEL +0 -0
  579. {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/top_level.txt +0 -0
@@ -14,7 +14,8 @@
14
14
  # ============================================================================
15
15
 
16
16
  from mindspore.common._stub_tensor import _convert_stub
17
- from mindspore.ops.auto_generate.gen_arg_handler import *
17
+ from mindspore.ops._utils.arg_handler import *
18
+ from mindspore._c_expression import AdaptiveMaxPool2DPrim_
18
19
  from mindspore._c_expression import ArgMaxWithValuePrim_
19
20
  from mindspore._c_expression import ArgMinWithValuePrim_
20
21
  from mindspore._c_expression import BatchMatMulPrim_
@@ -24,14 +25,14 @@ from mindspore._c_expression import BinaryCrossEntropyPrim_
24
25
  from mindspore._c_expression import BCEWithLogitsLossPrim_
25
26
  from mindspore._c_expression import BroadcastToPrim_
26
27
  from mindspore._c_expression import ConcatPrim_
27
- from mindspore._c_expression import ConvolutionGradPrim_
28
- from mindspore._c_expression import ConvolutionPrim_
29
28
  from mindspore._c_expression import CrossPrim_
30
29
  from mindspore._c_expression import CummaxPrim_
31
30
  from mindspore._c_expression import EluExtPrim_
32
31
  from mindspore._c_expression import FFNExtPrim_
33
32
  from mindspore._c_expression import FlashAttentionScoreGradPrim_
34
33
  from mindspore._c_expression import FlashAttentionScorePrim_
34
+ from mindspore._c_expression import GluGradPrim_
35
+ from mindspore._c_expression import GLUPrim_
35
36
  from mindspore._c_expression import GridSampler2DGradPrim_
36
37
  from mindspore._c_expression import GridSampler2DPrim_
37
38
  from mindspore._c_expression import GridSampler3DGradPrim_
@@ -47,31 +48,53 @@ from mindspore._c_expression import MaxPoolGradWithIndicesPrim_
47
48
  from mindspore._c_expression import MaxPoolGradWithMaskPrim_
48
49
  from mindspore._c_expression import MaxPoolWithIndicesPrim_
49
50
  from mindspore._c_expression import MaxPoolWithMaskPrim_
51
+ from mindspore._c_expression import MeshgridPrim_
50
52
  from mindspore._c_expression import NanToNumPrim_
53
+ from mindspore._c_expression import NLLLossGradPrim_
54
+ from mindspore._c_expression import NLLLossPrim_
51
55
  from mindspore._c_expression import OneHotExtPrim_
56
+ from mindspore._c_expression import PromptFlashAttentionPrim_
52
57
  from mindspore._c_expression import ReduceAllPrim_
53
58
  from mindspore._c_expression import ReduceAnyPrim_
59
+ from mindspore._c_expression import ReduceMaxPrim_
60
+ from mindspore._c_expression import ReduceMinPrim_
54
61
  from mindspore._c_expression import ReverseV2Prim_
55
62
  from mindspore._c_expression import RmsNormPrim_
56
63
  from mindspore._c_expression import RollPrim_
57
64
  from mindspore._c_expression import SearchSortedPrim_
65
+ from mindspore._c_expression import SmoothL1LossGradPrim_
66
+ from mindspore._c_expression import SmoothL1LossPrim_
58
67
  from mindspore._c_expression import SoftmaxPrim_
59
68
  from mindspore._c_expression import SoftShrinkGradPrim_
60
69
  from mindspore._c_expression import SoftShrinkPrim_
70
+ from mindspore._c_expression import SoftMarginLossGradPrim_
71
+ from mindspore._c_expression import SoftMarginLossPrim_
72
+ from mindspore._c_expression import SplitPrim_
73
+ from mindspore._c_expression import SqueezePrim_
61
74
  from mindspore._c_expression import StackExtPrim_
62
- from mindspore._c_expression import TrilExtPrim_
63
75
  from mindspore._c_expression import TriuPrim_
76
+ from mindspore._c_expression import UniqueConsecutivePrim_
64
77
  from mindspore._c_expression import UpsampleTrilinear3DGradPrim_
65
78
  from mindspore._c_expression import UpsampleTrilinear3DPrim_
79
+ from mindspore._c_expression import FusedInferAttentionScorePrim_
66
80
  from mindspore._c_expression import GroupedMatmulPrim_
67
81
  from mindspore._c_expression import QuantBatchMatmulPrim_
68
82
  from mindspore._c_expression import WeightQuantBatchMatmulPrim_
69
83
 
70
84
 
85
+ class _PyboostAdaptiveMaxPool2DPrim(AdaptiveMaxPool2DPrim_):
86
+ def __call__(self, input, output_size):
87
+
88
+ return super().__call__([input, output_size])
89
+
90
+
91
+ adaptive_max_pool2d_impl = _PyboostAdaptiveMaxPool2DPrim()
92
+
93
+
71
94
  class _PyboostArgMaxWithValuePrim(ArgMaxWithValuePrim_):
72
95
  def __call__(self, input, axis, keep_dims):
73
96
 
74
- return _convert_stub(super().__call__(input, axis, keep_dims))
97
+ return super().__call__([input, axis, keep_dims])
75
98
 
76
99
 
77
100
  argmax_with_value_impl = _PyboostArgMaxWithValuePrim()
@@ -80,7 +103,7 @@ argmax_with_value_impl = _PyboostArgMaxWithValuePrim()
80
103
  class _PyboostArgMinWithValuePrim(ArgMinWithValuePrim_):
81
104
  def __call__(self, input, axis, keep_dims):
82
105
 
83
- return _convert_stub(super().__call__(input, axis, keep_dims))
106
+ return super().__call__([input, axis, keep_dims])
84
107
 
85
108
 
86
109
  argmin_with_value_impl = _PyboostArgMinWithValuePrim()
@@ -89,16 +112,16 @@ argmin_with_value_impl = _PyboostArgMinWithValuePrim()
89
112
  class _PyboostBatchMatMulPrim(BatchMatMulPrim_):
90
113
  def __call__(self, x, y, transpose_a, transpose_b):
91
114
 
92
- return _convert_stub(super().__call__(x, y, transpose_a, transpose_b))
115
+ return super().__call__([x, y, transpose_a, transpose_b])
93
116
 
94
117
 
95
118
  batch_mat_mul_impl = _PyboostBatchMatMulPrim()
96
119
 
97
120
 
98
121
  class _PyboostBatchNormGradExtPrim(BatchNormGradExtPrim_):
99
- def __call__(self, dout, input, weight, running_mean, running_var, saved_mean, saved_rstd, training, eps):
122
+ def __call__(self, dout, input, weight, running_mean, running_var, saved_mean, saved_rstd, training, eps, output_mask):
100
123
 
101
- return _convert_stub(super().__call__(dout, input, weight, running_mean, running_var, saved_mean, saved_rstd, training, eps))
124
+ return super().__call__([dout, input, weight, running_mean, running_var, saved_mean, saved_rstd, training, eps, output_mask])
102
125
 
103
126
 
104
127
  batch_norm_grad_ext_impl = _PyboostBatchNormGradExtPrim()
@@ -107,7 +130,7 @@ batch_norm_grad_ext_impl = _PyboostBatchNormGradExtPrim()
107
130
  class _PyboostBinaryCrossEntropyGradPrim(BinaryCrossEntropyGradPrim_):
108
131
  def __call__(self, input, target, grad_output, weight, reduction):
109
132
  converted_reduction = str_to_enum('binary_cross_entropy_grad', 'reduction', reduction)
110
- return _convert_stub(super().__call__(input, target, grad_output, weight, converted_reduction))
133
+ return super().__call__([input, target, grad_output, weight, converted_reduction])
111
134
 
112
135
 
113
136
  binary_cross_entropy_grad_impl = _PyboostBinaryCrossEntropyGradPrim()
@@ -116,7 +139,7 @@ binary_cross_entropy_grad_impl = _PyboostBinaryCrossEntropyGradPrim()
116
139
  class _PyboostBinaryCrossEntropyPrim(BinaryCrossEntropyPrim_):
117
140
  def __call__(self, input, target, weight, reduction):
118
141
  converted_reduction = str_to_enum('binary_cross_entropy', 'reduction', reduction)
119
- return _convert_stub(super().__call__(input, target, weight, converted_reduction))
142
+ return super().__call__([input, target, weight, converted_reduction])
120
143
 
121
144
 
122
145
  binary_cross_entropy_impl = _PyboostBinaryCrossEntropyPrim()
@@ -125,7 +148,7 @@ binary_cross_entropy_impl = _PyboostBinaryCrossEntropyPrim()
125
148
  class _PyboostBCEWithLogitsLossPrim(BCEWithLogitsLossPrim_):
126
149
  def __call__(self, input, target, weight, posWeight, reduction):
127
150
  converted_reduction = str_to_enum('binary_cross_entropy_with_logits', 'reduction', reduction)
128
- return _convert_stub(super().__call__(input, target, weight, posWeight, converted_reduction))
151
+ return super().__call__([input, target, weight, posWeight, converted_reduction])
129
152
 
130
153
 
131
154
  binary_cross_entropy_with_logits_impl = _PyboostBCEWithLogitsLossPrim()
@@ -134,7 +157,7 @@ binary_cross_entropy_with_logits_impl = _PyboostBCEWithLogitsLossPrim()
134
157
  class _PyboostBroadcastToPrim(BroadcastToPrim_):
135
158
  def __call__(self, input, shape):
136
159
 
137
- return _convert_stub(super().__call__(input, shape))
160
+ return super().__call__([input, shape])
138
161
 
139
162
 
140
163
  broadcast_to_impl = _PyboostBroadcastToPrim()
@@ -143,40 +166,16 @@ broadcast_to_impl = _PyboostBroadcastToPrim()
143
166
  class _PyboostConcatPrim(ConcatPrim_):
144
167
  def __call__(self, tensors, axis):
145
168
 
146
- return _convert_stub(super().__call__(tensors, axis))
169
+ return super().__call__([tensors, axis])
147
170
 
148
171
 
149
172
  concat_impl = _PyboostConcatPrim()
150
173
 
151
174
 
152
- class _PyboostConvolutionGradPrim(ConvolutionGradPrim_):
153
- def __call__(self, dout, input, weight, bias, stride, padding, dilation, transposed, output_padding, groups, output_mask):
154
- converted_stride = to_strides('convolution_grad', 'stride', stride)
155
- converted_padding = to_2d_paddings('convolution_grad', 'padding', padding)
156
- converted_dilation = to_dilations('convolution_grad', 'dilation', dilation)
157
- converted_output_padding = to_output_padding('convolution_grad', 'output_padding', output_padding)
158
- return _convert_stub(super().__call__(dout, input, weight, bias, converted_stride, converted_padding, converted_dilation, transposed, converted_output_padding, groups, output_mask))
159
-
160
-
161
- convolution_grad_impl = _PyboostConvolutionGradPrim()
162
-
163
-
164
- class _PyboostConvolutionPrim(ConvolutionPrim_):
165
- def __call__(self, input, weight, bias, stride, padding, dilation, transposed, output_padding, groups):
166
- converted_stride = to_strides('convolution', 'stride', stride)
167
- converted_padding = to_2d_paddings('convolution', 'padding', padding)
168
- converted_dilation = to_dilations('convolution', 'dilation', dilation)
169
- converted_output_padding = to_output_padding('convolution', 'output_padding', output_padding)
170
- return _convert_stub(super().__call__(input, weight, bias, converted_stride, converted_padding, converted_dilation, transposed, converted_output_padding, groups))
171
-
172
-
173
- convolution_impl = _PyboostConvolutionPrim()
174
-
175
-
176
175
  class _PyboostCrossPrim(CrossPrim_):
177
176
  def __call__(self, input, other, dim):
178
177
 
179
- return _convert_stub(super().__call__(input, other, dim))
178
+ return super().__call__([input, other, dim])
180
179
 
181
180
 
182
181
  cross_impl = _PyboostCrossPrim()
@@ -185,7 +184,7 @@ cross_impl = _PyboostCrossPrim()
185
184
  class _PyboostCummaxPrim(CummaxPrim_):
186
185
  def __call__(self, input, axis):
187
186
 
188
- return _convert_stub(super().__call__(input, axis))
187
+ return super().__call__([input, axis])
189
188
 
190
189
 
191
190
  cummax_impl = _PyboostCummaxPrim()
@@ -194,7 +193,7 @@ cummax_impl = _PyboostCummaxPrim()
194
193
  class _PyboostEluExtPrim(EluExtPrim_):
195
194
  def __call__(self, input, alpha):
196
195
 
197
- return _convert_stub(super().__call__(input, alpha))
196
+ return super().__call__([input, alpha])
198
197
 
199
198
 
200
199
  elu_ext_impl = _PyboostEluExtPrim()
@@ -203,7 +202,7 @@ elu_ext_impl = _PyboostEluExtPrim()
203
202
  class _PyboostFFNExtPrim(FFNExtPrim_):
204
203
  def __call__(self, x, weight1, weight2, expertTokens, bias1, bias2, scale, offset, deqScale1, deqScale2, antiquant_scale1, antiquant_scale2, antiquant_offset1, antiquant_offset2, activation, inner_precise):
205
204
  converted_activation = str_to_enum('ffn_ext', 'activation', activation)
206
- return _convert_stub(super().__call__(x, weight1, weight2, expertTokens, bias1, bias2, scale, offset, deqScale1, deqScale2, antiquant_scale1, antiquant_scale2, antiquant_offset1, antiquant_offset2, converted_activation, inner_precise))
205
+ return super().__call__([x, weight1, weight2, expertTokens, bias1, bias2, scale, offset, deqScale1, deqScale2, antiquant_scale1, antiquant_scale2, antiquant_offset1, antiquant_offset2, converted_activation, inner_precise])
207
206
 
208
207
 
209
208
  ffn_ext_impl = _PyboostFFNExtPrim()
@@ -212,7 +211,7 @@ ffn_ext_impl = _PyboostFFNExtPrim()
212
211
  class _PyboostFlashAttentionScoreGradPrim(FlashAttentionScoreGradPrim_):
213
212
  def __call__(self, query, key, value, dy, pse_shift, drop_mask, padding_mask, atten_mask, softmax_max, softmax_sum, softmax_in, attention_in, prefix, actual_seq_qlen, actual_seq_kvlen, head_num, keep_prob, scale_value, pre_tokens, next_tokens, inner_precise, input_layout, sparse_mode):
214
213
  converted_input_layout = str_to_enum('flash_attention_score_grad', 'input_layout', input_layout)
215
- return _convert_stub(super().__call__(query, key, value, dy, pse_shift, drop_mask, padding_mask, atten_mask, softmax_max, softmax_sum, softmax_in, attention_in, prefix, actual_seq_qlen, actual_seq_kvlen, head_num, keep_prob, scale_value, pre_tokens, next_tokens, inner_precise, converted_input_layout, sparse_mode))
214
+ return super().__call__([query, key, value, dy, pse_shift, drop_mask, padding_mask, atten_mask, softmax_max, softmax_sum, softmax_in, attention_in, prefix, actual_seq_qlen, actual_seq_kvlen, head_num, keep_prob, scale_value, pre_tokens, next_tokens, inner_precise, converted_input_layout, sparse_mode])
216
215
 
217
216
 
218
217
  flash_attention_score_grad_impl = _PyboostFlashAttentionScoreGradPrim()
@@ -221,17 +220,35 @@ flash_attention_score_grad_impl = _PyboostFlashAttentionScoreGradPrim()
221
220
  class _PyboostFlashAttentionScorePrim(FlashAttentionScorePrim_):
222
221
  def __call__(self, query, key, value, real_shift, drop_mask, padding_mask, attn_mask, prefix, actual_seq_qlen, actual_seq_kvlen, head_num, keep_prob, scale_value, pre_tokens, next_tokens, inner_precise, input_layout, sparse_mode):
223
222
  converted_input_layout = str_to_enum('flash_attention_score', 'input_layout', input_layout)
224
- return _convert_stub(super().__call__(query, key, value, real_shift, drop_mask, padding_mask, attn_mask, prefix, actual_seq_qlen, actual_seq_kvlen, head_num, keep_prob, scale_value, pre_tokens, next_tokens, inner_precise, converted_input_layout, sparse_mode))
223
+ return super().__call__([query, key, value, real_shift, drop_mask, padding_mask, attn_mask, prefix, actual_seq_qlen, actual_seq_kvlen, head_num, keep_prob, scale_value, pre_tokens, next_tokens, inner_precise, converted_input_layout, sparse_mode])
225
224
 
226
225
 
227
226
  flash_attention_score_impl = _PyboostFlashAttentionScorePrim()
228
227
 
229
228
 
229
+ class _PyboostGluGradPrim(GluGradPrim_):
230
+ def __call__(self, grads, x, axis):
231
+
232
+ return super().__call__([grads, x, axis])
233
+
234
+
235
+ glu_grad_impl = _PyboostGluGradPrim()
236
+
237
+
238
+ class _PyboostGLUPrim(GLUPrim_):
239
+ def __call__(self, x, axis):
240
+
241
+ return super().__call__([x, axis])
242
+
243
+
244
+ glu_impl = _PyboostGLUPrim()
245
+
246
+
230
247
  class _PyboostGridSampler2DGradPrim(GridSampler2DGradPrim_):
231
- def __call__(self, grad, input_x, grid, interpolation_mode, padding_mode, align_corners):
248
+ def __call__(self, grad, input_x, grid, interpolation_mode, padding_mode, align_corners, output_mask):
232
249
  converted_interpolation_mode = str_to_enum('grid_sampler_2d_grad', 'interpolation_mode', interpolation_mode)
233
250
  converted_padding_mode = str_to_enum('grid_sampler_2d_grad', 'padding_mode', padding_mode)
234
- return _convert_stub(super().__call__(grad, input_x, grid, converted_interpolation_mode, converted_padding_mode, align_corners))
251
+ return super().__call__([grad, input_x, grid, converted_interpolation_mode, converted_padding_mode, align_corners, output_mask])
235
252
 
236
253
 
237
254
  grid_sampler_2d_grad_impl = _PyboostGridSampler2DGradPrim()
@@ -241,17 +258,17 @@ class _PyboostGridSampler2DPrim(GridSampler2DPrim_):
241
258
  def __call__(self, input_x, grid, interpolation_mode, padding_mode, align_corners):
242
259
  converted_interpolation_mode = str_to_enum('grid_sampler_2d', 'interpolation_mode', interpolation_mode)
243
260
  converted_padding_mode = str_to_enum('grid_sampler_2d', 'padding_mode', padding_mode)
244
- return _convert_stub(super().__call__(input_x, grid, converted_interpolation_mode, converted_padding_mode, align_corners))
261
+ return super().__call__([input_x, grid, converted_interpolation_mode, converted_padding_mode, align_corners])
245
262
 
246
263
 
247
264
  grid_sampler_2d_impl = _PyboostGridSampler2DPrim()
248
265
 
249
266
 
250
267
  class _PyboostGridSampler3DGradPrim(GridSampler3DGradPrim_):
251
- def __call__(self, grad, input_x, grid, interpolation_mode, padding_mode, align_corners):
268
+ def __call__(self, grad, input_x, grid, interpolation_mode, padding_mode, align_corners, output_mask):
252
269
  converted_interpolation_mode = str_to_enum('grid_sampler_3d_grad', 'interpolation_mode', interpolation_mode)
253
270
  converted_padding_mode = str_to_enum('grid_sampler_3d_grad', 'padding_mode', padding_mode)
254
- return _convert_stub(super().__call__(grad, input_x, grid, converted_interpolation_mode, converted_padding_mode, align_corners))
271
+ return super().__call__([grad, input_x, grid, converted_interpolation_mode, converted_padding_mode, align_corners, output_mask])
255
272
 
256
273
 
257
274
  grid_sampler_3d_grad_impl = _PyboostGridSampler3DGradPrim()
@@ -261,7 +278,7 @@ class _PyboostGridSampler3DPrim(GridSampler3DPrim_):
261
278
  def __call__(self, input_x, grid, interpolation_mode, padding_mode, align_corners):
262
279
  converted_interpolation_mode = str_to_enum('grid_sampler_3d', 'interpolation_mode', interpolation_mode)
263
280
  converted_padding_mode = str_to_enum('grid_sampler_3d', 'padding_mode', padding_mode)
264
- return _convert_stub(super().__call__(input_x, grid, converted_interpolation_mode, converted_padding_mode, align_corners))
281
+ return super().__call__([input_x, grid, converted_interpolation_mode, converted_padding_mode, align_corners])
265
282
 
266
283
 
267
284
  grid_sampler_3d_impl = _PyboostGridSampler3DPrim()
@@ -270,7 +287,7 @@ grid_sampler_3d_impl = _PyboostGridSampler3DPrim()
270
287
  class _PyboostHShrinkGradPrim(HShrinkGradPrim_):
271
288
  def __call__(self, gradients, features, lambd):
272
289
 
273
- return _convert_stub(super().__call__(gradients, features, lambd))
290
+ return super().__call__([gradients, features, lambd])
274
291
 
275
292
 
276
293
  hshrink_grad_impl = _PyboostHShrinkGradPrim()
@@ -279,7 +296,7 @@ hshrink_grad_impl = _PyboostHShrinkGradPrim()
279
296
  class _PyboostHShrinkPrim(HShrinkPrim_):
280
297
  def __call__(self, input, lambd):
281
298
 
282
- return _convert_stub(super().__call__(input, lambd))
299
+ return super().__call__([input, lambd])
283
300
 
284
301
 
285
302
  hshrink_impl = _PyboostHShrinkPrim()
@@ -288,7 +305,7 @@ hshrink_impl = _PyboostHShrinkPrim()
288
305
  class _PyboostIncreFlashAttentionPrim(IncreFlashAttentionPrim_):
289
306
  def __call__(self, query, key, value, attn_mask, actual_seq_lengths, pse_shift, dequant_scale1, quant_scale1, dequant_scale2, quant_scale2, quant_offset2, antiquant_scale, antiquant_offset, block_table, kv_padding_size, num_heads, input_layout, scale_value, num_key_value_heads, block_size, inner_precise):
290
307
  converted_input_layout = str_to_enum('incre_flash_attention', 'input_layout', input_layout)
291
- return _convert_stub(super().__call__(query, key, value, attn_mask, actual_seq_lengths, pse_shift, dequant_scale1, quant_scale1, dequant_scale2, quant_scale2, quant_offset2, antiquant_scale, antiquant_offset, block_table, kv_padding_size, num_heads, converted_input_layout, scale_value, num_key_value_heads, block_size, inner_precise))
308
+ return super().__call__([query, key, value, attn_mask, actual_seq_lengths, pse_shift, dequant_scale1, quant_scale1, dequant_scale2, quant_scale2, quant_offset2, antiquant_scale, antiquant_offset, block_table, kv_padding_size, num_heads, converted_input_layout, scale_value, num_key_value_heads, block_size, inner_precise])
292
309
 
293
310
 
294
311
  incre_flash_attention_impl = _PyboostIncreFlashAttentionPrim()
@@ -297,7 +314,7 @@ incre_flash_attention_impl = _PyboostIncreFlashAttentionPrim()
297
314
  class _PyboostIsClosePrim(IsClosePrim_):
298
315
  def __call__(self, input, other, rtol, atol, equal_nan):
299
316
 
300
- return _convert_stub(super().__call__(input, other, rtol, atol, equal_nan))
317
+ return super().__call__([input, other, rtol, atol, equal_nan])
301
318
 
302
319
 
303
320
  isclose_impl = _PyboostIsClosePrim()
@@ -306,7 +323,7 @@ isclose_impl = _PyboostIsClosePrim()
306
323
  class _PyboostLogSoftmaxGradPrim(LogSoftmaxGradPrim_):
307
324
  def __call__(self, logits, grad, axis):
308
325
 
309
- return _convert_stub(super().__call__(logits, grad, axis))
326
+ return super().__call__([logits, grad, axis])
310
327
 
311
328
 
312
329
  log_softmax_grad_impl = _PyboostLogSoftmaxGradPrim()
@@ -315,7 +332,7 @@ log_softmax_grad_impl = _PyboostLogSoftmaxGradPrim()
315
332
  class _PyboostLogSoftmaxPrim(LogSoftmaxPrim_):
316
333
  def __call__(self, logits, axis):
317
334
 
318
- return _convert_stub(super().__call__(logits, axis))
335
+ return super().__call__([logits, axis])
319
336
 
320
337
 
321
338
  log_softmax_impl = _PyboostLogSoftmaxPrim()
@@ -324,7 +341,7 @@ log_softmax_impl = _PyboostLogSoftmaxPrim()
324
341
  class _PyboostMatMulPrim(MatMulPrim_):
325
342
  def __call__(self, input, mat2, transpose_a, transpose_b):
326
343
 
327
- return _convert_stub(super().__call__(input, mat2, transpose_a, transpose_b))
344
+ return super().__call__([input, mat2, transpose_a, transpose_b])
328
345
 
329
346
 
330
347
  matmul_impl = _PyboostMatMulPrim()
@@ -336,7 +353,7 @@ class _PyboostMaxPoolGradWithIndicesPrim(MaxPoolGradWithIndicesPrim_):
336
353
  converted_strides = to_strides('max_pool_grad_with_indices', 'strides', strides)
337
354
  converted_pads = to_output_padding('max_pool_grad_with_indices', 'pads', pads)
338
355
  converted_dilation = to_dilations('max_pool_grad_with_indices', 'dilation', dilation)
339
- return _convert_stub(super().__call__(x, grad, argmax, converted_kernel_size, converted_strides, converted_pads, converted_dilation, ceil_mode, argmax_type))
356
+ return super().__call__([x, grad, argmax, converted_kernel_size, converted_strides, converted_pads, converted_dilation, ceil_mode, argmax_type])
340
357
 
341
358
 
342
359
  max_pool_grad_with_indices_impl = _PyboostMaxPoolGradWithIndicesPrim()
@@ -348,7 +365,7 @@ class _PyboostMaxPoolGradWithMaskPrim(MaxPoolGradWithMaskPrim_):
348
365
  converted_strides = to_strides('max_pool_grad_with_mask', 'strides', strides)
349
366
  converted_pads = to_output_padding('max_pool_grad_with_mask', 'pads', pads)
350
367
  converted_dilation = to_dilations('max_pool_grad_with_mask', 'dilation', dilation)
351
- return _convert_stub(super().__call__(x, grad, mask, converted_kernel_size, converted_strides, converted_pads, converted_dilation, ceil_mode, argmax_type))
368
+ return super().__call__([x, grad, mask, converted_kernel_size, converted_strides, converted_pads, converted_dilation, ceil_mode, argmax_type])
352
369
 
353
370
 
354
371
  max_pool_grad_with_mask_impl = _PyboostMaxPoolGradWithMaskPrim()
@@ -360,7 +377,7 @@ class _PyboostMaxPoolWithIndicesPrim(MaxPoolWithIndicesPrim_):
360
377
  converted_strides = to_strides('max_pool_with_indices', 'strides', strides)
361
378
  converted_pads = to_output_padding('max_pool_with_indices', 'pads', pads)
362
379
  converted_dilation = to_dilations('max_pool_with_indices', 'dilation', dilation)
363
- return _convert_stub(super().__call__(x, converted_kernel_size, converted_strides, converted_pads, converted_dilation, ceil_mode, argmax_type))
380
+ return super().__call__([x, converted_kernel_size, converted_strides, converted_pads, converted_dilation, ceil_mode, argmax_type])
364
381
 
365
382
 
366
383
  max_pool_with_indices_impl = _PyboostMaxPoolWithIndicesPrim()
@@ -372,34 +389,70 @@ class _PyboostMaxPoolWithMaskPrim(MaxPoolWithMaskPrim_):
372
389
  converted_strides = to_strides('max_pool_with_mask', 'strides', strides)
373
390
  converted_pads = to_output_padding('max_pool_with_mask', 'pads', pads)
374
391
  converted_dilation = to_dilations('max_pool_with_mask', 'dilation', dilation)
375
- return _convert_stub(super().__call__(x, converted_kernel_size, converted_strides, converted_pads, converted_dilation, ceil_mode, argmax_type))
392
+ return super().__call__([x, converted_kernel_size, converted_strides, converted_pads, converted_dilation, ceil_mode, argmax_type])
376
393
 
377
394
 
378
395
  max_pool_with_mask_impl = _PyboostMaxPoolWithMaskPrim()
379
396
 
380
397
 
398
+ class _PyboostMeshgridPrim(MeshgridPrim_):
399
+ def __call__(self, inputs, indexing):
400
+ converted_indexing = str_to_enum('meshgrid', 'indexing', indexing)
401
+ return super().__call__([inputs, converted_indexing])
402
+
403
+
404
+ meshgrid_impl = _PyboostMeshgridPrim()
405
+
406
+
381
407
  class _PyboostNanToNumPrim(NanToNumPrim_):
382
408
  def __call__(self, input, nan, posinf, neginf):
383
409
 
384
- return _convert_stub(super().__call__(input, nan, posinf, neginf))
410
+ return super().__call__([input, nan, posinf, neginf])
385
411
 
386
412
 
387
413
  nan_to_num_impl = _PyboostNanToNumPrim()
388
414
 
389
415
 
416
+ class _PyboostNLLLossGradPrim(NLLLossGradPrim_):
417
+ def __call__(self, logits, loss_grad, labels, weight, total_weight, reduction, ignore_index):
418
+ converted_reduction = str_to_enum('nllloss_grad', 'reduction', reduction)
419
+ return super().__call__([logits, loss_grad, labels, weight, total_weight, converted_reduction, ignore_index])
420
+
421
+
422
+ nllloss_grad_impl = _PyboostNLLLossGradPrim()
423
+
424
+
425
+ class _PyboostNLLLossPrim(NLLLossPrim_):
426
+ def __call__(self, logits, labels, weight, reduction, ignore_index):
427
+ converted_reduction = str_to_enum('nllloss', 'reduction', reduction)
428
+ return super().__call__([logits, labels, weight, converted_reduction, ignore_index])
429
+
430
+
431
+ nllloss_impl = _PyboostNLLLossPrim()
432
+
433
+
390
434
  class _PyboostOneHotExtPrim(OneHotExtPrim_):
391
435
  def __call__(self, tensor, num_classes, on_value, off_value, axis):
392
436
 
393
- return _convert_stub(super().__call__(tensor, num_classes, on_value, off_value, axis))
437
+ return super().__call__([tensor, num_classes, on_value, off_value, axis])
394
438
 
395
439
 
396
440
  one_hot_ext_impl = _PyboostOneHotExtPrim()
397
441
 
398
442
 
443
+ class _PyboostPromptFlashAttentionPrim(PromptFlashAttentionPrim_):
444
+ def __call__(self, query, key, value, attn_mask, actual_seq_lengths, actual_seq_lengths_kv, pse_shift, deq_scale1, quant_scale1, deq_scale2, quant_scale2, quant_offset2, num_heads, scale_value, pre_tokens, next_tokens, input_layout, num_key_value_heads, sparse_mode, inner_precise):
445
+ converted_input_layout = str_to_enum('prompt_flash_attention', 'input_layout', input_layout)
446
+ return super().__call__([query, key, value, attn_mask, actual_seq_lengths, actual_seq_lengths_kv, pse_shift, deq_scale1, quant_scale1, deq_scale2, quant_scale2, quant_offset2, num_heads, scale_value, pre_tokens, next_tokens, converted_input_layout, num_key_value_heads, sparse_mode, inner_precise])
447
+
448
+
449
+ prompt_flash_attention_impl = _PyboostPromptFlashAttentionPrim()
450
+
451
+
399
452
  class _PyboostReduceAllPrim(ReduceAllPrim_):
400
453
  def __call__(self, input, axis, keep_dims):
401
454
 
402
- return _convert_stub(super().__call__(input, axis, keep_dims))
455
+ return super().__call__([input, axis, keep_dims])
403
456
 
404
457
 
405
458
  reduce_all_impl = _PyboostReduceAllPrim()
@@ -408,16 +461,34 @@ reduce_all_impl = _PyboostReduceAllPrim()
408
461
  class _PyboostReduceAnyPrim(ReduceAnyPrim_):
409
462
  def __call__(self, x, axis, keep_dims):
410
463
 
411
- return _convert_stub(super().__call__(x, axis, keep_dims))
464
+ return super().__call__([x, axis, keep_dims])
412
465
 
413
466
 
414
467
  reduce_any_impl = _PyboostReduceAnyPrim()
415
468
 
416
469
 
470
+ class _PyboostReduceMaxPrim(ReduceMaxPrim_):
471
+ def __call__(self, x, axis, keep_dims):
472
+
473
+ return super().__call__([x, axis, keep_dims])
474
+
475
+
476
+ reduce_max_impl = _PyboostReduceMaxPrim()
477
+
478
+
479
+ class _PyboostReduceMinPrim(ReduceMinPrim_):
480
+ def __call__(self, x, axis, keep_dims):
481
+
482
+ return super().__call__([x, axis, keep_dims])
483
+
484
+
485
+ reduce_min_impl = _PyboostReduceMinPrim()
486
+
487
+
417
488
  class _PyboostReverseV2Prim(ReverseV2Prim_):
418
489
  def __call__(self, input, axis):
419
490
 
420
- return _convert_stub(super().__call__(input, axis))
491
+ return super().__call__([input, axis])
421
492
 
422
493
 
423
494
  reverse_v2_impl = _PyboostReverseV2Prim()
@@ -426,16 +497,16 @@ reverse_v2_impl = _PyboostReverseV2Prim()
426
497
  class _PyboostRmsNormPrim(RmsNormPrim_):
427
498
  def __call__(self, x, gamma, epsilon):
428
499
 
429
- return _convert_stub(super().__call__(x, gamma, epsilon))
500
+ return super().__call__([x, gamma, epsilon])
430
501
 
431
502
 
432
503
  rms_norm_impl = _PyboostRmsNormPrim()
433
504
 
434
505
 
435
506
  class _PyboostRollPrim(RollPrim_):
436
- def __call__(self, input, shift, axis):
507
+ def __call__(self, input, shifts, dims):
437
508
 
438
- return _convert_stub(super().__call__(input, shift, axis))
509
+ return super().__call__([input, shifts, dims])
439
510
 
440
511
 
441
512
  roll_impl = _PyboostRollPrim()
@@ -444,16 +515,34 @@ roll_impl = _PyboostRollPrim()
444
515
  class _PyboostSearchSortedPrim(SearchSortedPrim_):
445
516
  def __call__(self, sorted_sequence, values, sorter, dtype, right):
446
517
 
447
- return _convert_stub(super().__call__(sorted_sequence, values, sorter, dtype, right))
518
+ return super().__call__([sorted_sequence, values, sorter, dtype, right])
448
519
 
449
520
 
450
521
  searchsorted_impl = _PyboostSearchSortedPrim()
451
522
 
452
523
 
524
+ class _PyboostSmoothL1LossGradPrim(SmoothL1LossGradPrim_):
525
+ def __call__(self, prediction, target, dout, beta, reduction):
526
+ converted_reduction = str_to_enum('smooth_l1_loss_grad', 'reduction', reduction)
527
+ return super().__call__([prediction, target, dout, beta, converted_reduction])
528
+
529
+
530
+ smooth_l1_loss_grad_impl = _PyboostSmoothL1LossGradPrim()
531
+
532
+
533
+ class _PyboostSmoothL1LossPrim(SmoothL1LossPrim_):
534
+ def __call__(self, prediction, target, beta, reduction):
535
+ converted_reduction = str_to_enum('smooth_l1_loss', 'reduction', reduction)
536
+ return super().__call__([prediction, target, beta, converted_reduction])
537
+
538
+
539
+ smooth_l1_loss_impl = _PyboostSmoothL1LossPrim()
540
+
541
+
453
542
  class _PyboostSoftmaxPrim(SoftmaxPrim_):
454
543
  def __call__(self, input, axis):
455
544
 
456
- return _convert_stub(super().__call__(input, axis))
545
+ return super().__call__([input, axis])
457
546
 
458
547
 
459
548
  softmax_impl = _PyboostSoftmaxPrim()
@@ -462,7 +551,7 @@ softmax_impl = _PyboostSoftmaxPrim()
462
551
  class _PyboostSoftShrinkGradPrim(SoftShrinkGradPrim_):
463
552
  def __call__(self, input_grad, input_x, lambd):
464
553
 
465
- return _convert_stub(super().__call__(input_grad, input_x, lambd))
554
+ return super().__call__([input_grad, input_x, lambd])
466
555
 
467
556
 
468
557
  softshrink_grad_impl = _PyboostSoftShrinkGradPrim()
@@ -471,43 +560,79 @@ softshrink_grad_impl = _PyboostSoftShrinkGradPrim()
471
560
  class _PyboostSoftShrinkPrim(SoftShrinkPrim_):
472
561
  def __call__(self, input, lambd):
473
562
 
474
- return _convert_stub(super().__call__(input, lambd))
563
+ return super().__call__([input, lambd])
475
564
 
476
565
 
477
566
  softshrink_impl = _PyboostSoftShrinkPrim()
478
567
 
479
568
 
569
+ class _PyboostSoftMarginLossGradPrim(SoftMarginLossGradPrim_):
570
+ def __call__(self, predict, label, dout, reduction):
571
+ converted_reduction = str_to_enum('soft_margin_loss_grad', 'reduction', reduction)
572
+ return super().__call__([predict, label, dout, converted_reduction])
573
+
574
+
575
+ soft_margin_loss_grad_impl = _PyboostSoftMarginLossGradPrim()
576
+
577
+
578
+ class _PyboostSoftMarginLossPrim(SoftMarginLossPrim_):
579
+ def __call__(self, input, target, reduction):
580
+ converted_reduction = str_to_enum('soft_margin_loss', 'reduction', reduction)
581
+ return super().__call__([input, target, converted_reduction])
582
+
583
+
584
+ soft_margin_loss_impl = _PyboostSoftMarginLossPrim()
585
+
586
+
587
+ class _PyboostSplitPrim(SplitPrim_):
588
+ def __call__(self, input_x, axis, output_num):
589
+
590
+ return super().__call__([input_x, axis, output_num])
591
+
592
+
593
+ split_impl = _PyboostSplitPrim()
594
+
595
+
596
+ class _PyboostSqueezePrim(SqueezePrim_):
597
+ def __call__(self, input, axis):
598
+
599
+ return super().__call__([input, axis])
600
+
601
+
602
+ squeeze_impl = _PyboostSqueezePrim()
603
+
604
+
480
605
  class _PyboostStackExtPrim(StackExtPrim_):
481
606
  def __call__(self, tensors, dim):
482
607
 
483
- return _convert_stub(super().__call__(tensors, dim))
608
+ return super().__call__([tensors, dim])
484
609
 
485
610
 
486
611
  stack_ext_impl = _PyboostStackExtPrim()
487
612
 
488
613
 
489
- class _PyboostTrilExtPrim(TrilExtPrim_):
614
+ class _PyboostTriuPrim(TriuPrim_):
490
615
  def __call__(self, input, diagonal):
491
616
 
492
- return _convert_stub(super().__call__(input, diagonal))
617
+ return super().__call__([input, diagonal])
493
618
 
494
619
 
495
- tril_ext_impl = _PyboostTrilExtPrim()
620
+ triu_impl = _PyboostTriuPrim()
496
621
 
497
622
 
498
- class _PyboostTriuPrim(TriuPrim_):
499
- def __call__(self, input, diagonal):
623
+ class _PyboostUniqueConsecutivePrim(UniqueConsecutivePrim_):
624
+ def __call__(self, input, return_inverse, return_counts, dim):
500
625
 
501
- return _convert_stub(super().__call__(input, diagonal))
626
+ return super().__call__([input, return_inverse, return_counts, dim])
502
627
 
503
628
 
504
- triu_impl = _PyboostTriuPrim()
629
+ unique_consecutive_impl = _PyboostUniqueConsecutivePrim()
505
630
 
506
631
 
507
632
  class _PyboostUpsampleTrilinear3DGradPrim(UpsampleTrilinear3DGradPrim_):
508
633
  def __call__(self, dy, input_size, output_size, scales, align_corners):
509
634
 
510
- return _convert_stub(super().__call__(dy, input_size, output_size, scales, align_corners))
635
+ return super().__call__([dy, input_size, output_size, scales, align_corners])
511
636
 
512
637
 
513
638
  upsample_trilinear3d_grad_impl = _PyboostUpsampleTrilinear3DGradPrim()
@@ -516,16 +641,25 @@ upsample_trilinear3d_grad_impl = _PyboostUpsampleTrilinear3DGradPrim()
516
641
  class _PyboostUpsampleTrilinear3DPrim(UpsampleTrilinear3DPrim_):
517
642
  def __call__(self, x, output_size, scales, align_corners):
518
643
 
519
- return _convert_stub(super().__call__(x, output_size, scales, align_corners))
644
+ return super().__call__([x, output_size, scales, align_corners])
520
645
 
521
646
 
522
647
  upsample_trilinear3d_impl = _PyboostUpsampleTrilinear3DPrim()
523
648
 
524
649
 
650
+ class _PyboostFusedInferAttentionScorePrim(FusedInferAttentionScorePrim_):
651
+ def __call__(self, query, key, value, pse_shift, attn_mask, actual_seq_lengths, actual_seq_lengths_kv, dequant_scale1, quant_scale1, dequant_scale2, quant_scale2, quant_offset2, antiquant_scale, antiquant_offset, block_table, query_padding_size, kv_padding_size, key_antiquant_scale, key_antiquant_offset, value_antiquant_scale, value_antiquant_offset, key_shared_prefix, value_shared_prefix, actual_shared_prefix_len, num_heads, scale_value, pre_tokens, next_tokens, input_layout, num_key_value_heads, sparse_mode, inner_precise, block_size, antiquant_mode, softmax_lse_flag, key_antiquant_mode, value_antiquant_mode):
652
+ converted_input_layout = str_to_enum('fused_infer_attention_score', 'input_layout', input_layout)
653
+ return super().__call__([query, key, value, pse_shift, attn_mask, actual_seq_lengths, actual_seq_lengths_kv, dequant_scale1, quant_scale1, dequant_scale2, quant_scale2, quant_offset2, antiquant_scale, antiquant_offset, block_table, query_padding_size, kv_padding_size, key_antiquant_scale, key_antiquant_offset, value_antiquant_scale, value_antiquant_offset, key_shared_prefix, value_shared_prefix, actual_shared_prefix_len, num_heads, scale_value, pre_tokens, next_tokens, converted_input_layout, num_key_value_heads, sparse_mode, inner_precise, block_size, antiquant_mode, softmax_lse_flag, key_antiquant_mode, value_antiquant_mode])
654
+
655
+
656
+ fused_infer_attention_score_impl = _PyboostFusedInferAttentionScorePrim()
657
+
658
+
525
659
  class _PyboostGroupedMatmulPrim(GroupedMatmulPrim_):
526
- def __call__(self, x, weight, bias, scale, offset, antiquant_scale, antiquant_offset, group_list, split_item, group_type):
660
+ def __call__(self, x, weight, bias, scale, offset, antiquant_scale, antiquant_offset, group_list, split_item, group_type, transpose_a, transpose_b):
527
661
 
528
- return _convert_stub(super().__call__(x, weight, bias, scale, offset, antiquant_scale, antiquant_offset, group_list, split_item, group_type))
662
+ return super().__call__([x, weight, bias, scale, offset, antiquant_scale, antiquant_offset, group_list, split_item, group_type, transpose_a, transpose_b])
529
663
 
530
664
 
531
665
  grouped_matmul_impl = _PyboostGroupedMatmulPrim()
@@ -534,7 +668,7 @@ grouped_matmul_impl = _PyboostGroupedMatmulPrim()
534
668
  class _PyboostQuantBatchMatmulPrim(QuantBatchMatmulPrim_):
535
669
  def __call__(self, x1, x2, scale, offset, bias, pertokenScaleOptional, transpose_x1, transpose_x2, dtype):
536
670
 
537
- return _convert_stub(super().__call__(x1, x2, scale, offset, bias, pertokenScaleOptional, transpose_x1, transpose_x2, dtype))
671
+ return super().__call__([x1, x2, scale, offset, bias, pertokenScaleOptional, transpose_x1, transpose_x2, dtype])
538
672
 
539
673
 
540
674
  quant_batch_matmul_impl = _PyboostQuantBatchMatmulPrim()
@@ -543,7 +677,7 @@ quant_batch_matmul_impl = _PyboostQuantBatchMatmulPrim()
543
677
  class _PyboostWeightQuantBatchMatmulPrim(WeightQuantBatchMatmulPrim_):
544
678
  def __call__(self, x, weight, antiquant_scale, antiquant_offset, quant_scale, quant_offset, bias, transpose_x, transpose_weight, antiquant_group_size):
545
679
 
546
- return _convert_stub(super().__call__(x, weight, antiquant_scale, antiquant_offset, quant_scale, quant_offset, bias, transpose_x, transpose_weight, antiquant_group_size))
680
+ return super().__call__([x, weight, antiquant_scale, antiquant_offset, quant_scale, quant_offset, bias, transpose_x, transpose_weight, antiquant_group_size])
547
681
 
548
682
 
549
683
  weight_quant_batch_matmul_impl = _PyboostWeightQuantBatchMatmulPrim()