mindspore 2.5.0__cp310-cp310-win_amd64.whl → 2.6.0__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (493) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  3. mindspore/Newtonsoft.Json.dll +0 -0
  4. mindspore/__init__.py +6 -4
  5. mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
  6. mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
  7. mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
  8. mindspore/_check_jit_forbidden_api.py +3 -0
  9. mindspore/_checkparam.py +3 -33
  10. mindspore/_deprecated/__init__.py +17 -0
  11. mindspore/_deprecated/jit.py +198 -0
  12. mindspore/_extends/builtin_operations.py +1 -1
  13. mindspore/_extends/parse/__init__.py +6 -7
  14. mindspore/_extends/parse/compile_config.py +19 -0
  15. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +22 -3
  16. mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
  17. mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
  18. mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
  19. mindspore/_extends/parse/parser.py +25 -194
  20. mindspore/_extends/parse/resources.py +1 -5
  21. mindspore/_extends/parse/standard_method.py +109 -75
  22. mindspore/_extends/pijit/__init__.py +2 -2
  23. mindspore/_extends/pijit/pijit_func_white_list.py +16 -11
  24. mindspore/_extends/pijit/tensor_func_list.py +27 -0
  25. mindspore/_extends/utils.py +1 -1
  26. mindspore/amp.py +4 -4
  27. mindspore/atlprov.dll +0 -0
  28. mindspore/avcodec-59.dll +0 -0
  29. mindspore/avdevice-59.dll +0 -0
  30. mindspore/avfilter-8.dll +0 -0
  31. mindspore/avformat-59.dll +0 -0
  32. mindspore/avutil-57.dll +0 -0
  33. mindspore/boost/__init__.py +2 -2
  34. mindspore/boost/base.py +3 -7
  35. mindspore/boost/boost_cell_wrapper.py +2 -2
  36. mindspore/c1.dll +0 -0
  37. mindspore/c1xx.dll +0 -0
  38. mindspore/c2.dll +0 -0
  39. mindspore/common/__init__.py +4 -3
  40. mindspore/common/_grad_function.py +56 -0
  41. mindspore/common/_pijit_context.py +14 -5
  42. mindspore/common/_register_for_tensor.py +1 -1
  43. mindspore/common/_stub_tensor.py +5 -10
  44. mindspore/common/_tensor_cpp_method.py +1 -1
  45. mindspore/common/_tensor_docs.py +2014 -3386
  46. mindspore/common/api.py +386 -355
  47. mindspore/common/auto_dynamic_shape.py +41 -44
  48. mindspore/common/dtype.py +5 -2
  49. mindspore/common/dump.py +7 -5
  50. mindspore/common/file_system.py +3 -0
  51. mindspore/common/generator.py +3 -0
  52. mindspore/common/hook_handle.py +5 -3
  53. mindspore/common/initializer.py +10 -6
  54. mindspore/common/jit_begin_end.py +94 -0
  55. mindspore/common/jit_config.py +6 -1
  56. mindspore/common/jit_context.py +76 -0
  57. mindspore/common/jit_trace.py +378 -0
  58. mindspore/common/lazy_inline.py +2 -2
  59. mindspore/common/mutable.py +5 -4
  60. mindspore/common/parameter.py +106 -39
  61. mindspore/common/seed.py +2 -2
  62. mindspore/common/sparse_tensor.py +23 -17
  63. mindspore/common/tensor.py +332 -714
  64. mindspore/communication/__init__.py +7 -5
  65. mindspore/communication/_comm_helper.py +47 -2
  66. mindspore/communication/comm_func.py +70 -53
  67. mindspore/communication/management.py +83 -17
  68. mindspore/context.py +228 -571
  69. mindspore/dataset/__init__.py +44 -20
  70. mindspore/dataset/audio/__init__.py +2 -8
  71. mindspore/dataset/audio/transforms.py +3 -17
  72. mindspore/dataset/core/config.py +3 -3
  73. mindspore/dataset/engine/cache_client.py +1 -1
  74. mindspore/dataset/engine/datasets.py +102 -120
  75. mindspore/dataset/engine/datasets_audio.py +22 -22
  76. mindspore/dataset/engine/datasets_standard_format.py +43 -24
  77. mindspore/dataset/engine/datasets_text.py +78 -85
  78. mindspore/dataset/engine/datasets_user_defined.py +109 -77
  79. mindspore/dataset/engine/datasets_vision.py +111 -108
  80. mindspore/dataset/engine/iterators.py +5 -3
  81. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
  82. mindspore/dataset/engine/samplers.py +279 -57
  83. mindspore/dataset/engine/serializer_deserializer.py +2 -1
  84. mindspore/dataset/engine/validators.py +10 -0
  85. mindspore/dataset/text/__init__.py +7 -6
  86. mindspore/dataset/text/transforms.py +6 -5
  87. mindspore/dataset/text/utils.py +3 -3
  88. mindspore/dataset/transforms/__init__.py +0 -9
  89. mindspore/dataset/transforms/transforms.py +3 -3
  90. mindspore/dataset/utils/browse_dataset.py +1 -1
  91. mindspore/dataset/vision/__init__.py +2 -9
  92. mindspore/dataset/vision/transforms.py +202 -158
  93. mindspore/dataset/vision/utils.py +7 -5
  94. mindspore/device_context/ascend/op_debug.py +60 -1
  95. mindspore/device_context/ascend/op_tuning.py +0 -4
  96. mindspore/device_manager.py +39 -3
  97. mindspore/dnnl.dll +0 -0
  98. mindspore/dpcmi.dll +0 -0
  99. mindspore/experimental/es/embedding_service.py +35 -27
  100. mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +0 -2
  101. mindspore/experimental/map_parameter.py +4 -4
  102. mindspore/experimental/optim/adadelta.py +22 -26
  103. mindspore/experimental/optim/adagrad.py +4 -4
  104. mindspore/experimental/optim/adam.py +4 -0
  105. mindspore/experimental/optim/adamax.py +4 -4
  106. mindspore/experimental/optim/adamw.py +4 -0
  107. mindspore/experimental/optim/asgd.py +1 -1
  108. mindspore/experimental/optim/lr_scheduler.py +40 -22
  109. mindspore/experimental/optim/radam.py +5 -5
  110. mindspore/experimental/optim/rprop.py +1 -1
  111. mindspore/experimental/optim/sgd.py +1 -1
  112. mindspore/hal/contiguous_tensors_handle.py +6 -10
  113. mindspore/hal/device.py +55 -81
  114. mindspore/hal/event.py +38 -55
  115. mindspore/hal/memory.py +115 -147
  116. mindspore/hal/stream.py +81 -125
  117. mindspore/include/dataset/constants.h +7 -4
  118. mindspore/include/dataset/execute.h +2 -2
  119. mindspore/jpeg62.dll +0 -0
  120. mindspore/log.py +40 -2
  121. mindspore/mindrecord/__init__.py +20 -7
  122. mindspore/mindspore_backend_common.dll +0 -0
  123. mindspore/mindspore_backend_manager.dll +0 -0
  124. mindspore/mindspore_common.dll +0 -0
  125. mindspore/mindspore_core.dll +0 -0
  126. mindspore/mindspore_dump.dll +0 -0
  127. mindspore/mindspore_frontend.dll +0 -0
  128. mindspore/mindspore_glog.dll +0 -0
  129. mindspore/mindspore_memory_pool.dll +0 -0
  130. mindspore/mindspore_ms_backend.dll +0 -0
  131. mindspore/mindspore_ops.dll +0 -0
  132. mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
  133. mindspore/mindspore_ops_kernel_common.dll +0 -0
  134. mindspore/mindspore_profiler.dll +0 -0
  135. mindspore/mindspore_pyboost.dll +0 -0
  136. mindspore/mindspore_pynative.dll +0 -0
  137. mindspore/mindspore_res_manager.dll +0 -0
  138. mindspore/mindspore_runtime_pipeline.dll +0 -0
  139. mindspore/mint/__init__.py +133 -702
  140. mindspore/mint/distributed/__init__.py +5 -1
  141. mindspore/mint/distributed/distributed.py +198 -113
  142. mindspore/mint/linalg/__init__.py +2 -0
  143. mindspore/mint/nn/__init__.py +280 -18
  144. mindspore/mint/nn/functional.py +282 -64
  145. mindspore/mint/nn/layer/__init__.py +4 -0
  146. mindspore/mint/nn/layer/_functions.py +7 -3
  147. mindspore/mint/nn/layer/activation.py +120 -13
  148. mindspore/mint/nn/layer/conv.py +234 -28
  149. mindspore/mint/nn/layer/normalization.py +15 -16
  150. mindspore/mint/nn/layer/padding.py +1 -1
  151. mindspore/mint/nn/layer/pooling.py +66 -1
  152. mindspore/mint/optim/__init__.py +2 -1
  153. mindspore/mint/optim/sgd.py +171 -0
  154. mindspore/msobj140.dll +0 -0
  155. mindspore/mspdb140.dll +0 -0
  156. mindspore/mspdbcore.dll +0 -0
  157. mindspore/mspdbst.dll +0 -0
  158. mindspore/mspft140.dll +0 -0
  159. mindspore/msvcdis140.dll +0 -0
  160. mindspore/msvcp140_1.dll +0 -0
  161. mindspore/msvcp140_2.dll +0 -0
  162. mindspore/msvcp140_atomic_wait.dll +0 -0
  163. mindspore/msvcp140_codecvt_ids.dll +0 -0
  164. mindspore/nn/__init__.py +4 -1
  165. mindspore/nn/cell.py +1253 -179
  166. mindspore/nn/layer/activation.py +23 -21
  167. mindspore/nn/layer/basic.py +22 -16
  168. mindspore/nn/layer/container.py +1 -1
  169. mindspore/nn/layer/conv.py +53 -42
  170. mindspore/nn/layer/embedding.py +9 -8
  171. mindspore/nn/layer/normalization.py +48 -42
  172. mindspore/nn/layer/pooling.py +75 -31
  173. mindspore/nn/layer/transformer.py +11 -10
  174. mindspore/nn/learning_rate_schedule.py +4 -2
  175. mindspore/nn/loss/loss.py +27 -19
  176. mindspore/nn/optim/ada_grad.py +6 -5
  177. mindspore/nn/optim/adadelta.py +9 -7
  178. mindspore/nn/optim/adafactor.py +1 -1
  179. mindspore/nn/optim/adam.py +18 -14
  180. mindspore/nn/optim/adamax.py +8 -7
  181. mindspore/nn/optim/adasum.py +5 -5
  182. mindspore/nn/optim/asgd.py +3 -1
  183. mindspore/nn/optim/ftrl.py +11 -9
  184. mindspore/nn/optim/lamb.py +1 -1
  185. mindspore/nn/optim/lazyadam.py +12 -10
  186. mindspore/nn/optim/momentum.py +7 -6
  187. mindspore/nn/optim/optimizer.py +2 -2
  188. mindspore/nn/optim/proximal_ada_grad.py +12 -10
  189. mindspore/nn/optim/rmsprop.py +13 -12
  190. mindspore/nn/optim/rprop.py +9 -7
  191. mindspore/nn/optim/sgd.py +9 -6
  192. mindspore/nn/optim/tft_wrapper.py +5 -2
  193. mindspore/nn/probability/bijector/bijector.py +17 -11
  194. mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
  195. mindspore/nn/probability/bijector/invert.py +2 -2
  196. mindspore/nn/probability/bijector/scalar_affine.py +3 -3
  197. mindspore/nn/probability/bijector/softplus.py +3 -2
  198. mindspore/nn/probability/distribution/beta.py +3 -3
  199. mindspore/nn/probability/distribution/categorical.py +1 -1
  200. mindspore/nn/probability/distribution/cauchy.py +4 -2
  201. mindspore/nn/probability/distribution/exponential.py +6 -7
  202. mindspore/nn/probability/distribution/gamma.py +2 -2
  203. mindspore/nn/probability/distribution/gumbel.py +2 -2
  204. mindspore/nn/probability/distribution/half_normal.py +5 -3
  205. mindspore/nn/probability/distribution/logistic.py +5 -3
  206. mindspore/nn/probability/distribution/poisson.py +1 -1
  207. mindspore/nn/probability/distribution/uniform.py +5 -3
  208. mindspore/nn/reinforcement/_tensors_queue.py +1 -1
  209. mindspore/nn/reinforcement/tensor_array.py +1 -1
  210. mindspore/nn/wrap/__init__.py +6 -6
  211. mindspore/nn/wrap/cell_wrapper.py +178 -117
  212. mindspore/nn/wrap/grad_reducer.py +45 -36
  213. mindspore/nn/wrap/loss_scale.py +3 -3
  214. mindspore/numpy/array_creations.py +3 -3
  215. mindspore/numpy/array_ops.py +1 -1
  216. mindspore/numpy/utils.py +1 -2
  217. mindspore/numpy/utils_const.py +1 -2
  218. mindspore/opencv_core452.dll +0 -0
  219. mindspore/opencv_imgcodecs452.dll +0 -0
  220. mindspore/opencv_imgproc452.dll +0 -0
  221. mindspore/ops/__init__.py +3 -2
  222. mindspore/ops/_grad_experimental/grad_comm_ops.py +18 -3
  223. mindspore/ops/_grad_experimental/grad_debug_ops.py +8 -1
  224. mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
  225. mindspore/ops/_register_for_op.py +0 -11
  226. mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
  227. mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -4
  228. mindspore/ops/_vmap/vmap_array_ops.py +32 -6
  229. mindspore/ops/_vmap/vmap_grad_nn_ops.py +2 -1
  230. mindspore/ops/_vmap/vmap_math_ops.py +4 -7
  231. mindspore/ops/_vmap/vmap_nn_ops.py +9 -8
  232. mindspore/ops/auto_generate/__init__.py +4 -3
  233. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +127 -52
  234. mindspore/ops/auto_generate/gen_extend_func.py +286 -208
  235. mindspore/ops/auto_generate/gen_ops_def.py +2783 -2335
  236. mindspore/ops/auto_generate/gen_ops_prim.py +8992 -2686
  237. mindspore/ops/auto_generate/pyboost_inner_prim.py +106 -76
  238. mindspore/ops/composite/__init__.py +2 -1
  239. mindspore/ops/composite/base.py +19 -24
  240. mindspore/ops/composite/math_ops.py +6 -16
  241. mindspore/ops/composite/multitype_ops/__init__.py +5 -2
  242. mindspore/ops/composite/multitype_ops/_compile_utils.py +4 -5
  243. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
  244. mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
  245. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
  246. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
  247. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
  248. mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
  249. mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
  250. mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
  251. mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
  252. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
  253. mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
  254. mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
  255. mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
  256. mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
  257. mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
  258. mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
  259. mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
  260. mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
  261. mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
  262. mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
  263. mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
  264. mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
  265. mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
  266. mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
  267. mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
  268. mindspore/ops/composite/multitype_ops/pow_impl.py +2 -1
  269. mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
  270. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  271. mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
  272. mindspore/ops/function/__init__.py +28 -2
  273. mindspore/ops/function/_add_attr_func.py +58 -0
  274. mindspore/ops/function/array_func.py +1631 -2347
  275. mindspore/ops/function/clip_func.py +38 -45
  276. mindspore/ops/function/debug_func.py +36 -44
  277. mindspore/ops/function/grad/__init__.py +1 -0
  278. mindspore/ops/function/grad/grad_func.py +104 -71
  279. mindspore/ops/function/image_func.py +1 -1
  280. mindspore/ops/function/linalg_func.py +46 -78
  281. mindspore/ops/function/math_func.py +3024 -3855
  282. mindspore/ops/function/nn_func.py +678 -274
  283. mindspore/ops/function/other_func.py +159 -1
  284. mindspore/ops/function/parameter_func.py +17 -30
  285. mindspore/ops/function/random_func.py +216 -361
  286. mindspore/ops/function/reshard_func.py +4 -70
  287. mindspore/ops/function/sparse_func.py +3 -3
  288. mindspore/ops/function/sparse_unary_func.py +5 -5
  289. mindspore/ops/function/spectral_func.py +25 -58
  290. mindspore/ops/function/vmap_func.py +26 -18
  291. mindspore/ops/functional.py +8 -5
  292. mindspore/ops/functional_overload.py +655 -4
  293. mindspore/ops/op_info_register.py +32 -244
  294. mindspore/ops/operations/__init__.py +21 -14
  295. mindspore/ops/operations/_custom_ops_utils.py +235 -0
  296. mindspore/ops/operations/_grad_ops.py +1 -10
  297. mindspore/ops/operations/_inner_ops.py +5 -76
  298. mindspore/ops/operations/_ms_kernel.py +4 -10
  299. mindspore/ops/operations/_rl_inner_ops.py +1 -1
  300. mindspore/ops/operations/_scalar_ops.py +3 -2
  301. mindspore/ops/operations/_sequence_ops.py +1 -1
  302. mindspore/ops/operations/_tensor_array.py +1 -1
  303. mindspore/ops/operations/array_ops.py +39 -24
  304. mindspore/ops/operations/comm_ops.py +150 -107
  305. mindspore/ops/operations/custom_ops.py +287 -32
  306. mindspore/ops/operations/debug_ops.py +119 -16
  307. mindspore/ops/operations/inner_ops.py +1 -1
  308. mindspore/ops/operations/linalg_ops.py +1 -58
  309. mindspore/ops/operations/manually_defined/_inner.py +1 -1
  310. mindspore/ops/operations/manually_defined/ops_def.py +746 -79
  311. mindspore/ops/operations/math_ops.py +21 -18
  312. mindspore/ops/operations/nn_ops.py +67 -224
  313. mindspore/ops/operations/other_ops.py +62 -9
  314. mindspore/ops/operations/random_ops.py +13 -7
  315. mindspore/ops/operations/reshard_ops.py +1 -1
  316. mindspore/ops/operations/sparse_ops.py +2 -2
  317. mindspore/ops/primitive.py +43 -32
  318. mindspore/ops/tensor_method.py +243 -17
  319. mindspore/ops_generate/__init__.py +0 -5
  320. mindspore/ops_generate/aclnn/__init__.py +0 -0
  321. mindspore/ops_generate/{aclnn_kernel_register_auto_cc_generator.py → aclnn/aclnn_kernel_register_auto_cc_generator.py} +43 -18
  322. mindspore/ops_generate/{gen_aclnn_implement.py → aclnn/gen_aclnn_implement.py} +49 -51
  323. mindspore/ops_generate/api/__init__.py +0 -0
  324. mindspore/ops_generate/{add_tensor_docs_generator.py → api/add_tensor_docs_generator.py} +9 -7
  325. mindspore/ops_generate/{cpp_create_prim_instance_helper_generator.py → api/cpp_create_prim_instance_helper_generator.py} +6 -9
  326. mindspore/ops_generate/{functional_map_cpp_generator.py → api/functional_map_cpp_generator.py} +25 -12
  327. mindspore/ops_generate/{functional_overload_py_generator.py → api/functional_overload_py_generator.py} +8 -6
  328. mindspore/ops_generate/{functions_cc_generator.py → api/functions_cc_generator.py} +14 -10
  329. mindspore/ops_generate/api/gen_api.py +103 -0
  330. mindspore/ops_generate/{op_api_proto.py → api/op_api_proto.py} +98 -69
  331. mindspore/ops_generate/{tensor_func_reg_cpp_generator.py → api/tensor_func_reg_cpp_generator.py} +82 -43
  332. mindspore/ops_generate/common/__init__.py +0 -0
  333. mindspore/ops_generate/common/gen_constants.py +91 -0
  334. mindspore/ops_generate/{gen_utils.py → common/gen_utils.py} +72 -19
  335. mindspore/ops_generate/{op_proto.py → common/op_proto.py} +64 -1
  336. mindspore/ops_generate/{template.py → common/template.py} +96 -84
  337. mindspore/ops_generate/gen_ops.py +23 -325
  338. mindspore/ops_generate/op_def/__init__.py +0 -0
  339. mindspore/ops_generate/op_def/gen_op_def.py +90 -0
  340. mindspore/ops_generate/{lite_ops_cpp_generator.py → op_def/lite_ops_cpp_generator.py} +47 -11
  341. mindspore/ops_generate/{ops_def_cc_generator.py → op_def/ops_def_cc_generator.py} +18 -10
  342. mindspore/ops_generate/{ops_def_h_generator.py → op_def/ops_def_h_generator.py} +5 -5
  343. mindspore/ops_generate/{ops_name_h_generator.py → op_def/ops_name_h_generator.py} +30 -15
  344. mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
  345. mindspore/ops_generate/op_def_py/__init__.py +0 -0
  346. mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
  347. mindspore/ops_generate/{op_def_py_generator.py → op_def_py/op_def_py_generator.py} +6 -5
  348. mindspore/ops_generate/{op_prim_py_generator.py → op_def_py/op_prim_py_generator.py} +24 -15
  349. mindspore/ops_generate/pyboost/__init__.py +0 -0
  350. mindspore/ops_generate/{auto_grad_impl_cc_generator.py → pyboost/auto_grad_impl_cc_generator.py} +11 -7
  351. mindspore/ops_generate/{auto_grad_reg_cc_generator.py → pyboost/auto_grad_reg_cc_generator.py} +7 -7
  352. mindspore/ops_generate/{gen_pyboost_func.py → pyboost/gen_pyboost_func.py} +40 -16
  353. mindspore/ops_generate/{op_template_parser.py → pyboost/op_template_parser.py} +105 -24
  354. mindspore/ops_generate/{pyboost_functions_cpp_generator.py → pyboost/pyboost_functions_cpp_generator.py} +55 -18
  355. mindspore/ops_generate/{pyboost_functions_h_generator.py → pyboost/pyboost_functions_h_generator.py} +42 -10
  356. mindspore/ops_generate/{pyboost_functions_py_generator.py → pyboost/pyboost_functions_py_generator.py} +6 -6
  357. mindspore/ops_generate/{pyboost_grad_function_cpp_generator.py → pyboost/pyboost_grad_function_cpp_generator.py} +11 -10
  358. mindspore/ops_generate/{pyboost_inner_prim_generator.py → pyboost/pyboost_inner_prim_generator.py} +8 -7
  359. mindspore/ops_generate/{pyboost_native_grad_functions_generator.py → pyboost/pyboost_native_grad_functions_generator.py} +14 -10
  360. mindspore/ops_generate/{pyboost_op_cpp_code_generator.py → pyboost/pyboost_op_cpp_code_generator.py} +140 -53
  361. mindspore/ops_generate/{pyboost_overload_functions_cpp_generator.py → pyboost/pyboost_overload_functions_cpp_generator.py} +28 -15
  362. mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +88 -4
  363. mindspore/ops_generate/resources/__init__.py +0 -0
  364. mindspore/ops_generate/resources/resource_list.py +30 -0
  365. mindspore/ops_generate/resources/resource_loader.py +36 -0
  366. mindspore/ops_generate/resources/resource_manager.py +64 -0
  367. mindspore/ops_generate/resources/yaml_loader.py +88 -0
  368. mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
  369. mindspore/parallel/__init__.py +6 -2
  370. mindspore/parallel/_auto_parallel_context.py +140 -12
  371. mindspore/parallel/_cell_wrapper.py +132 -15
  372. mindspore/parallel/_parallel_serialization.py +95 -4
  373. mindspore/parallel/_ps_context.py +1 -1
  374. mindspore/parallel/_recovery_context.py +7 -2
  375. mindspore/parallel/_tensor.py +142 -18
  376. mindspore/parallel/_utils.py +198 -25
  377. mindspore/parallel/algo_parameter_config.py +3 -3
  378. mindspore/parallel/auto_parallel.py +732 -0
  379. mindspore/parallel/checkpoint_convert.py +159 -0
  380. mindspore/parallel/checkpoint_transform.py +658 -37
  381. mindspore/parallel/cluster/process_entity/_api.py +151 -19
  382. mindspore/parallel/cluster/run.py +1 -1
  383. mindspore/parallel/function/__init__.py +24 -0
  384. mindspore/parallel/function/reshard_func.py +258 -0
  385. mindspore/parallel/nn/__init__.py +25 -0
  386. mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
  387. mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
  388. mindspore/parallel/parameter_broadcast.py +24 -13
  389. mindspore/parallel/shard.py +137 -62
  390. mindspore/parallel/transform_safetensors.py +288 -95
  391. mindspore/pgodb140.dll +0 -0
  392. mindspore/pgort140.dll +0 -0
  393. mindspore/profiler/__init__.py +9 -5
  394. mindspore/profiler/analysis/parser/ascend_cann_parser.py +6 -2
  395. mindspore/profiler/analysis/parser/ms_framework_parser.py +4 -4
  396. mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +7 -4
  397. mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +25 -0
  398. mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +3 -3
  399. mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +241 -86
  400. mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +41 -2
  401. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +33 -35
  402. mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +7 -0
  403. mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +8 -3
  404. mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +141 -30
  405. mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +5 -6
  406. mindspore/profiler/common/ascend_msprof_exporter.py +5 -4
  407. mindspore/profiler/common/constant.py +12 -0
  408. mindspore/profiler/common/msprof_cmd_tool.py +42 -23
  409. mindspore/profiler/common/path_manager.py +24 -0
  410. mindspore/profiler/common/profiler_context.py +26 -2
  411. mindspore/profiler/common/profiler_meta_data.py +74 -0
  412. mindspore/profiler/common/profiler_parameters.py +59 -18
  413. mindspore/profiler/common/profiler_path_manager.py +66 -7
  414. mindspore/profiler/dynamic_profiler.py +112 -79
  415. mindspore/profiler/envprofiler.py +26 -1
  416. mindspore/profiler/experimental_config.py +197 -0
  417. mindspore/profiler/mstx.py +57 -14
  418. mindspore/profiler/platform/npu_profiler.py +33 -7
  419. mindspore/profiler/profiler.py +541 -45
  420. mindspore/profiler/profiler_action_controller.py +1 -1
  421. mindspore/profiler/profiler_interface.py +4 -0
  422. mindspore/profiler/schedule.py +57 -22
  423. mindspore/rewrite/api/node.py +15 -13
  424. mindspore/rewrite/api/symbol_tree.py +1 -1
  425. mindspore/run_check/_check_version.py +25 -14
  426. mindspore/run_check/run_check.py +1 -1
  427. mindspore/runtime/__init__.py +2 -2
  428. mindspore/runtime/executor.py +40 -11
  429. mindspore/runtime/memory.py +37 -13
  430. mindspore/safeguard/rewrite_obfuscation.py +12 -9
  431. mindspore/swresample-4.dll +0 -0
  432. mindspore/swscale-6.dll +0 -0
  433. mindspore/tbbmalloc.dll +0 -0
  434. mindspore/tinyxml2.dll +0 -0
  435. mindspore/train/__init__.py +8 -8
  436. mindspore/train/_utils.py +43 -9
  437. mindspore/train/amp.py +1 -1
  438. mindspore/train/callback/__init__.py +2 -2
  439. mindspore/train/callback/_callback.py +2 -16
  440. mindspore/train/callback/_checkpoint.py +24 -40
  441. mindspore/train/callback/_cluster_monitor.py +14 -18
  442. mindspore/train/callback/_flops_collector.py +2 -3
  443. mindspore/train/callback/_history.py +7 -4
  444. mindspore/train/callback/_lambda_callback.py +2 -2
  445. mindspore/train/callback/_landscape.py +0 -3
  446. mindspore/train/callback/_loss_monitor.py +2 -1
  447. mindspore/train/callback/_on_request_exit.py +6 -5
  448. mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
  449. mindspore/train/callback/_summary_collector.py +8 -13
  450. mindspore/train/callback/_time_monitor.py +2 -1
  451. mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +204 -105
  452. mindspore/train/data_sink.py +25 -2
  453. mindspore/train/dataset_helper.py +4 -5
  454. mindspore/train/loss_scale_manager.py +8 -7
  455. mindspore/train/metrics/accuracy.py +3 -3
  456. mindspore/train/metrics/confusion_matrix.py +9 -9
  457. mindspore/train/metrics/error.py +3 -3
  458. mindspore/train/metrics/hausdorff_distance.py +4 -4
  459. mindspore/train/metrics/mean_surface_distance.py +3 -3
  460. mindspore/train/metrics/metric.py +0 -12
  461. mindspore/train/metrics/occlusion_sensitivity.py +4 -2
  462. mindspore/train/metrics/precision.py +8 -6
  463. mindspore/train/metrics/recall.py +9 -9
  464. mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
  465. mindspore/train/mind_ir_pb2.py +19 -12
  466. mindspore/train/model.py +262 -127
  467. mindspore/train/serialization.py +246 -988
  468. mindspore/train/summary/_summary_adapter.py +2 -2
  469. mindspore/train/summary/summary_record.py +1 -1
  470. mindspore/turbojpeg.dll +0 -0
  471. mindspore/utils/__init__.py +3 -2
  472. mindspore/utils/dryrun.py +4 -2
  473. mindspore/utils/hooks.py +81 -0
  474. mindspore/utils/runtime_execution_order_check.py +2 -0
  475. mindspore/utils/utils.py +138 -4
  476. mindspore/vcmeta.dll +0 -0
  477. mindspore/vcruntime140.dll +0 -0
  478. mindspore/vcruntime140_1.dll +0 -0
  479. mindspore/version.py +1 -1
  480. {mindspore-2.5.0.dist-info → mindspore-2.6.0.dist-info}/METADATA +2 -1
  481. {mindspore-2.5.0.dist-info → mindspore-2.6.0.dist-info}/RECORD +485 -440
  482. mindspore/_install_custom.py +0 -43
  483. mindspore/common/_register_for_adapter.py +0 -74
  484. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
  485. mindspore/ops/auto_generate/gen_arg_handler.py +0 -136
  486. mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
  487. mindspore/ops_generate/gen_constants.py +0 -190
  488. mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
  489. mindspore/ops_generate/ops_primitive_h_generator.py +0 -81
  490. /mindspore/ops_generate/{base_generator.py → common/base_generator.py} +0 -0
  491. {mindspore-2.5.0.dist-info → mindspore-2.6.0.dist-info}/WHEEL +0 -0
  492. {mindspore-2.5.0.dist-info → mindspore-2.6.0.dist-info}/entry_points.txt +0 -0
  493. {mindspore-2.5.0.dist-info → mindspore-2.6.0.dist-info}/top_level.txt +0 -0
