mindspore 2.4.1__cp39-cp39-win_amd64.whl → 2.5.0__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (395) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  3. mindspore/Newtonsoft.Json.dll +0 -0
  4. mindspore/__init__.py +8 -3
  5. mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
  6. mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
  7. mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
  8. mindspore/_checkparam.py +0 -5
  9. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
  10. mindspore/_extends/parse/compile_config.py +64 -0
  11. mindspore/_extends/parse/deprecated/__init__.py +0 -0
  12. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +375 -0
  13. mindspore/_extends/parse/parser.py +23 -5
  14. mindspore/_extends/parse/standard_method.py +123 -27
  15. mindspore/_extends/pijit/pijit_func_white_list.py +1 -1
  16. mindspore/amp.py +7 -1
  17. mindspore/atlprov.dll +0 -0
  18. mindspore/avcodec-59.dll +0 -0
  19. mindspore/avdevice-59.dll +0 -0
  20. mindspore/avfilter-8.dll +0 -0
  21. mindspore/avformat-59.dll +0 -0
  22. mindspore/avutil-57.dll +0 -0
  23. mindspore/boost/boost_cell_wrapper.py +136 -41
  24. mindspore/c1.dll +0 -0
  25. mindspore/c1xx.dll +0 -0
  26. mindspore/c2.dll +0 -0
  27. mindspore/common/__init__.py +3 -1
  28. mindspore/common/_register_for_tensor.py +0 -1
  29. mindspore/common/_stub_tensor.py +25 -4
  30. mindspore/common/_tensor_cpp_method.py +17 -0
  31. mindspore/common/_tensor_docs.py +6132 -0
  32. mindspore/common/api.py +99 -25
  33. mindspore/common/dtype.py +34 -34
  34. mindspore/common/dump.py +2 -1
  35. mindspore/common/file_system.py +8 -1
  36. mindspore/common/generator.py +2 -0
  37. mindspore/common/hook_handle.py +3 -1
  38. mindspore/common/initializer.py +3 -4
  39. mindspore/common/lazy_inline.py +8 -2
  40. mindspore/common/mindir_util.py +10 -2
  41. mindspore/common/parameter.py +30 -27
  42. mindspore/common/tensor.py +713 -1337
  43. mindspore/communication/__init__.py +1 -1
  44. mindspore/communication/_comm_helper.py +10 -0
  45. mindspore/communication/comm_func.py +215 -173
  46. mindspore/communication/management.py +23 -20
  47. mindspore/context.py +292 -193
  48. mindspore/dataset/__init__.py +23 -19
  49. mindspore/dataset/callback/ds_callback.py +2 -1
  50. mindspore/dataset/core/config.py +84 -3
  51. mindspore/dataset/engine/cache_admin.py +3 -3
  52. mindspore/dataset/engine/cache_client.py +5 -4
  53. mindspore/dataset/engine/datasets.py +192 -149
  54. mindspore/dataset/engine/datasets_audio.py +14 -0
  55. mindspore/dataset/engine/datasets_standard_format.py +28 -11
  56. mindspore/dataset/engine/datasets_text.py +38 -1
  57. mindspore/dataset/engine/datasets_user_defined.py +125 -65
  58. mindspore/dataset/engine/datasets_vision.py +81 -8
  59. mindspore/dataset/engine/iterators.py +281 -63
  60. mindspore/dataset/engine/obs/util.py +8 -0
  61. mindspore/dataset/engine/queue.py +40 -0
  62. mindspore/dataset/engine/samplers.py +26 -2
  63. mindspore/dataset/engine/serializer_deserializer.py +1 -1
  64. mindspore/dataset/engine/validators.py +43 -11
  65. mindspore/dataset/transforms/py_transforms_util.py +17 -0
  66. mindspore/dataset/transforms/transforms.py +29 -12
  67. mindspore/dataset/vision/validators.py +1 -2
  68. mindspore/device_context/__init__.py +21 -0
  69. mindspore/device_context/ascend/__init__.py +25 -0
  70. mindspore/device_context/ascend/device.py +72 -0
  71. mindspore/device_context/ascend/op_debug.py +94 -0
  72. mindspore/device_context/ascend/op_precision.py +193 -0
  73. mindspore/device_context/ascend/op_tuning.py +127 -0
  74. mindspore/device_context/cpu/__init__.py +25 -0
  75. mindspore/device_context/cpu/device.py +62 -0
  76. mindspore/device_context/cpu/op_tuning.py +43 -0
  77. mindspore/device_context/gpu/__init__.py +21 -0
  78. mindspore/device_context/gpu/device.py +70 -0
  79. mindspore/device_context/gpu/op_precision.py +67 -0
  80. mindspore/device_context/gpu/op_tuning.py +175 -0
  81. mindspore/device_manager.py +134 -0
  82. mindspore/dnnl.dll +0 -0
  83. mindspore/dpcmi.dll +0 -0
  84. mindspore/experimental/llm_boost/__init__.py +3 -2
  85. mindspore/experimental/llm_boost/ascend_native/__init__.py +22 -0
  86. mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +211 -0
  87. mindspore/experimental/llm_boost/ascend_native/llm_boost.py +52 -0
  88. mindspore/experimental/llm_boost/atb/boost_base.py +239 -64
  89. mindspore/experimental/llm_boost/atb/llama_boost.py +52 -30
  90. mindspore/experimental/llm_boost/atb/qwen_boost.py +47 -24
  91. mindspore/experimental/llm_boost/register.py +1 -0
  92. mindspore/experimental/optim/adadelta.py +26 -22
  93. mindspore/experimental/optim/adam.py +3 -0
  94. mindspore/experimental/optim/lr_scheduler.py +33 -24
  95. mindspore/experimental/optim/radam.py +33 -30
  96. mindspore/hal/device.py +28 -0
  97. mindspore/hal/event.py +17 -0
  98. mindspore/hal/memory.py +94 -3
  99. mindspore/hal/stream.py +91 -6
  100. mindspore/include/api/context.h +1 -2
  101. mindspore/include/dataset/constants.h +2 -2
  102. mindspore/jpeg62.dll +0 -0
  103. mindspore/log.py +12 -0
  104. mindspore/mindrecord/__init__.py +1 -1
  105. mindspore/mindrecord/config.py +17 -316
  106. mindspore/mindrecord/filereader.py +1 -9
  107. mindspore/mindrecord/filewriter.py +5 -15
  108. mindspore/mindrecord/mindpage.py +1 -9
  109. mindspore/mindspore_backend.dll +0 -0
  110. mindspore/mindspore_common.dll +0 -0
  111. mindspore/mindspore_core.dll +0 -0
  112. mindspore/mindspore_glog.dll +0 -0
  113. mindspore/mindspore_ops.dll +0 -0
  114. mindspore/mint/__init__.py +824 -218
  115. mindspore/mint/distributed/__init__.py +66 -4
  116. mindspore/mint/distributed/distributed.py +2594 -44
  117. mindspore/mint/linalg/__init__.py +6 -0
  118. mindspore/mint/nn/__init__.py +473 -14
  119. mindspore/mint/nn/functional.py +486 -11
  120. mindspore/mint/nn/layer/__init__.py +17 -4
  121. mindspore/mint/nn/layer/_functions.py +330 -0
  122. mindspore/mint/nn/layer/activation.py +169 -1
  123. mindspore/mint/nn/layer/basic.py +123 -0
  124. mindspore/mint/nn/layer/conv.py +727 -0
  125. mindspore/mint/nn/layer/normalization.py +215 -19
  126. mindspore/mint/nn/layer/padding.py +797 -0
  127. mindspore/mint/nn/layer/pooling.py +170 -0
  128. mindspore/mint/optim/__init__.py +2 -1
  129. mindspore/mint/optim/adam.py +223 -0
  130. mindspore/mint/optim/adamw.py +26 -19
  131. mindspore/mint/special/__init__.py +2 -1
  132. mindspore/msobj140.dll +0 -0
  133. mindspore/mspdb140.dll +0 -0
  134. mindspore/mspdbcore.dll +0 -0
  135. mindspore/mspdbst.dll +0 -0
  136. mindspore/mspft140.dll +0 -0
  137. mindspore/msvcdis140.dll +0 -0
  138. mindspore/msvcp140_1.dll +0 -0
  139. mindspore/msvcp140_2.dll +0 -0
  140. mindspore/msvcp140_atomic_wait.dll +0 -0
  141. mindspore/msvcp140_codecvt_ids.dll +0 -0
  142. mindspore/multiprocessing/__init__.py +5 -0
  143. mindspore/nn/__init__.py +2 -0
  144. mindspore/nn/cell.py +142 -21
  145. mindspore/nn/dynamic_lr.py +2 -1
  146. mindspore/nn/layer/activation.py +6 -6
  147. mindspore/nn/layer/basic.py +35 -25
  148. mindspore/nn/layer/channel_shuffle.py +3 -3
  149. mindspore/nn/layer/conv.py +3 -0
  150. mindspore/nn/layer/embedding.py +3 -3
  151. mindspore/nn/layer/normalization.py +8 -7
  152. mindspore/nn/layer/padding.py +4 -3
  153. mindspore/nn/layer/pooling.py +55 -23
  154. mindspore/nn/layer/rnn_cells.py +1 -1
  155. mindspore/nn/layer/rnns.py +2 -1
  156. mindspore/nn/layer/timedistributed.py +5 -5
  157. mindspore/nn/layer/transformer.py +48 -26
  158. mindspore/nn/learning_rate_schedule.py +5 -3
  159. mindspore/nn/loss/loss.py +31 -36
  160. mindspore/nn/optim/ada_grad.py +1 -0
  161. mindspore/nn/optim/adadelta.py +2 -2
  162. mindspore/nn/optim/adam.py +1 -1
  163. mindspore/nn/optim/lars.py +1 -4
  164. mindspore/nn/optim/optimizer.py +1 -1
  165. mindspore/nn/optim/rprop.py +2 -2
  166. mindspore/nn/optim/thor.py +2 -1
  167. mindspore/nn/utils/__init__.py +22 -0
  168. mindspore/nn/utils/init.py +73 -0
  169. mindspore/nn/wrap/cell_wrapper.py +4 -6
  170. mindspore/nn/wrap/loss_scale.py +3 -4
  171. mindspore/numpy/array_creations.py +60 -62
  172. mindspore/numpy/array_ops.py +148 -143
  173. mindspore/numpy/logic_ops.py +41 -42
  174. mindspore/numpy/math_ops.py +361 -359
  175. mindspore/numpy/utils.py +16 -16
  176. mindspore/numpy/utils_const.py +4 -4
  177. mindspore/opencv_core452.dll +0 -0
  178. mindspore/opencv_imgcodecs452.dll +0 -0
  179. mindspore/opencv_imgproc452.dll +0 -0
  180. mindspore/ops/__init__.py +2 -1
  181. mindspore/ops/_grad_experimental/grad_comm_ops.py +107 -8
  182. mindspore/ops/_grad_experimental/grad_debug_ops.py +6 -1
  183. mindspore/ops/_grad_experimental/grad_inner_ops.py +9 -0
  184. mindspore/ops/_grad_experimental/grad_math_ops.py +2 -1
  185. mindspore/ops/_op_impl/cpu/__init__.py +1 -0
  186. mindspore/ops/_op_impl/cpu/raise_op.py +28 -0
  187. mindspore/ops/_vmap/vmap_array_ops.py +20 -19
  188. mindspore/ops/_vmap/vmap_base.py +0 -2
  189. mindspore/ops/_vmap/vmap_grad_nn_ops.py +19 -13
  190. mindspore/ops/_vmap/vmap_math_ops.py +11 -9
  191. mindspore/ops/_vmap/vmap_nn_ops.py +20 -34
  192. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +149 -12
  193. mindspore/ops/auto_generate/gen_arg_handler.py +0 -61
  194. mindspore/ops/auto_generate/gen_extend_func.py +554 -60
  195. mindspore/ops/auto_generate/gen_ops_def.py +1621 -115
  196. mindspore/ops/auto_generate/gen_ops_prim.py +8027 -3411
  197. mindspore/ops/auto_generate/pyboost_inner_prim.py +183 -79
  198. mindspore/ops/composite/base.py +1 -1
  199. mindspore/ops/composite/multitype_ops/_compile_utils.py +229 -30
  200. mindspore/ops/composite/multitype_ops/pow_impl.py +0 -29
  201. mindspore/ops/function/__init__.py +12 -0
  202. mindspore/ops/function/array_func.py +561 -159
  203. mindspore/ops/function/clip_func.py +64 -0
  204. mindspore/ops/function/debug_func.py +28 -20
  205. mindspore/ops/function/image_func.py +1 -1
  206. mindspore/ops/function/linalg_func.py +5 -4
  207. mindspore/ops/function/math_func.py +1664 -294
  208. mindspore/ops/function/nn_func.py +988 -317
  209. mindspore/ops/function/parameter_func.py +3 -56
  210. mindspore/ops/function/random_func.py +243 -33
  211. mindspore/ops/function/sparse_unary_func.py +1 -1
  212. mindspore/ops/functional.py +18 -5
  213. mindspore/ops/functional_overload.py +897 -0
  214. mindspore/ops/operations/__init__.py +3 -2
  215. mindspore/ops/operations/_embedding_cache_ops.py +4 -4
  216. mindspore/ops/operations/_grad_ops.py +2 -34
  217. mindspore/ops/operations/_infer_ops.py +2 -1
  218. mindspore/ops/operations/_inner_ops.py +38 -8
  219. mindspore/ops/operations/array_ops.py +45 -303
  220. mindspore/ops/operations/comm_ops.py +23 -17
  221. mindspore/ops/operations/custom_ops.py +7 -49
  222. mindspore/ops/operations/debug_ops.py +42 -47
  223. mindspore/ops/operations/inner_ops.py +6 -4
  224. mindspore/ops/operations/linalg_ops.py +3 -2
  225. mindspore/ops/operations/manually_defined/ops_def.py +185 -104
  226. mindspore/ops/operations/math_ops.py +11 -216
  227. mindspore/ops/operations/nn_ops.py +153 -310
  228. mindspore/ops/primitive.py +23 -21
  229. mindspore/ops/tensor_method.py +1669 -0
  230. mindspore/ops_generate/aclnn_kernel_register_auto_cc_generator.py +110 -0
  231. mindspore/ops_generate/add_tensor_docs_generator.py +54 -0
  232. mindspore/ops_generate/arg_handler.py +0 -61
  233. mindspore/ops_generate/auto_grad_impl_cc_generator.py +135 -0
  234. mindspore/ops_generate/auto_grad_reg_cc_generator.py +93 -0
  235. mindspore/ops_generate/base_generator.py +11 -0
  236. mindspore/ops_generate/cpp_create_prim_instance_helper_generator.py +108 -0
  237. mindspore/ops_generate/functional_map_cpp_generator.py +491 -0
  238. mindspore/ops_generate/functional_overload_py_generator.py +110 -0
  239. mindspore/ops_generate/functions_cc_generator.py +233 -0
  240. mindspore/ops_generate/gen_aclnn_implement.py +110 -114
  241. mindspore/ops_generate/gen_constants.py +157 -3
  242. mindspore/ops_generate/gen_ops.py +245 -990
  243. mindspore/ops_generate/gen_pyboost_func.py +97 -998
  244. mindspore/ops_generate/gen_utils.py +119 -33
  245. mindspore/ops_generate/lite_ops_cpp_generator.py +155 -0
  246. mindspore/ops_generate/op_api_proto.py +206 -0
  247. mindspore/ops_generate/op_def_py_generator.py +131 -0
  248. mindspore/ops_generate/op_prim_py_generator.py +480 -0
  249. mindspore/ops_generate/op_proto.py +373 -108
  250. mindspore/ops_generate/op_template_parser.py +436 -0
  251. mindspore/ops_generate/ops_def_cc_generator.py +288 -0
  252. mindspore/ops_generate/ops_def_h_generator.py +74 -0
  253. mindspore/ops_generate/ops_name_h_generator.py +68 -0
  254. mindspore/ops_generate/ops_primitive_h_generator.py +81 -0
  255. mindspore/ops_generate/pyboost_functions_cpp_generator.py +370 -0
  256. mindspore/ops_generate/pyboost_functions_h_generator.py +68 -0
  257. mindspore/ops_generate/pyboost_functions_py_generator.py +148 -0
  258. mindspore/ops_generate/pyboost_grad_function_cpp_generator.py +154 -0
  259. mindspore/ops_generate/pyboost_inner_prim_generator.py +131 -0
  260. mindspore/ops_generate/pyboost_native_grad_functions_generator.py +268 -0
  261. mindspore/ops_generate/pyboost_op_cpp_code_generator.py +851 -0
  262. mindspore/ops_generate/pyboost_overload_functions_cpp_generator.py +344 -0
  263. mindspore/ops_generate/pyboost_utils.py +92 -33
  264. mindspore/ops_generate/template.py +294 -44
  265. mindspore/ops_generate/tensor_func_reg_cpp_generator.py +422 -0
  266. mindspore/parallel/__init__.py +3 -3
  267. mindspore/parallel/_auto_parallel_context.py +44 -34
  268. mindspore/parallel/_cell_wrapper.py +22 -3
  269. mindspore/parallel/_parallel_serialization.py +13 -2
  270. mindspore/parallel/_utils.py +4 -2
  271. mindspore/parallel/algo_parameter_config.py +1 -1
  272. mindspore/parallel/checkpoint_transform.py +44 -0
  273. mindspore/parallel/cluster/process_entity/_api.py +131 -37
  274. mindspore/parallel/cluster/process_entity/_utils.py +41 -6
  275. mindspore/parallel/cluster/run.py +20 -3
  276. mindspore/parallel/parameter_broadcast.py +1 -1
  277. mindspore/parallel/shard.py +3 -0
  278. mindspore/parallel/transform_safetensors.py +119 -253
  279. mindspore/pgodb140.dll +0 -0
  280. mindspore/pgort140.dll +0 -0
  281. mindspore/profiler/__init__.py +17 -4
  282. mindspore/profiler/analysis/__init__.py +0 -0
  283. mindspore/profiler/analysis/parser/__init__.py +0 -0
  284. mindspore/profiler/analysis/parser/ascend_cann_parser.py +166 -0
  285. mindspore/profiler/analysis/parser/base_parser.py +158 -0
  286. mindspore/profiler/analysis/parser/framework_cann_relation_parser.py +45 -0
  287. mindspore/profiler/analysis/parser/ms_framework_parser.py +142 -0
  288. mindspore/profiler/analysis/parser/ms_minddata_parser.py +145 -0
  289. mindspore/profiler/analysis/parser/timeline_assembly_factory/__init__.py +0 -0
  290. mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +261 -0
  291. mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +40 -0
  292. mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +84 -0
  293. mindspore/profiler/analysis/parser/timeline_creator/__init__.py +0 -0
  294. mindspore/profiler/analysis/parser/timeline_creator/base_timeline_creator.py +44 -0
  295. mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +90 -0
  296. mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +76 -0
  297. mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +103 -0
  298. mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +134 -0
  299. mindspore/profiler/analysis/parser/timeline_event/__init__.py +0 -0
  300. mindspore/profiler/analysis/parser/timeline_event/base_event.py +233 -0
  301. mindspore/profiler/analysis/parser/timeline_event/cpu_op_event.py +47 -0
  302. mindspore/profiler/analysis/parser/timeline_event/flow_event.py +36 -0
  303. mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +260 -0
  304. mindspore/profiler/analysis/parser/timeline_event/msprof_event.py +73 -0
  305. mindspore/profiler/analysis/parser/timeline_event/scope_layer_event.py +53 -0
  306. mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +146 -0
  307. mindspore/profiler/analysis/task_manager.py +131 -0
  308. mindspore/profiler/analysis/time_converter.py +84 -0
  309. mindspore/profiler/analysis/viewer/__init__.py +0 -0
  310. mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +333 -0
  311. mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +87 -0
  312. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +252 -0
  313. mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +313 -0
  314. mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +322 -0
  315. mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +265 -0
  316. mindspore/profiler/analysis/viewer/ascend_timeline_viewer.py +58 -0
  317. mindspore/profiler/analysis/viewer/base_viewer.py +26 -0
  318. mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +97 -0
  319. mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +581 -0
  320. mindspore/profiler/analysis/work_flow.py +73 -0
  321. mindspore/profiler/common/ascend_msprof_exporter.py +138 -0
  322. mindspore/profiler/common/command_executor.py +90 -0
  323. mindspore/profiler/common/constant.py +174 -3
  324. mindspore/profiler/common/file_manager.py +208 -0
  325. mindspore/profiler/common/log.py +130 -0
  326. mindspore/profiler/common/msprof_cmd_tool.py +202 -0
  327. mindspore/profiler/common/path_manager.py +371 -0
  328. mindspore/profiler/common/process_bar.py +168 -0
  329. mindspore/profiler/common/process_pool.py +9 -3
  330. mindspore/profiler/common/profiler_context.py +476 -0
  331. mindspore/profiler/common/profiler_info.py +304 -0
  332. mindspore/profiler/common/profiler_output_path.py +284 -0
  333. mindspore/profiler/common/profiler_parameters.py +210 -0
  334. mindspore/profiler/common/profiler_path_manager.py +120 -0
  335. mindspore/profiler/common/record_function.py +76 -0
  336. mindspore/profiler/common/tlv_decoder.py +76 -0
  337. mindspore/profiler/common/util.py +75 -2
  338. mindspore/profiler/dynamic_profiler.py +270 -37
  339. mindspore/profiler/envprofiler.py +138 -0
  340. mindspore/profiler/mstx.py +199 -0
  341. mindspore/profiler/platform/__init__.py +21 -0
  342. mindspore/profiler/platform/base_profiler.py +40 -0
  343. mindspore/profiler/platform/cpu_profiler.py +124 -0
  344. mindspore/profiler/platform/gpu_profiler.py +74 -0
  345. mindspore/profiler/platform/npu_profiler.py +309 -0
  346. mindspore/profiler/profiler.py +580 -93
  347. mindspore/profiler/profiler_action_controller.py +187 -0
  348. mindspore/profiler/profiler_interface.py +114 -0
  349. mindspore/profiler/schedule.py +208 -0
  350. mindspore/rewrite/api/symbol_tree.py +1 -2
  351. mindspore/run_check/_check_version.py +18 -13
  352. mindspore/runtime/__init__.py +37 -0
  353. mindspore/runtime/device.py +27 -0
  354. mindspore/runtime/event.py +209 -0
  355. mindspore/runtime/executor.py +148 -0
  356. mindspore/runtime/memory.py +392 -0
  357. mindspore/runtime/stream.py +460 -0
  358. mindspore/runtime/thread_bind_core.py +401 -0
  359. mindspore/swresample-4.dll +0 -0
  360. mindspore/swscale-6.dll +0 -0
  361. mindspore/tbbmalloc.dll +0 -0
  362. mindspore/tinyxml2.dll +0 -0
  363. mindspore/train/__init__.py +2 -2
  364. mindspore/train/_utils.py +53 -18
  365. mindspore/train/amp.py +8 -4
  366. mindspore/train/callback/_checkpoint.py +32 -18
  367. mindspore/train/callback/_early_stop.py +1 -1
  368. mindspore/train/callback/_flops_collector.py +105 -69
  369. mindspore/train/callback/_history.py +1 -1
  370. mindspore/train/callback/_summary_collector.py +44 -6
  371. mindspore/train/callback/_tft_register.py +37 -15
  372. mindspore/train/dataset_helper.py +11 -11
  373. mindspore/train/metrics/precision.py +4 -5
  374. mindspore/train/mind_ir_pb2.py +167 -46
  375. mindspore/train/model.py +13 -14
  376. mindspore/train/serialization.py +461 -72
  377. mindspore/train/summary/summary_record.py +1 -2
  378. mindspore/train/train_thor/model_thor.py +1 -1
  379. mindspore/turbojpeg.dll +0 -0
  380. mindspore/utils/__init__.py +4 -2
  381. mindspore/utils/dryrun.py +138 -0
  382. mindspore/utils/runtime_execution_order_check.py +550 -0
  383. mindspore/vcmeta.dll +0 -0
  384. mindspore/vcruntime140.dll +0 -0
  385. mindspore/vcruntime140_1.dll +0 -0
  386. mindspore/version.py +1 -1
  387. {mindspore-2.4.1.dist-info → mindspore-2.5.0.dist-info}/METADATA +3 -4
  388. {mindspore-2.4.1.dist-info → mindspore-2.5.0.dist-info}/RECORD +391 -265
  389. {mindspore-2.4.1.dist-info → mindspore-2.5.0.dist-info}/entry_points.txt +1 -1
  390. mindspore/common/_tensor_overload.py +0 -139
  391. mindspore/mindspore_np_dtype.dll +0 -0
  392. mindspore/profiler/envprofiling.py +0 -254
  393. mindspore/profiler/profiling.py +0 -1926
  394. {mindspore-2.4.1.dist-info → mindspore-2.5.0.dist-info}/WHEEL +0 -0
  395. {mindspore-2.4.1.dist-info → mindspore-2.5.0.dist-info}/top_level.txt +0 -0
