mindspore 2.7.0rc1__cp310-cp310-win_amd64.whl → 2.7.1__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (370) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +5 -2
  3. mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
  4. mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
  5. mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
  6. mindspore/_checkparam.py +2 -2
  7. mindspore/_extends/builtin_operations.py +3 -3
  8. mindspore/_extends/parallel_compile/akg_compiler/custom.py +1109 -0
  9. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
  10. mindspore/_extends/parse/__init__.py +3 -3
  11. mindspore/_extends/parse/compile_config.py +24 -1
  12. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +6 -3
  13. mindspore/_extends/parse/parser.py +28 -22
  14. mindspore/_extends/parse/resources.py +1 -1
  15. mindspore/_extends/parse/standard_method.py +23 -2
  16. mindspore/_extends/parse/trope.py +2 -1
  17. mindspore/_extends/pijit/pijit_func_white_list.py +9 -27
  18. mindspore/amp.py +0 -18
  19. mindspore/avcodec-59.dll +0 -0
  20. mindspore/avdevice-59.dll +0 -0
  21. mindspore/avfilter-8.dll +0 -0
  22. mindspore/avformat-59.dll +0 -0
  23. mindspore/avutil-57.dll +0 -0
  24. mindspore/boost/base.py +29 -2
  25. mindspore/common/__init__.py +18 -12
  26. mindspore/common/_decorator.py +3 -2
  27. mindspore/common/_grad_function.py +3 -1
  28. mindspore/common/_tensor_cpp_method.py +1 -1
  29. mindspore/common/_tensor_docs.py +371 -96
  30. mindspore/common/_utils.py +7 -43
  31. mindspore/common/api.py +434 -135
  32. mindspore/common/dtype.py +98 -57
  33. mindspore/common/dump.py +7 -108
  34. mindspore/common/dynamic_shape/__init__.py +0 -0
  35. mindspore/common/{auto_dynamic_shape.py → dynamic_shape/auto_dynamic_shape.py} +15 -23
  36. mindspore/common/dynamic_shape/enable_dynamic.py +197 -0
  37. mindspore/common/file_system.py +59 -9
  38. mindspore/common/hook_handle.py +82 -3
  39. mindspore/common/jit_config.py +5 -1
  40. mindspore/common/jit_trace.py +27 -12
  41. mindspore/common/lazy_inline.py +5 -3
  42. mindspore/common/np_dtype.py +3 -3
  43. mindspore/common/parameter.py +17 -127
  44. mindspore/common/recompute.py +4 -13
  45. mindspore/common/tensor.py +50 -217
  46. mindspore/communication/_comm_helper.py +11 -1
  47. mindspore/communication/comm_func.py +138 -4
  48. mindspore/communication/management.py +85 -1
  49. mindspore/config/op_info.config +0 -15
  50. mindspore/context.py +20 -106
  51. mindspore/dataset/__init__.py +1 -1
  52. mindspore/dataset/audio/transforms.py +1 -1
  53. mindspore/dataset/core/config.py +35 -1
  54. mindspore/dataset/engine/datasets.py +338 -319
  55. mindspore/dataset/engine/datasets_user_defined.py +38 -22
  56. mindspore/dataset/engine/datasets_vision.py +1 -1
  57. mindspore/dataset/engine/validators.py +1 -15
  58. mindspore/dataset/transforms/c_transforms.py +2 -2
  59. mindspore/dataset/transforms/transforms.py +3 -3
  60. mindspore/dataset/vision/__init__.py +1 -1
  61. mindspore/dataset/vision/py_transforms.py +8 -8
  62. mindspore/dataset/vision/transforms.py +17 -5
  63. mindspore/dataset/vision/utils.py +632 -21
  64. mindspore/device_context/ascend/op_tuning.py +35 -1
  65. mindspore/dnnl.dll +0 -0
  66. mindspore/{profiler/common/validator → graph}/__init__.py +9 -1
  67. mindspore/graph/custom_pass.py +55 -0
  68. mindspore/include/api/cell.h +28 -4
  69. mindspore/include/api/cfg.h +24 -7
  70. mindspore/include/api/context.h +1 -0
  71. mindspore/include/api/delegate.h +0 -2
  72. mindspore/include/api/dual_abi_helper.h +100 -19
  73. mindspore/include/api/graph.h +14 -1
  74. mindspore/include/api/kernel.h +16 -3
  75. mindspore/include/api/kernel_api.h +9 -1
  76. mindspore/include/api/metrics/accuracy.h +9 -0
  77. mindspore/include/api/model.h +5 -1
  78. mindspore/include/api/model_group.h +4 -0
  79. mindspore/include/api/model_parallel_runner.h +2 -0
  80. mindspore/include/api/status.h +48 -10
  81. mindspore/include/api/types.h +6 -1
  82. mindspore/include/dataset/constants.h +9 -0
  83. mindspore/include/dataset/execute.h +2 -2
  84. mindspore/jpeg62.dll +0 -0
  85. mindspore/mindrecord/__init__.py +3 -3
  86. mindspore/mindrecord/common/exceptions.py +1 -0
  87. mindspore/mindrecord/config.py +1 -1
  88. mindspore/{parallel/mpi → mindrecord/core}/__init__.py +4 -1
  89. mindspore/mindrecord/{shardheader.py → core/shardheader.py} +2 -1
  90. mindspore/mindrecord/{shardindexgenerator.py → core/shardindexgenerator.py} +1 -1
  91. mindspore/mindrecord/{shardreader.py → core/shardreader.py} +2 -1
  92. mindspore/mindrecord/{shardsegment.py → core/shardsegment.py} +2 -2
  93. mindspore/mindrecord/{shardutils.py → core/shardutils.py} +1 -1
  94. mindspore/mindrecord/{shardwriter.py → core/shardwriter.py} +1 -1
  95. mindspore/mindrecord/filereader.py +4 -4
  96. mindspore/mindrecord/filewriter.py +5 -5
  97. mindspore/mindrecord/mindpage.py +2 -2
  98. mindspore/mindrecord/tools/cifar10.py +4 -3
  99. mindspore/mindrecord/tools/cifar100.py +1 -1
  100. mindspore/mindrecord/tools/cifar100_to_mr.py +1 -1
  101. mindspore/mindrecord/tools/cifar10_to_mr.py +6 -6
  102. mindspore/mindrecord/tools/csv_to_mr.py +1 -1
  103. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  104. mindspore/mindrecord/tools/mnist_to_mr.py +1 -1
  105. mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -1
  106. mindspore/mindspore_backend_common.dll +0 -0
  107. mindspore/mindspore_backend_manager.dll +0 -0
  108. mindspore/mindspore_cluster.dll +0 -0
  109. mindspore/mindspore_common.dll +0 -0
  110. mindspore/mindspore_core.dll +0 -0
  111. mindspore/mindspore_cpu.dll +0 -0
  112. mindspore/mindspore_dump.dll +0 -0
  113. mindspore/mindspore_frontend.dll +0 -0
  114. mindspore/mindspore_glog.dll +0 -0
  115. mindspore/mindspore_hardware_abstract.dll +0 -0
  116. mindspore/mindspore_memory_pool.dll +0 -0
  117. mindspore/mindspore_ms_backend.dll +0 -0
  118. mindspore/mindspore_ops.dll +0 -0
  119. mindspore/{mindspore_ops_host.dll → mindspore_ops_cpu.dll} +0 -0
  120. mindspore/mindspore_profiler.dll +0 -0
  121. mindspore/mindspore_pyboost.dll +0 -0
  122. mindspore/mindspore_pynative.dll +0 -0
  123. mindspore/mindspore_runtime_pipeline.dll +0 -0
  124. mindspore/mindspore_runtime_utils.dll +0 -0
  125. mindspore/mindspore_tools.dll +0 -0
  126. mindspore/mint/__init__.py +15 -10
  127. mindspore/mint/distributed/__init__.py +4 -0
  128. mindspore/mint/distributed/distributed.py +392 -69
  129. mindspore/mint/nn/__init__.py +2 -16
  130. mindspore/mint/nn/functional.py +4 -110
  131. mindspore/mint/nn/layer/__init__.py +0 -2
  132. mindspore/mint/nn/layer/_functions.py +1 -2
  133. mindspore/mint/nn/layer/activation.py +0 -6
  134. mindspore/mint/nn/layer/basic.py +0 -47
  135. mindspore/mint/nn/layer/conv.py +10 -10
  136. mindspore/mint/nn/layer/normalization.py +11 -16
  137. mindspore/mint/nn/layer/pooling.py +0 -4
  138. mindspore/nn/__init__.py +1 -3
  139. mindspore/nn/cell.py +231 -239
  140. mindspore/nn/layer/activation.py +4 -2
  141. mindspore/nn/layer/basic.py +56 -14
  142. mindspore/nn/layer/container.py +16 -0
  143. mindspore/nn/layer/embedding.py +4 -169
  144. mindspore/nn/layer/image.py +1 -1
  145. mindspore/nn/layer/normalization.py +2 -1
  146. mindspore/nn/layer/thor_layer.py +4 -85
  147. mindspore/nn/optim/ada_grad.py +0 -1
  148. mindspore/nn/optim/adafactor.py +0 -1
  149. mindspore/nn/optim/adam.py +32 -127
  150. mindspore/nn/optim/adamax.py +0 -1
  151. mindspore/nn/optim/asgd.py +0 -1
  152. mindspore/nn/optim/ftrl.py +8 -102
  153. mindspore/nn/optim/lamb.py +1 -4
  154. mindspore/nn/optim/lars.py +0 -3
  155. mindspore/nn/optim/lazyadam.py +25 -218
  156. mindspore/nn/optim/momentum.py +5 -43
  157. mindspore/nn/optim/optimizer.py +6 -55
  158. mindspore/nn/optim/proximal_ada_grad.py +0 -1
  159. mindspore/nn/optim/rmsprop.py +0 -1
  160. mindspore/nn/optim/rprop.py +0 -1
  161. mindspore/nn/optim/sgd.py +0 -1
  162. mindspore/nn/optim/tft_wrapper.py +2 -4
  163. mindspore/nn/optim/thor.py +0 -2
  164. mindspore/nn/probability/bijector/bijector.py +7 -8
  165. mindspore/nn/probability/bijector/gumbel_cdf.py +2 -2
  166. mindspore/nn/probability/bijector/power_transform.py +20 -21
  167. mindspore/nn/probability/bijector/scalar_affine.py +5 -5
  168. mindspore/nn/probability/bijector/softplus.py +13 -14
  169. mindspore/nn/probability/distribution/_utils/utils.py +2 -2
  170. mindspore/nn/wrap/cell_wrapper.py +39 -5
  171. mindspore/nn/wrap/grad_reducer.py +4 -89
  172. mindspore/numpy/array_creations.py +4 -4
  173. mindspore/numpy/fft.py +9 -9
  174. mindspore/numpy/utils_const.py +1 -1
  175. mindspore/{nn/reinforcement → onnx}/__init__.py +5 -8
  176. mindspore/onnx/onnx_export.py +137 -0
  177. mindspore/opencv_core4110.dll +0 -0
  178. mindspore/opencv_imgcodecs4110.dll +0 -0
  179. mindspore/{opencv_imgproc452.dll → opencv_imgproc4110.dll} +0 -0
  180. mindspore/ops/__init__.py +2 -0
  181. mindspore/ops/_grad_experimental/grad_comm_ops.py +38 -2
  182. mindspore/ops/_grad_experimental/grad_inner_ops.py +0 -9
  183. mindspore/ops/_op_impl/aicpu/__init__.py +0 -10
  184. mindspore/ops/_op_impl/cpu/__init__.py +1 -5
  185. mindspore/ops/_op_impl/cpu/{buffer_append.py → joinedstr_op.py} +8 -8
  186. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +28 -24
  187. mindspore/ops/auto_generate/gen_extend_func.py +6 -11
  188. mindspore/ops/auto_generate/gen_ops_def.py +385 -154
  189. mindspore/ops/auto_generate/gen_ops_prim.py +5676 -5167
  190. mindspore/ops/communication.py +97 -0
  191. mindspore/ops/composite/__init__.py +5 -2
  192. mindspore/ops/composite/base.py +16 -2
  193. mindspore/ops/composite/multitype_ops/__init__.py +3 -1
  194. mindspore/ops/composite/multitype_ops/_compile_utils.py +150 -8
  195. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -1
  196. mindspore/ops/composite/multitype_ops/add_impl.py +7 -0
  197. mindspore/ops/composite/multitype_ops/mod_impl.py +27 -0
  198. mindspore/ops/function/__init__.py +2 -0
  199. mindspore/ops/function/array_func.py +24 -18
  200. mindspore/ops/function/comm_func.py +3883 -0
  201. mindspore/ops/function/debug_func.py +7 -6
  202. mindspore/ops/function/grad/grad_func.py +4 -12
  203. mindspore/ops/function/math_func.py +89 -86
  204. mindspore/ops/function/nn_func.py +92 -313
  205. mindspore/ops/function/random_func.py +9 -18
  206. mindspore/ops/functional.py +4 -1
  207. mindspore/ops/functional_overload.py +377 -30
  208. mindspore/ops/operations/__init__.py +2 -5
  209. mindspore/ops/operations/_custom_ops_utils.py +7 -9
  210. mindspore/ops/operations/_inner_ops.py +12 -50
  211. mindspore/ops/operations/_rl_inner_ops.py +0 -933
  212. mindspore/ops/operations/array_ops.py +5 -50
  213. mindspore/ops/operations/comm_ops.py +95 -17
  214. mindspore/ops/operations/custom_ops.py +237 -22
  215. mindspore/ops/operations/debug_ops.py +33 -35
  216. mindspore/ops/operations/manually_defined/ops_def.py +39 -318
  217. mindspore/ops/operations/math_ops.py +5 -5
  218. mindspore/ops/operations/nn_ops.py +3 -3
  219. mindspore/ops/operations/sparse_ops.py +0 -83
  220. mindspore/ops/primitive.py +4 -27
  221. mindspore/ops/tensor_method.py +88 -10
  222. mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +5 -5
  223. mindspore/ops_generate/aclnn/gen_aclnn_implement.py +8 -8
  224. mindspore/ops_generate/api/functions_cc_generator.py +53 -4
  225. mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +25 -11
  226. mindspore/ops_generate/common/gen_constants.py +11 -10
  227. mindspore/ops_generate/common/op_proto.py +18 -1
  228. mindspore/ops_generate/common/template.py +102 -245
  229. mindspore/ops_generate/common/template_utils.py +212 -0
  230. mindspore/ops_generate/gen_custom_ops.py +69 -0
  231. mindspore/ops_generate/op_def/ops_def_cc_generator.py +78 -7
  232. mindspore/ops_generate/op_def_py/base_op_prim_py_generator.py +360 -0
  233. mindspore/ops_generate/op_def_py/custom_op_prim_py_generator.py +140 -0
  234. mindspore/ops_generate/op_def_py/op_def_py_generator.py +54 -7
  235. mindspore/ops_generate/op_def_py/op_prim_py_generator.py +5 -312
  236. mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +74 -17
  237. mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +22 -5
  238. mindspore/ops_generate/pyboost/gen_pyboost_func.py +0 -16
  239. mindspore/ops_generate/pyboost/op_template_parser.py +3 -2
  240. mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +21 -5
  241. mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +2 -2
  242. mindspore/ops_generate/pyboost/pyboost_functions_impl_cpp_generator.py +30 -10
  243. mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +10 -3
  244. mindspore/ops_generate/pyboost/pyboost_internal_kernel_info_adapter_generator.py +1 -1
  245. mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +19 -9
  246. mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +71 -28
  247. mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +10 -9
  248. mindspore/ops_generate/pyboost/pyboost_utils.py +27 -16
  249. mindspore/ops_generate/resources/yaml_loader.py +13 -0
  250. mindspore/ops_generate/tensor_py_cc_generator.py +2 -2
  251. mindspore/parallel/_auto_parallel_context.py +5 -15
  252. mindspore/parallel/_cell_wrapper.py +1 -1
  253. mindspore/parallel/_parallel_serialization.py +4 -6
  254. mindspore/parallel/_ps_context.py +2 -2
  255. mindspore/parallel/_utils.py +34 -17
  256. mindspore/parallel/auto_parallel.py +23 -9
  257. mindspore/parallel/checkpoint_transform.py +20 -2
  258. mindspore/parallel/cluster/process_entity/_api.py +28 -33
  259. mindspore/parallel/cluster/process_entity/_utils.py +9 -5
  260. mindspore/parallel/cluster/run.py +5 -3
  261. mindspore/{experimental/llm_boost/ascend_native → parallel/distributed}/__init__.py +21 -22
  262. mindspore/parallel/distributed/distributed_data_parallel.py +393 -0
  263. mindspore/parallel/distributed/flatten_grad_buffer.py +295 -0
  264. mindspore/parallel/function/reshard_func.py +6 -5
  265. mindspore/parallel/nn/parallel_cell_wrapper.py +40 -3
  266. mindspore/parallel/nn/parallel_grad_reducer.py +0 -8
  267. mindspore/parallel/shard.py +7 -21
  268. mindspore/parallel/strategy.py +336 -0
  269. mindspore/parallel/transform_safetensors.py +127 -20
  270. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +13 -9
  271. mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +1 -1
  272. mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +1 -1
  273. mindspore/profiler/common/constant.py +5 -0
  274. mindspore/profiler/common/file_manager.py +9 -0
  275. mindspore/profiler/common/msprof_cmd_tool.py +40 -4
  276. mindspore/profiler/common/path_manager.py +65 -24
  277. mindspore/profiler/common/profiler_context.py +27 -14
  278. mindspore/profiler/common/profiler_info.py +3 -3
  279. mindspore/profiler/common/profiler_meta_data.py +1 -0
  280. mindspore/profiler/common/profiler_op_analyse.py +10 -6
  281. mindspore/profiler/common/profiler_path_manager.py +13 -0
  282. mindspore/profiler/common/util.py +30 -3
  283. mindspore/profiler/dynamic_profiler.py +91 -46
  284. mindspore/profiler/envprofiler.py +30 -5
  285. mindspore/profiler/experimental_config.py +18 -2
  286. mindspore/profiler/platform/cpu_profiler.py +10 -4
  287. mindspore/profiler/platform/npu_profiler.py +34 -7
  288. mindspore/profiler/profiler.py +193 -145
  289. mindspore/profiler/profiler_action_controller.py +1 -1
  290. mindspore/profiler/profiler_interface.py +2 -2
  291. mindspore/rewrite/symbol_tree/symbol_tree.py +1 -1
  292. mindspore/run_check/_check_version.py +108 -24
  293. mindspore/runtime/__init__.py +9 -6
  294. mindspore/runtime/executor.py +35 -0
  295. mindspore/runtime/memory.py +113 -0
  296. mindspore/runtime/thread_bind_core.py +1 -1
  297. mindspore/swresample-4.dll +0 -0
  298. mindspore/swscale-6.dll +0 -0
  299. mindspore/tinyxml2.dll +0 -0
  300. mindspore/{experimental/llm_boost → tools}/__init__.py +5 -5
  301. mindspore/tools/data_dump.py +130 -0
  302. mindspore/tools/sdc_detect.py +91 -0
  303. mindspore/tools/stress_detect.py +63 -0
  304. mindspore/train/__init__.py +6 -6
  305. mindspore/train/_utils.py +8 -21
  306. mindspore/train/amp.py +6 -7
  307. mindspore/train/callback/_callback.py +2 -1
  308. mindspore/train/callback/_checkpoint.py +1 -17
  309. mindspore/train/callback/_flops_collector.py +10 -6
  310. mindspore/train/callback/_train_fault_tolerance.py +72 -25
  311. mindspore/train/data_sink.py +5 -9
  312. mindspore/train/dataset_helper.py +5 -5
  313. mindspore/train/model.py +41 -230
  314. mindspore/train/serialization.py +160 -401
  315. mindspore/train/train_thor/model_thor.py +2 -2
  316. mindspore/turbojpeg.dll +0 -0
  317. mindspore/utils/__init__.py +6 -3
  318. mindspore/utils/dlpack.py +92 -0
  319. mindspore/utils/dryrun.py +1 -1
  320. mindspore/utils/runtime_execution_order_check.py +10 -0
  321. mindspore/utils/sdc_detect.py +14 -12
  322. mindspore/utils/stress_detect.py +43 -0
  323. mindspore/utils/utils.py +152 -16
  324. mindspore/version.py +1 -1
  325. {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/METADATA +3 -2
  326. {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/RECORD +330 -344
  327. mindspore/_extends/remote/kernel_build_server_ascend.py +0 -75
  328. mindspore/communication/_hccl_management.py +0 -297
  329. mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +0 -207
  330. mindspore/experimental/llm_boost/ascend_native/llm_boost.py +0 -52
  331. mindspore/experimental/llm_boost/atb/__init__.py +0 -23
  332. mindspore/experimental/llm_boost/atb/boost_base.py +0 -385
  333. mindspore/experimental/llm_boost/atb/llama_boost.py +0 -137
  334. mindspore/experimental/llm_boost/atb/qwen_boost.py +0 -124
  335. mindspore/experimental/llm_boost/register.py +0 -130
  336. mindspore/experimental/llm_boost/utils.py +0 -31
  337. mindspore/include/OWNERS +0 -7
  338. mindspore/mindspore_cpu_res_manager.dll +0 -0
  339. mindspore/mindspore_ops_kernel_common.dll +0 -0
  340. mindspore/mindspore_res_manager.dll +0 -0
  341. mindspore/nn/optim/_dist_optimizer_registry.py +0 -111
  342. mindspore/nn/reinforcement/_batch_read_write.py +0 -142
  343. mindspore/nn/reinforcement/_tensors_queue.py +0 -152
  344. mindspore/nn/reinforcement/tensor_array.py +0 -145
  345. mindspore/opencv_core452.dll +0 -0
  346. mindspore/opencv_imgcodecs452.dll +0 -0
  347. mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +0 -113
  348. mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +0 -96
  349. mindspore/ops/_op_impl/aicpu/sparse_cross.py +0 -42
  350. mindspore/ops/_op_impl/cpu/buffer_get.py +0 -28
  351. mindspore/ops/_op_impl/cpu/buffer_sample.py +0 -28
  352. mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +0 -42
  353. mindspore/ops/operations/_tensor_array.py +0 -359
  354. mindspore/ops/operations/rl_ops.py +0 -288
  355. mindspore/parallel/_offload_context.py +0 -275
  356. mindspore/parallel/_recovery_context.py +0 -115
  357. mindspore/parallel/_transformer/__init__.py +0 -35
  358. mindspore/parallel/_transformer/layers.py +0 -765
  359. mindspore/parallel/_transformer/loss.py +0 -251
  360. mindspore/parallel/_transformer/moe.py +0 -693
  361. mindspore/parallel/_transformer/op_parallel_config.py +0 -222
  362. mindspore/parallel/_transformer/transformer.py +0 -3124
  363. mindspore/parallel/mpi/_mpi_config.py +0 -116
  364. mindspore/profiler/common/validator/validate_path.py +0 -84
  365. mindspore/train/memory_profiling_pb2.py +0 -298
  366. mindspore/utils/hooks.py +0 -81
  367. /mindspore/common/{_auto_dynamic.py → dynamic_shape/_auto_dynamic.py} +0 -0
  368. {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/WHEEL +0 -0
  369. {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/entry_points.txt +0 -0
  370. {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,197 @@
1
+ # This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
2
+ #
3
+ # Copyright 2025 Huawei Technologies Co., Ltd
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ # ============================================================================
17
+ """Define enable_dynamic decorator."""
18
+ import types
19
+ import inspect
20
+ from mindspore import log as logger
21
+ from mindspore.common.tensor import Tensor
22
+ from mindspore.common._utils import get_func, is_dim_unknown
23
+ from mindspore.common.dynamic_shape.auto_dynamic_shape import SHAPE_DIM_ANY
24
+
25
+
26
+ ENABLE_DYNAMIC = "__enable_dynamic__"
27
+
28
+
29
+ def _check_element_valid(item, shape, name):
30
+ """Check elements in shape."""
31
+ if item is not SHAPE_DIM_ANY and (isinstance(item, int) and item <= 0):
32
+ raise TypeError(f"The argument '{name}' has invalid shape '{shape}', only supports None " \
33
+ f"or a tuple/list of positive integers and None.")
34
+ return True
35
+
36
+
37
+ def _check_arg_shape_valid(arg, name):
38
+ """Check if the shape of arg is valid"""
39
+ #if the shape of arg is None
40
+ if isinstance(arg, Tensor) and is_dim_unknown(arg.shape):
41
+ return True
42
+ if isinstance(arg, Tensor) and \
43
+ SHAPE_DIM_ANY in arg.shape and \
44
+ all(_check_element_valid(item, arg.shape, name) for item in arg.shape):
45
+ return True
46
+ if isinstance(arg, (tuple, list)) and any(_check_arg_shape_valid(item, name) for item in arg):
47
+ return True
48
+ return False
49
+
50
+
51
+ def _check_arg_type_valid(arg, name):
52
+ """Check if the type of arg is valid."""
53
+ if isinstance(arg, Tensor):
54
+ return
55
+ if isinstance(arg, (tuple, list)):
56
+ for item in arg:
57
+ _check_arg_type_valid(item, name)
58
+ else:
59
+ raise TypeError(f"The decorator enable_dynamic only supports Tensor " \
60
+ f"or a tuple/list of Tensor, but the argument : {name} is type of:{type(arg)}.")
61
+
62
+
63
+ def _check_input_valid(arg):
64
+ """Check if real argument is valid."""
65
+ if isinstance(arg, Tensor):
66
+ if not all(isinstance(item, int) and item > 0 for item in arg.shape):
67
+ raise ValueError(f"When using decorator enable_dynamic, the corresponding shape of inputs should be " \
68
+ f"a tuple/list of positive integers")
69
+ elif isinstance(arg, (tuple, list)):
70
+ for item in arg:
71
+ _check_input_valid(item)
72
+ else:
73
+ raise TypeError(f"When using decorator enable_dynamic, the corresponding inputs only supports Tensor " \
74
+ f"or a tuple/list of Tensor.")
75
+
76
+
77
+ def _check_arg_type_shape(arg, dyn_arg, name):
78
+ """Check the type, shape and dtype of real argument."""
79
+ if isinstance(arg, Tensor) and isinstance(dyn_arg, Tensor):
80
+ if arg.dtype != dyn_arg.dtype:
81
+ raise TypeError(f"When using decorator enable_dynamic, input tensor dtype = {arg.dtype}, " \
82
+ f"dynamic tensor dtype = {dyn_arg.dtype}, tensor dtypes are not the same.")
83
+ if is_dim_unknown(dyn_arg.shape):
84
+ return
85
+ if len(arg.shape) != len(dyn_arg.shape) or \
86
+ any(y is not SHAPE_DIM_ANY and x != y for x, y in zip(arg.shape, dyn_arg.shape)):
87
+ raise ValueError(f"When using decorator enable_dynamic, input tensor shape = {arg.shape}, " \
88
+ f"dynamic tensor shape = {dyn_arg.shape}, tensor shapes are not the same.")
89
+ elif isinstance(arg, (tuple, list)) and isinstance(dyn_arg, (tuple, list)):
90
+ if len(arg) != len(dyn_arg):
91
+ raise ValueError("Input sequences must have the same structure and length.")
92
+ for x, y in zip(arg, dyn_arg):
93
+ _check_arg_type_shape(x, y, name)
94
+ else:
95
+ raise TypeError(f"When using decorator enable_dynamic, the type between argument '{name}' " \
96
+ f"and corresponding input are not the same.")
97
+
98
+
99
+ def generate_dynamic_sequence_args(args_list, dyn_args_list):
100
+ """Generate dynamic shapes for input sequence"""
101
+ if isinstance(args_list, Tensor):
102
+ return dyn_args_list if args_list.shape != dyn_args_list.shape else args_list
103
+ result = []
104
+ for x, y in zip(args_list, dyn_args_list):
105
+ result.append(generate_dynamic_sequence_args(x, y))
106
+ return type(args_list)(result)
107
+
108
+
109
+ def generate_dynamic_tensor_args(args_list, dynamic_shapes):
110
+ """Generate compile args with dynamic_shapes"""
111
+ new_compile_args = list(args_list)
112
+ for index, arg in enumerate(args_list):
113
+ if isinstance(arg, (tuple, list)) and not hasattr(arg, "__ms_mutable__"):
114
+ raise ValueError(f"When using decorator enable_dynamic, the corresponding attribute of input should be " \
115
+ f"mutable(tuple/list)")
116
+ if index not in dynamic_shapes:
117
+ continue
118
+ _check_input_valid(arg)
119
+ name, dyn_arg = dynamic_shapes[index]
120
+ _check_arg_type_shape(arg, dyn_arg, name)
121
+ new_compile_args[index] = generate_dynamic_sequence_args(arg, dyn_arg)
122
+ logger.debug(f"args_list: {args_list}, dynamic_shapes: {dynamic_shapes}, " \
123
+ f"new_compile_args: {new_compile_args}")
124
+ return new_compile_args
125
+
126
+
127
+ def enable_dynamic(**kwargs):
128
+ """
129
+ Use to specify whether the shape of the parameter is dynamic shape or dynamic rank.
130
+
131
+ Note:
132
+ - It needs to be used in conjunction with the JIT interface. Without using the JIT decorator,
133
+ the dynamic shape and dynamic rank functions will not be enabled.
134
+ - In the scenario where both set_context(mode=GRAPH_MODE) and nn.Cell are set simultaneously,
135
+ use enabled_dynamic to report an error.
136
+
137
+ Args:
138
+ kwargs (dict): The input types are Tensor, tuple[Tensor] and list[Tensor]. If one or
139
+ more dimensions in the shape of the parameter need to be specified as dynamic shapes,
140
+ the corresponding dimensions in the shape can be set to None. If the shape that needs
141
+ to generate specified parameters is dynamic rank, the shape can be set to None.
142
+
143
+ Returns:
144
+ Function, return a function that specifies the dynamic shape information of the parameter.
145
+
146
+ Supported Platforms:
147
+ ``Ascend`` ``GPU`` ``CPU``
148
+
149
+ Examples:
150
+ >>> import numpy as np
151
+ >>> import mindspore as ms
152
+ >>> from mindspore import Tensor
153
+ >>> from mindspore import enable_dynamic
154
+ >>> from mindspore import jit
155
+ ...
156
+ >>> x = Tensor(np.random.randn(2, 3), ms.float32)
157
+ >>> y = Tensor(np.random.randn(2, 3), ms.float32)
158
+ ...
159
+ >>> # Specify parameter y as dynamic shape
160
+ >>> @enable_dynamic(y=Tensor(shape=None, dtype=ms.float32))
161
+ >>> @jit
162
+ >>> def func(x, y):
163
+ ... return x + 1, y + 1
164
+ ...
165
+ >>> out = func(x, y)
166
+ """
167
+ # Check inputs at first.
168
+ if not kwargs:
169
+ raise ValueError(f"When using decorator enable_dynamic, the input cannot be empty!")
170
+ for name, arg in kwargs.items():
171
+ _check_arg_type_valid(arg, name)
172
+ if not _check_arg_shape_valid(arg, name):
173
+ raise TypeError(f"When using decorator enable_dynamic, the shape of argument '{name}' " \
174
+ f"at least have one None.")
175
+
176
+ def decorator(func):
177
+ if not isinstance(func, (types.FunctionType, types.MethodType)):
178
+ raise ValueError(f"Decorator enable_dynamic can only be used for function or method " \
179
+ f"decrocated by ms.jit, but got {func}.")
180
+ signature = inspect.signature(func)
181
+ sigs_name = [sig_name for sig_name in signature.parameters if sig_name != "self"]
182
+ if len(kwargs) > len(sigs_name):
183
+ raise ValueError(f"When using decorator enable_dynamic, the number of arguments {len(kwargs)} " \
184
+ f"exceeds the number of function arguments {len(sigs_name)}.")
185
+ # Generate dynamic args.
186
+ dynamic_args = dict()
187
+ for key, value in kwargs.items():
188
+ index = sigs_name.index(key)
189
+ if index in dynamic_args:
190
+ raise ValueError(f"keyword argument repeated: {key}")
191
+ dynamic_args[index] = (key, value)
192
+ # Set dynamic_tensor_shape to func.
193
+ inner_func = inspect.unwrap(func, stop=lambda f: not hasattr(f, '__wrapped__'))
194
+ setattr(get_func(inner_func), ENABLE_DYNAMIC, dynamic_args)
195
+ logger.info(f"Set enable dynamic: {dynamic_args} to {inner_func}")
196
+ return func
197
+ return decorator
@@ -14,10 +14,14 @@
14
14
  # ============================================================================
15
15
  """File system registration management"""
16
16
  from mindspore import log as logger
17
+ from mindspore import _checkparam as Validator
18
+
19
+ mindio_server_info = {"memfs.data_block_pool_capacity_in_gb": "100"}
17
20
 
18
21
 
19
22
  class FileSystem:
20
23
  """File operation interface manager"""
24
+
21
25
  def __init__(self):
22
26
  self.create = open
23
27
  self.create_args = ("ab",)
@@ -35,20 +39,33 @@ def _register_basic_file_system(fs: FileSystem):
35
39
  return True
36
40
 
37
41
 
38
- def _register_mindio_file_system(fs: FileSystem):
39
- """register mindio file system"""
42
+ def _init_mindio():
43
+ """Initialize MindIO and return the module if successful"""
40
44
  try:
41
- import mindio
45
+ import mindio_acp as mindio
46
+ ret = mindio.initialize(server_info=mindio_server_info)
47
+ if ret == 0:
48
+ return mindio
49
+ logger.warning(f"Failed to initialize mindio_acp: ret = {ret}")
42
50
  except ImportError:
43
- return False
51
+ pass
44
52
  try:
53
+ import mindio
45
54
  ret = mindio.initialize()
46
- except AttributeError as e:
47
- logger.warning(f"Failed to initialize MindIO: {e}")
48
- return False
49
- if ret != 0:
50
- logger.warning(f"Failed to initialize MindIO: ret = {ret}")
55
+ if ret == 0:
56
+ return mindio
57
+ logger.warning(f"Failed to initialize mindio: ret = {ret}")
58
+ except ImportError:
59
+ pass
60
+ return None
61
+
62
+
63
+ def _register_mindio_file_system(fs: FileSystem):
64
+ """register mindio file system"""
65
+ mindio = _init_mindio()
66
+ if mindio is None:
51
67
  return False
68
+
52
69
  fs.create = mindio.create_file
53
70
  fs.create_args = ()
54
71
  fs.open = mindio.open_file
@@ -56,3 +73,36 @@ def _register_mindio_file_system(fs: FileSystem):
56
73
  fs.backend = "mindio"
57
74
  logger.info("The weights are stored using MindIO as the backend.")
58
75
  return True
76
+
77
+
78
+ def set_mindio_server_info(data_block_pool_capacity_in_gb=100):
79
+ """
80
+ Configure MindIO server settings.
81
+
82
+ Args:
83
+ data_block_pool_capacity_in_gb (int): Memory pool capacity for data blocks in gigabytes.
84
+ """
85
+ global mindio_server_info
86
+ Validator.check_positive_int(data_block_pool_capacity_in_gb, "data_block_pool_capacity_in_gb")
87
+ mindio_server_info["memfs.data_block_pool_capacity_in_gb"] = str(data_block_pool_capacity_in_gb)
88
+
89
+
90
+ def mindio_preload(ckpt_file_name):
91
+ """
92
+ Preload data into memory using MindIO for faster access.
93
+
94
+ Args:
95
+ ckpt_file_name (str): Checkpoint file name.
96
+
97
+ Returns:
98
+ bool: True if preloading is successful, False otherwise.
99
+ """
100
+ Validator.check_value_type('ckpt_file_name', ckpt_file_name, str, "mindio_preload")
101
+ mindio = _init_mindio()
102
+ if mindio is None:
103
+ return False
104
+ if not hasattr(mindio, 'preload'):
105
+ logger.warning("MindIO module does not have preload method")
106
+ return False
107
+ mindio.preload(ckpt_file_name)
108
+ return True
@@ -15,7 +15,23 @@
15
15
  """The removable handle for cell hook function."""
16
16
  from __future__ import absolute_import
17
17
  import weakref
18
+ from collections import OrderedDict
18
19
  from mindspore._c_expression import TensorPy as Tensor_
20
+ from mindspore._check_jit_forbidden_api import jit_forbidden_register
21
+
22
+
23
+ # Global variable to mark the `Parameter` hook and `Cell` hook version
24
+ _HOOK_VERSION = 0
25
+
26
+
27
+ def _update_hook_version():
28
+ global _HOOK_VERSION
29
+ _HOOK_VERSION += 1
30
+
31
+
32
+ def _hook_version():
33
+ global _HOOK_VERSION
34
+ return _HOOK_VERSION
19
35
 
20
36
 
21
37
  class _TensorHookHandle:
@@ -31,8 +47,9 @@ class _TensorHookHandle:
31
47
 
32
48
  def __init__(self, tensor):
33
49
  self.id = None
34
- self.tensor_ref = weakref.ref(tensor)
50
+ self.tensor_weakref = weakref.ref(tensor)
35
51
 
52
+ @jit_forbidden_register
36
53
  def remove(self):
37
54
  """
38
55
  Remove the tensor hook function, which corresponds to this '_TensorHookHandle' object.
@@ -67,9 +84,9 @@ class _TensorHookHandle:
67
84
  """
68
85
  if self.id is not None:
69
86
  Tensor_.remove_hook(self.id)
70
- tensor = self.tensor_ref()
87
+ tensor = self.tensor_weakref()
71
88
  if tensor is not None:
72
- tensor._remove_hook() # pylint:disable=protected-access
89
+ tensor._remove_hook() # pylint:disable=protected-access
73
90
 
74
91
 
75
92
  class HookHandle:
@@ -99,6 +116,7 @@ class HookHandle:
99
116
  if extra_dict is not None:
100
117
  self.extra_dict_ref = weakref.ref(extra_dict)
101
118
 
119
+ @jit_forbidden_register
102
120
  def remove(self):
103
121
  """
104
122
  Remove the cell hook function, which corresponds to this 'HookHandle' object.
@@ -145,6 +163,8 @@ class HookHandle:
145
163
  (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]), Tensor(shape=[1], dtype=Float32,
146
164
  value= [ 2.00000000e+00]))
147
165
  """
166
+ _update_hook_version() # pylint:disable=protected-access
167
+
148
168
  if self.hook_dict_ref is not None:
149
169
  hook_dict = self.hook_dict_ref()
150
170
  if hook_dict is not None and self.handle_id in hook_dict:
@@ -154,3 +174,62 @@ class HookHandle:
154
174
  extra_dict = self.extra_dict_ref()
155
175
  if extra_dict is not None and self.handle_id in extra_dict:
156
176
  del extra_dict[self.handle_id]
177
+
178
+
179
+ def _check_hook_results(pre_res, new_res, hook_fn):
180
+ if not isinstance(new_res, tuple):
181
+ raise RuntimeError(f"hook {hook_fn.__name__} should return a tuple of grad.")
182
+
183
+ new_res_len = len(new_res)
184
+ pre_res_len = len(pre_res)
185
+ if new_res_len != pre_res_len:
186
+ raise RuntimeError(
187
+ f"hook {hook_fn.__name__} returned incorrect length {new_res_len}, expected {pre_res_len}."
188
+ )
189
+
190
+
191
+ class _HookUtils:
192
+ r"""
193
+ Internal utility class for hook registration and execution.
194
+ """
195
+
196
+ @staticmethod
197
+ def register_hook(hook_dict, hook_fn):
198
+ """
199
+ Register hook
200
+
201
+ Args:
202
+ hook_dict (dict): hook dict.
203
+ hook_fn (function): hook function.
204
+
205
+ Returns:
206
+ tuple: Updated hook_dict and HookHandle object.
207
+ """
208
+ if hook_dict is None:
209
+ hook_dict = OrderedDict()
210
+ handle = HookHandle(hook_dict)
211
+ hook_dict[handle.handle_id] = hook_fn
212
+ return hook_dict, handle
213
+
214
+ @staticmethod
215
+ def run_hook(hook_dict, args):
216
+ """
217
+ Run all hooks in the hook_dict with the given arguments.
218
+
219
+ Args:
220
+ hook_dict (dict): Dictionary of registered hooks.
221
+ args (tuple): Arguments to pass to the hook functions.
222
+
223
+ Returns:
224
+ Modified first argument if any hook returns a new value; otherwise, None.
225
+ """
226
+ is_modify = False
227
+ args_list = list(args)
228
+ # Note: We create a list from hook_dict.values() to ensure safe iteration.
229
+ for hook_fn in list(hook_dict.values()):
230
+ res = hook_fn(*args_list)
231
+ if res is not None:
232
+ _check_hook_results(args_list[0], res, hook_fn)
233
+ args_list[0] = res
234
+ is_modify = True
235
+ return args_list[0] if is_modify else None
@@ -27,7 +27,11 @@ class JitConfig:
27
27
  adopt KernelByKernel execution mode.
28
28
  - ``"O1"``: Using commonly used optimizations and automatic operator fusion optimizations,
29
29
  adopt KernelByKernel execution mode.
30
- - ``"O2"``: Ultimate performance optimization, adopt Sink execution mode.
30
+ - ``"O2"``: Utilizes the GraphEngine, a graph compilation and execution engine within CANN,
31
+ for Ascend model compilation and execution. Note: O2 only supports GRAPH Mode in Ascend,
32
+ only supports whole graph sinking or sub graph sinking in pipeline parallel, and does not support
33
+ dynamic shape scenes. In addition, this mode incurs additional compilation costs and is difficult to
34
+ debug and tune.
31
35
 
32
36
  exc_mode (str, optional): Control the execution mode of the model.
33
37
  Supports ["auto", "sink", "no_sink"]. Default: ``"auto"`` .
@@ -28,6 +28,7 @@ from mindspore._c_expression import TraceRecorder as tr
28
28
  from mindspore._c_expression import JitExecutor_
29
29
  from mindspore._c_expression import TensorPy as Tensor, CSRTensor, COOTensor
30
30
  from mindspore._c_expression import typing
31
+ from mindspore.common.jit_config import JitConfig
31
32
 
32
33
 
33
34
  class TraceJitContext(JitContext):
@@ -123,19 +124,19 @@ def nested_run(obj, cell, *args):
123
124
  return file_names, linenos, res
124
125
 
125
126
 
126
- def _jit_trace():
127
+ def _jit_trace(jit_config):
127
128
  """Return the wrapped function for trace mode jit."""
128
129
  def wrap_func(fn):
129
130
  if hasattr(fn, "construct"):
130
131
  if isinstance(fn, ms.nn.Cell):
131
132
  # Bound the cell object to get the self arg.
132
- return types.MethodType(_jit_trace()(fn.construct.__func__), fn)
133
+ return types.MethodType(_jit_trace(jit_config)(fn.construct.__func__), fn)
133
134
  if isinstance(fn, type) and issubclass(fn, ms.nn.Cell):
134
- fn.construct = _jit_trace()(fn.construct)
135
+ fn.construct = _jit_trace(jit_config)(fn.construct)
135
136
  return fn
136
137
 
137
138
  if isinstance(fn, types.MethodType):
138
- return types.MethodType(_jit_trace()(fn.__func__), fn.__self__)
139
+ return types.MethodType(_jit_trace(jit_config)(fn.__func__), fn.__self__)
139
140
 
140
141
  if not isinstance(fn, types.FunctionType):
141
142
  logger.warning(f"The fn should be function, method or cell instance/class, but got {fn}")
@@ -150,6 +151,10 @@ def _jit_trace():
150
151
  if jit_context():
151
152
  return fn(*args, **kwargs)
152
153
  # Start trace process.
154
+ if jit_config:
155
+ jit_config_dict = jit_config.jit_config_dict
156
+ else:
157
+ jit_config_dict = JitConfig().jit_config_dict
153
158
  if kwargs:
154
159
  bound_arguments = inspect.signature(fn).bind(*args, **kwargs)
155
160
  bound_arguments.apply_defaults()
@@ -170,14 +175,16 @@ def _jit_trace():
170
175
  line_str = fn.__code__.co_filename + ":" + str(fn.__code__.co_firstlineno)
171
176
  generate_name = generate_name + '#[' + line_str + ']'
172
177
 
173
- new_compile = _jit_trace_begin(generate_name, *jit_args)
178
+ new_compile = _jit_trace_begin(
179
+ generate_name, *jit_args, jit_config=jit_config_dict)
174
180
  if new_compile:
175
181
  fn_res = fn(*args, **kwargs)
176
182
  logger.debug(f'fn: {fn}, fn_res: {fn_res}, line: {line_str}')
177
183
  # Use fn's output to build func graph's output.
178
- output = _jit_trace_end(fn_res)
184
+ output = _jit_trace_end(fn_res, jit_config=jit_config_dict)
179
185
  else:
180
- output = _jit_trace_end(None) # Run with compilation.
186
+ # Run with compilation.
187
+ output = _jit_trace_end(None, jit_config=jit_config_dict)
181
188
  logger.debug(f'output: {output}')
182
189
  return output
183
190
 
@@ -224,7 +231,7 @@ def _get_args_for_run(args):
224
231
  return tuple(new_args)
225
232
 
226
233
 
227
- def _jit_trace_begin(fn_name, *args):
234
+ def _jit_trace_begin(fn_name, *args, **kwargs):
228
235
  """
229
236
  Start to build a MindIR func graph for a code snippet by trace method.
230
237
 
@@ -257,6 +264,10 @@ def _jit_trace_begin(fn_name, *args):
257
264
  ...
258
265
  >>> out = tensor_add(x, y)
259
266
  """
267
+ if "jit_config" in kwargs:
268
+ jit_config = kwargs["jit_config"]
269
+ else:
270
+ jit_config = JitConfig().jit_config_dict
260
271
  global _using_trace
261
272
  if _using_trace:
262
273
  raise RuntimeError(
@@ -279,7 +290,7 @@ def _jit_trace_begin(fn_name, *args):
279
290
  if not _compile_only and phase in _trace_compile_cache:
280
291
  logger.debug('Had compiled, just run.')
281
292
  _trace_jit_context.compiled = True
282
- output = tr.get_instance().run_graph(phase, args)
293
+ output = tr.get_instance().run_graph(phase, jit_config, args)
283
294
  from mindspore.common.api import _convert_python_data
284
295
  _trace_jit_context.result = _convert_python_data(output)
285
296
  logger.debug(f'jit trace result: {_trace_jit_context.result}')
@@ -295,7 +306,7 @@ def _jit_trace_begin(fn_name, *args):
295
306
  return True
296
307
 
297
308
 
298
- def _jit_trace_end(*output_args):
309
+ def _jit_trace_end(*output_args, **kwargs):
299
310
  """
300
311
  Finish building a MindIR func graph for a code snippet by trace method.
301
312
 
@@ -330,19 +341,23 @@ def _jit_trace_end(*output_args):
330
341
  ...
331
342
  >>> out = tensor_add(x, y)
332
343
  """
344
+ if "jit_config" in kwargs:
345
+ jit_config = kwargs["jit_config"]
346
+ else:
347
+ jit_config = JitConfig().jit_config_dict
333
348
  if _trace_jit_context.compiled:
334
349
  output = _trace_jit_context.result
335
350
  logger.debug(f'jit trace result: {output}')
336
351
  else:
337
352
  logger.debug(f'output_args: {output_args}')
338
353
  file_names, linenos = _get_caller_lines()
339
- tr.get_instance().end_graph(file_names, linenos, *output_args)
354
+ tr.get_instance().end_graph(file_names, linenos, jit_config, *output_args)
340
355
  if _compile_only:
341
356
  output = output_args[0] if len(output_args) == 1 else output_args
342
357
  else:
343
358
  args = _get_args_for_run(_trace_jit_context.args)
344
359
  output = tr.get_instance().run_graph(
345
- _trace_jit_context.phase, args)
360
+ _trace_jit_context.phase, jit_config, args)
346
361
  from mindspore.common.api import _convert_python_data
347
362
  output = _convert_python_data(output)
348
363
  logger.debug(f'jit trace result: {output}')
@@ -32,9 +32,11 @@ def lazy_inline(fn=None, attrs=None, policy=None):
32
32
  static_graph_expert_programming.html#using-lazy-inline-decorator>`_ .
33
33
 
34
34
  .. warning::
35
- This feature is only supported on Ascend and is not supported on other hardwares.
36
- The construct parameters must be positional or key word arguments and have not default values.
37
- The cell has not switch sub graph.
35
+ - This feature is only supported on Ascend and is not supported on other hardwares.
36
+ - The construct parameters must be positional or key word arguments and have not default values.
37
+ - The cell has not switch sub graph.
38
+ - In the gradient accumulation scenario, it is recommended to use the @lazy_inline decorator to
39
+ reduce compilation time, and this decorator is only allowed to configure on the outermost cell.
38
40
 
39
41
  Args:
40
42
  fn (function): `__init__` function of a cell.
@@ -16,10 +16,10 @@
16
16
  # ============================================================================
17
17
  """Numpy data type for MindSpore."""
18
18
 
19
- from mindspore._c_expression.np_dtypes import np_version_valid
20
- if np_version_valid(True):
19
+ from mindspore._c_expression.np_dtypes import np_dtype_valid
20
+ if np_dtype_valid(True):
21
21
  from mindspore._c_expression.np_dtypes import bfloat16 # pylint: disable=unused-import
22
22
 
23
23
  __all__ = []
24
- if np_version_valid(False):
24
+ if np_dtype_valid(False):
25
25
  __all__.extend(["bfloat16"])