mindspore/common/api.py CHANGED
@@ -17,6 +17,7 @@
17
17
  """Providing interface methods."""
18
18
  from __future__ import absolute_import
19
19
 
20
+ import gc
20
21
  import types
21
22
  import sys
22
23
  import os
@@ -24,11 +25,11 @@ import time
24
25
  import ast
25
26
  import inspect
26
27
  import importlib
27
- import hashlib
28
28
  import contextlib
29
+ import json
29
30
  from collections import OrderedDict, namedtuple
30
31
  from functools import wraps
31
- import numpy as np
32
+ from typing import Optional, Callable
32
33
  import mindspore as ms
33
34
  from mindspore import context
34
35
  from mindspore import log as logger
@@ -39,21 +40,23 @@ from mindspore.common.sparse_tensor import CSRTensor as PythonCSRTensor
39
40
  from mindspore.common.sparse_tensor import COOTensor as PythonCOOTensor
40
41
  from mindspore.common.sparse_tensor import RowTensor as PythonRowTensor
41
42
  from mindspore._c_expression.amp import get_curr_amp_strategy
42
- from mindspore._c_expression import GraphExecutor_, Tensor, CSRTensor, RowTensor, COOTensor, \
43
+ from mindspore._c_expression import GraphExecutor_, JitExecutor_, CSRTensor, RowTensor, COOTensor, \
43
44
  PyNativeExecutor_, verify_inputs_signature, init_exec_dataset, _set_dataset_mode_config, init_pipeline, \