@@ -13,17 +13,32 @@
13
13
  # limitations under the License.
14
14
  # ============================================================================
15
15
  """boost base class"""
16
+ from enum import Enum
16
17
  import numpy as np
17
18
  import mindspore as ms
18
19
  from mindspore import ops, Tensor
20
+ from mindspore import log as logger
19
21
  from mindspore.ops import operations as P
20
22
  import mindspore.common.dtype as mstype
21
23
  from mindspore._c_expression import _set_format
22
-
23
24
  from mindspore.common.parameter import Parameter
24
25
  from mindspore.experimental.llm_boost.utils import get_real_rank, get_real_group_size
25
26
  from mindspore.common.initializer import Zero
26
27
 
28
+ FORMAT_NZ = "FRACTAL_NZ"
29
+ BUILDIN_BACKEND_NAME = "ATB"
30
+
31
+
32
+ class PositionEmbeddingType(int, Enum):
33
+ ROPE = 0
34
+ ALIBI = 1
35
+ ABSOLUTE = 2
36
+
37
+
38
+ class NormType(int, Enum):
39
+ RMS_NORM = 0
40
+ LAYER_NORM = 1
41
+
27
42
 
28
43
  class AttentionMask:
29
44
  """attention mask"""
@@ -31,30 +46,34 @@ class AttentionMask:
31
46
  @classmethod
