mindspore 2.7.0rc1__cp310-cp310-win_amd64.whl → 2.7.1__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (370) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +5 -2
  3. mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
  4. mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
  5. mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
  6. mindspore/_checkparam.py +2 -2
  7. mindspore/_extends/builtin_operations.py +3 -3
  8. mindspore/_extends/parallel_compile/akg_compiler/custom.py +1109 -0
  9. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
  10. mindspore/_extends/parse/__init__.py +3 -3
  11. mindspore/_extends/parse/compile_config.py +24 -1
  12. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +6 -3
  13. mindspore/_extends/parse/parser.py +28 -22
  14. mindspore/_extends/parse/resources.py +1 -1
  15. mindspore/_extends/parse/standard_method.py +23 -2
  16. mindspore/_extends/parse/trope.py +2 -1
  17. mindspore/_extends/pijit/pijit_func_white_list.py +9 -27
  18. mindspore/amp.py +0 -18
  19. mindspore/avcodec-59.dll +0 -0
  20. mindspore/avdevice-59.dll +0 -0
  21. mindspore/avfilter-8.dll +0 -0
  22. mindspore/avformat-59.dll +0 -0
  23. mindspore/avutil-57.dll +0 -0
  24. mindspore/boost/base.py +29 -2
  25. mindspore/common/__init__.py +18 -12
  26. mindspore/common/_decorator.py +3 -2
  27. mindspore/common/_grad_function.py +3 -1
  28. mindspore/common/_tensor_cpp_method.py +1 -1
  29. mindspore/common/_tensor_docs.py +371 -96
  30. mindspore/common/_utils.py +7 -43
  31. mindspore/common/api.py +434 -135
  32. mindspore/common/dtype.py +98 -57
  33. mindspore/common/dump.py +7 -108
  34. mindspore/common/dynamic_shape/__init__.py +0 -0
  35. mindspore/common/{auto_dynamic_shape.py → dynamic_shape/auto_dynamic_shape.py} +15 -23
  36. mindspore/common/dynamic_shape/enable_dynamic.py +197 -0
  37. mindspore/common/file_system.py +59 -9
  38. mindspore/common/hook_handle.py +82 -3
  39. mindspore/common/jit_config.py +5 -1
  40. mindspore/common/jit_trace.py +27 -12
  41. mindspore/common/lazy_inline.py +5 -3
  42. mindspore/common/np_dtype.py +3 -3
  43. mindspore/common/parameter.py +17 -127
  44. mindspore/common/recompute.py +4 -13
  45. mindspore/common/tensor.py +50 -217
  46. mindspore/communication/_comm_helper.py +11 -1
  47. mindspore/communication/comm_func.py +138 -4
  48. mindspore/communication/management.py +85 -1
  49. mindspore/config/op_info.config +0 -15
  50. mindspore/context.py +20 -106
  51. mindspore/dataset/__init__.py +1 -1
  52. mindspore/dataset/audio/transforms.py +1 -1
  53. mindspore/dataset/core/config.py +35 -1
  54. mindspore/dataset/engine/datasets.py +338 -319
  55. mindspore/dataset/engine/datasets_user_defined.py +38 -22
  56. mindspore/dataset/engine/datasets_vision.py +1 -1
  57. mindspore/dataset/engine/validators.py +1 -15
  58. mindspore/dataset/transforms/c_transforms.py +2 -2
  59. mindspore/dataset/transforms/transforms.py +3 -3
  60. mindspore/dataset/vision/__init__.py +1 -1
  61. mindspore/dataset/vision/py_transforms.py +8 -8
  62. mindspore/dataset/vision/transforms.py +17 -5
  63. mindspore/dataset/vision/utils.py +632 -21
  64. mindspore/device_context/ascend/op_tuning.py +35 -1
  65. mindspore/dnnl.dll +0 -0
  66. mindspore/{profiler/common/validator → graph}/__init__.py +9 -1
  67. mindspore/graph/custom_pass.py +55 -0
  68. mindspore/include/api/cell.h +28 -4
  69. mindspore/include/api/cfg.h +24 -7
  70. mindspore/include/api/context.h +1 -0
  71. mindspore/include/api/delegate.h +0 -2
  72. mindspore/include/api/dual_abi_helper.h +100 -19
  73. mindspore/include/api/graph.h +14 -1
  74. mindspore/include/api/kernel.h +16 -3
  75. mindspore/include/api/kernel_api.h +9 -1
  76. mindspore/include/api/metrics/accuracy.h +9 -0
  77. mindspore/include/api/model.h +5 -1
  78. mindspore/include/api/model_group.h +4 -0
  79. mindspore/include/api/model_parallel_runner.h +2 -0
  80. mindspore/include/api/status.h +48 -10
  81. mindspore/include/api/types.h +6 -1
  82. mindspore/include/dataset/constants.h +9 -0
  83. mindspore/include/dataset/execute.h +2 -2
  84. mindspore/jpeg62.dll +0 -0
  85. mindspore/mindrecord/__init__.py +3 -3
  86. mindspore/mindrecord/common/exceptions.py +1 -0
  87. mindspore/mindrecord/config.py +1 -1
  88. mindspore/{parallel/mpi → mindrecord/core}/__init__.py +4 -1
  89. mindspore/mindrecord/{shardheader.py → core/shardheader.py} +2 -1
  90. mindspore/mindrecord/{shardindexgenerator.py → core/shardindexgenerator.py} +1 -1
  91. mindspore/mindrecord/{shardreader.py → core/shardreader.py} +2 -1
  92. mindspore/mindrecord/{shardsegment.py → core/shardsegment.py} +2 -2
  93. mindspore/mindrecord/{shardutils.py → core/shardutils.py} +1 -1
  94. mindspore/mindrecord/{shardwriter.py → core/shardwriter.py} +1 -1
  95. mindspore/mindrecord/filereader.py +4 -4
  96. mindspore/mindrecord/filewriter.py +5 -5
  97. mindspore/mindrecord/mindpage.py +2 -2
  98. mindspore/mindrecord/tools/cifar10.py +4 -3
  99. mindspore/mindrecord/tools/cifar100.py +1 -1
  100. mindspore/mindrecord/tools/cifar100_to_mr.py +1 -1
  101. mindspore/mindrecord/tools/cifar10_to_mr.py +6 -6
  102. mindspore/mindrecord/tools/csv_to_mr.py +1 -1
  103. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  104. mindspore/mindrecord/tools/mnist_to_mr.py +1 -1
  105. mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -1
  106. mindspore/mindspore_backend_common.dll +0 -0
  107. mindspore/mindspore_backend_manager.dll +0 -0
  108. mindspore/mindspore_cluster.dll +0 -0
  109. mindspore/mindspore_common.dll +0 -0
  110. mindspore/mindspore_core.dll +0 -0
  111. mindspore/mindspore_cpu.dll +0 -0
  112. mindspore/mindspore_dump.dll +0 -0
  113. mindspore/mindspore_frontend.dll +0 -0
  114. mindspore/mindspore_glog.dll +0 -0
  115. mindspore/mindspore_hardware_abstract.dll +0 -0
  116. mindspore/mindspore_memory_pool.dll +0 -0
  117. mindspore/mindspore_ms_backend.dll +0 -0
  118. mindspore/mindspore_ops.dll +0 -0
  119. mindspore/{mindspore_ops_host.dll → mindspore_ops_cpu.dll} +0 -0
  120. mindspore/mindspore_profiler.dll +0 -0
  121. mindspore/mindspore_pyboost.dll +0 -0
  122. mindspore/mindspore_pynative.dll +0 -0
  123. mindspore/mindspore_runtime_pipeline.dll +0 -0
  124. mindspore/mindspore_runtime_utils.dll +0 -0
  125. mindspore/mindspore_tools.dll +0 -0
  126. mindspore/mint/__init__.py +15 -10
  127. mindspore/mint/distributed/__init__.py +4 -0
  128. mindspore/mint/distributed/distributed.py +392 -69
  129. mindspore/mint/nn/__init__.py +2 -16
  130. mindspore/mint/nn/functional.py +4 -110
  131. mindspore/mint/nn/layer/__init__.py +0 -2
  132. mindspore/mint/nn/layer/_functions.py +1 -2
  133. mindspore/mint/nn/layer/activation.py +0 -6
  134. mindspore/mint/nn/layer/basic.py +0 -47
  135. mindspore/mint/nn/layer/conv.py +10 -10
  136. mindspore/mint/nn/layer/normalization.py +11 -16
  137. mindspore/mint/nn/layer/pooling.py +0 -4
  138. mindspore/nn/__init__.py +1 -3
  139. mindspore/nn/cell.py +231 -239
  140. mindspore/nn/layer/activation.py +4 -2
  141. mindspore/nn/layer/basic.py +56 -14
  142. mindspore/nn/layer/container.py +16 -0
  143. mindspore/nn/layer/embedding.py +4 -169
  144. mindspore/nn/layer/image.py +1 -1
  145. mindspore/nn/layer/normalization.py +2 -1
  146. mindspore/nn/layer/thor_layer.py +4 -85
  147. mindspore/nn/optim/ada_grad.py +0 -1
  148. mindspore/nn/optim/adafactor.py +0 -1
  149. mindspore/nn/optim/adam.py +32 -127
  150. mindspore/nn/optim/adamax.py +0 -1
  151. mindspore/nn/optim/asgd.py +0 -1
  152. mindspore/nn/optim/ftrl.py +8 -102
  153. mindspore/nn/optim/lamb.py +1 -4
  154. mindspore/nn/optim/lars.py +0 -3
  155. mindspore/nn/optim/lazyadam.py +25 -218
  156. mindspore/nn/optim/momentum.py +5 -43
  157. mindspore/nn/optim/optimizer.py +6 -55
  158. mindspore/nn/optim/proximal_ada_grad.py +0 -1
  159. mindspore/nn/optim/rmsprop.py +0 -1
  160. mindspore/nn/optim/rprop.py +0 -1
  161. mindspore/nn/optim/sgd.py +0 -1
  162. mindspore/nn/optim/tft_wrapper.py +2 -4
  163. mindspore/nn/optim/thor.py +0 -2
  164. mindspore/nn/probability/bijector/bijector.py +7 -8
  165. mindspore/nn/probability/bijector/gumbel_cdf.py +2 -2
  166. mindspore/nn/probability/bijector/power_transform.py +20 -21
  167. mindspore/nn/probability/bijector/scalar_affine.py +5 -5
  168. mindspore/nn/probability/bijector/softplus.py +13 -14
  169. mindspore/nn/probability/distribution/_utils/utils.py +2 -2
  170. mindspore/nn/wrap/cell_wrapper.py +39 -5
  171. mindspore/nn/wrap/grad_reducer.py +4 -89
  172. mindspore/numpy/array_creations.py +4 -4
  173. mindspore/numpy/fft.py +9 -9
  174. mindspore/numpy/utils_const.py +1 -1
  175. mindspore/{nn/reinforcement → onnx}/__init__.py +5 -8
  176. mindspore/onnx/onnx_export.py +137 -0
  177. mindspore/opencv_core4110.dll +0 -0
  178. mindspore/opencv_imgcodecs4110.dll +0 -0
  179. mindspore/{opencv_imgproc452.dll → opencv_imgproc4110.dll} +0 -0
  180. mindspore/ops/__init__.py +2 -0
  181. mindspore/ops/_grad_experimental/grad_comm_ops.py +38 -2
  182. mindspore/ops/_grad_experimental/grad_inner_ops.py +0 -9
  183. mindspore/ops/_op_impl/aicpu/__init__.py +0 -10
  184. mindspore/ops/_op_impl/cpu/__init__.py +1 -5
  185. mindspore/ops/_op_impl/cpu/{buffer_append.py → joinedstr_op.py} +8 -8
  186. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +28 -24
  187. mindspore/ops/auto_generate/gen_extend_func.py +6 -11
  188. mindspore/ops/auto_generate/gen_ops_def.py +385 -154
  189. mindspore/ops/auto_generate/gen_ops_prim.py +5676 -5167
  190. mindspore/ops/communication.py +97 -0
  191. mindspore/ops/composite/__init__.py +5 -2
  192. mindspore/ops/composite/base.py +16 -2
  193. mindspore/ops/composite/multitype_ops/__init__.py +3 -1
  194. mindspore/ops/composite/multitype_ops/_compile_utils.py +150 -8
  195. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -1
  196. mindspore/ops/composite/multitype_ops/add_impl.py +7 -0
  197. mindspore/ops/composite/multitype_ops/mod_impl.py +27 -0
  198. mindspore/ops/function/__init__.py +2 -0
  199. mindspore/ops/function/array_func.py +24 -18
  200. mindspore/ops/function/comm_func.py +3883 -0
  201. mindspore/ops/function/debug_func.py +7 -6
  202. mindspore/ops/function/grad/grad_func.py +4 -12
  203. mindspore/ops/function/math_func.py +89 -86
  204. mindspore/ops/function/nn_func.py +92 -313
  205. mindspore/ops/function/random_func.py +9 -18
  206. mindspore/ops/functional.py +4 -1
  207. mindspore/ops/functional_overload.py +377 -30
  208. mindspore/ops/operations/__init__.py +2 -5
  209. mindspore/ops/operations/_custom_ops_utils.py +7 -9
  210. mindspore/ops/operations/_inner_ops.py +12 -50
  211. mindspore/ops/operations/_rl_inner_ops.py +0 -933
  212. mindspore/ops/operations/array_ops.py +5 -50
  213. mindspore/ops/operations/comm_ops.py +95 -17
  214. mindspore/ops/operations/custom_ops.py +237 -22
  215. mindspore/ops/operations/debug_ops.py +33 -35
  216. mindspore/ops/operations/manually_defined/ops_def.py +39 -318
  217. mindspore/ops/operations/math_ops.py +5 -5
  218. mindspore/ops/operations/nn_ops.py +3 -3
  219. mindspore/ops/operations/sparse_ops.py +0 -83
  220. mindspore/ops/primitive.py +4 -27
  221. mindspore/ops/tensor_method.py +88 -10
  222. mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +5 -5
  223. mindspore/ops_generate/aclnn/gen_aclnn_implement.py +8 -8
  224. mindspore/ops_generate/api/functions_cc_generator.py +53 -4
  225. mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +25 -11
  226. mindspore/ops_generate/common/gen_constants.py +11 -10
  227. mindspore/ops_generate/common/op_proto.py +18 -1
  228. mindspore/ops_generate/common/template.py +102 -245
  229. mindspore/ops_generate/common/template_utils.py +212 -0
  230. mindspore/ops_generate/gen_custom_ops.py +69 -0
  231. mindspore/ops_generate/op_def/ops_def_cc_generator.py +78 -7
  232. mindspore/ops_generate/op_def_py/base_op_prim_py_generator.py +360 -0
  233. mindspore/ops_generate/op_def_py/custom_op_prim_py_generator.py +140 -0
  234. mindspore/ops_generate/op_def_py/op_def_py_generator.py +54 -7
  235. mindspore/ops_generate/op_def_py/op_prim_py_generator.py +5 -312
  236. mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +74 -17
  237. mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +22 -5
  238. mindspore/ops_generate/pyboost/gen_pyboost_func.py +0 -16
  239. mindspore/ops_generate/pyboost/op_template_parser.py +3 -2
  240. mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +21 -5
  241. mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +2 -2
  242. mindspore/ops_generate/pyboost/pyboost_functions_impl_cpp_generator.py +30 -10
  243. mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +10 -3
  244. mindspore/ops_generate/pyboost/pyboost_internal_kernel_info_adapter_generator.py +1 -1
  245. mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +19 -9
  246. mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +71 -28
  247. mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +10 -9
  248. mindspore/ops_generate/pyboost/pyboost_utils.py +27 -16
  249. mindspore/ops_generate/resources/yaml_loader.py +13 -0
  250. mindspore/ops_generate/tensor_py_cc_generator.py +2 -2
  251. mindspore/parallel/_auto_parallel_context.py +5 -15
  252. mindspore/parallel/_cell_wrapper.py +1 -1
  253. mindspore/parallel/_parallel_serialization.py +4 -6
  254. mindspore/parallel/_ps_context.py +2 -2
  255. mindspore/parallel/_utils.py +34 -17
  256. mindspore/parallel/auto_parallel.py +23 -9
  257. mindspore/parallel/checkpoint_transform.py +20 -2
  258. mindspore/parallel/cluster/process_entity/_api.py +28 -33
  259. mindspore/parallel/cluster/process_entity/_utils.py +9 -5
  260. mindspore/parallel/cluster/run.py +5 -3
  261. mindspore/{experimental/llm_boost/ascend_native → parallel/distributed}/__init__.py +21 -22
  262. mindspore/parallel/distributed/distributed_data_parallel.py +393 -0
  263. mindspore/parallel/distributed/flatten_grad_buffer.py +295 -0
  264. mindspore/parallel/function/reshard_func.py +6 -5
  265. mindspore/parallel/nn/parallel_cell_wrapper.py +40 -3
  266. mindspore/parallel/nn/parallel_grad_reducer.py +0 -8
  267. mindspore/parallel/shard.py +7 -21
  268. mindspore/parallel/strategy.py +336 -0
  269. mindspore/parallel/transform_safetensors.py +127 -20
  270. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +13 -9
  271. mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +1 -1
  272. mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +1 -1
  273. mindspore/profiler/common/constant.py +5 -0
  274. mindspore/profiler/common/file_manager.py +9 -0
  275. mindspore/profiler/common/msprof_cmd_tool.py +40 -4
  276. mindspore/profiler/common/path_manager.py +65 -24
  277. mindspore/profiler/common/profiler_context.py +27 -14
  278. mindspore/profiler/common/profiler_info.py +3 -3
  279. mindspore/profiler/common/profiler_meta_data.py +1 -0
  280. mindspore/profiler/common/profiler_op_analyse.py +10 -6
  281. mindspore/profiler/common/profiler_path_manager.py +13 -0
  282. mindspore/profiler/common/util.py +30 -3
  283. mindspore/profiler/dynamic_profiler.py +91 -46
  284. mindspore/profiler/envprofiler.py +30 -5
  285. mindspore/profiler/experimental_config.py +18 -2
  286. mindspore/profiler/platform/cpu_profiler.py +10 -4
  287. mindspore/profiler/platform/npu_profiler.py +34 -7
  288. mindspore/profiler/profiler.py +193 -145
  289. mindspore/profiler/profiler_action_controller.py +1 -1
  290. mindspore/profiler/profiler_interface.py +2 -2
  291. mindspore/rewrite/symbol_tree/symbol_tree.py +1 -1
  292. mindspore/run_check/_check_version.py +108 -24
  293. mindspore/runtime/__init__.py +9 -6
  294. mindspore/runtime/executor.py +35 -0
  295. mindspore/runtime/memory.py +113 -0
  296. mindspore/runtime/thread_bind_core.py +1 -1
  297. mindspore/swresample-4.dll +0 -0
  298. mindspore/swscale-6.dll +0 -0
  299. mindspore/tinyxml2.dll +0 -0
  300. mindspore/{experimental/llm_boost → tools}/__init__.py +5 -5
  301. mindspore/tools/data_dump.py +130 -0
  302. mindspore/tools/sdc_detect.py +91 -0
  303. mindspore/tools/stress_detect.py +63 -0
  304. mindspore/train/__init__.py +6 -6
  305. mindspore/train/_utils.py +8 -21
  306. mindspore/train/amp.py +6 -7
  307. mindspore/train/callback/_callback.py +2 -1
  308. mindspore/train/callback/_checkpoint.py +1 -17
  309. mindspore/train/callback/_flops_collector.py +10 -6
  310. mindspore/train/callback/_train_fault_tolerance.py +72 -25
  311. mindspore/train/data_sink.py +5 -9
  312. mindspore/train/dataset_helper.py +5 -5
  313. mindspore/train/model.py +41 -230
  314. mindspore/train/serialization.py +160 -401
  315. mindspore/train/train_thor/model_thor.py +2 -2
  316. mindspore/turbojpeg.dll +0 -0
  317. mindspore/utils/__init__.py +6 -3
  318. mindspore/utils/dlpack.py +92 -0
  319. mindspore/utils/dryrun.py +1 -1
  320. mindspore/utils/runtime_execution_order_check.py +10 -0
  321. mindspore/utils/sdc_detect.py +14 -12
  322. mindspore/utils/stress_detect.py +43 -0
  323. mindspore/utils/utils.py +152 -16
  324. mindspore/version.py +1 -1
  325. {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/METADATA +3 -2
  326. {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/RECORD +330 -344
  327. mindspore/_extends/remote/kernel_build_server_ascend.py +0 -75
  328. mindspore/communication/_hccl_management.py +0 -297
  329. mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +0 -207
  330. mindspore/experimental/llm_boost/ascend_native/llm_boost.py +0 -52
  331. mindspore/experimental/llm_boost/atb/__init__.py +0 -23
  332. mindspore/experimental/llm_boost/atb/boost_base.py +0 -385
  333. mindspore/experimental/llm_boost/atb/llama_boost.py +0 -137
  334. mindspore/experimental/llm_boost/atb/qwen_boost.py +0 -124
  335. mindspore/experimental/llm_boost/register.py +0 -130
  336. mindspore/experimental/llm_boost/utils.py +0 -31
  337. mindspore/include/OWNERS +0 -7
  338. mindspore/mindspore_cpu_res_manager.dll +0 -0
  339. mindspore/mindspore_ops_kernel_common.dll +0 -0
  340. mindspore/mindspore_res_manager.dll +0 -0
  341. mindspore/nn/optim/_dist_optimizer_registry.py +0 -111
  342. mindspore/nn/reinforcement/_batch_read_write.py +0 -142
  343. mindspore/nn/reinforcement/_tensors_queue.py +0 -152
  344. mindspore/nn/reinforcement/tensor_array.py +0 -145
  345. mindspore/opencv_core452.dll +0 -0
  346. mindspore/opencv_imgcodecs452.dll +0 -0
  347. mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +0 -113
  348. mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +0 -96
  349. mindspore/ops/_op_impl/aicpu/sparse_cross.py +0 -42
  350. mindspore/ops/_op_impl/cpu/buffer_get.py +0 -28
  351. mindspore/ops/_op_impl/cpu/buffer_sample.py +0 -28
  352. mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +0 -42
  353. mindspore/ops/operations/_tensor_array.py +0 -359
  354. mindspore/ops/operations/rl_ops.py +0 -288
  355. mindspore/parallel/_offload_context.py +0 -275
  356. mindspore/parallel/_recovery_context.py +0 -115
  357. mindspore/parallel/_transformer/__init__.py +0 -35
  358. mindspore/parallel/_transformer/layers.py +0 -765
  359. mindspore/parallel/_transformer/loss.py +0 -251
  360. mindspore/parallel/_transformer/moe.py +0 -693
  361. mindspore/parallel/_transformer/op_parallel_config.py +0 -222
  362. mindspore/parallel/_transformer/transformer.py +0 -3124
  363. mindspore/parallel/mpi/_mpi_config.py +0 -116
  364. mindspore/profiler/common/validator/validate_path.py +0 -84
  365. mindspore/train/memory_profiling_pb2.py +0 -298
  366. mindspore/utils/hooks.py +0 -81
  367. /mindspore/common/{_auto_dynamic.py → dynamic_shape/_auto_dynamic.py} +0 -0
  368. {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/WHEEL +0 -0
  369. {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/entry_points.txt +0 -0
  370. {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/top_level.txt +0 -0
@@ -1,75 +0,0 @@
1
- # Copyright 2020-2021 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ============================================================================
15
- """kernel build server for ascend"""
16
- import sys
17
- import warnings
18
- import json
19
-
20
- from mindspore._extends.parallel_compile.tbe_compiler.tbe_job_manager import TbeJobManager
21
- from mindspore._extends.remote.kernel_build_server import Messager, get_logger, AkgBuilder
22
-
23
-
24
- class AscendMessager(Messager):
25
- """
26
- Ascend Messager
27
- It works as a server, communicating with c++ client.
28
- """
29
-
30
- def __init__(self, fdin, fdout):
31
- super().__init__(fdin, fdout)
32
- get_logger().info("[TRACE] Ascend Messager init...")
33
- self.tbe_builder = TbeJobManager()
34
- self.akg_builder = AkgBuilder("ASCEND")
35
-
36
- def handle(self):
37
- """
38
- Communicate with remote client.
39
- Reference protocol between them at PR#3821 and PR#3935
40
- """
41
- arg = self.get_message()
42
- if arg.startswith('AKG'):
43
- self.akg_builder.handle(self, arg)
44
- else:
45
- job_json = dict()
46
- try:
47
- job_json = json.loads(arg)
48
- except json.decoder.JSONDecodeError:
49
- get_logger().error("[TRACE] Request is not a json message: {}".format(arg))
50
- self.send_ack(False)
51
- self.exit()
52
- finally:
53
- pass
54
-
55
- if "job_type" in job_json:
56
- res = self.tbe_builder.job_handler(arg)
57
- self.send_res(res)
58
- else:
59
- get_logger().error("[TRACE] Request is not a TBE Job message: {}".format(arg))
60
- self.send_ack(False)
61
- self.exit()
62
-
63
- def exit(self):
64
- self.tbe_builder.reset()
65
- get_logger().info("[TRACE] Ascend Messager Exit...")
66
- exit()
67
-
68
-
69
- if __name__ == '__main__':
70
- warnings.simplefilter("ignore")
71
- if len(sys.argv) != 3:
72
- raise Exception('Incorrect argv: {}'.format(sys.argv))
73
- get_logger().debug(f"[TRACE] argv: {str(sys.argv)}")
74
- messager = AscendMessager(int(sys.argv[1]), int(sys.argv[2]))
75
- messager.run()
@@ -1,297 +0,0 @@
1
- # Copyright 2020 Huawei Technologies Co., Ltd
2
-
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- # ============================================================================
16
- """HCCL management API"""
17
- from __future__ import absolute_import
18
- from __future__ import division
19
-
20
- import ctypes
21
- import os
22
-
23
- from mindspore import context
24
- from mindspore._c_expression import get_hccl_rank_id, get_hccl_rank_size
25
-
26
- MAX_GROUP_NAME_LEN = 127
27
- MAX_RANK_NUM = 4096
28
- HCCL_LIB = 'libhccl_plugin.so'
29
- HCCL_LIB_CTYPES = ""
30
-
31
-
32
- def check_group(group):
33
- """
34
- A function that check if a collection communication group is legal.
35
-
36
- Returns:
37
- None
38
- """
39
- if isinstance(group, (str)):
40
- group_len = len(group)
41
- if group_len > MAX_GROUP_NAME_LEN or group_len == 0:
42
- raise ValueError("The length of communication group name must be in range [1, 127), "
43
- "but got the value : {} ".format(group_len))
44
- else:
45
- raise TypeError("The type of communication group name must be type of string, "
46
- "but got 'group' type : {}.".format(type(group)))
47
-
48
-
49
- def check_rank_num(rank_num):
50
- """
51
- A function that check if a collection communication rank number is legal.If not raise error.
52
-
53
- Returns:
54
- None
55
- """
56
- if isinstance(rank_num, (int)):
57
- if rank_num > MAX_RANK_NUM or rank_num <= 0:
58
- raise ValueError("For 'create_group', the size of argument 'rand_ids' should be greater than 0 and"
59
- "less than {}, but got the size of 'rank_ids' : {}.".format(MAX_RANK_NUM, rank_num))
60
- else:
61
- raise TypeError("The argument 'rank_num' must be type of int, "
62
- "but got 'rank_num' type : {}.".format(type(rank_num)))
63
-
64
-
65
- def check_rank_id(rank_id):
66
- """
67
- A function that check if a collection communication rank id is legal.If not raise error.
68
-
69
- Returns:
70
- None
71
- """
72
- if isinstance(rank_id, (int)):
73
- if rank_id >= MAX_RANK_NUM or rank_id < 0:
74
- raise ValueError("The rand id in the communication group must be greater or equal 0 and "
75
- "less than {}, but got type value : {}.".format(MAX_RANK_NUM, rank_id))
76
- else:
77
- raise TypeError("The rand id in the communication group must be must be type of int, "
78
- "but got type value : {}.".format(type(rank_id)))
79
-
80
-
81
- def load_lib():
82
- """load hccl lib"""
83
- try:
84
- base_dir = os.path.dirname(os.path.realpath(__file__))
85
- lib_path = os.path.join(base_dir, "../lib/plugin/ascend", HCCL_LIB)
86
- hccl_lib = ctypes.CDLL(lib_path)
87
- except Exception:
88
- raise RuntimeError('Get hccl lib error.')
89
-
90
- global HCCL_LIB_CTYPES
91
- HCCL_LIB_CTYPES = hccl_lib
92
-
93
-
94
- def c_str(string):
95
- """Convert a python string to C string."""
96
- if not isinstance(string, str):
97
- string = string.decode('ascii')
98
- return ctypes.c_char_p(string.encode('utf-8'))
99
-
100
-
101
- def c_array(ctype, values):
102
- """Create ctypes array from a python array."""
103
- return (ctype * len(values))(*values)
104
-
105
-
106
- def create_group(group, rank_num, rank_ids):
107
- """
108
- Create group.
109
-
110
- A function that creates a collection communication group which includes 'rank_num'
111
- device and 'rank_ids' is the list of these ranks of devices.
112
-
113
- Note:
114
- The world group can not be created.
115
-
116
- Returns:
117
- None
118
- """
119
- check_group(group)
120
- check_rank_num(rank_num)
121
- if isinstance(rank_ids, (list)):
122
- if rank_num != len(rank_ids):
123
- raise ValueError("The argument 'rank_num' number should be equal to the length "
124
- "of rank_ids, but got 'rank_num' value : {} and 'rank_ids' value : {}."
125
- .format(rank_num, rank_ids))
126
- for rank_id in rank_ids:
127
- if not isinstance(rank_id, (int)) or rank_id < 0:
128
- raise ValueError("The elements of argument 'rank_ids' must be "
129
- "unsigned integer, but got the type : {}".format(type(rank_id)))
130
- c_array_rank_ids = c_array(ctypes.c_uint, rank_ids)
131
- c_rank_num = ctypes.c_uint(rank_num)
132
- c_group = c_str(group)
133
- ret = HCCL_LIB_CTYPES.HcomCreateGroup(c_group, c_rank_num, c_array_rank_ids)
134
- if ret != 0:
135
- raise RuntimeError('Create group error, the error code is {}.'.format(ret))
136
- else:
137
- raise TypeError("For 'create_group', the argument 'rank_ids' must be type of list, "
138
- "but got 'rank_ids' type : {}.".format(type(rank_ids)))
139
-
140
-
141
- def destroy_group(group):
142
- """
143
- A function that destroy the group which created by user.
144
-
145
- Note:
146
- The world group can not be destroy.
147
-
148
- Returns:
149
- None
150
- """
151
- check_group(group)
152
- c_group = c_str(group)
153
- ret = HCCL_LIB_CTYPES.HcomDestroyGroup(c_group)
154
- if ret != 0:
155
- raise RuntimeError('Destroy group error.')
156
-
157
-
158
- def get_rank_size(group="hccl_world_group"):
159
- """
160
- A function that returns the number of ranks within the given collection communication group.
161
-
162
- Note:
163
- The default group is hccl_world_group.
164
-
165
- Returns:
166
- An integer scalar with the num of ranks.
167
- """
168
-
169
- if context.get_context("mode") == context.PYNATIVE_MODE:
170
- return get_hccl_rank_size()
171
-
172
- check_group(group)
173
- c_group = c_str(group)
174
- c_rank_size = ctypes.c_uint()
175
- ret = HCCL_LIB_CTYPES.HcomGetRankSize(c_group, ctypes.byref(c_rank_size))
176
- if ret != 0:
177
- raise RuntimeError('Get rank size error.')
178
-
179
- return c_rank_size.value
180
-
181
-
182
- def get_rank_id(group="hccl_world_group"):
183
- """
184
- A function that returns the rank id of the calling process, within the given collection communication group.
185
-
186
- Returns:
187
- An integer scalar with the rank id of the calling process.
188
- """
189
-
190
- if context.get_context("mode") == context.PYNATIVE_MODE:
191
- return get_hccl_rank_id()
192
-
193
- check_group(group)
194
- c_group = c_str(group)
195
- c_rank_id = ctypes.c_uint()
196
- ret = HCCL_LIB_CTYPES.HcomGetRankId(c_group, ctypes.byref(c_rank_id))
197
- if ret != 0:
198
- raise RuntimeError('Get rank id error.')
199
-
200
- return c_rank_id.value
201
-
202
-
203
-
204
- def get_local_rank_size(group="hccl_world_group"):
205
- """
206
- A function that returns the number of local ranks within the given collection communication group.
207
-
208
- Note:
209
- The default group is hccl_world_group.
210
-
211
- Returns:
212
- An integer scalar with the num of local ranks.
213
- """
214
- if context.get_context("mode") is context.PYNATIVE_MODE:
215
- raise RuntimeError("The function 'get_local_rank_size' is not supported in PYNATIVE_MODE, "
216
- "'get_local_rank_size' only support GRAPH_MODE")
217
- check_group(group)
218
- c_group = c_str(group)
219
- c_local_rank_size = ctypes.c_uint()
220
- ret = HCCL_LIB_CTYPES.HcomGetLocalRankSize(c_group, ctypes.byref(c_local_rank_size))
221
- if ret != 0:
222
- raise RuntimeError('Get local rank size error.')
223
-
224
- return c_local_rank_size.value
225
-
226
-
227
- def get_local_rank_id(group="hccl_world_group"):
228
- """
229
- Get local rank id.
230
-
231
- A function that returns the local rank id of the calling process, within the given collection communication group.
232
-
233
- Returns:
234
- An integer scalar with the local rank id of the calling process.
235
- """
236
-
237
- if context.get_context("mode") is context.PYNATIVE_MODE:
238
- raise RuntimeError("The function 'get_local_rank_id' is not supported in PYNATIVE_MODE, "
239
- "'get_local_rank_id' only support GRAPH_MODE")
240
- check_group(group)
241
- c_group = c_str(group)
242
- c_local_rank_id = ctypes.c_uint()
243
- ret = HCCL_LIB_CTYPES.HcomGetLocalRankId(c_group, ctypes.byref(c_local_rank_id))
244
- if ret != 0:
245
- raise RuntimeError('Get local rank id error.')
246
-
247
- return c_local_rank_id.value
248
-
249
-
250
- def get_world_rank_from_group_rank(group, group_rank_id):
251
- """
252
- Get world rank from group rank.
253
-
254
- A function that returns the rank id in the world group corresponding to the
255
- rank which id is 'group_rank_id' in the user group.
256
-
257
- Returns:
258
- An integer scalar with the rank id in the world group.
259
- """
260
- if context.get_context("mode") is context.PYNATIVE_MODE:
261
- raise RuntimeError("The function 'get_world_rank_from_group_rank' is not supported in PYNATIVE_MODE, "
262
- "'get_world_rank_from_group_rank' only support GRAPH_MODE")
263
- check_group(group)
264
- check_rank_id(group_rank_id)
265
- c_group = c_str(group)
266
- c_group_rank_id = ctypes.c_uint(group_rank_id)
267
- c_world_rank_id = ctypes.c_uint()
268
- ret = HCCL_LIB_CTYPES.HcomGetWorldRankFromGroupRank(c_group, c_group_rank_id, ctypes.byref(c_world_rank_id))
269
- if ret != 0:
270
- raise RuntimeError('Get world rank from group rank error.')
271
-
272
- return c_world_rank_id.value
273
-
274
-
275
- def get_group_rank_from_world_rank(world_rank_id, group):
276
- """
277
- Get group rank from world rank.
278
-
279
- A function that returns the rank id in the user group corresponding to the
280
- rank which id is 'world_rank_id' in the world group.
281
-
282
- Returns:
283
- An integer scalar with the rank id in the user group.
284
- """
285
- if context.get_context("mode") is context.PYNATIVE_MODE:
286
- raise RuntimeError("The function 'get_group_rank_from_world_rank' is not supported in PYNATIVE_MODE, "
287
- "'get_group_rank_from_world_rank' only support GRAPH_MODE")
288
- check_group(group)
289
- check_rank_id(world_rank_id)
290
- c_group = c_str(group)
291
- c_world_rank_id = ctypes.c_uint(world_rank_id)
292
- c_group_rank_id = ctypes.c_uint()
293
- ret = HCCL_LIB_CTYPES.HcomGetGroupRankFromWorldRank(c_world_rank_id, c_group, ctypes.byref(c_group_rank_id))
294
- if ret != 0:
295
- raise RuntimeError('Get group rank from world rank error.')
296
-
297
- return c_group_rank_id.value
@@ -1,207 +0,0 @@
1
- # Copyright 2024 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ============================================================================
15
- """AscendNative Llama Boost APIs."""
16
-
17
- import os
18
- import numpy as np
19
- from mindspore.common import Tensor, dtype
20
- from mindspore.experimental.llm_boost.ascend_native.llm_boost import LLMBoost
21
-
22
- def RoundUp(val: int, align: int) -> int:
23
- if align == 0:
24
- return 0
25
- return -(val // -align) * align
26
-
27
-
28
- def ConvertTensor(nd_mat: np.ndarray, transpose: bool = True, nd2nz: bool = True) -> np.ndarray:
29
- """ Transforms tensor format from Nd to Nz """
30
- if transpose:
31
- nd_mat = np.transpose(nd_mat)
32
- if not nd2nz:
33
- return nd_mat
34
- block_size = (16, 16)
35
- r = RoundUp(nd_mat.shape[0], block_size[0])
36
- c = RoundUp(nd_mat.shape[1], block_size[1])
37
- r_pad = r - nd_mat.shape[0]
38
- c_pad = c - nd_mat.shape[1]
39
- nd_mat = np.pad(nd_mat, ((0, r_pad), (0, c_pad)))
40
- nz_mat = np.transpose(np.reshape(
41
- nd_mat, (r, c // block_size[1], block_size[1])), (1, 0, 2))
42
- nz_mat = nz_mat.reshape(r, c)
43
- return nz_mat
44
-
45
- class LlamaBoostAscendNative(LLMBoost):
46
- r"""
47
- Implements an Llama model in a single kernel.
48
- it forwards the python functions to the C++ binded object
49
- """
50
- def _get_from_dict(self, dictionary, name):
51
- """ internal function to get a specific tensor from the dictionary """
52
- all_relevant_layers = [value for key, value in dictionary.items() if name in key]
53
- if all_relevant_layers:
54
- return all_relevant_layers[0].asnumpy()
55
- return None
56
-
57
- def _get_quant_triplet_from_dict(self, dictionary, name):
58
- """ internal function to get a weight triple tensor from the dictionary """
59
- weights = self._get_from_dict(dictionary, name + "._handler.weight")
60
- scale = self._get_from_dict(dictionary, name + "._weight_quantizer.scale")
61
- offset = self._get_from_dict(dictionary, name + "._weight_quantizer.zp_neg")
62
- return weights, scale, offset
63
-
64
- def _prepare_single_layer(self, ckpt, config, id):
65
- """ prepares the dictionary of weights of a single layer """
66
- prefix = 'model.layers.' + str(id)
67
- is_last = id == config.num_layers-1
68
- layer = 'layers.' + str(id) + '.'
69
- l_dict = {key: value for key, value in ckpt.items() if layer in key}
70
- if config.n_kv_heads is None:
71
- config.n_kv_heads = config.num_heads
72
- start = 0
73
- end = config.hidden_size
74
- kv_start = 0
75
- kv_end = int(config.hidden_size*config.n_kv_heads/config.num_heads)
76
- ffn_hid = [value for key, value in l_dict.items() if "w3" in key][0].shape[0]
77
- ffn_start = 0
78
- ffn_end = ffn_hid
79
- rank_size = int(os.getenv('RANK_SIZE', '1'))
80
- #Emir if (config.parallel_mode != 2): # 2 - AUTO_PARALLEL
81
- hid_size = end
82
- kv_hid_size = kv_end
83
- embed_size = config.vocab_size
84
- rank_id = int(os.getenv('RANK_ID', '0'))
85
- if (hid_size % rank_size == 0) and (ffn_hid % rank_size == 0) and (embed_size % rank_size == 0):
86
- start = int(rank_id * hid_size / rank_size)
87
- end = int((rank_id + 1) * hid_size / rank_size)
88
- kv_start = int(rank_id * kv_hid_size / rank_size)
89
- kv_end = int((rank_id + 1) * kv_hid_size / rank_size)
90
- ffn_start = int(rank_id * ffn_hid / rank_size)
91
- ffn_end = int((rank_id + 1) * ffn_hid / rank_size)
92
- else:
93
- raise RuntimeError("hidden size and ffn hidden size must be divided by rank size without remainder. \
94
- hidden_size: ", hid_size, " ffn_hidden_size: ", ffn_hid, " rank_size: ", rank_size)
95
- quant = self._get_from_dict(l_dict, "_weight_quantizer") is not None
96
- unite_qkv = config.num_heads == config.n_kv_heads
97
- self.dictionary[prefix + ".attention_norm.weight"] = \
98
- Tensor(self._get_from_dict(l_dict, "attention_norm"), dtype=dtype.float16)
99
- self.dictionary[prefix + ".ffn_norm.weight"] = \
100
- Tensor(self._get_from_dict(l_dict, "ffn_norm"), dtype=dtype.float16)
101
- if is_last:
102
- self.dictionary['lm_head.weight'] = Tensor(ConvertTensor(ckpt['lm_head.weight'].asnumpy()[:, start:end]))
103
-
104
- if not quant:
105
- self._pack_attn_weights(l_dict, prefix, start, end, kv_start, kv_end, unite_qkv)
106
- self._pack_ffn_weights(l_dict, prefix, ffn_start, ffn_end)
107
- else:
108
- self._pack_attn_quant_weights(l_dict, prefix, start, end, kv_start, kv_end, unite_qkv)
109
- self._pack_ffn_quant_weights(l_dict, prefix, ffn_start, ffn_end)
110
-
111
- def _pack_attn_weights(self, l_dict, prefix, start, end, kv_start, kv_end, unite_qkv):
112
- """ prepares the dictionary of weights of an attention block """
113
- wq = self._get_from_dict(l_dict, "wq")[start:end, :]
114
- wk = self._get_from_dict(l_dict, "wk")[kv_start:kv_end, :]
115
- wv = self._get_from_dict(l_dict, "wv")[kv_start:kv_end, :]
116
- self.dictionary[prefix + ".attention.wo.weight"] = \
117
- Tensor(ConvertTensor(self._get_from_dict(l_dict, "wo")[:, start:end]))
118
- if unite_qkv:
119
- self.dictionary[prefix + ".attention.wqkv.weight"] = Tensor(ConvertTensor(np.concatenate((wq, wk, wv))))
120
- else:
121
- self.dictionary[prefix + ".attention.wq.weight"] = Tensor(ConvertTensor(wq))
122
- self.dictionary[prefix + ".attention.wkv.weight"] = Tensor(ConvertTensor(np.concatenate((wk, wv))))
123
-
124
- def _pack_ffn_weights(self, l_dict, prefix, ffn_start, ffn_end):
125
- """ prepares the dictionary of weights of an ffn block """
126
- self.dictionary[prefix + ".feed_forward.w2.weight"] = \
127
- Tensor(ConvertTensor(self._get_from_dict(l_dict, "w2")[:, ffn_start:ffn_end]))
128
- w1 = self._get_from_dict(l_dict, "w1")[ffn_start:ffn_end, :]
129
- w3 = self._get_from_dict(l_dict, "w3")[ffn_start:ffn_end, :]
130
- self.dictionary[prefix + ".feed_forward.w13.weight"] = Tensor(ConvertTensor(np.concatenate((w1, w3))))
131
-
132
- def _pack_attn_quant_weights(self, l_dict, prefix, start, end, kv_start, kv_end, unite_qkv):
133
- """ prepares the dictionary of weights of a quantized attention block """
134
- wq, wq_scale, wq_offset = self._get_quant_triplet_from_dict(l_dict, "wq")
135
- wk, wk_scale, wk_offset = self._get_quant_triplet_from_dict(l_dict, "wk")
136
- wv, wv_scale, wv_offset = self._get_quant_triplet_from_dict(l_dict, "wv")
137
- wo, wo_scale, wo_offset = self._get_quant_triplet_from_dict(l_dict, "wo")
138
- self.dictionary[prefix + ".attention.wo.weight"] = Tensor(ConvertTensor(wo[:, start:end], nd2nz=False))
139
- self.dictionary[prefix + ".attention.wo.weight.scale"] = Tensor(wo_scale[start:end])
140
- self.dictionary[prefix + ".attention.wo.weight.offset"] = Tensor(wo_offset[start:end])
141
-
142
- if unite_qkv:
143
- self.dictionary[prefix + ".attention.wqkv.weight"] = \
144
- Tensor(ConvertTensor(np.concatenate((wq[start:end, :], wk[kv_start:kv_end, :], wv[kv_start:kv_end, :])),
145
- nd2nz=False))
146
- self.dictionary[prefix + ".attention.wqkv.weight.scale"] = \
147
- Tensor(np.concatenate((wq_scale[start:end], wk_scale[kv_start:kv_end], wv_scale[kv_start:kv_end])))
148
- self.dictionary[prefix + ".attention.wqkv.weight.offset"] = \
149
- Tensor(np.concatenate((wq_offset[start:end], wk_offset[kv_start:kv_end], wv_offset[kv_start:kv_end])))
150
- else:
151
- self.dictionary[prefix + ".attention.wq.weight"] = Tensor(ConvertTensor(wq[start:end, :], nd2nz=False))
152
- self.dictionary[prefix + ".attention.wq.weight.scale"] = Tensor(wq_scale[start:end])
153
- self.dictionary[prefix + ".attention.wq.weight.offset"] = Tensor(wq_offset[start:end])
154
- self.dictionary[prefix + ".attention.wkv.weight"] = \
155
- Tensor(ConvertTensor(np.concatenate((wk[kv_start:kv_end, :], wv[kv_start:kv_end, :])), nd2nz=False))
156
- self.dictionary[prefix + ".attention.wkv.weight.scale"] = \
157
- Tensor(np.concatenate((wk_scale[kv_start:kv_end], wv_scale[kv_start:kv_end])))
158
- self.dictionary[prefix + ".attention.wkv.weight.offset"] = \
159
- Tensor(np.concatenate((wk_offset[kv_start:kv_end], wv_offset[kv_start:kv_end])))
160
-
161
- def _pack_ffn_quant_weights(self, l_dict, prefix, ffn_start, ffn_end):
162
- """ prepares the dictionary of weights of a quantized ffn block """
163
- w1, w1_scale, w1_offset = self._get_quant_triplet_from_dict(l_dict, "w1")
164
- w2, w2_scale, w2_offset = self._get_quant_triplet_from_dict(l_dict, "w2")
165
- w3, w3_scale, w3_offset = self._get_quant_triplet_from_dict(l_dict, "w3")
166
- self.dictionary[prefix + ".feed_forward.w2.weight"] = Tensor(ConvertTensor(w2[:, ffn_start:ffn_end],
167
- nd2nz=False))
168
- self.dictionary[prefix + ".feed_forward.w2.weight.scale"] = Tensor(w2_scale[ffn_start:ffn_end])
169
- self.dictionary[prefix + ".feed_forward.w2.weight.offset"] = Tensor(w2_offset[ffn_start:ffn_end])
170
-
171
- self.dictionary[prefix + ".feed_forward.w13.weight"] = \
172
- Tensor(ConvertTensor(np.concatenate((w1[ffn_start:ffn_end, :], w3[ffn_start:ffn_end, :])), nd2nz=False))
173
- self.dictionary[prefix + ".feed_forward.w13.weight.scale"] = \
174
- Tensor(np.concatenate((w1_scale[ffn_start:ffn_end], w3_scale[ffn_start:ffn_end])))
175
- self.dictionary[prefix + ".feed_forward.w13.weight.offset"] = \
176
- Tensor(np.concatenate((w1_offset[ffn_start:ffn_end], w3_offset[ffn_start:ffn_end])))
177
-
178
- def _prepare_cos_sin_arrays(self, config, theta=10000):
179
- """ prepares the cosine and sine arrays """
180
- head_dim = config.hidden_size // config.num_heads
181
- max_position_embedding = \
182
- config.max_position_embedding if config.max_position_embedding is not None else config.seq_length
183
- freqs_base = np.arange(0, head_dim, 2)[: (head_dim // 2)].astype(np.float32)
184
- freqs = 1.0 / (theta ** (freqs_base / head_dim))
185
- t = np.arange(0, max_position_embedding, 1).astype(np.float32)
186
- freqs = np.outer(t, freqs)
187
- emb = np.concatenate((freqs, freqs), axis=-1)
188
- freqs_cos = Tensor(np.cos(emb), dtype=dtype.float16)
189
- sin = np.sin(emb)
190
-
191
- sin[:, :int(emb.shape[1]/2)] = -sin[:, :int(emb.shape[1]/2)]
192
- self.dictionary['model.cos.weight'] = freqs_cos
193
- freqs_sin = Tensor(sin, dtype=dtype.float16)
194
- self.dictionary['model.sin.weight'] = freqs_sin
195
-
196
- def set_weights(self, ckpt_dict):
197
- """ load the checkpoint """
198
- self.dictionary = {}
199
- self.dictionary['model.tok_embeddings.embedding_weight'] = \
200
- Tensor(ckpt_dict['model.tok_embeddings.embedding_weight'].asnumpy())
201
- self.dictionary['model.norm_out.weight'] = \
202
- Tensor(ckpt_dict['model.norm_out.weight'].asnumpy(), dtype=dtype.float16)
203
- self._prepare_cos_sin_arrays(self.config)
204
- for layer_id in range(self.config.num_layers):
205
- self._prepare_single_layer(ckpt_dict, self.config, layer_id)
206
-
207
- self.binder.set_weights_map(self.dictionary)
@@ -1,52 +0,0 @@
1
- # Copyright 2024 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ============================================================================
15
- """LLMBoost APIs."""
16
-
17
- from mindspore.common import Tensor
18
-
19
- class LLMBoost():
20
- r"""
21
- Implements an LLM in a single kernel.
22
- it forwards the python function to the C++ binded object
23
- """
24
- def __init__(self, config):
25
- r"""
26
- initialize the parameters of the llm binder.
27
- config is simply the config object of the model
28
- """
29
- from mindspore._c_expression import LlmBoostBinder
30
- self.config = config
31
- self.binder = LlmBoostBinder("AscendNative", config.model_type)
32
- self.binder.init_model(config.to_dict())
33
-
34
- def init(self):
35
- """
36
- Initialize the object
37
- returns True if object needs input manipulation by mindformers
38
- """
39
- return False
40
-
41
- def set_kvcache(self, k_caches=None, v_caches=None):
42
- return
43
-
44
- def forward(self, input_ids, batch_valid_length, position_ids=None):
45
- ret = self.binder.forward([input_ids, batch_valid_length], "nothing really")
46
- return Tensor(ret[0])
47
-
48
- def set_weights(self, ckpt_dict):
49
- self.binder.set_weights_map(ckpt_dict)
50
-
51
- def add_flags(self, is_first_iteration=False):
52
- self.binder.add_flags(is_first_iteration=is_first_iteration)
@@ -1,23 +0,0 @@
1
- # Copyright 2024 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ============================================================================
15
- """
16
- Provide llm boost for inference, such as LlamaBoost.
17
- """
18
- from __future__ import absolute_import
19
-
20
- from mindspore.experimental.llm_boost.atb.llama_boost import LlamaBoost
21
- from mindspore.experimental.llm_boost.atb.qwen_boost import QwenBoost
22
-
23
- __all__ = ['LlamaBoost', 'QwenBoost']