44
- _ms_memory_recycle, _bind_device_ctx, StubNode
45
+ _run_jit_pipeline, _ms_memory_recycle, _bind_device_ctx, StubNode, MSContext, TensorPy as Tensor
45
46
  from mindspore.parallel._ps_context import _is_role_sched
46
- from mindspore.parallel._utils import _check_full_batch, _get_parameter_broadcast, _is_pynative_parallel, \
47
- _is_in_auto_parallel_mode, _is_parallel_mode
47
+ from mindspore.parallel._utils import _check_full_batch, _get_parameter_broadcast, _is_in_auto_parallel_mode, \
48
+ _is_parallel_mode
48
49
  from mindspore import _checkparam as Validator
49
50
  from mindspore._checkparam import is_stub_tensor
50
51
  from mindspore.common._utils import is_shape_unknown
51
52
  from mindspore.common.mutable import mutable, _check_element_type
52
- from mindspore.common._register_for_adapter import ms_adapter_registry
53
53
  from mindspore.common.auto_dynamic_shape import get_auto_dynamic_shape_args, update_auto_dynamic_shape_phase, \
54
54
  get_auto_dynamic_shape_args_with_check_input_signature, update_auto_dynamic_shape_phase_with_check_input_signature
55
55
  from mindspore.common._pijit_context import PIJitCaptureContext