32
47
  def static(cls, max_seq_len, dtype=mstype.float16, need_nz=False):
33
48
  """cache mask"""
34
- bias_cache = Tensor(np.tril(np.ones((max_seq_len, max_seq_len), dtype=np.bool_))).reshape(max_seq_len,
35
- max_seq_len)
49
+ bias_cache = Tensor(
50
+ np.tril(np.ones((max_seq_len, max_seq_len), dtype=np.bool_))
51
+ ).reshape(max_seq_len, max_seq_len)
36
52
  bias_cache = ~bias_cache
37
53
  if dtype == mstype.float16:
38
54
  mask_value = Tensor(np.finfo(np.float32).min, mstype.float16)
39
55
  else:
40
56
  mask_value = Tensor(1)
41
- attn_mask = ops.masked_fill(Tensor(np.zeros(
42
- (max_seq_len, max_seq_len)), dtype=mstype.float16), bias_cache, mask_value)
57
+ attn_mask = ops.masked_fill(
58
+ Tensor(np.zeros((max_seq_len, max_seq_len)), dtype=mstype.float16),
59
+ bias_cache,
60
+ mask_value,
61
+ )
43
62
  if need_nz:
44
63
  # ND -> NZ
45
64
  attn_mask = ops.reshape(attn_mask, (1, max_seq_len, max_seq_len))
