mindspore 2.7.0rc1__cp310-cp310-win_amd64.whl → 2.7.1__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (370) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +5 -2
  3. mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
  4. mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
  5. mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
  6. mindspore/_checkparam.py +2 -2
  7. mindspore/_extends/builtin_operations.py +3 -3
  8. mindspore/_extends/parallel_compile/akg_compiler/custom.py +1109 -0
  9. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
  10. mindspore/_extends/parse/__init__.py +3 -3
  11. mindspore/_extends/parse/compile_config.py +24 -1
  12. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +6 -3
  13. mindspore/_extends/parse/parser.py +28 -22
  14. mindspore/_extends/parse/resources.py +1 -1
  15. mindspore/_extends/parse/standard_method.py +23 -2
  16. mindspore/_extends/parse/trope.py +2 -1
  17. mindspore/_extends/pijit/pijit_func_white_list.py +9 -27
  18. mindspore/amp.py +0 -18
  19. mindspore/avcodec-59.dll +0 -0
  20. mindspore/avdevice-59.dll +0 -0
  21. mindspore/avfilter-8.dll +0 -0
  22. mindspore/avformat-59.dll +0 -0
  23. mindspore/avutil-57.dll +0 -0
  24. mindspore/boost/base.py +29 -2
  25. mindspore/common/__init__.py +18 -12
  26. mindspore/common/_decorator.py +3 -2
  27. mindspore/common/_grad_function.py +3 -1
  28. mindspore/common/_tensor_cpp_method.py +1 -1
  29. mindspore/common/_tensor_docs.py +371 -96
  30. mindspore/common/_utils.py +7 -43
  31. mindspore/common/api.py +434 -135
  32. mindspore/common/dtype.py +98 -57
  33. mindspore/common/dump.py +7 -108
  34. mindspore/common/dynamic_shape/__init__.py +0 -0
  35. mindspore/common/{auto_dynamic_shape.py → dynamic_shape/auto_dynamic_shape.py} +15 -23
  36. mindspore/common/dynamic_shape/enable_dynamic.py +197 -0
  37. mindspore/common/file_system.py +59 -9
  38. mindspore/common/hook_handle.py +82 -3
  39. mindspore/common/jit_config.py +5 -1
  40. mindspore/common/jit_trace.py +27 -12
  41. mindspore/common/lazy_inline.py +5 -3
  42. mindspore/common/np_dtype.py +3 -3
  43. mindspore/common/parameter.py +17 -127
  44. mindspore/common/recompute.py +4 -13
  45. mindspore/common/tensor.py +50 -217
  46. mindspore/communication/_comm_helper.py +11 -1
  47. mindspore/communication/comm_func.py +138 -4
  48. mindspore/communication/management.py +85 -1
  49. mindspore/config/op_info.config +0 -15
  50. mindspore/context.py +20 -106
  51. mindspore/dataset/__init__.py +1 -1
  52. mindspore/dataset/audio/transforms.py +1 -1
  53. mindspore/dataset/core/config.py +35 -1
  54. mindspore/dataset/engine/datasets.py +338 -319
  55. mindspore/dataset/engine/datasets_user_defined.py +38 -22
  56. mindspore/dataset/engine/datasets_vision.py +1 -1
  57. mindspore/dataset/engine/validators.py +1 -15
  58. mindspore/dataset/transforms/c_transforms.py +2 -2
  59. mindspore/dataset/transforms/transforms.py +3 -3
  60. mindspore/dataset/vision/__init__.py +1 -1
  61. mindspore/dataset/vision/py_transforms.py +8 -8
  62. mindspore/dataset/vision/transforms.py +17 -5
  63. mindspore/dataset/vision/utils.py +632 -21
  64. mindspore/device_context/ascend/op_tuning.py +35 -1
  65. mindspore/dnnl.dll +0 -0
  66. mindspore/{profiler/common/validator → graph}/__init__.py +9 -1
  67. mindspore/graph/custom_pass.py +55 -0
  68. mindspore/include/api/cell.h +28 -4
  69. mindspore/include/api/cfg.h +24 -7
  70. mindspore/include/api/context.h +1 -0
  71. mindspore/include/api/delegate.h +0 -2
  72. mindspore/include/api/dual_abi_helper.h +100 -19
  73. mindspore/include/api/graph.h +14 -1
  74. mindspore/include/api/kernel.h +16 -3
  75. mindspore/include/api/kernel_api.h +9 -1
  76. mindspore/include/api/metrics/accuracy.h +9 -0
  77. mindspore/include/api/model.h +5 -1
  78. mindspore/include/api/model_group.h +4 -0
  79. mindspore/include/api/model_parallel_runner.h +2 -0
  80. mindspore/include/api/status.h +48 -10
  81. mindspore/include/api/types.h +6 -1
  82. mindspore/include/dataset/constants.h +9 -0
  83. mindspore/include/dataset/execute.h +2 -2
  84. mindspore/jpeg62.dll +0 -0
  85. mindspore/mindrecord/__init__.py +3 -3
  86. mindspore/mindrecord/common/exceptions.py +1 -0
  87. mindspore/mindrecord/config.py +1 -1
  88. mindspore/{parallel/mpi → mindrecord/core}/__init__.py +4 -1
  89. mindspore/mindrecord/{shardheader.py → core/shardheader.py} +2 -1
  90. mindspore/mindrecord/{shardindexgenerator.py → core/shardindexgenerator.py} +1 -1
  91. mindspore/mindrecord/{shardreader.py → core/shardreader.py} +2 -1
  92. mindspore/mindrecord/{shardsegment.py → core/shardsegment.py} +2 -2
  93. mindspore/mindrecord/{shardutils.py → core/shardutils.py} +1 -1
  94. mindspore/mindrecord/{shardwriter.py → core/shardwriter.py} +1 -1
  95. mindspore/mindrecord/filereader.py +4 -4
  96. mindspore/mindrecord/filewriter.py +5 -5
  97. mindspore/mindrecord/mindpage.py +2 -2
  98. mindspore/mindrecord/tools/cifar10.py +4 -3
  99. mindspore/mindrecord/tools/cifar100.py +1 -1
  100. mindspore/mindrecord/tools/cifar100_to_mr.py +1 -1
  101. mindspore/mindrecord/tools/cifar10_to_mr.py +6 -6
  102. mindspore/mindrecord/tools/csv_to_mr.py +1 -1
  103. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  104. mindspore/mindrecord/tools/mnist_to_mr.py +1 -1
  105. mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -1
  106. mindspore/mindspore_backend_common.dll +0 -0
  107. mindspore/mindspore_backend_manager.dll +0 -0
  108. mindspore/mindspore_cluster.dll +0 -0
  109. mindspore/mindspore_common.dll +0 -0
  110. mindspore/mindspore_core.dll +0 -0
  111. mindspore/mindspore_cpu.dll +0 -0
  112. mindspore/mindspore_dump.dll +0 -0
  113. mindspore/mindspore_frontend.dll +0 -0
  114. mindspore/mindspore_glog.dll +0 -0
  115. mindspore/mindspore_hardware_abstract.dll +0 -0
  116. mindspore/mindspore_memory_pool.dll +0 -0
  117. mindspore/mindspore_ms_backend.dll +0 -0
  118. mindspore/mindspore_ops.dll +0 -0
  119. mindspore/{mindspore_ops_host.dll → mindspore_ops_cpu.dll} +0 -0
  120. mindspore/mindspore_profiler.dll +0 -0
  121. mindspore/mindspore_pyboost.dll +0 -0
  122. mindspore/mindspore_pynative.dll +0 -0
  123. mindspore/mindspore_runtime_pipeline.dll +0 -0
  124. mindspore/mindspore_runtime_utils.dll +0 -0
  125. mindspore/mindspore_tools.dll +0 -0
  126. mindspore/mint/__init__.py +15 -10
  127. mindspore/mint/distributed/__init__.py +4 -0
  128. mindspore/mint/distributed/distributed.py +392 -69
  129. mindspore/mint/nn/__init__.py +2 -16
  130. mindspore/mint/nn/functional.py +4 -110
  131. mindspore/mint/nn/layer/__init__.py +0 -2
  132. mindspore/mint/nn/layer/_functions.py +1 -2
  133. mindspore/mint/nn/layer/activation.py +0 -6
  134. mindspore/mint/nn/layer/basic.py +0 -47
  135. mindspore/mint/nn/layer/conv.py +10 -10
  136. mindspore/mint/nn/layer/normalization.py +11 -16
  137. mindspore/mint/nn/layer/pooling.py +0 -4
  138. mindspore/nn/__init__.py +1 -3
  139. mindspore/nn/cell.py +231 -239
  140. mindspore/nn/layer/activation.py +4 -2
  141. mindspore/nn/layer/basic.py +56 -14
  142. mindspore/nn/layer/container.py +16 -0
  143. mindspore/nn/layer/embedding.py +4 -169
  144. mindspore/nn/layer/image.py +1 -1
  145. mindspore/nn/layer/normalization.py +2 -1
  146. mindspore/nn/layer/thor_layer.py +4 -85
  147. mindspore/nn/optim/ada_grad.py +0 -1
  148. mindspore/nn/optim/adafactor.py +0 -1
  149. mindspore/nn/optim/adam.py +32 -127
  150. mindspore/nn/optim/adamax.py +0 -1
  151. mindspore/nn/optim/asgd.py +0 -1
  152. mindspore/nn/optim/ftrl.py +8 -102
  153. mindspore/nn/optim/lamb.py +1 -4
  154. mindspore/nn/optim/lars.py +0 -3
  155. mindspore/nn/optim/lazyadam.py +25 -218
  156. mindspore/nn/optim/momentum.py +5 -43
  157. mindspore/nn/optim/optimizer.py +6 -55
  158. mindspore/nn/optim/proximal_ada_grad.py +0 -1
  159. mindspore/nn/optim/rmsprop.py +0 -1
  160. mindspore/nn/optim/rprop.py +0 -1
  161. mindspore/nn/optim/sgd.py +0 -1
  162. mindspore/nn/optim/tft_wrapper.py +2 -4
  163. mindspore/nn/optim/thor.py +0 -2
  164. mindspore/nn/probability/bijector/bijector.py +7 -8
  165. mindspore/nn/probability/bijector/gumbel_cdf.py +2 -2
  166. mindspore/nn/probability/bijector/power_transform.py +20 -21
  167. mindspore/nn/probability/bijector/scalar_affine.py +5 -5
  168. mindspore/nn/probability/bijector/softplus.py +13 -14
  169. mindspore/nn/probability/distribution/_utils/utils.py +2 -2
  170. mindspore/nn/wrap/cell_wrapper.py +39 -5
  171. mindspore/nn/wrap/grad_reducer.py +4 -89
  172. mindspore/numpy/array_creations.py +4 -4
  173. mindspore/numpy/fft.py +9 -9
  174. mindspore/numpy/utils_const.py +1 -1
  175. mindspore/{nn/reinforcement → onnx}/__init__.py +5 -8
  176. mindspore/onnx/onnx_export.py +137 -0
  177. mindspore/opencv_core4110.dll +0 -0
  178. mindspore/opencv_imgcodecs4110.dll +0 -0
  179. mindspore/{opencv_imgproc452.dll → opencv_imgproc4110.dll} +0 -0
  180. mindspore/ops/__init__.py +2 -0
  181. mindspore/ops/_grad_experimental/grad_comm_ops.py +38 -2
  182. mindspore/ops/_grad_experimental/grad_inner_ops.py +0 -9
  183. mindspore/ops/_op_impl/aicpu/__init__.py +0 -10
  184. mindspore/ops/_op_impl/cpu/__init__.py +1 -5
  185. mindspore/ops/_op_impl/cpu/{buffer_append.py → joinedstr_op.py} +8 -8
  186. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +28 -24
  187. mindspore/ops/auto_generate/gen_extend_func.py +6 -11
  188. mindspore/ops/auto_generate/gen_ops_def.py +385 -154
  189. mindspore/ops/auto_generate/gen_ops_prim.py +5676 -5167
  190. mindspore/ops/communication.py +97 -0
  191. mindspore/ops/composite/__init__.py +5 -2
  192. mindspore/ops/composite/base.py +16 -2
  193. mindspore/ops/composite/multitype_ops/__init__.py +3 -1
  194. mindspore/ops/composite/multitype_ops/_compile_utils.py +150 -8
  195. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -1
  196. mindspore/ops/composite/multitype_ops/add_impl.py +7 -0
  197. mindspore/ops/composite/multitype_ops/mod_impl.py +27 -0
  198. mindspore/ops/function/__init__.py +2 -0
  199. mindspore/ops/function/array_func.py +24 -18
  200. mindspore/ops/function/comm_func.py +3883 -0
  201. mindspore/ops/function/debug_func.py +7 -6
  202. mindspore/ops/function/grad/grad_func.py +4 -12
  203. mindspore/ops/function/math_func.py +89 -86
  204. mindspore/ops/function/nn_func.py +92 -313
  205. mindspore/ops/function/random_func.py +9 -18
  206. mindspore/ops/functional.py +4 -1
  207. mindspore/ops/functional_overload.py +377 -30
  208. mindspore/ops/operations/__init__.py +2 -5
  209. mindspore/ops/operations/_custom_ops_utils.py +7 -9
  210. mindspore/ops/operations/_inner_ops.py +12 -50
  211. mindspore/ops/operations/_rl_inner_ops.py +0 -933
  212. mindspore/ops/operations/array_ops.py +5 -50
  213. mindspore/ops/operations/comm_ops.py +95 -17
  214. mindspore/ops/operations/custom_ops.py +237 -22
  215. mindspore/ops/operations/debug_ops.py +33 -35
  216. mindspore/ops/operations/manually_defined/ops_def.py +39 -318
  217. mindspore/ops/operations/math_ops.py +5 -5
  218. mindspore/ops/operations/nn_ops.py +3 -3
  219. mindspore/ops/operations/sparse_ops.py +0 -83
  220. mindspore/ops/primitive.py +4 -27
  221. mindspore/ops/tensor_method.py +88 -10
  222. mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +5 -5
  223. mindspore/ops_generate/aclnn/gen_aclnn_implement.py +8 -8
  224. mindspore/ops_generate/api/functions_cc_generator.py +53 -4
  225. mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +25 -11
  226. mindspore/ops_generate/common/gen_constants.py +11 -10
  227. mindspore/ops_generate/common/op_proto.py +18 -1
  228. mindspore/ops_generate/common/template.py +102 -245
  229. mindspore/ops_generate/common/template_utils.py +212 -0
  230. mindspore/ops_generate/gen_custom_ops.py +69 -0
  231. mindspore/ops_generate/op_def/ops_def_cc_generator.py +78 -7
  232. mindspore/ops_generate/op_def_py/base_op_prim_py_generator.py +360 -0
  233. mindspore/ops_generate/op_def_py/custom_op_prim_py_generator.py +140 -0
  234. mindspore/ops_generate/op_def_py/op_def_py_generator.py +54 -7
  235. mindspore/ops_generate/op_def_py/op_prim_py_generator.py +5 -312
  236. mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +74 -17
  237. mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +22 -5
  238. mindspore/ops_generate/pyboost/gen_pyboost_func.py +0 -16
  239. mindspore/ops_generate/pyboost/op_template_parser.py +3 -2
  240. mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +21 -5
  241. mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +2 -2
  242. mindspore/ops_generate/pyboost/pyboost_functions_impl_cpp_generator.py +30 -10
  243. mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +10 -3
  244. mindspore/ops_generate/pyboost/pyboost_internal_kernel_info_adapter_generator.py +1 -1
  245. mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +19 -9
  246. mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +71 -28
  247. mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +10 -9
  248. mindspore/ops_generate/pyboost/pyboost_utils.py +27 -16
  249. mindspore/ops_generate/resources/yaml_loader.py +13 -0
  250. mindspore/ops_generate/tensor_py_cc_generator.py +2 -2
  251. mindspore/parallel/_auto_parallel_context.py +5 -15
  252. mindspore/parallel/_cell_wrapper.py +1 -1
  253. mindspore/parallel/_parallel_serialization.py +4 -6
  254. mindspore/parallel/_ps_context.py +2 -2
  255. mindspore/parallel/_utils.py +34 -17
  256. mindspore/parallel/auto_parallel.py +23 -9
  257. mindspore/parallel/checkpoint_transform.py +20 -2
  258. mindspore/parallel/cluster/process_entity/_api.py +28 -33
  259. mindspore/parallel/cluster/process_entity/_utils.py +9 -5
  260. mindspore/parallel/cluster/run.py +5 -3
  261. mindspore/{experimental/llm_boost/ascend_native → parallel/distributed}/__init__.py +21 -22
  262. mindspore/parallel/distributed/distributed_data_parallel.py +393 -0
  263. mindspore/parallel/distributed/flatten_grad_buffer.py +295 -0
  264. mindspore/parallel/function/reshard_func.py +6 -5
  265. mindspore/parallel/nn/parallel_cell_wrapper.py +40 -3
  266. mindspore/parallel/nn/parallel_grad_reducer.py +0 -8
  267. mindspore/parallel/shard.py +7 -21
  268. mindspore/parallel/strategy.py +336 -0
  269. mindspore/parallel/transform_safetensors.py +127 -20
  270. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +13 -9
  271. mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +1 -1
  272. mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +1 -1
  273. mindspore/profiler/common/constant.py +5 -0
  274. mindspore/profiler/common/file_manager.py +9 -0
  275. mindspore/profiler/common/msprof_cmd_tool.py +40 -4
  276. mindspore/profiler/common/path_manager.py +65 -24
  277. mindspore/profiler/common/profiler_context.py +27 -14
  278. mindspore/profiler/common/profiler_info.py +3 -3
  279. mindspore/profiler/common/profiler_meta_data.py +1 -0
  280. mindspore/profiler/common/profiler_op_analyse.py +10 -6
  281. mindspore/profiler/common/profiler_path_manager.py +13 -0
  282. mindspore/profiler/common/util.py +30 -3
  283. mindspore/profiler/dynamic_profiler.py +91 -46
  284. mindspore/profiler/envprofiler.py +30 -5
  285. mindspore/profiler/experimental_config.py +18 -2
  286. mindspore/profiler/platform/cpu_profiler.py +10 -4
  287. mindspore/profiler/platform/npu_profiler.py +34 -7
  288. mindspore/profiler/profiler.py +193 -145
  289. mindspore/profiler/profiler_action_controller.py +1 -1
  290. mindspore/profiler/profiler_interface.py +2 -2
  291. mindspore/rewrite/symbol_tree/symbol_tree.py +1 -1
  292. mindspore/run_check/_check_version.py +108 -24
  293. mindspore/runtime/__init__.py +9 -6
  294. mindspore/runtime/executor.py +35 -0
  295. mindspore/runtime/memory.py +113 -0
  296. mindspore/runtime/thread_bind_core.py +1 -1
  297. mindspore/swresample-4.dll +0 -0
  298. mindspore/swscale-6.dll +0 -0
  299. mindspore/tinyxml2.dll +0 -0
  300. mindspore/{experimental/llm_boost → tools}/__init__.py +5 -5
  301. mindspore/tools/data_dump.py +130 -0
  302. mindspore/tools/sdc_detect.py +91 -0
  303. mindspore/tools/stress_detect.py +63 -0
  304. mindspore/train/__init__.py +6 -6
  305. mindspore/train/_utils.py +8 -21
  306. mindspore/train/amp.py +6 -7
  307. mindspore/train/callback/_callback.py +2 -1
  308. mindspore/train/callback/_checkpoint.py +1 -17
  309. mindspore/train/callback/_flops_collector.py +10 -6
  310. mindspore/train/callback/_train_fault_tolerance.py +72 -25
  311. mindspore/train/data_sink.py +5 -9
  312. mindspore/train/dataset_helper.py +5 -5
  313. mindspore/train/model.py +41 -230
  314. mindspore/train/serialization.py +160 -401
  315. mindspore/train/train_thor/model_thor.py +2 -2
  316. mindspore/turbojpeg.dll +0 -0
  317. mindspore/utils/__init__.py +6 -3
  318. mindspore/utils/dlpack.py +92 -0
  319. mindspore/utils/dryrun.py +1 -1
  320. mindspore/utils/runtime_execution_order_check.py +10 -0
  321. mindspore/utils/sdc_detect.py +14 -12
  322. mindspore/utils/stress_detect.py +43 -0
  323. mindspore/utils/utils.py +152 -16
  324. mindspore/version.py +1 -1
  325. {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/METADATA +3 -2
  326. {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/RECORD +330 -344
  327. mindspore/_extends/remote/kernel_build_server_ascend.py +0 -75
  328. mindspore/communication/_hccl_management.py +0 -297
  329. mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +0 -207
  330. mindspore/experimental/llm_boost/ascend_native/llm_boost.py +0 -52
  331. mindspore/experimental/llm_boost/atb/__init__.py +0 -23
  332. mindspore/experimental/llm_boost/atb/boost_base.py +0 -385
  333. mindspore/experimental/llm_boost/atb/llama_boost.py +0 -137
  334. mindspore/experimental/llm_boost/atb/qwen_boost.py +0 -124
  335. mindspore/experimental/llm_boost/register.py +0 -130
  336. mindspore/experimental/llm_boost/utils.py +0 -31
  337. mindspore/include/OWNERS +0 -7
  338. mindspore/mindspore_cpu_res_manager.dll +0 -0
  339. mindspore/mindspore_ops_kernel_common.dll +0 -0
  340. mindspore/mindspore_res_manager.dll +0 -0
  341. mindspore/nn/optim/_dist_optimizer_registry.py +0 -111
  342. mindspore/nn/reinforcement/_batch_read_write.py +0 -142
  343. mindspore/nn/reinforcement/_tensors_queue.py +0 -152
  344. mindspore/nn/reinforcement/tensor_array.py +0 -145
  345. mindspore/opencv_core452.dll +0 -0
  346. mindspore/opencv_imgcodecs452.dll +0 -0
  347. mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +0 -113
  348. mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +0 -96
  349. mindspore/ops/_op_impl/aicpu/sparse_cross.py +0 -42
  350. mindspore/ops/_op_impl/cpu/buffer_get.py +0 -28
  351. mindspore/ops/_op_impl/cpu/buffer_sample.py +0 -28
  352. mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +0 -42
  353. mindspore/ops/operations/_tensor_array.py +0 -359
  354. mindspore/ops/operations/rl_ops.py +0 -288
  355. mindspore/parallel/_offload_context.py +0 -275
  356. mindspore/parallel/_recovery_context.py +0 -115
  357. mindspore/parallel/_transformer/__init__.py +0 -35
  358. mindspore/parallel/_transformer/layers.py +0 -765
  359. mindspore/parallel/_transformer/loss.py +0 -251
  360. mindspore/parallel/_transformer/moe.py +0 -693
  361. mindspore/parallel/_transformer/op_parallel_config.py +0 -222
  362. mindspore/parallel/_transformer/transformer.py +0 -3124
  363. mindspore/parallel/mpi/_mpi_config.py +0 -116
  364. mindspore/profiler/common/validator/validate_path.py +0 -84
  365. mindspore/train/memory_profiling_pb2.py +0 -298
  366. mindspore/utils/hooks.py +0 -81
  367. /mindspore/common/{_auto_dynamic.py → dynamic_shape/_auto_dynamic.py} +0 -0
  368. {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/WHEEL +0 -0
  369. {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/entry_points.txt +0 -0
  370. {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/top_level.txt +0 -0
@@ -250,11 +250,11 @@ def add_ext(input, other, alpha=1):
250
250
  input (Union[Tensor, number.Number, bool]): The first input is a number.Number or
251
251
  a bool or a tensor whose data type is
252
252
  `number <https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.dtype.html>`_ or
253
- `bool_ <https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.dtype.html>`_.
253
+ `bool <https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.dtype.html>`_.
254
254
  other (Union[Tensor, number.Number, bool]): The second input, is a number.Number or
255
255
  a bool or a tensor whose data type is
256
256
  `number <https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.dtype.html>`_ or
257
- `bool_ <https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.dtype.html>`_.
257
+ `bool <https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.dtype.html>`_.
258
258
  alpha (number.Number): A scaling factor applied to `other`, default 1.
259
259
 
260
260
  Returns:
@@ -312,7 +312,7 @@ def add(input, other):
312
312
 
313
313
  Note:
314
314
  - The two inputs can not be bool type at the same time,
315
- [True, Tensor(True, bool\_), Tensor(np.array([True]), bool\_)] are all considered bool type.
315
+ [True, Tensor(True), Tensor(np.array([True]))] are all considered bool type.
316
316
  - Support broadcast, support implicit type conversion and type promotion.
317
317
  - When the input is a tensor, the dimension should be greater than or equal to 1.
318
318
 
@@ -1558,9 +1558,6 @@ def clone(input):
1558
1558
  r"""
1559
1559
  Returns a copy of the input tensor.
1560
1560
 
1561
- .. warning::
1562
- This is an experimental API that is subject to change or deletion.
1563
-
1564
1561
  Note:
1565
1562
  This function is differentiable, and gradients will flow back directly from the calculation
1566
1563
  result of the function to the `input`.
@@ -1939,9 +1936,6 @@ def count_nonzero(input, dim=None):
1939
1936
  r"""
1940
1937
  Count the number of non-zero elements in the Tensor `input` on a given dimension `dim`. If no dim is specified then all non-zeros in the tensor are counted.
1941
1938
 
1942
- .. warning::
1943
- This is an experimental API that is subject to change or deletion.
1944
-
1945
1939
  Args:
1946
1940
  input (Tensor): Input data is used to count non-zero numbers. With shape
1947
1941
  :math:`(*)` where :math:`*` means, any number of additional dimensions.
@@ -1985,6 +1979,112 @@ def count_nonzero(input, dim=None):
1985
1979
  return count_nonzero_op(input, dim)
1986
1980
 
1987
1981
 
1982
+ def cross_entropy_loss_grad(grad_loss, log_prob, target, weight=None, grad_zloss=None, lse_for_zloss=None, reduction='mean', ignore_index=-100, label_smoothing=0.0, lse_square_scale_for_zloss=0.0):
1983
+ r"""
1984
+
1985
+ """
1986
+ return cross_entropy_loss_grad_op(grad_loss, log_prob, target, weight, grad_zloss, lse_for_zloss, reduction, ignore_index, label_smoothing, lse_square_scale_for_zloss)
1987
+
1988
+
1989
+ def cross_entropy_loss(input, target, weight=None, reduction='mean', ignore_index=-100, label_smoothing=0.0, lse_square_scale_for_zloss=0.0, return_zloss=False):
1990
+ r"""
1991
+ Computes the cross entropy loss between input and target.
1992
+
1993
+ Assume the number of classes :math:`C` in the range :math:`[0, C)`,
1994
+ the loss with reduction=none can be described as:
1995
+
1996
+ .. math::
1997
+
1998
+ \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad
1999
+ l_n = - w_{y_n} \log \frac{\exp(x_{n,y_n})}{\sum_{c=1}^C \exp(x_{n,c})}
2000
+ \cdot \mathbb{1}\{y_n \not= \text{ignore_index}\}
2001
+
2002
+ where :math:`x` is the inputs, :math:`y` is the target, :math:`w` is the weight, :math:`N` is the batch size,
2003
+ :math:`c` belonging to :math:`[0, C-1]` is class index, where :math:`C` is the number of classes.
2004
+
2005
+ If `reduction` is not ``None`` (default ``'mean'`` ), then
2006
+
2007
+ .. math::
2008
+
2009
+ \ell(x, y) = \begin{cases}
2010
+ \sum_{n=1}^N \frac{1}{\sum_{n=1}^N w_{y_n} \cdot \mathbb{1}\{y_n \not= \text{ignore_index}\}} l_n, &
2011
+ \text{if reduction} = \text{'mean',}\\
2012
+ \sum_{n=1}^N l_n, &
2013
+ \text{if reduction} = \text{'sum'.}
2014
+ \end{cases}
2015
+
2016
+ .. warning::
2017
+ This is an experimental API that is subject to change or deletion.
2018
+
2019
+ Inputs:
2020
+ - **input** (Tensor) - Tensor of shape of :math:`(N, C)` where `C = number of classes`, data type must be bfloat16, float16 or float32.
2021
+ - **target** (Tensor) - For class indices, tensor of shape :math:`(N)`, data type must be int64. The value must be in range [0, C).
2022
+ - **weight** (Tensor, optional) - A rescaling weight applied to the loss of each batch element.
2023
+ If not None, the shape is :math:`(C,)`, data type must be float32. Default: ``None`` .
2024
+ - **reduction** (str, optional) - Apply specific reduction method to the output: ``'none'`` , ``'mean'`` ,
2025
+ ``'sum'`` . Default: ``'mean'`` .
2026
+
2027
+ - ``'none'``: no reduction will be applied.
2028
+ - ``'mean'``: compute and return the weighted mean of elements in the output.
2029
+ - ``'sum'``: the output elements will be summed.
2030
+
2031
+ - **ignore_index** (int, optional) - Specifies a target value that is ignored and does not contribute to the input
2032
+ gradient. When set to negative values, no target value is ignored. It should be int64.
2033
+ Default: ``-100`` .
2034
+ - **label_smoothing** (float, optional) - Label smoothing values, a regularization tool used to prevent the model
2035
+ from overfitting when calculating Loss. This value must be 0.0 currently. Default: ``0.0`` .
2036
+ - **lse_square_scale_for_zloss** (float, optional) - The value range is [0.0, 1.0), not enabled for now, can only be 0.0. Default: ``0.0`` .
2037
+ - **return_zloss** (float, optional) - Not enabled for now, can only be ``False``. Default: ``False`` .
2038
+
2039
+ Outputs:
2040
+ A tuple consisting of 4 Tensors.
2041
+
2042
+ - **loss** (Tensor) - loss between `input` and `target`, the dtype is the same as `input`.
2043
+
2044
+ - If `reduction` is ``'none'`` , the shape is :math:`(N,)` .
2045
+ - If `reduction` is ``'sum'` or ``'mean'`, the shape is :math:`(1,)` .
2046
+
2047
+ - **log_prob** (Tensor) - the shape is :math:`(N, C)` with the same dtype as `input`.
2048
+ - **zloss** (Tensor) - the shape is :math:`(N,)` if `return_zloss` is True, or the shape is :math:`(0,)` with the same dtype as `input`. This parameter is disabled for now.
2049
+ - **lse_for_zloss** (Tensor) - the shape is :math:`(N,)` if `lse_square_scale_for_zloss` is not 0.0, or the shape is :math:`(0,)` with the same dtype as `input`. This parameter is disabled for now.
2050
+
2051
+
2052
+ Raises:
2053
+ ValueError: If `reduction` is not one of ``'none'``, ``'mean'`` or ``'sum'``.
2054
+ TypeError: If `input`, `target` or `weight` is not a Tensor.
2055
+
2056
+ Supported Platforms:
2057
+ ``Ascend``
2058
+
2059
+ Examples:
2060
+ >>> import mindspore
2061
+ >>> import numpy as np
2062
+ >>> from mindspore import Tensor, nn, ops
2063
+ >>>
2064
+ >>>
2065
+ >>> class Net(nn.Cell):
2066
+ ... def __init__(self):
2067
+ ... super(Net, self).__init__()
2068
+ ... self.cross_entropy_loss = ops.auto_generate.CrossEntropyLoss()
2069
+ ...
2070
+ ... def construct(self, input, target, weight):
2071
+ ... result = self.cross_entropy_loss(input, target, weight)
2072
+ ... return result
2073
+ ...
2074
+ >>>
2075
+ >>> net = Net()
2076
+ >>> input = Tensor(np.array([[0.2, 0.7, 0.1], [0.2, 0.7, 0.1]]), mindspore.float32)
2077
+ >>> target = Tensor(np.array([0, 1]), mindspore.int64)
2078
+ >>> weight = Tensor(np.array([1, 0.5, 0.5]), mindspore.float32)
2079
+ >>> output = net(input, target, weight)
2080
+ >>> print(output[:2])
2081
+ (Tensor(shape=[1], dtype=Float32, value= [ 1.10128295e+00]), Tensor(shape=[2, 3], dtype=Float32, value=
2082
+ [[-1.26794958e+00, -7.67949641e-01, -1.36794960e+00],
2083
+ [-1.26794958e+00, -7.67949641e-01, -1.36794960e+00]]))
2084
+ """
2085
+ return cross_entropy_loss_op(input, target, weight, reduction, ignore_index, label_smoothing, lse_square_scale_for_zloss, return_zloss)
2086
+
2087
+
1988
2088
  def cummax(input, axis):
1989
2089
  r"""
1990
2090
  Return the cumulative maximum values and their indices along the given axis of the tensor.
@@ -2162,6 +2262,13 @@ def dense(input, weight, bias=None):
2162
2262
  return dense_op(input, weight, bias)
2163
2263
 
2164
2264
 
2265
+ def dequant_swiglu_quant(x, weight_scale, activation_scale, bias=None, quant_scale=None, quant_offset=None, group_index=None, activate_left=False, quant_mode='static'):
2266
+ r"""
2267
+
2268
+ """
2269
+ return dequant_swiglu_quant_op(x, weight_scale, activation_scale, bias, quant_scale, quant_offset, group_index, activate_left, quant_mode)
2270
+
2271
+
2165
2272
  def diagonal(input, offset=0, dim1=0, dim2=1):
2166
2273
  r"""
2167
2274
  Returns diagonals of the input tensor along specified dimension.
@@ -2330,9 +2437,6 @@ def dot(input, other):
2330
2437
  r"""
2331
2438
  Computes the dot product of two 1D tensor.
2332
2439
 
2333
- .. warning::
2334
- This is an experimental API that is subject to change or deletion.
2335
-
2336
2440
  Args:
2337
2441
  input (Tensor): The first input in the dot product, must be 1D.
2338
2442
  other (Tensor): The second input in the dot product, must be 1D.
@@ -2467,104 +2571,6 @@ def elu(input_x, alpha=1.0):
2467
2571
  return elu_op(input_x)
2468
2572
 
2469
2573
 
2470
- def embedding_apply_adam_w(var_handle, beta1_power, beta2_power, lr, weight_decay, beta1, beta2, epsilon, grad, keys, max_grad_norm, global_step, embedding_dim, ams_grad=(0,), mask_zero=(0,), padding_key=(0,), padding_key_mask=(1,), completion_key=(0,), completion_key_mask=(1,), _embedding_dim=1, _max_key_num=1):
2471
- r"""
2472
-
2473
- """
2474
- return embedding_apply_adam_w_op(var_handle, beta1_power, beta2_power, lr, weight_decay, beta1, beta2, epsilon, grad, keys, max_grad_norm, global_step, embedding_dim, ams_grad, mask_zero, padding_key, padding_key_mask, completion_key, completion_key_mask, _embedding_dim, _max_key_num)
2475
-
2476
-
2477
- def embedding_apply_adam(var_handle, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad, keys, global_step, embedding_dim, mask_zero=(0,), padding_key=(0,), padding_key_mask=(1,), completion_key=(0,), completion_key_mask=(1,), _embedding_dim=1, _max_key_num=1):
2478
- r"""
2479
-
2480
- """
2481
- return embedding_apply_adam_op(var_handle, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad, keys, global_step, embedding_dim, mask_zero, padding_key, padding_key_mask, completion_key, completion_key_mask, _embedding_dim, _max_key_num)
2482
-
2483
-
2484
- def embedding_apply_ada_grad(var_handle, lr, grad, keys, global_step, embedding_dim, mask_zero=(0,), padding_key=(0,), padding_key_mask=(1,), completion_key=(0,), completion_key_mask=(1,), _embedding_dim=1, _max_key_num=1):
2485
- r"""
2486
-
2487
- """
2488
- return embedding_apply_ada_grad_op(var_handle, lr, grad, keys, global_step, embedding_dim, mask_zero, padding_key, padding_key_mask, completion_key, completion_key_mask, _embedding_dim, _max_key_num)
2489
-
2490
-
2491
- def embedding_apply_ftrl(var_handle, lr, lr_power, lambda1, lambda2, grad, keys, global_step, embedding_dim, mask_zero=(0,), padding_key=(0,), padding_key_mask=(1,), completion_key=(0,), completion_key_mask=(1,), _embedding_dim=1, _max_key_num=1):
2492
- r"""
2493
-
2494
- """
2495
- return embedding_apply_ftrl_op(var_handle, lr, lr_power, lambda1, lambda2, grad, keys, global_step, embedding_dim, mask_zero, padding_key, padding_key_mask, completion_key, completion_key_mask, _embedding_dim, _max_key_num)
2496
-
2497
-
2498
- def embedding_apply_rmsprop(var_handle, lr, rho, momentum, epsilon, grad, keys, global_step, embedding_dim, mask_zero=(0,), padding_key=(0,), padding_key_mask=(1,), completion_key=(0,), completion_key_mask=(1,), _embedding_dim=1, _max_key_num=1):
2499
- r"""
2500
-
2501
- """
2502
- return embedding_apply_rmsprop_op(var_handle, lr, rho, momentum, epsilon, grad, keys, global_step, embedding_dim, mask_zero, padding_key, padding_key_mask, completion_key, completion_key_mask, _embedding_dim, _max_key_num)
2503
-
2504
-
2505
- def embedding_apply_sgd(var_handle, lr, grad, keys, global_step, embedding_dim, mask_zero=(0,), padding_key=(0,), padding_key_mask=(1,), completion_key=(0,), completion_key_mask=(1,), _embedding_dim=1, _max_key_num=1):
2506
- r"""
2507
-
2508
- """
2509
- return embedding_apply_sgd_op(var_handle, lr, grad, keys, global_step, embedding_dim, mask_zero, padding_key, padding_key_mask, completion_key, completion_key_mask, _embedding_dim, _max_key_num)
2510
-
2511
-
2512
- def embedding_feature_mapping_export(file_path, table_name, global_step, values, embedding_dim, feature_id, offset_id):
2513
- r"""
2514
-
2515
- """
2516
- return embedding_feature_mapping_export_op(file_path, table_name, global_step, values, embedding_dim, feature_id, offset_id)
2517
-
2518
-
2519
- def embedding_feature_mapping_file_size(file_path, table_name, global_step, embedding_dim, only_offset_flag=True):
2520
- r"""
2521
-
2522
- """
2523
- return embedding_feature_mapping_file_size_op(file_path, table_name, global_step, embedding_dim, only_offset_flag)
2524
-
2525
-
2526
- def embedding_feature_mapping_find(table_name, feature_size, num=1):
2527
- r"""
2528
-
2529
- """
2530
- return embedding_feature_mapping_find_op(table_name, feature_size, num)
2531
-
2532
-
2533
- def embedding_feature_mapping_import(file_path, teble_name, feature_size, global_step, embedding_dim, only_offset_flag=True, num=1):
2534
- r"""
2535
-
2536
- """
2537
- return embedding_feature_mapping_import_op(file_path, teble_name, feature_size, global_step, embedding_dim, only_offset_flag, num)
2538
-
2539
-
2540
- def embedding_feature_mapping_insert(table_name, num, feature_id, offset_id):
2541
- r"""
2542
-
2543
- """
2544
- return embedding_feature_mapping_insert_op(table_name, num, feature_id, offset_id)
2545
-
2546
-
2547
- def embedding_feature_mapping_table_size(table_name):
2548
- r"""
2549
-
2550
- """
2551
- return embedding_feature_mapping_table_size_op(table_name)
2552
-
2553
-
2554
- def embedding_feature_mapping_v2(table_name, feature_id, table_total_size, table_actual_size):
2555
- r"""
2556
-
2557
- """
2558
- return embedding_feature_mapping_v2_op(table_name, feature_id, table_total_size, table_actual_size)
2559
-
2560
-
2561
- def embedding_table_evict(var_handle, global_step, steps_to_live=0):
2562
- r"""
2563
-
2564
- """
2565
- return embedding_table_evict_op(var_handle, global_step, steps_to_live)
2566
-
2567
-
2568
2574
  def equal(input, other):
2569
2575
  r"""
2570
2576
  Compute the equivalence of the two inputs element-wise.
@@ -3415,6 +3421,43 @@ def floor(input):
3415
3421
  return floor_op(input)
3416
3422
 
3417
3423
 
3424
+ def format_cast(input, acl_format):
3425
+ r"""
3426
+ Change tensor format.
3427
+
3428
+ .. warning::
3429
+ FormatCast will not work in the ge backend, origin input will be returned.
3430
+
3431
+ Args:
3432
+ input (Tensor): The input tensor.
3433
+ acl_format (int): enum value of acl format, the valid values are below:
3434
+ - ``0`` NCHW
3435
+ - ``1`` NHWC
3436
+ - ``2`` ND
3437
+ - ``3`` NC1HWC0
3438
+ - ``4`` FRACTAL_Z
3439
+ - ``27`` NDHWC
3440
+ - ``29`` FRACTAL_NZ
3441
+ - ``30`` NCDHW
3442
+ - ``32`` NDC1HWC0
3443
+ - ``33`` FRACTAL_Z_3D
3444
+
3445
+ Returns:
3446
+ Tensor
3447
+
3448
+ Supported Platforms:
3449
+ ``Ascend``
3450
+
3451
+ Examples:
3452
+ >>> import mindspore
3453
+ >>> input = mindspore.ops.randn((2, 3, 4, 5))
3454
+ >>> output = mindspore.ops.format_cast(input, 2)
3455
+ >>> print(output.shape)
3456
+ (2, 3, 4, 5)
3457
+ """
3458
+ return format_cast_op(input, acl_format)
3459
+
3460
+
3418
3461
  def frac_ext(input):
3419
3462
  r"""
3420
3463
  Calculates the fractional part of each element in the input.
@@ -3526,7 +3569,7 @@ def gather(input_params, input_indices, axis, batch_dims=0):
3526
3569
  - The value of input_indices must be in the range of `[0, input_param.shape[axis])`.
3527
3570
  On CPU and GPU, an error is raised if an out of bound indice is found. On Ascend, the results may be
3528
3571
  undefined.
3529
- - The data type of input_params cannot be `mindspore.bool_` .
3572
+ - The data type of input_params cannot be `mindspore.bool` .
3530
3573
  - The shape of returned tensor is :math:`input\_params.shape[:axis] + input\_indices.shape[batch\_dims:] + input\_params.shape[axis + 1:]` .
3531
3574
 
3532
3575
  Args:
@@ -3910,7 +3953,6 @@ def histc_ext(input, bins=100, min=0, max=0):
3910
3953
  Elements lower than min or higher than max are ignored.
3911
3954
 
3912
3955
  .. warning::
3913
- This is an experimental API that is subject to change or deletion.
3914
3956
  If input is int64, valid values fit within int32; exceeding this may cause precision errors.
3915
3957
 
3916
3958
  Args:
@@ -4622,7 +4664,7 @@ def index(input, indices):
4622
4664
  [2 6 5]
4623
4665
  >>> input2 = Tensor(np.arange(4 * 3 * 3).reshape(4, 3, 3), mindspore.int32)
4624
4666
  >>> indices3 = Tensor(np.array([1, 0]), mindspore.int32)
4625
- >>> indices4 = Tensor(np.array([1, 1, 0]), mindspore.bool_)
4667
+ >>> indices4 = Tensor(np.array([1, 1, 0]), mindspore.bool)
4626
4668
  >>> output2 = ops.auto_generate.index(input2, [indices3, indices4])
4627
4669
  >>> print(output2)
4628
4670
  [[ 9 10 11]
@@ -4698,6 +4740,20 @@ def inplace_add_ext(input, other, alpha=1):
4698
4740
  return inplace_add_ext_op(input, other, alpha)
4699
4741
 
4700
4742
 
4743
+ def inplace_bernoulli_scalar(input, p, seed, offset):
4744
+ r"""
4745
+
4746
+ """
4747
+ return inplace_bernoulli_scalar_op(input, p, seed, offset)
4748
+
4749
+
4750
+ def inplace_bernoulli_tensor(input, p, seed, offset):
4751
+ r"""
4752
+
4753
+ """
4754
+ return inplace_bernoulli_tensor_op(input, p, seed, offset)
4755
+
4756
+
4701
4757
  def inplace_clamp_scalar(input, min=None, max=None):
4702
4758
  r"""
4703
4759
 
@@ -4712,11 +4768,11 @@ def inplace_clamp_tensor(input, min=None, max=None):
4712
4768
  return inplace_clamp_tensor_op(input, min, max)
4713
4769
 
4714
4770
 
4715
- def inplace_copy(input, src):
4771
+ def inplace_copy(input, src, non_blocking=False):
4716
4772
  r"""
4717
4773
 
4718
4774
  """
4719
- return inplace_copy_op(input, src)
4775
+ return inplace_copy_op(input, src, non_blocking)
4720
4776
 
4721
4777
 
4722
4778
  def divmod_scalar_(input, other, rounding_mode=None):
@@ -5064,6 +5120,25 @@ def inplace_scatter_add(input, dim, index, src):
5064
5120
  return inplace_scatter_add_op(input, dim, index, src)
5065
5121
 
5066
5122
 
5123
+ def inplace_sigmoid(input):
5124
+ r"""
5125
+ sigmoid_() -> Tensor
5126
+
5127
+ In-place version of sigmoid().
5128
+
5129
+ .. warning::
5130
+ Only supports Ascend.
5131
+ """
5132
+ return inplace_sigmoid_op(input)
5133
+
5134
+
5135
+ def inplace_sign(input):
5136
+ r"""
5137
+
5138
+ """
5139
+ return inplace_sign_op(input)
5140
+
5141
+
5067
5142
  def inplace_silu(input):
5068
5143
  r"""
5069
5144
  Computes Sigmoid Linear Unit of input element-wise. The SiLU function is defined as:
@@ -5380,7 +5455,7 @@ def isinf(input):
5380
5455
  Return a boolean tensor indicating which elements are +/- inifnity.
5381
5456
 
5382
5457
  .. warning::
5383
- - This is an experimental API that is subject to change.
5458
+ - This is an experimental API that is subject to change or deletion.
5384
5459
  - For Ascend, it is only supported on platforms above Atlas A2.
5385
5460
 
5386
5461
  Args:
@@ -5491,6 +5566,13 @@ def kthvalue(input, k, dim=-1, keepdim=False):
5491
5566
  return kthvalue_op(input, k, dim, keepdim)
5492
5567
 
5493
5568
 
5569
+ def kv_scale_cache(key_scale, value_scale, key_value_scale_cache, batch_valid_length, cache_mode):
5570
+ r"""
5571
+
5572
+ """
5573
+ return kv_scale_cache_op(key_scale, value_scale, key_value_scale_cache, batch_valid_length, cache_mode)
5574
+
5575
+
5494
5576
  def l1_loss_ext(input, target, reduction='mean'):
5495
5577
  r"""
5496
5578
  Calculate the mean absolute error between the `input` value and the `target` value.
@@ -6146,7 +6228,7 @@ def masked_fill(input_x, mask, value):
6146
6228
  Examples:
6147
6229
  >>> import mindspore
6148
6230
  >>> input_x = mindspore.tensor([1., 2., 3., 4.], mindspore.float32)
6149
- >>> mask = mindspore.tensor([True, True, False, True], mindspore.bool_)
6231
+ >>> mask = mindspore.tensor([True, True, False, True], mindspore.bool)
6150
6232
  >>> output = mindspore.ops.masked_fill(input_x, mask, 0.5)
6151
6233
  >>> print(output)
6152
6234
  [0.5 0.5 3. 0.5]
@@ -6154,6 +6236,13 @@ def masked_fill(input_x, mask, value):
6154
6236
  return masked_fill_op(input_x, mask, value)
6155
6237
 
6156
6238
 
6239
+ def masked_scatter(input, mask, source):
6240
+ r"""
6241
+
6242
+ """
6243
+ return masked_scatter_op(input, mask, source)
6244
+
6245
+
6157
6246
  def masked_select(input, mask):
6158
6247
  r"""
6159
6248
  Return a new 1-D tensor which indexes the `input` tensor according to the boolean `mask`.
@@ -6173,7 +6262,7 @@ def masked_select(input, mask):
6173
6262
  Examples:
6174
6263
  >>> import mindspore
6175
6264
  >>> x = mindspore.tensor([1, 2, 3, 4], mindspore.int64)
6176
- >>> mask = mindspore.tensor([1, 0, 1, 0], mindspore.bool_)
6265
+ >>> mask = mindspore.tensor([1, 0, 1, 0], mindspore.bool)
6177
6266
  >>> output = mindspore.ops.masked_select(x, mask)
6178
6267
  >>> print(output)
6179
6268
  [1 3]
@@ -6550,6 +6639,20 @@ def mish_ext(input):
6550
6639
  return mish_ext_op(input)
6551
6640
 
6552
6641
 
6642
+ def mla(query, q_rope, kv_cache, k_rope, block_tables, attn_mask=None, deq_scale_qk=None, deq_scale_pv=None, q_seq_lens=None, context_lens=None, head_num=32, scale_value=0.0, kv_head_num=1, mask_mode='MASK_NONE', is_ring=0):
6643
+ r"""
6644
+
6645
+ """
6646
+ return mla_op(query, q_rope, kv_cache, k_rope, block_tables, attn_mask, deq_scale_qk, deq_scale_pv, q_seq_lens, context_lens, head_num, scale_value, kv_head_num, mask_mode, is_ring)
6647
+
6648
+
6649
+ def mla_preprocess(input1, gamma1, beta1, quant_scale1, quant_offset1, wdqkv, bias1, gamma2, beta2, quant_scale2, quant_offset2, gamma3, sin1, cos1, sin2, cos2, key_cache, slot_mapping, wuq, bias2, slot_wuk, de_scale1, de_scale2, ctkv_scale, qnope_scale, krope_cache, param_cache_mode=0):
6650
+ r"""
6651
+
6652
+ """
6653
+ return mla_preprocess_op(input1, gamma1, beta1, quant_scale1, quant_offset1, wdqkv, bias1, gamma2, beta2, quant_scale2, quant_offset2, gamma3, sin1, cos1, sin2, cos2, key_cache, slot_mapping, wuq, bias2, slot_wuk, de_scale1, de_scale2, ctkv_scale, qnope_scale, krope_cache, param_cache_mode)
6654
+
6655
+
6553
6656
  def mm_ext(input, mat2):
6554
6657
  r"""
6555
6658
  Returns the matrix product of two arrays.
@@ -6978,7 +7081,7 @@ def mul(input, other):
6978
7081
  - When the two inputs have different shapes,
6979
7082
  they must be able to broadcast to a common shape.
6980
7083
  - The two inputs can not be bool type at the same time,
6981
- [True, Tensor(True, bool\_), Tensor(np.array([True]), bool\_)] are all considered bool type.
7084
+ [True, Tensor(True), Tensor(np.array([True]))] are all considered bool type.
6982
7085
  - Support implicit type conversion and type promotion.
6983
7086
 
6984
7087
  Args:
@@ -7223,11 +7326,18 @@ def nextafter(input, other):
7223
7326
 
7224
7327
  Examples:
7225
7328
  >>> import mindspore
7226
- >>> input = mindspore.tensor([0.0], mindspore.float32)
7227
- >>> other = mindspore.tensor([0.1], mindspore.float32)
7329
+ >>> import numpy as np
7330
+ >>> eps = np.finfo(np.float32).eps
7331
+ >>> input = mindspore.tensor([1.0], mindspore.float32)
7332
+ >>> other = mindspore.tensor([2.0], mindspore.float32)
7228
7333
  >>> output = mindspore.ops.nextafter(input, other)
7229
- >>> print(output)
7230
- [1.e-45]
7334
+ >>> print(output == eps + 1)
7335
+ [ True]
7336
+ >>> input = mindspore.tensor([1.0, 2.0], mindspore.float32)
7337
+ >>> other = mindspore.tensor([2.0, 1.0], mindspore.float32)
7338
+ >>> output = mindspore.ops.nextafter(input, other)
7339
+ >>> print(output == mindspore.tensor([eps + 1, 2 - eps], mindspore.float32))
7340
+ [ True True]
7231
7341
  """
7232
7342
  return next_after_op(input, other)
7233
7343
 
@@ -7281,9 +7391,6 @@ def outer_ext(input, vec2):
7281
7391
  Return outer product of `input` and `vec2`. If `input` is a vector of size :math:`n`
7282
7392
  and `vec2` is a vector of size :math:`m` , then output must be a matrix of shape :math:`(n, m)` .
7283
7393
 
7284
- .. warning::
7285
- This is an experimental API that is subject to change or deletion.
7286
-
7287
7394
  .. note::
7288
7395
  This function does not broadcast.
7289
7396
 
@@ -7374,8 +7481,10 @@ def prelu(input, weight):
7374
7481
  :align: center
7375
7482
 
7376
7483
  .. note::
7377
- Channel dim is the 2nd dim of input. When input has dims < 2, then there is
7378
- no channel dim and the number of channels = 1.
7484
+ - Channel dim is the 2nd dim of input. When input has dims < 2, then there is
7485
+ no channel dim and the number of channels = 1.
7486
+ - In GE mode, the rank of the input tensor must be greater than 1;
7487
+ otherwise, an error will be triggered.
7379
7488
 
7380
7489
  Args:
7381
7490
  input (Tensor): The input Tensor of the activation function.
@@ -7528,12 +7637,13 @@ def range(start, end, step, maxlen=1000000):
7528
7637
  Returns a tensor with a step length of `step` in the interval [ `start` , `end` ).
7529
7638
 
7530
7639
  .. note::
7531
- The types of all 3 inputs must be all integers or floating-point numbers.
7640
+ - The types of all 3 inputs must be all integers or floating-point numbers.
7641
+ - When the input is a tensor, the tensor must contain only one element, whose dtype is Number.
7532
7642
 
7533
7643
  Args:
7534
- start (number): The start value of the interval.
7535
- end (number): The end value of the interval.
7536
- step (number): The interval between each value.
7644
+ start (Union[Number, Tensor]): The start value of the interval.
7645
+ end (Union[Number, Tensor]): The end value of the interval.
7646
+ step (Union[Number, Tensor]): The interval between each value.
7537
7647
  maxlen (int, optional): Memory that can fit `maxlen` many elements
7538
7648
  will be allocated for the output. Optional, must be positive. Default: 1000000.
7539
7649
  If the output has more than `maxlen` elements, a runtime error will occur.
@@ -8020,6 +8130,78 @@ def rfft(input, n=None, dim=-1, norm=None):
8020
8130
  return rfft_op(input, n, dim, norm)
8021
8131
 
8022
8132
 
8133
+ def ring_attention_update(prev_attn_out, prev_softmax_max, prev_softmax_sum, cur_attn_out, cur_softmax_max, cur_softmax_sum, actual_seq_qlen=None, layout='SBH'):
8134
+ r"""
8135
+ The RingAttentionUpdate operator updates the output of two FlashAttention operations based on their respective softmax max and softmax sum values.
8136
+
8137
+ - S: Sequence length
8138
+ - B: Batch dimension
8139
+ - H: Hidden layer size, equals to N * D
8140
+ - T: time, equals to B*S
8141
+ - N: Number of attention heads
8142
+ - D: Head dimension
8143
+
8144
+ .. warning::
8145
+ - It is only supported on Atlas A2 Training Series Products.
8146
+ - This is an experimental API that is subject to change or deletion.
8147
+ - When `layout` is ``"TND"``, the last dimension of `prev_attn_out` must be a multiple of 64.
8148
+ - When `layout` is ``"TND"``, `actual_seq_qlen` is mandatory.
8149
+ - When `layout` is ``"TND"``, N * D must satisfy the constraint:
8150
+ :math:`(\text{AlignUp}(N*D, 64)*(DataSize*6+8))+(\text{AlignUp}(N*8, 64)*56) <= 192*1024`.
8151
+ :math:`DataSize` is 4 bytes when `prev_attn_out` dtype is float32, 2 bytes when dtype is float16 / bfloat16.
8152
+ - When `layout` is ``"TND"``, if `actual_seq_qlen` is not a non-decreasing sequence from 0 to T, the result is undefined.
8153
+
8154
+ Args:
8155
+ prev_attn_out (Tensor): Output of the first FlashAttention operation. The dtype is float16, float32, bfloat16.
8156
+ The shape is :math:`(S, B, H)` or :math:`(T, N, D)`.
8157
+ prev_softmax_max (Tensor): The max values from the first FlashAttention softmax computation. The dtype float32.
8158
+ The shape is :math:`(B, N, S, 8)` or :math:`(T, N, 8)`. The last dimension contains 8 identical values, which must be positive.
8159
+ prev_softmax_sum (Tensor): The sum values from the first FlashAttention softmax computation.
8160
+ It has the same shape and dtype as `prev_softmax_max`.
8161
+ cur_attn_out (Tensor): Output of the second FlashAttention operation. It has the same shape and dtype as `prev_attn_out`.
8162
+ cur_softmax_max (Tensor): The max values from the second FlashAttention softmax computation. It has the same shape and dtype as `prev_softmax_max`.
8163
+ cur_softmax_sum (Tensor):The sum values from the second FlashAttention softmax computation. It has the same shape and dtype as `prev_softmax_max`.
8164
+ actual_seq_qlen (Tensor, optional): Cumulative sequence length, starting from 0. Required if `layout` is ``"TND"``. Does not take effect if `layout` is ``"SBH"``.
8165
+ The tensor must be 1D and contain non-decreasing integer values starting from 0 to T. Default: ``None``.
8166
+ layout (str, optional): Indicates the input layout, currently support ``"TND"`` and ``"SBH"``. Default: ``"SBH"``.
8167
+
8168
+ Returns:
8169
+ tuple (Tensor), tuple of 3 tensors.
8170
+
8171
+ - **attn_out** (Tensor) - The updated attention out, with the same shape and dtype as `prev_attn_out`.
8172
+ - **softmax_max** (Tensor) - The updated softmax max values, with the same shape and dtype as `prev_softmax_max`.
8173
+ - **softmax_sum** (Tensor) - The updated softmax sum values, with the same shape and dtype as `prev_softmax_max`.
8174
+
8175
+ Raises:
8176
+ RuntimeError: If `layout` is ``"TND"``, and `prev_attn_out`'s last dimension is not aligned to 64.
8177
+ RuntimeError: If `layout` is ``"TND"``, and `actual_seq_qlen` is not provided.
8178
+ RuntimeError: If `layout` is ``"TND"``, and `actual_seq_qlen` is not a non-decreasing sequence from 0 to T.
8179
+ RuntimeError: If `layout` is ``"TND"``, and `prev_attn_out` exceeds the size constraints.
8180
+
8181
+ Supported Platforms:
8182
+ ``Ascend``
8183
+
8184
+ Examples:
8185
+ >>> import numpy as np
8186
+ >>> import mindspore
8187
+ >>> from mindspore import Tensor, ops
8188
+ >>> np.random.seed(123)
8189
+ >>> S, B, H, N= 4, 6, 16, 8
8190
+ >>> prev_attn_out = np.random.uniform(-1.0, 1.0, size=(S, B, H)).astype(np.float32)
8191
+ >>> prev_softmax_max = np.random.uniform(-1.0, 1.0, size=(B, N, S, 8)).astype(np.float32)
8192
+ >>> prev_softmax_sum = np.random.uniform(-1.0, 1.0, size=(B, N, S, 8)).astype(np.float32)
8193
+ >>> cur_attn_out = np.random.uniform(-1.0, 1.0, size=(S, B, H)).astype(np.float32)
8194
+ >>> cur_softmax_max = np.random.uniform(-1.0, 1.0, size=(B, N, S, 8)).astype(np.float32)
8195
+ >>> cur_softmax_sum = np.random.uniform(-1.0, 1.0, size=(B, N, S, 8)).astype(np.float32)
8196
+ >>> inputs_np = [prev_attn_out, prev_softmax_max, prev_softmax_sum, cur_attn_out, cur_softmax_max, cur_softmax_sum]
8197
+ >>> inputs_ms = [Tensor(item) for item in inputs_np]
8198
+ >>> out = ops.ring_attention_update(*inputs_ms)
8199
+ >>> print(out[0].shape)
8200
+ (4, 6, 16)
8201
+ """
8202
+ return ring_attention_update_op(prev_attn_out, prev_softmax_max, prev_softmax_sum, cur_attn_out, cur_softmax_max, cur_softmax_sum, actual_seq_qlen, layout)
8203
+
8204
+
8023
8205
  def rms_norm(x, gamma, epsilon=1e-6):
8024
8206
  r"""
8025
8207
  The RmsNorm(Root Mean Square Layer Normalization) operator is a normalization operation. Compared to
@@ -8215,7 +8397,7 @@ def scalar_cast(input_x, input_y):
8215
8397
 
8216
8398
  Args:
8217
8399
  input_x (scalar): The input scalar. Only constant value is allowed.
8218
- input_y (mindspore.dtype): The type to be cast. Only constant value is allowed. And the value should only be mindspore.int64, mindspore.float64, or mindspore.bool_.
8400
+ input_y (mindspore.dtype): The type to be cast. Only constant value is allowed. And the value should only be mindspore.int64, mindspore.float64, or mindspore.bool.
8219
8401
 
8220
8402
  Returns:
8221
8403
  Scalar. The type is the same as the python type corresponding to `input_y`.
@@ -8725,6 +8907,58 @@ def sin(input):
8725
8907
  return sin_op(input)
8726
8908
 
8727
8909
 
8910
+ def smooth_l1_loss(prediction, target, beta=1.0, reduction='none'):
8911
+ r"""
8912
+ Calculate the smooth L1 loss, and the L1 loss function has robustness.
8913
+
8914
+ Refer to :func:`mindspore.ops.smooth_l1_loss` for more details.
8915
+
8916
+ .. warning::
8917
+ This API has poor performance on CPU and it is recommended to run it on the Ascend/GPU.
8918
+
8919
+ Args:
8920
+ beta (number, optional): A parameter used to control the point where the function will change between
8921
+ L1 to L2 loss. Default: ``1.0`` .
8922
+
8923
+ - Ascend: The value should be equal to or greater than zero.
8924
+ - CPU/GPU: The value should be greater than zero.
8925
+ reduction (str, optional): Apply specific reduction method to the output: ``'none'`` , ``'mean'`` ,
8926
+ ``'sum'`` . Default: ``'none'`` .
8927
+
8928
+ - ``'none'``: no reduction will be applied.
8929
+ - ``'mean'``: compute and return the mean of elements in the output.
8930
+ - ``'sum'``: the output elements will be summed.
8931
+
8932
+ Inputs:
8933
+ - **logits** (Tensor) - Input Tensor of any dimension. Supported dtypes:
8934
+
8935
+ - Ascend: float16, float32, bfloat16.
8936
+ - CPU/GPU: float16, float32, float64.
8937
+ - **labels** (Tensor) - Ground truth data.
8938
+
8939
+ - CPU/Ascend: has the same shape as the `logits`, `logits` and `labels` comply with the implicit type conversion rules to make the data types consistent.
8940
+ - GPU: has the same shape and dtype as the `logits`.
8941
+
8942
+ Outputs:
8943
+ Tensor, if `reduction` is ``'none'``, then output is a tensor with the same shape as `logits`. Otherwise the shape of output tensor is :math:`()`.
8944
+
8945
+ Supported Platforms:
8946
+ ``Ascend`` ``GPU`` ``CPU``
8947
+
8948
+ Examples:
8949
+ >>> import mindspore
8950
+ >>> import numpy as np
8951
+ >>> from mindspore import Tensor, ops
8952
+ >>> loss = ops.SmoothL1Loss()
8953
+ >>> logits = Tensor(np.array([1, 2, 3]), mindspore.float32)
8954
+ >>> labels = Tensor(np.array([1, 2, 2]), mindspore.float32)
8955
+ >>> output = loss(logits, labels)
8956
+ >>> print(output)
8957
+ [0. 0. 0.5]
8958
+ """
8959
+ return smooth_l1_loss_impl(prediction, target, beta, reduction)
8960
+
8961
+
8728
8962
  def softplus_ext(input, beta=1, threshold=20):
8729
8963
  r"""
8730
8964
  Applies softplus function to `input` element-wise.
@@ -9029,14 +9263,13 @@ def stack_ext(tensors, dim=0):
9029
9263
  :math:`(x_1, x_2, ..., x_{dim}, N, x_{dim+1}, ..., x_R)`.
9030
9264
 
9031
9265
  Args:
9032
- tensors (Union[tuple, list]): A Tuple or list of Tensor objects with the same shape and type.
9266
+ tensors (Union[tuple, list]): A Tuple or list of Tensor objects with the same shape.
9033
9267
  dim (int, optional): Dimension to stack. The range is [-(R+1), R+1). Default: ``0`` .
9034
9268
 
9035
9269
  Returns:
9036
- Tensor. A stacked Tensor with the same type as `tensors`.
9270
+ A stacked Tensor.
9037
9271
 
9038
9272
  Raises:
9039
- TypeError: If the data types of elements in `tensors` are not the same.
9040
9273
  ValueError: If `dim` is out of the range [-(R+1), R+1);
9041
9274
  or if the shapes of elements in `tensors` are not the same.
9042
9275
 
@@ -9184,11 +9417,11 @@ def sub_ext(input, other, alpha=1):
9184
9417
  input (Union[Tensor, number.Number, bool]): The first input is a number.Number or
9185
9418
  a bool or a tensor whose data type is
9186
9419
  `number <https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.dtype.html>`_ or
9187
- `bool_ <https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.dtype.html>`_.
9420
+ `bool <https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.dtype.html>`_.
9188
9421
  other (Union[Tensor, number.Number, bool]): The second input, is a number.Number or
9189
9422
  a bool or a tensor whose data type is
9190
9423
  `number <https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.dtype.html>`_ or
9191
- `bool_ <https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.dtype.html>`_.
9424
+ `bool <https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.dtype.html>`_.
9192
9425
  alpha (number.Number): A scaling factor applied to `other`, default 1.
9193
9426
 
9194
9427
  Returns:
@@ -9233,7 +9466,7 @@ def sub(input, other):
9233
9466
  Note:
9234
9467
  - When the two inputs have different shapes, they must be able to broadcast to a common shape.
9235
9468
  - The two inputs can not be bool type at the same time,
9236
- [True, Tensor(True, bool\_), Tensor(np.array([True]), bool\_)] are all considered bool type.
9469
+ [True, Tensor(True), Tensor(np.array([True]))] are all considered bool type.
9237
9470
  - Support implicit type conversion and type promotion.
9238
9471
 
9239
9472
  Args:
@@ -9672,9 +9905,6 @@ def transpose_ext_view(input, dim0, dim1):
9672
9905
  r"""
9673
9906
  Interchange two axes of a tensor.
9674
9907
 
9675
- .. warning::
9676
- This is an experimental API that is subject to change or deletion.
9677
-
9678
9908
  Args:
9679
9909
  input(Tensor): Input tensor.
9680
9910
  dim0 (int): First axis.
@@ -9702,17 +9932,17 @@ def transpose_ext_view(input, dim0, dim1):
9702
9932
  return transpose_ext_view_op(input, dim0, dim1)
9703
9933
 
9704
9934
 
9705
- def transpose(input, input_perm):
9935
+ def transpose(input, dims):
9706
9936
  r"""
9707
9937
  Transpose dimensions of the input tensor according to input permutation.
9708
9938
 
9709
9939
  Note:
9710
- On GPU and CPU, if the value of `input_perm` is negative, its actual value is `input_perm[i] + rank(input)`.
9711
- Negative value of `input_perm` is not supported on Ascend.
9940
+ On GPU and CPU, if the value of `dims` is negative, its actual value is `dims[i] + rank(input)`.
9941
+ Negative value of `dims` is not supported on Ascend.
9712
9942
 
9713
9943
  Args:
9714
9944
  input (Tensor): The input tensor.
9715
- input_perm (tuple[int]): Specify the new axis ordering.
9945
+ dims (Union[tuple[int], list[int]]): Specify the new axis ordering.
9716
9946
 
9717
9947
  Returns:
9718
9948
  Tensor
@@ -9732,7 +9962,7 @@ def transpose(input, input_perm):
9732
9962
  [ 8. 11.]
9733
9963
  [ 9. 12.]]]
9734
9964
  """
9735
- return transpose_op(input, input_perm)
9965
+ return transpose_op(input, dims)
9736
9966
 
9737
9967
 
9738
9968
  def transpose_view(input, input_perm):
@@ -9846,9 +10076,6 @@ def triu(input, diagonal=0):
9846
10076
  r"""
9847
10077
  Zero the input tensor below the diagonal specified.
9848
10078
 
9849
- .. warning::
9850
- This is an experimental API that is subject to change or deletion.
9851
-
9852
10079
  Args:
9853
10080
  input (Tensor): The input tensor.
9854
10081
  diagonal (int, optional): The diagonal specified of 2-D tensor. Default ``0`` represents the main diagonal.
@@ -10283,7 +10510,7 @@ def grouped_matmul_v2(x, weight, bias=None, scale=None, offset=None, antiquant_s
10283
10510
  return grouped_matmul_v2_op(x, weight, bias, scale, offset, antiquant_scale, antiquant_offset, group_list, split_item, group_type)
10284
10511
 
10285
10512
 
10286
- def grouped_matmul_v4(x, weight, bias=None, scale=None, offset=None, antiquant_scale=None, antiquant_offset=None, pre_token_scale=None, group_list=None, activation_input=None, activation_quant_scale=None, activation_quant_offset=None, split_item=0, group_type=-1, group_list_type=0, act_type=0):
10513
+ def grouped_matmul_v4(x, weight, bias=None, scale=None, offset=None, antiquant_scale=None, antiquant_offset=None, pre_token_scale=None, group_list=None, activation_input=None, activation_quant_scale=None, activation_quant_offset=None, split_item=0, group_type=-1, group_list_type=0, act_type=0, output_dtype=None):
10287
10514
  r"""
10288
10515
  Group calculation matmul.
10289
10516
 
@@ -10298,8 +10525,10 @@ def grouped_matmul_v4(x, weight, bias=None, scale=None, offset=None, antiquant_s
10298
10525
  y_i = x_i\times (weight_i + antiquant\_offset_i) * antiquant\_scale_i + bias_i
10299
10526
 
10300
10527
  .. note::
10301
- Only when `bias` , `scale` , `offset` , `antiquant_scale` and `antiquant_offset` are all None, `group_type` is 0,
10302
- and `split_item` is 3, the reverse derivative is supported.
10528
+ - Only when `bias` , `scale` , `offset` , `antiquant_scale` and `antiquant_offset` are all None, `group_type` is 0,
10529
+ and `split_item` is 3, the reverse derivative is supported.
10530
+ - When `x` type is int8 and `weight` type is int4, the `scale` should be of the uint64 data type,
10531
+ but its memory needs to be arranged in float32 format.
10303
10532
 
10304
10533
  ** Per-Token-Quant **
10305
10534
 
@@ -10339,6 +10568,8 @@ def grouped_matmul_v4(x, weight, bias=None, scale=None, offset=None, antiquant_s
10339
10568
  as the cumsum of grouping size in each group, and 1 represents the positions as the grouping size in
10340
10569
  each group. Default: ``0``.
10341
10570
  act_type (int): Activation function type. Currently not supported. Default: ``0``.
10571
+ output_dtype (mindspore.dtype): Specifies the output data type, currently taking effect only when input x is int8 and weight is int4.
10572
+ If None is passed in, bfloat16 will be used by default. Default: ``None``.
10342
10573
 
10343
10574
 
10344
10575
  Parameter limitations 1
@@ -10429,7 +10660,7 @@ def grouped_matmul_v4(x, weight, bias=None, scale=None, offset=None, antiquant_s
10429
10660
  [108 112]
10430
10661
  [108 112]]
10431
10662
  """
10432
- return grouped_matmul_v4_op(x, weight, bias, scale, offset, antiquant_scale, antiquant_offset, pre_token_scale, group_list, activation_input, activation_quant_scale, activation_quant_offset, split_item, group_type, group_list_type, act_type)
10663
+ return grouped_matmul_v4_op(x, weight, bias, scale, offset, antiquant_scale, antiquant_offset, pre_token_scale, group_list, activation_input, activation_quant_scale, activation_quant_offset, split_item, group_type, group_list_type, act_type, output_dtype)
10433
10664
 
10434
10665
 
10435
10666
  def kv_cache_scatter_update(var, indices, updates, axis, reduce='none'):