56
56
  from mindspore.common.parameter import Parameter, set_parameter_hook_updated, parameter_hook_updated
57
+ from mindspore.common.jit_context import jit_context
58
+ from mindspore.common.jit_trace import _jit_trace
59
+ from mindspore.parallel._utils import _init_auto_parallel_context, _clear_auto_parallel_context
57
60
 
58
61
  # Store ms_function class compiled pipeline cache.
59
62
  ms_compile_cache = set()
@@ -107,8 +110,7 @@ def _check_recompile(obj, compile_args, kwargs, full_function_name, create_time,
107
110
  logger.info(f"The {echo_function_name} has been compiled again. "
108
111
  f"{tips} ")
109
112
  else:
110
- tips = "Try to decorate the function with @jit(hash_args=...) " \
111
- "or @jit(compile_once=True) to reduce the compile time. " \
113
+ tips = "Try to reuse the function object decorated by @jit to reduce the compile time. " \
112
114
  "For more details, get instructions about `jit` at " \
113
115
  "https://www.mindspore.cn/search?inputValue=jit."
114
116
  logger.warning(f"The {echo_function_name} has been compiled again. "
@@ -120,14 +122,6 @@ def _check_recompile(obj, compile_args, kwargs, full_function_name, create_time,
120
122
  function_phases[full_function_name].add(create_time)
121
123
 
122
124
 
123
- def _ms_adapter_tensor_as_parameter_output(data):
124
- """Check whether the data is an output from a parameter which is a ms_adapter tensor.
125
- Pylint: disable=unidiomatic-typecheck.
126
- """
127
- return ms_adapter_registry.is_registered and isinstance(data, ms_adapter_registry.tensor) \
128
- and hasattr(data, "__ms_parameter_output__") and getattr(data, "__ms_parameter_output__")
129
-
130
-
131
125
  def _convert_python_data(data):
132
126
  """
133
127
  Convert C++ data to python.
@@ -138,18 +132,8 @@ def _convert_python_data(data):
138
132
  Returns:
139
133
  data, a data convert C++ to python
140
134
  """
141
- if isinstance(data, (Tensor, PythonTensor)) and data.adapter_flag:
142
- return ms_adapter_registry.tensor(data)
143
- if _ms_adapter_tensor_as_parameter_output(data) and hasattr(data, "tensor"):
144
- return data.tensor
145
- if isinstance(data, Tensor) and not isinstance(data, PythonTensor):
146
- return PythonTensor(data, internal=True)
147
- if isinstance(data, CSRTensor) and not isinstance(data, PythonCSRTensor):
148
- return PythonCSRTensor(csr_tensor=data)
149
- if isinstance(data, COOTensor) and not isinstance(data, PythonCOOTensor):
150
- return PythonCOOTensor(coo_tensor=data)
151
- if isinstance(data, RowTensor) and not isinstance(data, PythonRowTensor):
152
- return PythonRowTensor(row_tensor=data)
135
+ if isinstance(data, PythonTensor):
136
+ return data
153
137
  if isinstance(data, StubNode):
154
138
  return ms.common._stub_tensor._convert_stub(data)
155
139
  if data.__class__ is tuple:
@@ -160,6 +144,12 @@ def _convert_python_data(data):
160
144
  fields = data_dict.keys()
161
145
  return namedtuple(type_name, fields)(**_convert_python_data(data_dict))
162
146
  return tuple(_convert_python_data(x) for x in data)
147
+ if isinstance(data, CSRTensor) and not isinstance(data, PythonCSRTensor):
148
+ return PythonCSRTensor(csr_tensor=data)
149
+ if isinstance(data, COOTensor) and not isinstance(data, PythonCOOTensor):
150
+ return PythonCOOTensor(coo_tensor=data)
151
+ if isinstance(data, RowTensor) and not isinstance(data, PythonRowTensor):
152
+ return PythonRowTensor(row_tensor=data)
163
153
  if data.__class__ is list:
164
154
  # Keep list object not change for inplace operation.
165
155
  for i in range(len(data)):
@@ -578,11 +568,11 @@ def _get_hook_key(*args, **kwargs):
578
568
  return hook_key
579
569
 
580
570
 
581
- class _MindsporeFunctionExecutor:
571
+ class _JitExecutor:
582
572
  """
583
573
  Represents a function compiled by graph compiler.
584
574
 
585
- _MindsporeFunctionExecutor will compile the original function for every combination
575
+ _JitExecutor will compile the original function for every combination
586
576
  of argument types and shapes it is given (as well as their values, optionally).
587
577
 
588
578
  Args:
@@ -596,7 +586,7 @@ class _MindsporeFunctionExecutor:
596
586
  The result of pipeline running in graph mode.
597
587
  """
598
588
 
599
- def __init__(self, fn, ms_create_time, input_signature=None, obj=None, jit_config=None):
589
+ def __init__(self, fn, ms_create_time, input_signature=None, obj=None, jit_config=None, dynamic=0):
600
590
  init_pipeline()
601
591
  if not isinstance(fn, (types.FunctionType, types.MethodType)):
602
592
  raise RuntimeError('fn {} is not function or method'.format(fn))
@@ -608,13 +598,61 @@ class _MindsporeFunctionExecutor:
608
598
  self.obj = obj
609
599
  self.shard_parent_obj = obj
610
600
  self.enable_tuple_broaden = False
611
- self._graph_executor = GraphExecutor_.get_instance()
601
+ if _run_jit_pipeline():
602
+ self._graph_executor = JitExecutor_.get_instance()
603
+ else:
604
+ self._graph_executor = GraphExecutor_.get_instance()
612
605
  self._create_time = ms_create_time
613
606
  self._compile_args = None
607
+ self._enable_auto_dynamic = dynamic == 1
614
608
  self.jit_config_dict = jit_config.jit_config_dict if jit_config else None
615
609
 
610
+ def _predict(self, *args, **kwargs):
611
+ """Dedicated routine for predict."""
612
+ if not hasattr(self.obj, "phase"):
613
+ return False, None
614
+
615
+ predict_vailid_phase = {"prefill", 'increment'}
616
+ predict_phase = self.obj.phase
617
+ if predict_phase not in predict_vailid_phase:
618
+ return False, None
619
+
620
+ args_list = args
621
+ if self.obj is not None:
622
+ args_list = args_list[1:]
623
+
624
+ if predict_phase not in self.obj.phase_cache:
625
+ try:
626
+ predict_phase = self.compile(self.fn.__name__, *args_list, **kwargs)
627
+ except Exception as err:
628
+ _pynative_executor.clear_res()
629
+ raise err
630
+ else: # get compiled args to generate run args by _generate_run_args
631
+ compile_args = self._generate_compile_args(args_list)
632
+ key_id = self._get_key_id()
633
+ compile_args = get_auto_dynamic_shape_args_with_check_input_signature(
634
+ compile_args,
635
+ key_id,
636
+ self.input_signature,
637
+ self._enable_auto_dynamic
638
+ )
639
+ self._compile_args = compile_args
640
+
641
+ new_inputs = self._generate_run_args(args_list, kwargs)
642
+ output = self._graph_executor(
643
+ tuple(new_inputs),
644
+ self.obj.phase_cache[self.obj.phase]
645
+ )
646
+ res = _convert_python_data(output)
647
+ return True, res
648
+
616
649
  @_wrap_func
617
650
  def __call__(self, *args, **kwargs):
651
+ predict, res = self._predict(*args, **kwargs)
652
+ if predict:
653
+ return res
654
+ if jit_context() and jit_context().is_nested():
655
+ return jit_context().run_graph("", None, *())
618
656
  args_list = args
619
657
  if self.obj is not None:
620
658
  args_list = args_list[1:]
@@ -634,10 +672,14 @@ class _MindsporeFunctionExecutor:
634
672
  return None
635
673
 
636
674
  new_inputs = self._generate_run_args(args_list, kwargs)
637
- if context.get_context("mode") == context.PYNATIVE_MODE:
675
+ if context.get_context("mode") == context.PYNATIVE_MODE and not jit_context():
638
676
  output = _pynative_executor.grad_jit(*new_inputs)
639
677
  else:
640
678
  output = self._graph_executor(tuple(new_inputs), phase)
679
+ if jit_context():
680
+ if is_stub_tensor(output):
681
+ output = output.stub_sync()
682
+ return jit_context().run_graph(phase, output, *tuple(new_inputs))
641
683
 
642
684
  return output
643
685
 
@@ -653,7 +695,8 @@ class _MindsporeFunctionExecutor:
653
695
  compile_args = self._generate_compile_args(args)
654
696
  key_id = self._get_key_id()
655
697
  compile_args = get_auto_dynamic_shape_args_with_check_input_signature(compile_args, key_id,
656
- self.input_signature)
698
+ self.input_signature,
699
+ self._enable_auto_dynamic)
657
700
 
658
701
  # Add mutable for compile_args for two scene:
659
702
  # 1) Origin args is mutable.
@@ -673,7 +716,7 @@ class _MindsporeFunctionExecutor:
673
716
  f'`{self.fn.__module__}`')
674
717
  self.obj.__parse_method__ = method_name
675
718
  if isinstance(self.obj, ms.nn.Cell):
676
- generate_name = generate_name + '.' + str(self.obj.create_time)
719
+ generate_name = generate_name + '.' + str(self.obj.create_time) + self.obj.phase
677
720
  create_time = str(self.obj.create_time)
678
721
  else:
679
722
  generate_name = generate_name + '.' + str(self._create_time)
@@ -704,7 +747,7 @@ class _MindsporeFunctionExecutor:
704
747
 
705
748
  update_auto_dynamic_shape_phase_with_check_input_signature(compile_args, key_id, phase, self.input_signature)
706
749
 
707
- if phase in ms_compile_cache and not parameter_hook_updated():
750
+ if phase in ms_compile_cache and self._graph_executor.has_compiled(phase) and not parameter_hook_updated():
708
751
  # Release resource should be released when CompileInner won't be executed, such as cur_convert_input_
709
752
  # generated in generate_arguments_key.
710
753
  self._graph_executor.clear_compile_arguments_resource()
@@ -726,7 +769,7 @@ class _MindsporeFunctionExecutor:
726
769
  setattr(self.fn.__func__, "__jit_function__", True)
727
770
  else:
728
771
  setattr(self.fn, "__jit_function__", True)
729
- is_compile = self._graph_executor.compile(self.fn, compile_args, kwargs, phase, True)
772
+ is_compile = self._graph_executor.compile(self.fn, compile_args, kwargs, phase)
730
773
  if isinstance(self.fn, types.MethodType):
731
774
  delattr(self.fn.__func__, "__jit_function__")
732
775
  else:
@@ -734,12 +777,14 @@ class _MindsporeFunctionExecutor:
734
777
  else:
735
778
  if isinstance(self.obj, ms.nn.Cell):
736
779
  self._graph_executor.set_weights_values(self.obj.parameters_dict())
737
- is_compile = self._graph_executor.compile(self.obj, compile_args, kwargs, phase, True)
780
+ is_compile = self._graph_executor.compile(self.obj, compile_args, kwargs, phase)
738
781
 
739
782
  if not is_compile:
740
783
  raise RuntimeError("Executor compile failed.")
741
784
  set_parameter_hook_updated(False)
742
785
  ms_compile_cache.add(phase)
786
+ if hasattr(self.obj, "phase"):
787
+ self.obj.phase_cache[self.obj.phase] = phase
743
788
 
744
789
  return phase
745
790
 
@@ -760,7 +805,7 @@ class _MindsporeFunctionExecutor:
760
805
  else:
761
806
  key_id = str(id(self.obj)) + str(self._create_time)
762
807
 
763
- if _pynative_executor.grad_flag():
808
+ if _pynative_executor.requires_grad():
764
809
  key_id = key_id + ".grad"
765
810
  return key_id
766
811
 
@@ -770,9 +815,9 @@ class _MindsporeFunctionExecutor:
770
815
  self.fn.__code__.co_firstlineno)