46
- attn_mask = ops.reshape(
47
- attn_mask, (1, max_seq_len, max_seq_len // 16, 16))
65
+ attn_mask = ops.reshape(attn_mask, (1, max_seq_len, max_seq_len // 16, 16))
48
66
  attn_mask = ops.transpose(attn_mask, (0, 2, 1, 3)).contiguous()
49
- attn_mask = _set_format(attn_mask, "FRACTAL_NZ")
67
+ attn_mask = _set_format(attn_mask, FORMAT_NZ)
50
68
  return attn_mask
51
69
 
52
70
 
53
- class AtbBoostBase():
71
+ class AtbBoostBase:
54
72
  """atb boost base class"""
55
73
 
56
74
  def __init__(self, config):
57
75
  super().__init__()
76
+ self.backend_name = BUILDIN_BACKEND_NAME
58
77
  self.is_first_iteration = False
59
78
  self.config = config
60
79
  self.dtype = config.compute_dtype
@@ -68,27 +87,97 @@ class AtbBoostBase():
68
87
  self.need_nz = config.need_nz
69
88
  self.placeholder = Tensor(np.zeros(1), dtype=self.dtype)
70
89
  self.lm_head_indices_fake = Tensor([0], dtype=mstype.int64)
71
- self.position_embedding_type = "ROPE"
90
+ self.position_embedding_type = PositionEmbeddingType.ROPE
72
91
  self.add_norm_enable = True
73
92
  self.max_decode_length = self.config.max_decode_length
74
93
  self.max_base_len = 128
75
94
  self.attn_mask = AttentionMask.static(
76
- self.max_base_len, dtype=self.dtype, need_nz=self.need_nz)
95
+ self.max_base_len, dtype=self.dtype, need_nz=self.need_nz
96
+ )
77
97
 
78
98
  self.cast = P.Cast()
79
99
  self.reshape = P.Reshape()
80
100
  self.kv_quant = None
81
101
  self.rank_id = get_real_rank()
82
102
  self.device_num = get_real_group_size()
103
+ self.ascend_weight = []
104
+ self.k_caches = []
105
+ self.v_caches = []
83
106
 
84
107
  def _convert_tensor_format_and_dtype(self, tensor, dtype=mstype.float16):
85
108
  tensor = self.cast(tensor, dtype=dtype)
86
109
  if self.need_nz:
87
- tensor = _set_format(tensor, "FRACTAL_NZ")
110
+ tensor = _set_format(tensor, FORMAT_NZ)
88
111
  return tensor
89
112
 
113
+ def _convert_qkv_concat_weight(self, param_dict):
114
+ """convert qkv concat weight"""
115
+ for i in range(self.num_layers):
116
+ # qkv weight concat
117
+ wq_weight_name = f"model.layers.{i}.attention.wq.weight"
118
+ wk_weight_name = f"model.layers.{i}.attention.wk.weight"
119
+ wv_weight_name = f"model.layers.{i}.attention.wv.weight"
120
+ qkv_concat_weight_name = f"model.layers.{i}.attention.w_qkv.weight"
121
+ if wq_weight_name not in param_dict:
122
+ break
123
+ wq_weight = param_dict[wq_weight_name].asnumpy()
124
+ wk_weight = param_dict[wk_weight_name].asnumpy()
125
+ wv_weight = param_dict[wv_weight_name].asnumpy()
126
+ qkv_weight = np.concatenate((wq_weight, wk_weight, wv_weight), 0)
127
+ param_dict[qkv_concat_weight_name] = Parameter(
128
+ qkv_weight, name=qkv_concat_weight_name
129
+ )
130
+
131
+ # gate hidden weight concat
132
+ ffn_gate_weight_name = f"model.layers.{i}.feed_forward.w1.weight"
133
+ ffn_hidden_weight_name = f"model.layers.{i}.feed_forward.w3.weight"
134
+ gate_hidden_concat_weight_name = (
135
+ f"model.layers.{i}.feed_forward.w_gate_hidden.weight"
136
+ )
137
+
138
+ ffn_gate_weight = param_dict[ffn_gate_weight_name].asnumpy()
139
+ ffn_hidden_weight = param_dict[ffn_hidden_weight_name].asnumpy()
140
+ gate_hidden_weight = np.concatenate((ffn_gate_weight, ffn_hidden_weight), 0)
141
+ param_dict[gate_hidden_concat_weight_name] = Parameter(
142
+ gate_hidden_weight, name=gate_hidden_concat_weight_name
143
+ )
144
+
145
+ param_dict.pop(wq_weight_name)
146
+ param_dict.pop(wk_weight_name)
147
+ param_dict.pop(wv_weight_name)
148
+ param_dict.pop(ffn_gate_weight_name)
149
+ param_dict.pop(ffn_hidden_weight_name)
150
+ logger.info(f"transform: {qkv_concat_weight_name}")
151
+ logger.info(f"transform: {gate_hidden_concat_weight_name}")
152
+
153
+ for i in range(self.num_layers):
154
+ # qkv bias concat
155
+ wq_bias_name = f"model.layers.{i}.attention.wq.bias"
156
+ wk_bias_name = f"model.layers.{i}.attention.wk.bias"
157
+ wv_bias_name = f"model.layers.{i}.attention.wv.bias"
158
+ qkv_concat_bias_name = f"model.layers.{i}.attention.w_qkv.bias"
159
+ if wq_bias_name not in param_dict:
160
+ break
161
+
162
+ wq_bias_weight = param_dict[wq_bias_name].asnumpy()
163
+ wk_bias_weight = param_dict[wk_bias_name].asnumpy()
164
+ wv_bias_weight = param_dict[wv_bias_name].asnumpy()
165
+ qkv_bias_weight = np.concatenate(
166
+ (wq_bias_weight, wk_bias_weight, wv_bias_weight), 0
167
+ )
168
+ param_dict[qkv_concat_bias_name] = Parameter(
169
+ qkv_bias_weight, name=qkv_concat_bias_name
170
+ )
171
+
172
+ param_dict.pop(wq_bias_name)
173
+ param_dict.pop(wk_bias_name)
174
+ param_dict.pop(wv_bias_name)
175
+ logger.info(f"transform: {qkv_concat_bias_name}")
176
+ return param_dict
177
+
90
178
  def set_weights(self, parm_dict, dtype=mstype.float16):
91
179
  """set weights for llm boost"""
180
+ self._convert_qkv_concat_weight(parm_dict)
92
181
  embedding_weight_name = "model.tok_embeddings.embedding_weight"
93
182
  attention_norm_name = "attention_norm"
94
183
  qkv_name = "attention.w_qkv"
@@ -101,45 +190,88 @@ class AtbBoostBase():
101
190
  placeholder = Parameter(Tensor(np.zeros(1), dtype=dtype))
102
191
 
103
192
  ascend_weight = []
104
- ascend_weight.append(
105
- self.cast(parm_dict[embedding_weight_name], dtype))
193
+ ascend_weight.append(self.cast(parm_dict[embedding_weight_name], dtype))
106
194
  for i in range(self.num_layers):
107
- ascend_weight.append(self._convert_tensor_format_and_dtype(
108
- parm_dict[f"model.layers.{i}.{attention_norm_name}.weight"], dtype))
195
+ ascend_weight.append(
196
+ self._convert_tensor_format_and_dtype(
197
+ parm_dict[f"model.layers.{i}.{attention_norm_name}.weight"], dtype
198
+ )
199
+ )
109
200
  ascend_weight.extend([placeholder] * 3)
110
201
 
111
202
  ascend_weight.append(
112
- self._convert_tensor_format_and_dtype(parm_dict[f"model.layers.{i}.{qkv_name}.weight"], dtype))
113
- ascend_weight.append(self._convert_tensor_format_and_dtype(parm_dict.get(
114
- f"model.layers.{i}.{qkv_name}.bias", placeholder), dtype))
203
+ self._convert_tensor_format_and_dtype(
204
+ parm_dict[f"model.layers.{i}.{qkv_name}.weight"], dtype
205
+ )
206
+ )
207
+ ascend_weight.append(
208
+ self._convert_tensor_format_and_dtype(
209
+ parm_dict.get(f"model.layers.{i}.{qkv_name}.bias", placeholder),
210
+ dtype,
211
+ )
212
+ )
115
213
  ascend_weight.extend([placeholder] * 16)
116
214
 
117
215
  ascend_weight.append(
118
- self._convert_tensor_format_and_dtype(parm_dict[f"model.layers.{i}.{o_name}.weight"], dtype))
119
- ascend_weight.append(self._convert_tensor_format_and_dtype(parm_dict.get(
120
- f"model.layers.{i}.{o_name}.bias", placeholder), dtype))
216
+ self._convert_tensor_format_and_dtype(
217
+ parm_dict[f"model.layers.{i}.{o_name}.weight"], dtype
218
+ )
219
+ )
220
+ ascend_weight.append(
221
+ self._convert_tensor_format_and_dtype(
222
+ parm_dict.get(f"model.layers.{i}.{o_name}.bias", placeholder), dtype
223
+ )
224
+ )
121
225
  ascend_weight.extend([placeholder] * 4)
122
226
 
123
227
  ascend_weight.append(
124
- self._convert_tensor_format_and_dtype(parm_dict[f"model.layers.{i}.{mlp_norm_name}.weight"], dtype))
228
+ self._convert_tensor_format_and_dtype(
229
+ parm_dict[f"model.layers.{i}.{mlp_norm_name}.weight"], dtype
230
+ )
231
+ )
125
232
  ascend_weight.extend([placeholder] * 3)