771
816
  echo_function_name = "function \"" + self.fn.__name__ + "\" at the file \"" + self.fn.__code__.co_filename \
772
817
  + "\", line " + str(self.fn.__code__.co_firstlineno)
773
- if _pynative_executor.grad_flag():
818
+ if _pynative_executor.requires_grad():
774
819
  generate_name = generate_name + ".grad"
775
- if _is_pynative_parallel():
820
+ if self.fn.__name__ == _PYNATIVE_PARALLEL_FUNC_NAME:
776
821
  generate_name = generate_name[:generate_name.rfind(str(id(self.fn)))] + str(id(self.shard_parent_obj))
777
822
  return generate_name, echo_function_name
778
823
 
@@ -833,6 +878,14 @@ class _MindsporeFunctionExecutor:
833
878
  """
834
879
  return _get_args_for_run(self, args_list, kwargs, self._compile_args)
835
880
 
881
+ def _get_func_graph_proto(self, obj, exec_id, ir_type="onnx_ir", use_prefix=False, incremental=False):
882
+ """Get graph proto from pipeline."""
883
+ if use_prefix:
884
+ exec_id = exec_id + '.' + obj.arguments_key
885
+ if self._graph_executor.has_compiled(exec_id) is False:
886
+ return None
887
+ return self._graph_executor.get_func_graph_proto(exec_id, ir_type, incremental)
888
+
836
889
 
837
890
  # The attributes used to identify a given object.
838
891
  attr_op = {"__str__": lambda x: x.__str__(),
@@ -845,6 +898,13 @@ attr_op = {"__str__": lambda x: x.__str__(),
845
898
  }
846
899
 
847
900
 
901
+ def _is_inner_func(func):
902
+ """Check whether the func is an inner func which needs hash_args parameter."""
903
+ # This is a workaround for inner api, should fix it later.
904
+ inner_func = ["after_shard", "_wrap_container"]
905
+ return func.__name__ in inner_func
906
+
907
+
848
908
  def _get_obj_id(input_obj):
849
909
  """Get hash id of single object."""
850
910
  obj_id = ".".join(
@@ -859,50 +919,227 @@ def _get_jit_hash(hash_input):
859
919
  return _get_obj_id(hash_input)
860
920
 
861
921
 
862
- def jit(fn=None, mode="PSJit", input_signature=None, hash_args=None, jit_config=None, compile_once=False):
922
+ def _get_hash_obj(options):
923
+ hash_obj = None
924
+ if "hash_args" in options:
925
+ hash_obj = _get_jit_hash(options["hash_args"])
926
+ del options["hash_args"]
927
+ return hash_obj
928
+
929
+
930
+ def _check_option_device(option, device):
931
+ """Check jit options wiwh device"""
932
+ option_device_cfgs = {
933
+ 'disable_format_transform': ['GPU'],
934
+ 'exec_order': ['Ascend'],
935
+ 'ge_options': ['Ascend'],
936
+ 'infer_boost': ['Ascend'],
937
+ }
938
+ if option in option_device_cfgs and device not in option_device_cfgs[option]:
939
+ logger.warning(f"For 'jit(options)', the option '{option}' is only support device in "
940
+ f"'{option_device_cfgs[option]}', but got '{device}', ignore it.")
941
+
942
+
943
+ def _check_option_backend(option, backend):
944
+ """Check jit options wiwh backend"""
945
+ option_backend_cfgs = {
946
+ 'disable_format_transform': ['ms_backend'],
947
+ 'exec_order': ['ms_backend'],
948
+ 'ge_options': ['GE'],
949
+ 'infer_boost': ['ms_backend'],
950
+ }
951
+ if option in option_backend_cfgs and backend not in option_backend_cfgs[option]:
952
+ logger.warning(f"For 'jit(options)', the option '{option}' is only support backend in "
953
+ f"'{option_backend_cfgs[option]}', but got '{backend}', ignore it.")
954
+
955
+
956
+ def _check_disable_format_transform_value(option, disable_format_transform):
957
+ """check disable_format_transform option value"""
958
+ if not isinstance(disable_format_transform, bool):
959
+ raise TypeError(f"For 'jit(options)', the type of '{option}' must be bool, "
960
+ f"but got {type(disable_format_transform)}.")
961
+
962
+
963
+ def _check_exec_order_value(option, exec_order):
964
+ """check exec_order option value"""
965
+ if not isinstance(exec_order, str):
966
+ raise TypeError(f"For 'jit(options)', the type of '{option}' must be str, but got {type(exec_order)}.")
967
+
968
+ if exec_order not in ['bfs', 'dfs']:
969
+ raise ValueError(f"For '{option}', the value of '{option}' must be one of "
970
+ f"['bfs', 'dfs'], but got '{exec_order}'.")
971
+
972
+
973
+ def _check_ge_options_value(option, ge_options):
974
+ """check ge_options option value"""
975
+ if not isinstance(ge_options, dict):
976
+ raise TypeError(f"For 'jit(options)', the type of '{option}' must be dict, but got {type(ge_options)}.")
977
+
978
+ for level, options in ge_options.items():
979
+ if level not in ['global', 'session']:
980
+ raise ValueError(f"For '{option}', the key of '{option}' must be one of "
981
+ f"['global', 'session'], but got '{level}'.")
982
+
983
+ if not isinstance(options, dict):
984
+ raise TypeError(f"For '{option}', the type of {level} options must be dict, "
985
+ f"but got {type(options)}. The error options: {options}.")
986
+
987
+ for key, value in options.items():
988
+ if not isinstance(key, str):
989
+ raise TypeError(f"For '{option}', the type of key and value must be str, "
990
+ f"but got {type(key)}. The error key is {key}.")
991
+ if not isinstance(value, str):
992
+ raise TypeError(f"For '{option}', the type of key and value must be str, "
993
+ f"but got {type(value)}. The error value is {value}")
994
+
995
+
996
+ def _check_infer_boost_value(option, value):
997
+ """check infer_boost option value"""
998
+ if not isinstance(value, str):
999
+ raise TypeError(f"For 'jit(options)', the type of '{option}' must be str, but got {type(value)}.")
1000
+
1001
+ if value not in ['on', 'off']:
1002
+ raise ValueError(f"For '{option}', the value of '{option}' must be one of ['on', 'off'], but got '{value}'.")
1003
+
1004
+
1005
+ def _check_option_value(option, value):
1006
+ """check jit options wiwh value"""
1007
+ option_valuecheck_funcs = {
1008
+ 'disable_format_transform': _check_disable_format_transform_value,
1009
+ 'exec_order': _check_exec_order_value,
1010
+ 'ge_options': _check_ge_options_value,
1011
+ 'infer_boost': _check_infer_boost_value,
1012
+ }
1013
+ if option in option_valuecheck_funcs:
1014
+ option_valuecheck_funcs[option](option, value)
1015
+ else:
1016
+ logger.warning(f"For 'jit(options)', the option argument '{option}' is not recognized, please check!"
1017
+ f"For detailed usage of 'jit(options)', please refer to the Mindspore official website.")
1018
+
1019
+
1020
+ def _check_options(options, backend):
1021
+ """Check jit options"""
1022
+ # check whether there are deprecated parameters in the dict `options`.
1023
+ deprecated_args = {'mode': 'capture_mode', 'input_signature': 'dynamic', 'hash_args: ': '',
1024
+ 'jit_config': 'jit_level, fullgraph or options', 'compile_once': ''}
1025
+ for key, value in deprecated_args.items():
1026
+ if key in options:
1027
+ log = f"For 'jit', the parameter '{key}' has been deprecated."
1028
+ if value != '':
1029
+ log += f" Please use the parameter '{value}' instead. For more details, please refer to " \
1030
+ f"https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.jit.html."
1031
+ logger.warning(log)
1032
+ del options[key]
1033
+
1034
+ # check options' device, backend and value
1035
+ for option, value in options.items():
1036
+ _check_option_backend(option, backend)
1037
+ _check_option_value(option, value)
1038
+
1039
+
1040
+ def jit(
1041
+ function: Optional[Callable] = None,
1042
+ *,
1043
+ capture_mode: str = "ast",
1044
+ jit_level: str = "O0",
1045
+ dynamic: int = 0,
1046
+ fullgraph: bool = False,
1047
+ backend: str = "",
1048
+ **options):
863
1049
  """
864
1050
  Create a callable MindSpore graph from a Python function.
865
1051
 
866
1052
  This allows the MindSpore runtime to apply optimizations based on graph.
867
1053
 
868
1054
  Note:
869
- - If `input_signature` is specified, each input of `fn` must be a Tensor. And the input arguments for `fn`
870
- will not accept `**kwargs`.
871
- - It is not supported to run a function with decoration @jit(mode=“PIJit”)
872
- in static graph mode, in which case the decoration @jit(mode=“PIJit”) is considered invalid.
873
- - Calls to functions with decorated @jit(mode=“PIJit”) inside functions
874
- decorated with @jit(mode=“PIJit”) are not supported,
875
- and the decoration @jit(mode=“PIJit”) is considered invalid.
1055
+ - It is not supported to run a function with decoration @jit(capture_mode=“bytecode”)
1056
+ in static graph mode, in which case the decoration @jit(capture_mode=“bytecode”) is considered invalid.
1057
+ - Calls to functions with decorated @jit(capture_mode=“bytecode”) inside functions
1058
+ decorated with @jit(capture_mode=“ast”) are not supported,
1059
+ and the decoration @jit(capture_mode=“bytecode”) is considered invalid.
876
1060
 
877
1061
  Args:
878
- fn (Function): The Python function that will be run as a graph. Default: ``None`` .
879
- mode (str): The type of jit used, the value of mode should be ``PIJit`` or ``PSJit``. Default: ``PSJit`` .
880
-
881
- - `PSJit <https://www.mindspore.cn/docs/en/master/model_train/program_form/static_graph.html>`_ :
882
- Parse python ast to build graph.
883
- - `PIJit <https://www.mindspore.cn/docs/en/master/model_train/program_form/pynative.html#pijit>`_ :
884
- Parse python bytecode to build graph at runtime.
885
-
886
- input_signature (Union[Tuple, List, Dict, Tensor]): The Tensor which describes the input arguments. The
887
- shape and dtype of the Tensor will be supplied to this function. If `input_signature` is specified, the
888
- input parameters of `fn` cannot accept `**kwargs`, and the shape and dtype of actual inputs should keep the
889
- same as `input_signature`. Otherwise, TypeError will be raised. There are two mode for `input_signature`:
890
-
891
- - Full mode: Arguments is a Tuple, List or a Tensor, and they will be used as all compile inputs
892
- for graph-compiling.
893
- - Incremental mode: Argument is a Dict, and they will set to some of the graph inputs, which will be
894
- substituted into the input at the corresponding position for graph-compiling.
895
-
896
- Default: ``None`` .
897
-
898
- hash_args (Union[Object, List or Tuple of Objects]): The local free variables used inside `fn`,
899
- like functions or objects of class defined outside `fn`. Calling `fn` again with change of `hash_args`
900
- will trigger recompilation. Default: ``None`` .
901
- jit_config (JitConfig): Jit config for compile. Default: ``None`` .
902
- compile_once(bool): ``True``: The function would be compiled once when it was created many times.
903
- But it may be wrong if the free variables were changed. ``False`` : It would be recompiled when
904
- it was created again.
905
- Default: ``False`` .
1062
+ function (Function, optional): The Python function that will be run as a graph. Default: ``None``.
1063
+
1064
+ Keyword Args:
1065
+ capture_mode (str, optional): The method to create a callable MindSpore graph. The value of capture_mode
1066
+ should be ``ast`` , ``bytecode`` or ``trace`` . Default: ``ast`` .
1067
+
1068
+ - `ast <https://www.mindspore.cn/docs/en/r2.5.0/model_train/program_form/static_graph.html>`_ :
1069
+ Parse Python ast to build graph.
1070
+ - `bytecode <https://www.mindspore.cn/docs/en/r2.5.0/model_train/program_form/pynative.html#pijit>`_ :
1071
+ Parse Python bytecode to build graph at runtime. This is an experimental prototype that is subject to
1072
+ change and/or deletion.
1073
+ - `trace` : Trace the execution of Python code to build graph. This is an experimental prototype that is
1074
+ subject to change and/or deletion.
1075
+
1076
+ jit_level (str, optional): Used to control the compilation optimization level. Currently is only effective
1077
+ with default backend. The value of jit_level should be ``O0`` or ``O1`` . Default: ``O0`` .
1078
+
1079
+ - `O0`: Except for optimizations that may affect functionality, all other optimizations are turned off.
1080
+ - `O1`: Using commonly used optimizations and automatic operator fusion optimizations. This optimization
1081
+ level is experimental and is being improved.
1082
+
1083
+ dynamic (int, optional): Whether dynamic shape compilation should be performed. Default: ``0``. The value range
1084
+ is as follows:
1085
+
1086
+ - `0`: Do not perform dynamic shape compilation.
1087
+ - `1`: Enable dynamic shape compilation and automatically detect shape changes.
1088
+
1089
+ fullgraph (bool, optional): Whether to capture the entire function into graph. If False, jit attempts to
1090
+ be compatible with all Python syntax in the function as much as possible. If True, we require that the
1091
+ entire function can be captured into graph. If this is not possible (that is, if there is Python syntax
1092
+ not supported), then it will raise an exception. This currently only applies when capture_mode is ast.
1093
+ Default: ``False``.
1094
+ backend (str, optional): The compilation backend to be used. If this parameter is not set, the framework will
1095
+ use ``GE`` backend for Atlas training series products and ``ms_backend`` backend for others including Atlas
1096
+ A2 training series products by default.
1097
+
1098
+ - `ms_backend`: Adopt KernelByKernel execution mode.
1099
+ - `GE`: Adopt Sink execution mode. The whole model will be sinked to device to execute, only applicable to
1100
+ the top cell of model. And only can be used in Ascend platform.
1101
+
1102
+ **options (dict): A dictionary of options to pass to the compilation backend.
1103
+
1104
+ Some options are device specific, see the below table for details:
1105
+
1106
+ +---------------------------+---------------------------+-------------------------+
1107
+ | Option Parameters | Hardware Platform Support | Backend Support |
1108
+ +===========================+===========================+=========================+
1109
+ | disable_format_transform | GPU | ms_backend |
1110
+ +---------------------------+---------------------------+-------------------------+
1111
+ | exec_order | Ascend | ms_backend |
1112
+ +---------------------------+---------------------------+-------------------------+
1113
+ | ge_options | Ascend | GE |
1114
+ +---------------------------+---------------------------+-------------------------+
1115
+ | infer_boost | Ascend | ms_backend |
1116
+ +---------------------------+---------------------------+-------------------------+
1117
+
1118
+ - disable_format_transform (bool, optional): Whether to disable the automatic format transform function
1119
+ from NCHW to NHWC. When the network training performance of fp16 is worse than fp32,
1120
+ `disable_format_transform` can be set to ``True`` to try to improve training performance.
1121
+ Default: ``False`` .
1122
+ - exec_order (str, optional): Set the sorting method for operator execution, currently only two sorting
1123
+ methods are supported: ``bfs`` and ``dfs`` . Default: ``bfs`` .
1124
+
1125
+ - `bfs`: The default sorting method, breadth priority, good communication masking, relatively good
1126
+ performance.
1127
+ - `dfs`: An optional sorting method, depth-first sorting. The performance is relatively worse than that
1128
+ of bfs execution order, but it occupies less memory. It is recommended to try dfs in scenarios where
1129
+ other execution orders run out of memory (OOM).
1130
+
1131
+ - ge_options (dict): Set options for ge backend. The options are divided into two categories: global,
1132
+ and session. This is an experimental prototype that is subject to change and/or deletion.
1133
+ For detailed information, please refer to `Ascend community <https://www.hiascend.com/document/detail/zh/canncommercial/80RC3/apiref/ascendgraphapi/atlasgeapi_07_0146.html>`_ .
1134
+
1135
+ - global (dict): Set global options.
1136
+ - session (dict): Set session options.
1137
+
1138
+ - infer_boost (str, optional): Used to control the inference mode. Default: ``off``, which means
1139
+ the inference mode is disabled. The range is as follows:
1140
+
1141
+ - `on`: Enable inference mode, get better infer performance.
1142
+ - `off`: Disable inference mode, use forward for inference. The performance is poor.
906
1143
 
907
1144
  Returns:
908
1145
  Function, if `fn` is not None, returns a callable function that will execute the compiled function; If `fn` is
@@ -921,12 +1158,12 @@ def jit(fn=None, mode="PSJit", input_signature=None, hash_args=None, jit_config=
921
1158
  >>> x = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
922
1159
  >>> y = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
923
1160
  ...
924
- >>> # create a callable MindSpore graph by calling decorator @jit
1161
+ >>> # create a callable MindSpore graph by calling jit
925
1162
  >>> def tensor_add(x, y):
926
1163
  ... z = x + y
927
1164
  ... return z
928
1165
  ...
929
- >>> tensor_add_graph = jit(fn=tensor_add)
1166
+ >>> tensor_add_graph = jit(function=tensor_add)
930
1167
  >>> out = tensor_add_graph(x, y)
931
1168
  ...
932
1169
  >>> # create a callable MindSpore graph through decorator @jit
@@ -937,180 +1174,70 @@ def jit(fn=None, mode="PSJit", input_signature=None, hash_args=None, jit_config=
937
1174
  ...
938
1175
  >>> out = tensor_add_with_dec(x, y)
939
1176
  ...
940
- >>> # create a callable MindSpore graph through decorator @jit with input_signature parameter
941
- >>> @jit(input_signature=(Tensor(np.ones([1, 1, 3, 3]).astype(np.float32)),
942
- ... Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))))
943
- ... def tensor_add_with_sig(x, y):
1177
+ >>> # create a callable MindSpore graph and capture the entire function into the graph
1178
+ >>> @jit(fullgraph=True)
1179
+ ... def tensor_add_fullgraph(x, y):
944
1180
  ... z = x + y
945
1181
  ... return z
946
1182
  ...