126
233
 
127
234
  ascend_weight.append(
128
- self._convert_tensor_format_and_dtype(parm_dict[f"model.layers.{i}.{mlp_gate_name}.weight"], dtype))
129
- ascend_weight.append(self._convert_tensor_format_and_dtype(parm_dict.get(
130
- f"model.layers.{i}.{mlp_gate_name}.bias", placeholder), dtype))
235
+ self._convert_tensor_format_and_dtype(
236
+ parm_dict[f"model.layers.{i}.{mlp_gate_name}.weight"], dtype
237
+ )
238
+ )
239
+ ascend_weight.append(
240
+ self._convert_tensor_format_and_dtype(
241
+ parm_dict.get(
242
+ f"model.layers.{i}.{mlp_gate_name}.bias", placeholder
243
+ ),
244
+ dtype,
245
+ )
246
+ )
131
247
  ascend_weight.extend([placeholder] * 10)
132
248
 
133
249
  ascend_weight.append(
134
- self._convert_tensor_format_and_dtype(parm_dict[f"model.layers.{i}.{mlp_down_name}.weight"], dtype))
135
- ascend_weight.append(self._convert_tensor_format_and_dtype(parm_dict.get(
136
- f"model.layers.{i}.{mlp_down_name}.bias", placeholder), dtype))
250
+ self._convert_tensor_format_and_dtype(
251
+ parm_dict[f"model.layers.{i}.{mlp_down_name}.weight"], dtype
252
+ )
253
+ )
254
+ ascend_weight.append(
255
+ self._convert_tensor_format_and_dtype(
256
+ parm_dict.get(
257
+ f"model.layers.{i}.{mlp_down_name}.bias", placeholder
258
+ ),
259
+ dtype,
260
+ )
261
+ )
137
262
  ascend_weight.extend([placeholder] * 4)