947
- >>> out = tensor_add_with_sig(x, y)
948
- ...
949
- >>> @jit(input_signature={"y": Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))})
950
- ... def tensor_add_with_sig_1(x, y):
951
- ... z = x + y
952
- ... return z
953
- ...
954
- >>> out1 = tensor_add_with_sig_1(x, y)
955
- ...
956
- ... # Set hash_args as fn, otherwise cache of compiled closure_fn will not be reused.
957
- ... # While fn differs during calling again, recompilation will be triggered.
958
- >>> def func(x):
959
- ... return ops.exp(x)
960
- ...
961
- >>> def closure_fn(x, fn):
962
- ... @jit(hash_args=fn)
963
- ... def inner_fn(a):
964
- ... return fn(a)
965
- ... return inner_fn(x)
966
- ...
967
- >>> inputs = Tensor(np.ones([10, 10, 10]).astype(np.float32))
968
- >>> for i in range(10):
969
- ... closure_fn(inputs, func)
970
- ...
971
- ... # Set compile_once = True, otherwise the train_step will be compiled again.
972
- >>> def train(x):
973
- ... @jit(compile_once = True)
974
- ... def train_step(x):
975
- ... return ops.exp(x)
976
- ... for i in range(10):
977
- ... train_step(x)
978
- ...
979
- >>> inputs = Tensor(np.ones([10, 10, 10]).astype(np.float32))
980
- >>> for i in range(10):
981
- ... train(inputs)
1183
+ >>> out = tensor_add_fullgraph(x, y)
982
1184
  """
983
1185
 
984
- def wrap_mindspore(func):
985
- if not isinstance(compile_once, bool):
986
- logger.warning(f"The parameter `compile_once` of jit should be a bool, "
987
- f"but got {type(compile_once)}.")
988
- if hash_args:
989
- hash_obj = _get_jit_hash(hash_args)
990
- elif compile_once:
991
- hash_obj = 0
992
- else:
1186
+ capture_mode = Validator.check_string(capture_mode, ["ast", "bytecode", "trace"], "capture_mode", "jit")
1187
+ jit_level = Validator.check_string(jit_level, ["O0", "O1"], "jit_level", "jit")
1188
+ dynamic = Validator.check_int_range(dynamic, 0, 1, Validator.INC_BOTH, "dynamic", "jit")
1189
+ fullgraph = Validator.check_bool(fullgraph, "fullgraph", "jit")
1190
+ if backend == "":
1191
+ backend = "GE" if MSContext.get_instance().get_ascend_soc_version() == "ascend910" else "ms_backend"
1192
+ backend = Validator.check_string(backend, ["ms_backend", "GE"], "backend", "jit")
1193
+ jit_syntax_level = "LAX" if fullgraph is False else "STRICT"
1194
+ hash_obj = _get_hash_obj(options)
1195
+ _check_options(options, backend)
1196
+ options_str = json.dumps(options)
1197
+ infer_boost = options['infer_boost'] if 'infer_boost' in options else "off"
1198
+ exc_mode = options['exc_mode'] if 'exc_mode' in options else "auto"
1199
+ jit_config = JitConfig(jit_level=jit_level, exc_mode=exc_mode, jit_syntax_level=jit_syntax_level,
1200
+ infer_boost=infer_boost, backend=backend, options=options_str)
1201
+
1202
+ def wrap_func(func):
1203
+ nonlocal hash_obj
1204
+ if hash_obj is None or not _is_inner_func(func):
993
1205
  hash_obj = int(time.time() * 1e9)
994
1206
 
995
- dyn_args = _process_dyn_args(func, input_signature)
996
-
997
1207
  @wraps(func)
998
1208
  def staging_specialize(*args, **kwargs):
999
1209
  if os.getenv("MS_JIT") == '0':
1000
1210
  return func(*args, **kwargs)
1001
1211
 
1002
1212
  args, kwargs = _handle_func_args(func, *args, **kwargs)
1003
-
1004
1213
  process_obj = None
1005
1214
  if args and not isinstance(args[0], PythonTensor) and hasattr(args[0], func.__name__):
1006
1215
  process_obj = args[0]
1007
- # only the function or cell instance wrapped by shard will fall into this branch
1008
- if _is_pynative_parallel() and func.__name__ == _PYNATIVE_PARALLEL_FUNC_NAME:
1009
- process_obj = hash_args
1010
1216
  # Handle auto mixed precision strategy.
1011
1217
  if not hasattr(func, "amp_strategy"):
1012
1218
  if isinstance(func, types.MethodType):
1013
1219
  setattr(func.__func__, "amp_strategy", get_curr_amp_strategy())
1014
1220
  else:
1015
1221
  setattr(func, "amp_strategy", get_curr_amp_strategy())
1016
- out = _MindsporeFunctionExecutor(func, hash_obj, dyn_args, process_obj, jit_config)(*args, **kwargs)
1222
+
1223
+ ms_function_executor = _JitExecutor(func, hash_obj, None, process_obj, jit_config, dynamic)
1224
+ out = ms_function_executor(*args, **kwargs)
1017
1225
  return out
1018
1226
 
1019
1227
  return staging_specialize
1020
1228
 
1021
- wrap_func = wrap_mindspore
1022
- if mode == "PIJit":
1023
- wrap_func = PIJitCaptureContext(jit_config, input_signature)
1229
+ if capture_mode == "bytecode":
1230
+ wrap_func = PIJitCaptureContext(jit_config)
1231
+ elif capture_mode == "trace":
1232
+ if function is not None:
1233
+ return _jit_trace(function)
1234
+ return _jit_trace
1024
1235
 
1025
- if fn is not None:
1026
- return wrap_func(fn)
1236
+ if function is not None:
1237
+ return wrap_func(function)
1027
1238
  return wrap_func
1028
1239
 
1029
1240
 
1030
- def ms_function(fn=None, input_signature=None, hash_args=None, jit_config=None):
1031
- """
1032
- Create a callable MindSpore graph from a Python function.
1033
-
1034
- This allows the MindSpore runtime to apply optimizations based on graph.
1035
-
1036
- Note:
1037
- - `ms_function` will be deprecated and removed in a future version. Please use :func:`mindspore.jit` instead.
1038
- - If `input_signature` is specified, each input of `fn` must be a Tensor. And the input arguments for `fn`
1039
- will not accept `**kwargs`.
1040
-
1041
- Args:
1042
- fn (Function): The Python function that will be run as a graph. Default: ``None`` .
1043
- input_signature (Tensor): The Tensor which describes the input arguments. The shape and dtype of the Tensor
1044
- will be supplied to this function. The shape and dtype of actual inputs of `fn` should
1045
- keep the same as input_signature. Otherwise, TypeError will be raised. Default: ``None`` .
1046
- hash_args (Union[Object, List or Tuple of Objects]): The local free variables used inside `fn`,
1047
- like functions or objects of class defined outside `fn`. Calling `fn` again with change of `hash_args`
1048
- will trigger recompilation. Default: ``None`` .
1049
- jit_config (JitConfig): Jit config for compile. Default: ``None`` .
1050
-
1051
- Returns:
1052
- Function, if `fn` is not None, returns a callable function that will execute the compiled function; If `fn` is
1053
- None, returns a decorator and when this decorator invokes with a single `fn` argument, the callable function is
1054
- equal to the case when `fn` is not None.
1055
-
1056
- Supported Platforms:
1057
- ``Ascend`` ``GPU`` ``CPU``
1058
-
1059
- Examples:
1060
- >>> import numpy as np
1061
- >>> from mindspore import Tensor
1062
- >>> from mindspore import ops
1063
- >>> from mindspore import ms_function
1064
- ...
1065
- >>> x = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
1066
- >>> y = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
1067
- ...
1068
- >>> # create a callable MindSpore graph by calling ms_function
1069
- >>> def tensor_add(x, y):
1070
- ... z = x + y
1071
- ... return z
1072
- ...
1073
- >>> tensor_add_graph = ms_function(fn=tensor_add)
1074
- >>> out = tensor_add_graph(x, y)
1075
- ...
1076
- >>> # create a callable MindSpore graph through decorator @ms_function
1077
- >>> @ms_function
1078
- ... def tensor_add_with_dec(x, y):
1079
- ... z = x + y
1080
- ... return z
1081
- ...
1082
- >>> out = tensor_add_with_dec(x, y)
1083
- ...
1084
- >>> # create a callable MindSpore graph through decorator @ms_function with input_signature parameter
1085
- >>> @ms_function(input_signature=(Tensor(np.ones([1, 1, 3, 3]).astype(np.float32)),
1086
- ... Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))))
1087
- ... def tensor_add_with_sig(x, y):
1088
- ... z = x + y
1089
- ... return z
1090
- ...
1091
- >>> out = tensor_add_with_sig(x, y)
1092
- ...
1093
- ... # Set hash_args as fn, otherwise cache of compiled `closure_fn` will not be reused.
1094
- ... # While fn differs during calling again, recompilation will be triggered.
1095
- >>> def func(x):
1096
- ... return ops.exp(x)
1097
- ...
1098
- >>> def closure_fn(x, fn):
1099
- ... @ms_function(hash_args=fn)
1100
- ... def inner_fn(a):
1101
- ... return fn(a)
1102
- ... return inner_fn(x)
1103
- ...
1104
- >>> inputs = Tensor(np.ones([10, 10, 10]).astype(np.float32))
1105
- >>> for i in range(10):
1106
- ... closure_fn(inputs, func)
1107
- """
1108
-
1109
- logger.warning("'mindspore.ms_function' will be deprecated and removed in a future version. "
1110
- "Please use 'mindspore.jit' instead.")
1111
- return jit(fn=fn, input_signature=input_signature, hash_args=hash_args, jit_config=jit_config)
1112
-
1113
-
1114
1241
  def _core(fn=None, **flags):
1115
1242
  """
1116
1243
  A decorator that adds a flag to the function.
@@ -1203,69 +1330,6 @@ def _no_recursive(callable_obj):
1203
1330
  return callable_obj
1204
1331
 
1205
1332
 
1206
- def ms_class(cls):
1207
- """
1208
- Class decorator for user-defined classes.
1209
-
1210
- This allows MindSpore to identify user-defined classes and thus obtain their attributes and methods.
1211
-
1212
- Note:
1213
- `ms_class` will be deprecated and removed in a future version. Please use :func:`mindspore.jit_class` instead.
1214
-
1215
- Args:
1216
- cls (Class): User-defined class.
1217
-
1218
- Returns:
1219
- Class.
1220
-
1221
- Raises:
1222
- TypeError: If ms_class is used for non-class types or nn.Cell.
1223
- AttributeError: If the private attributes or magic methods of the class decorated with ms_class is called.
1224
-
1225
- Supported Platforms:
1226
- ``Ascend`` ``GPU`` ``CPU``
1227
-
1228
- Examples:
1229
- >>> import mindspore.nn as nn
1230
- >>> from mindspore import ms_class
1231
- ...
1232
- >>> @ms_class
1233
- ... class UserDefinedNet:
1234
- ... def __init__(self):
1235
- ... self.value = 10
1236
- ...
1237
- ... def func(self, x):
1238
- ... return 2 * x
1239
- ...
1240
- >>> class Net(nn.Cell):
1241
- ... def __init__(self):
1242
- ... super(Net, self).__init__()
1243
- ... self.net = UserDefinedNet()
1244
- ...
1245
- ... def construct(self, x):
1246
- ... out = self.net.value + self.net.func(x)
1247
- ... return out
1248
- ...
1249
- >>> net = Net()
1250
- >>> out = net(5)
1251
- >>> print(out)
1252
- 20
1253
- """
1254
-
1255
- logger.warning("'mindspore.ms_class' will be deprecated and removed in a future version. "
1256
- "Please use 'mindspore.jit_class' instead.")
1257
-
1258
- # Check if cls is of type class.
1259
- if not inspect.isclass(cls):
1260
- raise TypeError(f'Decorator ms_class can only be used for class type, but got {cls}.')
1261
- # Check if cls is nn.Cell.
1262
- if issubclass(cls, ms.nn.Cell):
1263
- raise TypeError(f"Decorator ms_class is used for user-defined classes and cannot be used for nn.Cell: {cls}.")
1264
- logger.info(f'Found ms_class: {cls}.')
1265
- setattr(cls, '__ms_class__', True)
1266
- return cls
1267
-
1268
-
1269
1333
  def jit_class(cls):