138
263
 
139
264
  ascend_weight.append(
140
- self._convert_tensor_format_and_dtype(parm_dict[f"{norm_out_name}.weight"], dtype))
265
+ self._convert_tensor_format_and_dtype(
266
+ parm_dict[f"{norm_out_name}.weight"], dtype
267
+ )
268
+ )
141
269
  ascend_weight.append(
142
- self._convert_tensor_format_and_dtype(parm_dict[f"{lm_head_name}.weight"], dtype))
270
+ self._convert_tensor_format_and_dtype(
271
+ parm_dict[f"{lm_head_name}.weight"], dtype
272
+ )
273
+ )
274
+ self.ascend_weight = ascend_weight
143
275
  self.atb_encoder_operation.set_weights(ascend_weight)
144
276
  self.atb_decoder_operation.set_weights(ascend_weight)
145
277
 
@@ -147,20 +279,47 @@ class AtbBoostBase():
147
279
  """set kv_cache for llm boost"""
148
280
  if not k_caches or v_caches:
149
281
  if self.need_nz:
150
- kv_shape = (self.config.num_blocks, self.num_kv_heads*self.head_dim //
151
- self.device_num // 16, self.config.block_size, 16)
152
- k_caches = [_set_format(Parameter(Tensor(
153
- shape=kv_shape, dtype=self.dtype, init=Zero())), "FRACTAL_NZ") for _ in range(self.num_layers)]
154
- v_caches = [_set_format(Parameter(Tensor(
155
- shape=kv_shape, dtype=self.dtype, init=Zero())), "FRACTAL_NZ") for _ in range(self.num_layers)]
282
+ kv_shape = (
283
+ self.config.num_blocks,
284
+ self.num_kv_heads * self.head_dim // self.device_num // 16,
285
+ self.config.block_size,
286
+ 16,
287
+ )
288
+ k_caches = [
289
+ _set_format(
290
+ Parameter(
291
+ Tensor(shape=kv_shape, dtype=self.dtype, init=Zero())
292
+ ),
293
+ FORMAT_NZ,
294
+ )
295
+ for _ in range(self.num_layers)
296
+ ]
297
+ v_caches = [
298
+ _set_format(
299
+ Parameter(
300
+ Tensor(shape=kv_shape, dtype=self.dtype, init=Zero())
301
+ ),
302
+ FORMAT_NZ,
303
+ )
304
+ for _ in range(self.num_layers)
305
+ ]
156
306
  else:
157
- kv_shape = (self.config.num_blocks, self.config.block_size,
158
- self.num_kv_heads // self.device_num, self.head_dim)
159
- k_caches = [Parameter(Tensor(
160
- shape=kv_shape, dtype=self.dtype, init=Zero())) for _ in range(self.num_layers)]
161
- v_caches = [Parameter(Tensor(
162
- shape=kv_shape, dtype=self.dtype, init=Zero())) for _ in range(self.num_layers)]
163
-
307
+ kv_shape = (
308
+ self.config.num_blocks,
309
+ self.config.block_size,
310
+ self.num_kv_heads // self.device_num,
311
+ self.head_dim,
312
+ )
313
+ k_caches = [
314
+ Parameter(Tensor(shape=kv_shape, dtype=self.dtype, init=Zero()))
315
+ for _ in range(self.num_layers)
316
+ ]
317
+ v_caches = [
318
+ Parameter(Tensor(shape=kv_shape, dtype=self.dtype, init=Zero()))
319
+ for _ in range(self.num_layers)
320
+ ]
321
+ self.k_caches = k_caches
322
+ self.v_caches = v_caches
164
323
  self.atb_encoder_operation.set_kvcache(k_caches, v_caches)
165
324
  self.atb_decoder_operation.set_kvcache(k_caches, v_caches)
166
325
 
@@ -171,11 +330,9 @@ class AtbBoostBase():
171
330
  def _execute_operator(self, acl_inputs, acl_param):
172
331
  """execute operator."""
173
332
  if self.is_first_iteration:
174
- acl_model_out = self.atb_encoder_operation.forward(
175
- acl_inputs, acl_param)
333
+ acl_model_out = self.atb_encoder_operation.forward(acl_inputs, acl_param)
176
334
  else:
177
- acl_model_out = self.atb_decoder_operation.forward(
178
- acl_inputs, acl_param)
335
+ acl_model_out = self.atb_decoder_operation.forward(acl_inputs, acl_param)
179
336
  acl_hidden_state = acl_model_out[0]
180
337
  return acl_hidden_state
181
338
 
@@ -183,28 +340,46 @@ class AtbBoostBase():
183
340
  r"""
184
341
  LlmBoost forward.
185
342
  """
186
- input_ids = boost_inputs["input_ids"]
187
- position_ids = boost_inputs["position_ids"]
188
- cos_embed = boost_inputs["cos_embed"]
189
- sin_embed = boost_inputs["sin_embed"]
190
- block_tables = boost_inputs["block_tables"]
191
- slot_mapping = boost_inputs["slot_mapping"]
192
- batch_valid_length = boost_inputs["batch_valid_length"]
193
- lm_head_indices = boost_inputs["lm_head_indices"]
194
- seqLen = boost_inputs["seq_lens"]
343
+ input_ids = boost_inputs.get("input_ids", None)
344
+ position_ids = boost_inputs.get("position_ids", None)
345
+ cos_embed = boost_inputs.get("cos_embed", None)
346
+ sin_embed = boost_inputs.get("sin_embed", None)
347
+ block_tables = boost_inputs.get("block_tables", None)
348
+ slot_mapping = boost_inputs.get("slot_mapping", None)
349
+ batch_valid_length = boost_inputs.get("batch_valid_length", None)
350
+ lm_head_indices = boost_inputs.get("lm_head_indices", None)
351
+ seqLen = boost_inputs.get("seq_lens", None)
352
+ input_ids = self.reshape(input_ids, (-1,))
195
353
  if self.is_first_iteration:
196
354
  attention_mask = self.attn_mask
197
355
  else:
198
- position_ids = batch_valid_length - 1
356
+ if position_ids is None:
357
+ position_ids = batch_valid_length - 1
199
358
  attention_mask = self.placeholder
200
359
  lm_head_indices = self.lm_head_indices_fake
201
360
 
202
- acl_inputs, acl_param = self._prepare_inputs(prefill=self.is_first_iteration, input_ids=input_ids,
203
- position_ids=position_ids, cos_embed=cos_embed,
204
- sin_embed=sin_embed, attention_mask=attention_mask,
205
- block_tables=block_tables, slots=slot_mapping,
206
- input_lengths=batch_valid_length, lm_head_indices=lm_head_indices,
207
- seqLen=seqLen)
361
+ if input_ids is not None and input_ids.dtype != mstype.int64:
362
+ input_ids = self.cast(input_ids, mstype.int64)
363
+ if position_ids is not None and position_ids.dtype != mstype.int64:
364
+ position_ids = self.cast(position_ids, mstype.int64)
365
+ if batch_valid_length is not None and batch_valid_length.dtype != mstype.int32:
366
+ batch_valid_length = self.cast(batch_valid_length, mstype.int32)
367
+ if lm_head_indices is not None and lm_head_indices.dtype != mstype.int64:
368
+ lm_head_indices = self.cast(lm_head_indices, mstype.int64)
369
+
370
+ acl_inputs, acl_param = self._prepare_inputs(
371
+ prefill=self.is_first_iteration,
372
+ input_ids=input_ids,
373
+ position_ids=position_ids,
374
+ cos_embed=cos_embed,
375
+ sin_embed=sin_embed,
376
+ attention_mask=attention_mask,
377
+ block_tables=block_tables,
378
+ slots=slot_mapping,
379
+ input_lengths=batch_valid_length,
380
+ lm_head_indices=lm_head_indices,
381
+ seqLen=seqLen,
382
+ )
208
383
  ms.hal.synchronize()
209
384
  logits = self._execute_operator(acl_inputs, acl_param)
210
385
  logits = self.cast(logits, mstype.float32)
@@ -15,10 +15,16 @@
15
15
  """llm boost"""
16
16
  import json
17
17
  import mindspore.common.dtype as mstype
18
- from mindspore.experimental.llm_boost.atb.boost_base import AtbBoostBase
18
+ from mindspore.experimental.llm_boost.atb.boost_base import (
19
+ AtbBoostBase,
20
+ PositionEmbeddingType,
21
+ NormType,
22
+ )
19
23
  from mindspore._c_expression import LlmBoostBinder
20
24
  from mindspore.experimental.llm_boost.register import LlmBoostRegister, LlmBoostType
21
25
 
26
+ CPP_LLAMA_MODEL_CLASS_NAME = "llama_LlamaDecoderModel"
27
+
22
28
 
23
29
  @LlmBoostRegister.register(LlmBoostType.BUILDIN, "Llama")
24
30
  class LlamaBoost(AtbBoostBase):
@@ -30,14 +36,21 @@ class LlamaBoost(AtbBoostBase):
30
36
  self.acl_encoder_operation_inputs = [None] * self.in_tensor_length
31
37
  self.acl_decoder_operation_inputs = [None] * self.in_tensor_length
32
38
  self.atb_encoder_operation = LlmBoostBinder(
33
- "ATB", "llama_parallel_DecoderModel")
39
+ self.backend_name, CPP_LLAMA_MODEL_CLASS_NAME
40
+ )
34
41
  self.atb_decoder_operation = LlmBoostBinder(
35
- "ATB", "llama_parallel_DecoderModel")
42
+ self.backend_name, CPP_LLAMA_MODEL_CLASS_NAME
43
+ )
36
44
 