1270
1334
  """
1271
1335
  Class decorator for user-defined classes.
@@ -1322,28 +1386,6 @@ def jit_class(cls):
1322
1386
  return cls
1323
1387
 
1324
1388
 
1325
- def set_adapter_config(config):
1326
- """
1327
- Register configuration information for MSAdapter.
1328
-
1329
- Args:
1330
- config (dict): Configuration information.
1331
- """
1332
- if not isinstance(config, dict):
1333
- raise TypeError(f"The input argument of 'set_adapter_config' should be a dict, but got {config}.")
1334
- for key, value in config.items():
1335
- if key == "Tensor":
1336
- ms_adapter_registry.register_tensor(value)
1337
- elif key == "Parameter":
1338
- ms_adapter_registry.register_parameter(value)
1339
- elif key == "convert_object_map":
1340
- ms_adapter_registry.register_convert_map(value)
1341
- elif key == "convert_adapter_tensor_map":
1342
- ms_adapter_registry.register_convert_adapter_tensor_map(value)
1343
- else:
1344
- raise ValueError(f"Unsupported key in adapter config: {key}")
1345
-
1346
-
1347
1389
  def _function_forbid_reuse(func):
1348
1390
  if not inspect.isfunction(func):
1349
1391
  raise TypeError(f'Decorator _function_forbid_reuse can only be used for function type, but got {func}.')
@@ -1535,7 +1577,24 @@ class _PyNativeExecutor:
1535
1577
  Return:
1536
1578
  None.
1537
1579
  """
1538
- return self._executor.grad(grad, obj, weights, grad_position, *args)
1580
+ return self._executor.grad(grad, obj, weights, grad_position, False, *args)
1581
+
1582
+ def grad_aux(self, obj, grad, weights, grad_position, *args):
1583
+ """
1584
+ Run grad graph with aux
1585
+
1586
+ Args:
1587
+ obj (Function/Cell): The function or cell instance.
1588
+ grad (GradOperation): The gradoperation object.
1589
+ weights (ParameterTuple): The weights of cell instance.
1590
+ grad_position (Union(int, tuple[int])): If int, get the gradient with respect to single input.
1591
+ If tuple, get the gradients with respect to selected inputs. 'grad_position' begins with 0. Default: 0.
1592
+ args (tuple): Function or cell input arguments.
1593
+
1594
+ Return:
1595
+ None.
1596
+ """
1597
+ return self._executor.grad(grad, obj, weights, grad_position, True, *args)
1539
1598
 
1540
1599
  def clear_res(self):
1541
1600
  """
@@ -1671,6 +1730,15 @@ class _PyNativeExecutor:
1671
1730
  """
1672
1731
  self._executor.set_is_run_recompute(status)
1673
1732
 
1733
+ def high_order(self):
1734
+ """
1735
+ Is high order of current scene, this is a inner interface.
1736
+
1737
+ Return:
1738
+ Bool.
1739
+ """
1740
+ return self._executor.high_order()
1741
+
1674
1742
  def set_cell_use_dynamic_shape_process(self, flag):
1675
1743
  """
1676
1744
  Set the dynamic shape flag of eval process.
@@ -1753,7 +1821,6 @@ class _CellGraphExecutor:
1753
1821
  # create needed graph by lazy mode
1754
1822
  self.is_init = False
1755
1823
  self.enable_tuple_broaden = False
1756
- self.obfuscate_config = None # used for model's dynamic obfuscation
1757
1824
  self._graph_executor = GraphExecutor_.get_instance()
1758
1825
  self._graph_executor.set_py_exe_path(sys.executable)
1759
1826
  self._graph_executor.set_kernel_build_server_dir(os.path.split(kernel_build_server.__file__)[0] + os.sep)
@@ -1845,6 +1912,7 @@ class _CellGraphExecutor:
1845
1912
  Str, the full phase of the cell.
1846
1913
  Bool, if the graph has been compiled before, return False, else return True.
1847
1914
  """
1915
+ _init_auto_parallel_context(obj)
1848
1916
  obj.__parse_method__ = 'construct'
1849
1917
  if not hasattr(obj, obj.__parse_method__):
1850
1918
  raise AttributeError(
@@ -1877,6 +1945,7 @@ class _CellGraphExecutor:
1877
1945
  # Release resource should be released when CompileInner won't be executed, such as cur_convert_input_
1878
1946
  # generated in generate_arguments_key.
1879
1947
  self._graph_executor.clear_compile_arguments_resource()
1948
+ _clear_auto_parallel_context(obj)
1880
1949
  return phase, False
1881
1950
 
1882
1951
  full_function_name = obj.__class__.__name__ + '.' + str(obj.instance_count) + '.' + str(id(type(obj)))
@@ -1894,7 +1963,8 @@ class _CellGraphExecutor:
1894
1963
  else:
1895
1964
  jit_config_dict = JitConfig().jit_config_dict
1896
1965
  self._graph_executor.set_jit_config(jit_config_dict)
1897
- result = self._graph_executor.compile(obj, args, kwargs, phase, self._use_vm_mode())
1966
+ gc.collect()
1967
+ result = self._graph_executor.compile(obj, args, kwargs, phase)
1898
1968
  obj.compile_cache.add(phase)
1899
1969
  if not result:
1900
1970
  raise RuntimeError("Executor compile failed.")
@@ -1915,6 +1985,7 @@ class _CellGraphExecutor:
1915
1985
  self._build_data_graph(obj, phase)
1916
1986
  elif BROADCAST_PHASE not in phase and _get_parameter_broadcast():
1917
1987
  _parameter_broadcast(obj)
1988
+ _clear_auto_parallel_context(obj)
1918
1989
  return phase, True
1919
1990
 
1920
1991
  def _update_param_node_default_input(self, phase, replace):
@@ -1994,25 +2065,12 @@ class _CellGraphExecutor:
1994
2065
  """Clear the memory resource of a network."""
1995
2066
  self._graph_executor.del_net_res(obj, net_id)
1996
2067
 
1997
- def _get_branch_control_input(self):
1998
- if ('obf_ratio' not in self.obfuscate_config.keys()) or (
1999
- 'obf_random_seed' not in self.obfuscate_config.keys()):
2000
- raise ValueError("'obf_ratio' and 'obf_random_seed' must be in obfuscate_config.")
2001
- obf_random_seed = self.obfuscate_config.get('obf_random_seed')
2002
- if obf_random_seed == 0:
2003
- branch_control_input = 0
2004
- else:
2005
- branch_control_input = _generate_branch_control_input(obf_random_seed)
2006
- return branch_control_input
2007
-
2008
2068
  def _get_func_graph(self, obj, exec_id, use_prefix=False):
2009
2069
  """Get func graph from pipeline."""
2010
2070
  if use_prefix:
2011
2071
  exec_id = exec_id + '.' + obj.arguments_key
2012
2072
  if self._graph_executor.has_compiled(exec_id) is False:
2013
2073
  return None
2014
- if self.obfuscate_config is not None:
2015
- raise ValueError('For get func graph, obfuscate_config is currently not supported now.')
2016
2074
  return self._graph_executor.get_func_graph(exec_id)
2017
2075
 
2018
2076
  def _get_func_graph_proto(self, obj, exec_id, ir_type="onnx_ir", use_prefix=False, incremental=False):
@@ -2021,11 +2079,6 @@ class _CellGraphExecutor:
2021
2079
  exec_id = exec_id + '.' + obj.arguments_key
2022
2080
  if self._graph_executor.has_compiled(exec_id) is False:
2023
2081
  return None
2024
- if self.obfuscate_config is not None:
2025
- branch_control_input = self._get_branch_control_input()
2026
- return self._graph_executor.get_obfuscate_func_graph_proto(exec_id, incremental,
2027
- self.obfuscate_config['obf_ratio'],
2028
- branch_control_input)
2029
2082
  return self._graph_executor.get_func_graph_proto(exec_id, ir_type, incremental)
2030
2083
 
2031
2084
  def get_optimize_graph_proto(self, obj):
@@ -2063,6 +2116,8 @@ def ms_memory_recycle():
2063
2116
  """
2064
2117
  if ms_compile_cache:
2065
2118
  _cell_graph_executor.del_net_res(None, ms_compile_cache)
2119
+ if os.getenv('MS_DEV_JIT_PIPELINE') != '0':
2120
+ JitExecutor_.get_instance().del_net_res(None, ms_compile_cache)
2066
2121
  ms_compile_cache.clear()
2067
2122
  for cell_cache in cells_compile_cache.values():
2068
2123
  if cell_cache:
@@ -2089,30 +2144,6 @@ def set_recursion_limit(recursion_limit=1000):
2089
2144
  GraphExecutor_.get_instance().set_max_call_depth(recursion_limit)
2090
2145
 
2091
2146
 
2092
- def _generate_branch_control_input(obf_random_seed):
2093
- """Generate append network input for dynamic obfuscation in random seed mode."""
2094
- seed_max = 2 ** 32 - 1
2095
- int_max = 2 ** 31 - 1
2096
- np.random.seed(obf_random_seed % seed_max)
2097
- # generate a string as hash function inputs
2098
- word_repo = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + "abcdefghigklmnopqrstuvwxyz" + "0123456789"
2099
- repo_len = len(word_repo)
2100
- sha_string = ''
2101
- string_len = 1024 * 1024
2102
- for _ in range(string_len):
2103
- rand_index = np.random.randint(0, repo_len)
2104
- sha_string += word_repo[rand_index]
2105
- # get hash result
2106
- sha_result = hashlib.sha256(sha_string.encode('utf-8')).hexdigest() # len is 64
2107
- branch_control_input = 1
2108
- hex_base = 16
2109
- for item in sha_result:
2110
- if int(item, hex_base) > 0:
2111
- branch_control_input *= int(item, hex_base)
2112
- branch_control_input %= int_max
2113
- return branch_control_input
2114
-
2115
-
2116
2147
  def _bind_device_context():
2117
2148
  """Bind device context to current thread"""
2118
2149
  _bind_device_ctx()
@@ -2135,4 +2166,4 @@ def flops_collection(phase='train'):
2135
2166
  _cell_graph_executor = _CellGraphExecutor()
2136
2167
  _pynative_executor = _PyNativeExecutor()
2137
2168
 
2138
- __all__ = ['ms_function', 'ms_memory_recycle', 'ms_class', 'jit', 'jit_class', 'flops_collection']
2169
+ __all__ = ['ms_memory_recycle', 'jit', 'jit_class', 'flops_collection']