37
45
  def init(self):
38
- """set param"""
46
+ """
47
+ Initialize the object
48
+ returns True if object needs input manipulation by mindformers
49
+ """
50
+
39
51
  coder_param = {
40
- "rmsNormEps": self.config.rms_norm_eps,
52
+ "normEps": self.config.rms_norm_eps,
53
+ "normType": NormType.RMS_NORM,
41
54
  "numAttentionHeadsPerRank": self.config.num_heads // self.device_num,
42
55
  "hiddenSizePerAttentionHead": self.head_dim,
43
56
  "numHiddenLayers": self.num_layers,
@@ -46,35 +59,45 @@ class LlamaBoost(AtbBoostBase):
46
59
  "isFA": False,
47
60
  "isBF16": self.dtype == mstype.bfloat16,
48
61
  "packQuantType": [[1, 1] for _ in range(self.num_layers)],
49
- "linearQuantType": [[0, -1, -1, 0, 0, -1, 0] for _ in range(self.num_layers)],
50
- "linearTransposeType": [[1, -1, -1, 1, 1, -1, 1] for i in range(self.num_layers)],
62
+ "linearQuantType": [
63
+ [0, -1, -1, 0, 0, -1, 0] for _ in range(self.num_layers)
64
+ ],
65
+ "linearTransposeType": [
66
+ [1, -1, -1, 1, 1, -1, 1] for i in range(self.num_layers)
67
+ ],
51
68
  "isEmbeddingParallel": False,
52
69
  "isLmHeadParallel": not self.config.parallel_config.vocab_emb_dp,
53
70
  "lmHeadTransposeType": 1,
54
- "supportSwiGLU": True,
55
- "kvQuant": self.kv_quant is not None,
71
+ "enableSwiGLU": True,
72
+ "enablekvQuant": self.kv_quant is not None,
56
73
  "rank": self.rank_id,
57
74
  "worldSize": self.device_num,
58
- "backend": "lccl",
75
+ "backend": self.config.communication_backend,
59
76
  "rankTableFile": "",
60
- "positionEmbeddingType": self.position_embedding_type,
77
+ "positionEmbeddingType": PositionEmbeddingType.ROPE,
61
78
  "hiddenSize": self.config.hidden_size,
62
79
  "gemma": False,
63
- "enableAddNorm": True,
64
- "supportCompressHead": False,
80
+ "enableAddNorm": False,
81
+ "enableCompressHead": False,
82
+ "isUnpadInputs": True,
65
83
  }
66
84
  encoder_param = {
67
- **coder_param, "isPrefill": True,
68
- "supportLcoc": True,
69
- "supportSpeculate": False,
70
- "skipWordEmbedding": False
85
+ **coder_param,
86
+ "isPrefill": True,
87
+ "enableLcoc": True,
88
+ "enableSpeculate": False,
89
+ "skipWordEmbedding": False,
90
+ "enableSplitFuse": False,
71
91
  }
72
92
  decoder_param = {
73
- **coder_param, "isPrefill": False, "supportLcoc": False,
74
- "supportSpeculate": False
93
+ **coder_param,
94
+ "isPrefill": False,
95
+ "enableLcoc": False,
96
+ "enableSpeculate": False,
75
97
  }
76
98
  self.atb_encoder_operation.init(json.dumps({**encoder_param}))
77
99
  self.atb_decoder_operation.init(json.dumps({**decoder_param}))
100
+ return True
78
101
 
79
102
  def _prepare_inputs(
80
103
  self,
@@ -92,14 +115,15 @@ class LlamaBoost(AtbBoostBase):
92
115
  **kwargs
93
116
  ):
94
117
  """prepare inputs"""
95
- self.acl_param = json.dumps({
96
- "seqLen": seqLen,
97
- })
98
- self.acl_decoder_operation_inputs[0] = self.cast(
99
- input_ids, mstype.int64)
118
+ self.acl_param = json.dumps(
119
+ {
120
+ "seqLen": seqLen,
121
+ }
122
+ )
123
+
124
+ self.acl_decoder_operation_inputs[0] = input_ids
100
125
  self.acl_decoder_operation_inputs[1] = self.placeholder
101
- self.acl_decoder_operation_inputs[2] = self.cast(
102
- position_ids, mstype.int32)
126
+ self.acl_decoder_operation_inputs[2] = position_ids
103
127
  self.acl_decoder_operation_inputs[3] = cos_embed
104
128
  self.acl_decoder_operation_inputs[4] = sin_embed
105
129
  self.acl_decoder_operation_inputs[5] = attention_mask
@@ -108,8 +132,6 @@ class LlamaBoost(AtbBoostBase):
108
132
  self.acl_decoder_operation_inputs[8] = self.placeholder
109
133
  self.acl_decoder_operation_inputs[9] = self.placeholder
110
134
  self.acl_decoder_operation_inputs[10] = self.placeholder
111
- self.acl_decoder_operation_inputs[11] = self.cast(
112
- input_lengths, mstype.int32)
113
- self.acl_decoder_operation_inputs[12] = self.cast(
114
- lm_head_indices, mstype.int64)
135
+ self.acl_decoder_operation_inputs[11] = input_lengths
136
+ self.acl_decoder_operation_inputs[12] = lm_head_indices
115
137
  return self.acl_decoder_operation_inputs, self.acl_param