mindspore 2.4.10__cp310-cp310-win_amd64.whl → 2.6.0__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (602) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  3. mindspore/Newtonsoft.Json.dll +0 -0
  4. mindspore/__init__.py +13 -6
  5. mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
  6. mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
  7. mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
  8. mindspore/_check_jit_forbidden_api.py +3 -0
  9. mindspore/_checkparam.py +3 -38
  10. mindspore/_deprecated/__init__.py +17 -0
  11. mindspore/_deprecated/jit.py +198 -0
  12. mindspore/_extends/builtin_operations.py +1 -1
  13. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
  14. mindspore/_extends/parse/__init__.py +6 -7
  15. mindspore/_extends/parse/compile_config.py +83 -0
  16. mindspore/_extends/parse/deprecated/__init__.py +0 -0
  17. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +394 -0
  18. mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
  19. mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
  20. mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
  21. mindspore/_extends/parse/parser.py +47 -198
  22. mindspore/_extends/parse/resources.py +1 -5
  23. mindspore/_extends/parse/standard_method.py +229 -99
  24. mindspore/_extends/pijit/__init__.py +2 -2
  25. mindspore/_extends/pijit/pijit_func_white_list.py +17 -12
  26. mindspore/_extends/pijit/tensor_func_list.py +27 -0
  27. mindspore/_extends/utils.py +1 -1
  28. mindspore/amp.py +11 -5
  29. mindspore/atlprov.dll +0 -0
  30. mindspore/avcodec-59.dll +0 -0
  31. mindspore/avdevice-59.dll +0 -0
  32. mindspore/avfilter-8.dll +0 -0
  33. mindspore/avformat-59.dll +0 -0
  34. mindspore/avutil-57.dll +0 -0
  35. mindspore/boost/__init__.py +2 -2
  36. mindspore/boost/base.py +3 -7
  37. mindspore/boost/boost_cell_wrapper.py +138 -43
  38. mindspore/c1.dll +0 -0
  39. mindspore/c1xx.dll +0 -0
  40. mindspore/c2.dll +0 -0
  41. mindspore/common/__init__.py +6 -3
  42. mindspore/common/_grad_function.py +56 -0
  43. mindspore/common/_pijit_context.py +14 -5
  44. mindspore/common/_register_for_tensor.py +1 -2
  45. mindspore/common/_stub_tensor.py +30 -14
  46. mindspore/common/_tensor_cpp_method.py +17 -0
  47. mindspore/common/_tensor_docs.py +4760 -0
  48. mindspore/common/api.py +480 -372
  49. mindspore/common/auto_dynamic_shape.py +41 -44
  50. mindspore/common/dtype.py +39 -36
  51. mindspore/common/dump.py +9 -6
  52. mindspore/common/file_system.py +9 -1
  53. mindspore/common/generator.py +5 -0
  54. mindspore/common/hook_handle.py +6 -2
  55. mindspore/common/initializer.py +13 -10
  56. mindspore/common/jit_begin_end.py +94 -0
  57. mindspore/common/jit_config.py +6 -1
  58. mindspore/common/jit_context.py +76 -0
  59. mindspore/common/jit_trace.py +378 -0
  60. mindspore/common/lazy_inline.py +9 -3
  61. mindspore/common/mindir_util.py +10 -2
  62. mindspore/common/mutable.py +5 -4
  63. mindspore/common/parameter.py +135 -52
  64. mindspore/common/seed.py +2 -2
  65. mindspore/common/sparse_tensor.py +23 -17
  66. mindspore/common/tensor.py +975 -1981
  67. mindspore/communication/__init__.py +7 -5
  68. mindspore/communication/_comm_helper.py +52 -2
  69. mindspore/communication/comm_func.py +240 -181
  70. mindspore/communication/management.py +95 -26
  71. mindspore/context.py +324 -573
  72. mindspore/dataset/__init__.py +65 -37
  73. mindspore/dataset/audio/__init__.py +2 -8
  74. mindspore/dataset/audio/transforms.py +3 -17
  75. mindspore/dataset/callback/ds_callback.py +2 -1
  76. mindspore/dataset/core/config.py +87 -6
  77. mindspore/dataset/engine/cache_admin.py +3 -3
  78. mindspore/dataset/engine/cache_client.py +6 -5
  79. mindspore/dataset/engine/datasets.py +292 -267
  80. mindspore/dataset/engine/datasets_audio.py +22 -8
  81. mindspore/dataset/engine/datasets_standard_format.py +46 -27
  82. mindspore/dataset/engine/datasets_text.py +78 -48
  83. mindspore/dataset/engine/datasets_user_defined.py +183 -117
  84. mindspore/dataset/engine/datasets_vision.py +120 -44
  85. mindspore/dataset/engine/iterators.py +283 -63
  86. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
  87. mindspore/dataset/engine/obs/util.py +8 -0
  88. mindspore/dataset/engine/queue.py +40 -0
  89. mindspore/dataset/engine/samplers.py +289 -43
  90. mindspore/dataset/engine/serializer_deserializer.py +3 -2
  91. mindspore/dataset/engine/validators.py +53 -11
  92. mindspore/dataset/text/__init__.py +7 -6
  93. mindspore/dataset/text/transforms.py +6 -5
  94. mindspore/dataset/text/utils.py +3 -3
  95. mindspore/dataset/transforms/__init__.py +0 -9
  96. mindspore/dataset/transforms/py_transforms_util.py +17 -0
  97. mindspore/dataset/transforms/transforms.py +31 -14
  98. mindspore/dataset/utils/browse_dataset.py +1 -1
  99. mindspore/dataset/vision/__init__.py +2 -9
  100. mindspore/dataset/vision/transforms.py +202 -158
  101. mindspore/dataset/vision/utils.py +7 -5
  102. mindspore/dataset/vision/validators.py +1 -2
  103. mindspore/device_context/__init__.py +21 -0
  104. mindspore/device_context/ascend/__init__.py +25 -0
  105. mindspore/device_context/ascend/device.py +72 -0
  106. mindspore/device_context/ascend/op_debug.py +153 -0
  107. mindspore/device_context/ascend/op_precision.py +193 -0
  108. mindspore/device_context/ascend/op_tuning.py +123 -0
  109. mindspore/{ops_generate/gen_constants.py → device_context/cpu/__init__.py} +6 -17
  110. mindspore/device_context/cpu/device.py +62 -0
  111. mindspore/device_context/cpu/op_tuning.py +43 -0
  112. mindspore/device_context/gpu/__init__.py +21 -0
  113. mindspore/device_context/gpu/device.py +70 -0
  114. mindspore/device_context/gpu/op_precision.py +67 -0
  115. mindspore/device_context/gpu/op_tuning.py +175 -0
  116. mindspore/device_manager.py +170 -0
  117. mindspore/dnnl.dll +0 -0
  118. mindspore/dpcmi.dll +0 -0
  119. mindspore/experimental/es/embedding_service.py +35 -27
  120. mindspore/experimental/llm_boost/__init__.py +1 -0
  121. mindspore/experimental/llm_boost/ascend_native/__init__.py +22 -0
  122. mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +209 -0
  123. mindspore/experimental/llm_boost/ascend_native/llm_boost.py +52 -0
  124. mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
  125. mindspore/experimental/llm_boost/atb/llama_boost.py +6 -1
  126. mindspore/experimental/llm_boost/register.py +1 -0
  127. mindspore/experimental/map_parameter.py +4 -4
  128. mindspore/experimental/optim/adadelta.py +6 -6
  129. mindspore/experimental/optim/adagrad.py +4 -4
  130. mindspore/experimental/optim/adam.py +7 -0
  131. mindspore/experimental/optim/adamax.py +4 -4
  132. mindspore/experimental/optim/adamw.py +4 -0
  133. mindspore/experimental/optim/asgd.py +1 -1
  134. mindspore/experimental/optim/lr_scheduler.py +73 -46
  135. mindspore/experimental/optim/radam.py +34 -31
  136. mindspore/experimental/optim/rprop.py +1 -1
  137. mindspore/experimental/optim/sgd.py +1 -1
  138. mindspore/hal/contiguous_tensors_handle.py +6 -10
  139. mindspore/hal/device.py +55 -53
  140. mindspore/hal/event.py +52 -52
  141. mindspore/hal/memory.py +179 -120
  142. mindspore/hal/stream.py +150 -109
  143. mindspore/include/api/context.h +0 -1
  144. mindspore/include/dataset/constants.h +7 -4
  145. mindspore/include/dataset/execute.h +2 -2
  146. mindspore/jpeg62.dll +0 -0
  147. mindspore/log.py +50 -0
  148. mindspore/mindrecord/__init__.py +21 -8
  149. mindspore/mindrecord/config.py +17 -316
  150. mindspore/mindrecord/filereader.py +1 -9
  151. mindspore/mindrecord/filewriter.py +5 -15
  152. mindspore/mindrecord/mindpage.py +1 -9
  153. mindspore/mindspore_backend_common.dll +0 -0
  154. mindspore/mindspore_backend_manager.dll +0 -0
  155. mindspore/mindspore_common.dll +0 -0
  156. mindspore/mindspore_core.dll +0 -0
  157. mindspore/mindspore_dump.dll +0 -0
  158. mindspore/mindspore_frontend.dll +0 -0
  159. mindspore/mindspore_glog.dll +0 -0
  160. mindspore/mindspore_memory_pool.dll +0 -0
  161. mindspore/mindspore_ms_backend.dll +0 -0
  162. mindspore/mindspore_ops.dll +0 -0
  163. mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
  164. mindspore/mindspore_ops_kernel_common.dll +0 -0
  165. mindspore/mindspore_profiler.dll +0 -0
  166. mindspore/mindspore_pyboost.dll +0 -0
  167. mindspore/mindspore_pynative.dll +0 -0
  168. mindspore/mindspore_res_manager.dll +0 -0
  169. mindspore/mindspore_runtime_pipeline.dll +0 -0
  170. mindspore/mint/__init__.py +798 -761
  171. mindspore/mint/distributed/__init__.py +70 -4
  172. mindspore/mint/distributed/distributed.py +2679 -44
  173. mindspore/mint/linalg/__init__.py +8 -0
  174. mindspore/mint/nn/__init__.py +743 -22
  175. mindspore/mint/nn/functional.py +716 -23
  176. mindspore/mint/nn/layer/__init__.py +21 -4
  177. mindspore/mint/nn/layer/_functions.py +334 -0
  178. mindspore/mint/nn/layer/activation.py +276 -1
  179. mindspore/mint/nn/layer/basic.py +123 -0
  180. mindspore/mint/nn/layer/conv.py +933 -0
  181. mindspore/mint/nn/layer/normalization.py +223 -28
  182. mindspore/mint/nn/layer/padding.py +797 -0
  183. mindspore/mint/nn/layer/pooling.py +235 -0
  184. mindspore/mint/optim/__init__.py +3 -1
  185. mindspore/mint/optim/adam.py +223 -0
  186. mindspore/mint/optim/adamw.py +26 -19
  187. mindspore/mint/optim/sgd.py +171 -0
  188. mindspore/mint/special/__init__.py +2 -1
  189. mindspore/msobj140.dll +0 -0
  190. mindspore/mspdb140.dll +0 -0
  191. mindspore/mspdbcore.dll +0 -0
  192. mindspore/mspdbst.dll +0 -0
  193. mindspore/mspft140.dll +0 -0
  194. mindspore/msvcdis140.dll +0 -0
  195. mindspore/msvcp140_1.dll +0 -0
  196. mindspore/msvcp140_2.dll +0 -0
  197. mindspore/msvcp140_atomic_wait.dll +0 -0
  198. mindspore/msvcp140_codecvt_ids.dll +0 -0
  199. mindspore/multiprocessing/__init__.py +5 -0
  200. mindspore/nn/__init__.py +4 -1
  201. mindspore/nn/cell.py +1373 -192
  202. mindspore/nn/dynamic_lr.py +2 -1
  203. mindspore/nn/layer/activation.py +29 -27
  204. mindspore/nn/layer/basic.py +51 -35
  205. mindspore/nn/layer/channel_shuffle.py +3 -3
  206. mindspore/nn/layer/container.py +1 -1
  207. mindspore/nn/layer/conv.py +53 -42
  208. mindspore/nn/layer/embedding.py +12 -11
  209. mindspore/nn/layer/normalization.py +56 -49
  210. mindspore/nn/layer/padding.py +4 -3
  211. mindspore/nn/layer/pooling.py +120 -42
  212. mindspore/nn/layer/rnn_cells.py +1 -1
  213. mindspore/nn/layer/rnns.py +2 -1
  214. mindspore/nn/layer/timedistributed.py +5 -5
  215. mindspore/nn/layer/transformer.py +59 -36
  216. mindspore/nn/learning_rate_schedule.py +8 -4
  217. mindspore/nn/loss/loss.py +58 -55
  218. mindspore/nn/optim/ada_grad.py +7 -5
  219. mindspore/nn/optim/adadelta.py +11 -9
  220. mindspore/nn/optim/adafactor.py +1 -1
  221. mindspore/nn/optim/adam.py +19 -15
  222. mindspore/nn/optim/adamax.py +8 -7
  223. mindspore/nn/optim/adasum.py +5 -5
  224. mindspore/nn/optim/asgd.py +3 -1
  225. mindspore/nn/optim/ftrl.py +11 -9
  226. mindspore/nn/optim/lamb.py +1 -1
  227. mindspore/nn/optim/lars.py +1 -4
  228. mindspore/nn/optim/lazyadam.py +12 -10
  229. mindspore/nn/optim/momentum.py +7 -6
  230. mindspore/nn/optim/optimizer.py +3 -3
  231. mindspore/nn/optim/proximal_ada_grad.py +12 -10
  232. mindspore/nn/optim/rmsprop.py +13 -12
  233. mindspore/nn/optim/rprop.py +11 -9
  234. mindspore/nn/optim/sgd.py +9 -6
  235. mindspore/nn/optim/tft_wrapper.py +5 -2
  236. mindspore/nn/optim/thor.py +2 -1
  237. mindspore/nn/probability/bijector/bijector.py +17 -11
  238. mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
  239. mindspore/nn/probability/bijector/invert.py +2 -2
  240. mindspore/nn/probability/bijector/scalar_affine.py +3 -3
  241. mindspore/nn/probability/bijector/softplus.py +3 -2
  242. mindspore/nn/probability/distribution/beta.py +3 -3
  243. mindspore/nn/probability/distribution/categorical.py +1 -1
  244. mindspore/nn/probability/distribution/cauchy.py +4 -2
  245. mindspore/nn/probability/distribution/exponential.py +6 -7
  246. mindspore/nn/probability/distribution/gamma.py +2 -2
  247. mindspore/nn/probability/distribution/gumbel.py +2 -2
  248. mindspore/nn/probability/distribution/half_normal.py +5 -3
  249. mindspore/nn/probability/distribution/logistic.py +5 -3
  250. mindspore/nn/probability/distribution/poisson.py +1 -1
  251. mindspore/nn/probability/distribution/uniform.py +5 -3
  252. mindspore/nn/reinforcement/_tensors_queue.py +1 -1
  253. mindspore/nn/reinforcement/tensor_array.py +1 -1
  254. mindspore/nn/utils/init.py +13 -11
  255. mindspore/nn/wrap/__init__.py +6 -6
  256. mindspore/nn/wrap/cell_wrapper.py +181 -122
  257. mindspore/nn/wrap/grad_reducer.py +45 -36
  258. mindspore/nn/wrap/loss_scale.py +6 -7
  259. mindspore/numpy/array_creations.py +63 -65
  260. mindspore/numpy/array_ops.py +149 -144
  261. mindspore/numpy/logic_ops.py +41 -42
  262. mindspore/numpy/math_ops.py +361 -359
  263. mindspore/numpy/utils.py +17 -18
  264. mindspore/numpy/utils_const.py +5 -6
  265. mindspore/opencv_core452.dll +0 -0
  266. mindspore/opencv_imgcodecs452.dll +0 -0
  267. mindspore/opencv_imgproc452.dll +0 -0
  268. mindspore/ops/__init__.py +5 -3
  269. mindspore/ops/_grad_experimental/grad_comm_ops.py +112 -16
  270. mindspore/ops/_grad_experimental/grad_debug_ops.py +14 -2
  271. mindspore/ops/_grad_experimental/grad_inner_ops.py +9 -0
  272. mindspore/ops/_grad_experimental/grad_math_ops.py +2 -1
  273. mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
  274. mindspore/ops/_op_impl/cpu/__init__.py +1 -0
  275. mindspore/ops/_op_impl/cpu/raise_op.py +28 -0
  276. mindspore/ops/_register_for_op.py +0 -11
  277. mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
  278. mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -65
  279. mindspore/ops/_vmap/vmap_array_ops.py +52 -25
  280. mindspore/ops/_vmap/vmap_base.py +0 -2
  281. mindspore/ops/_vmap/vmap_grad_nn_ops.py +21 -14
  282. mindspore/ops/_vmap/vmap_math_ops.py +15 -16
  283. mindspore/ops/_vmap/vmap_nn_ops.py +29 -42
  284. mindspore/ops/auto_generate/__init__.py +4 -3
  285. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +258 -46
  286. mindspore/ops/auto_generate/gen_extend_func.py +757 -185
  287. mindspore/ops/auto_generate/gen_ops_def.py +4197 -2243
  288. mindspore/ops/auto_generate/gen_ops_prim.py +16976 -6055
  289. mindspore/ops/auto_generate/pyboost_inner_prim.py +221 -87
  290. mindspore/ops/composite/__init__.py +2 -1
  291. mindspore/ops/composite/base.py +20 -25
  292. mindspore/ops/composite/math_ops.py +6 -16
  293. mindspore/ops/composite/multitype_ops/__init__.py +5 -2
  294. mindspore/ops/composite/multitype_ops/_compile_utils.py +228 -30
  295. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
  296. mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
  297. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
  298. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
  299. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
  300. mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
  301. mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
  302. mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
  303. mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
  304. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
  305. mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
  306. mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
  307. mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
  308. mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
  309. mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
  310. mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
  311. mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
  312. mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
  313. mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
  314. mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
  315. mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
  316. mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
  317. mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
  318. mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
  319. mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
  320. mindspore/ops/composite/multitype_ops/pow_impl.py +2 -30
  321. mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
  322. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  323. mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
  324. mindspore/ops/function/__init__.py +40 -2
  325. mindspore/ops/function/_add_attr_func.py +58 -0
  326. mindspore/ops/function/array_func.py +2089 -2403
  327. mindspore/ops/function/clip_func.py +80 -23
  328. mindspore/ops/function/debug_func.py +57 -57
  329. mindspore/ops/function/grad/__init__.py +1 -0
  330. mindspore/ops/function/grad/grad_func.py +104 -71
  331. mindspore/ops/function/image_func.py +2 -2
  332. mindspore/ops/function/linalg_func.py +47 -78
  333. mindspore/ops/function/math_func.py +4351 -3813
  334. mindspore/ops/function/nn_func.py +1712 -637
  335. mindspore/ops/function/other_func.py +159 -1
  336. mindspore/ops/function/parameter_func.py +18 -84
  337. mindspore/ops/function/random_func.py +452 -387
  338. mindspore/ops/function/reshard_func.py +4 -70
  339. mindspore/ops/function/sparse_func.py +3 -3
  340. mindspore/ops/function/sparse_unary_func.py +6 -6
  341. mindspore/ops/function/spectral_func.py +25 -58
  342. mindspore/ops/function/vmap_func.py +26 -18
  343. mindspore/ops/functional.py +23 -7
  344. mindspore/ops/functional_overload.py +1548 -0
  345. mindspore/ops/op_info_register.py +32 -244
  346. mindspore/ops/operations/__init__.py +23 -15
  347. mindspore/ops/operations/_custom_ops_utils.py +235 -0
  348. mindspore/ops/operations/_embedding_cache_ops.py +4 -4
  349. mindspore/ops/operations/_grad_ops.py +2 -43
  350. mindspore/ops/operations/_infer_ops.py +2 -1
  351. mindspore/ops/operations/_inner_ops.py +43 -84
  352. mindspore/ops/operations/_ms_kernel.py +4 -10
  353. mindspore/ops/operations/_rl_inner_ops.py +1 -1
  354. mindspore/ops/operations/_scalar_ops.py +3 -2
  355. mindspore/ops/operations/_sequence_ops.py +1 -1
  356. mindspore/ops/operations/_tensor_array.py +1 -1
  357. mindspore/ops/operations/array_ops.py +81 -324
  358. mindspore/ops/operations/comm_ops.py +154 -108
  359. mindspore/ops/operations/custom_ops.py +298 -87
  360. mindspore/ops/operations/debug_ops.py +157 -59
  361. mindspore/ops/operations/inner_ops.py +7 -5
  362. mindspore/ops/operations/linalg_ops.py +1 -57
  363. mindspore/ops/operations/manually_defined/_inner.py +1 -1
  364. mindspore/ops/operations/manually_defined/ops_def.py +928 -180
  365. mindspore/ops/operations/math_ops.py +32 -234
  366. mindspore/ops/operations/nn_ops.py +212 -531
  367. mindspore/ops/operations/other_ops.py +62 -9
  368. mindspore/ops/operations/random_ops.py +13 -7
  369. mindspore/ops/operations/reshard_ops.py +1 -1
  370. mindspore/ops/operations/sparse_ops.py +2 -2
  371. mindspore/ops/primitive.py +66 -53
  372. mindspore/ops/tensor_method.py +1895 -0
  373. mindspore/ops_generate/__init__.py +0 -5
  374. mindspore/ops_generate/aclnn/__init__.py +0 -0
  375. mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +135 -0
  376. mindspore/ops_generate/aclnn/gen_aclnn_implement.py +257 -0
  377. mindspore/ops_generate/api/__init__.py +0 -0
  378. mindspore/ops_generate/api/add_tensor_docs_generator.py +56 -0
  379. mindspore/ops_generate/api/cpp_create_prim_instance_helper_generator.py +105 -0
  380. mindspore/ops_generate/api/functional_map_cpp_generator.py +504 -0
  381. mindspore/ops_generate/api/functional_overload_py_generator.py +112 -0
  382. mindspore/ops_generate/api/functions_cc_generator.py +237 -0
  383. mindspore/ops_generate/api/gen_api.py +103 -0
  384. mindspore/ops_generate/api/op_api_proto.py +235 -0
  385. mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +461 -0
  386. mindspore/ops_generate/common/__init__.py +0 -0
  387. mindspore/ops_generate/common/base_generator.py +11 -0
  388. mindspore/ops_generate/common/gen_constants.py +91 -0
  389. mindspore/ops_generate/common/gen_utils.py +348 -0
  390. mindspore/ops_generate/common/op_proto.py +473 -0
  391. mindspore/ops_generate/common/template.py +523 -0
  392. mindspore/ops_generate/gen_ops.py +22 -1069
  393. mindspore/ops_generate/op_def/__init__.py +0 -0
  394. mindspore/ops_generate/op_def/gen_op_def.py +90 -0
  395. mindspore/ops_generate/op_def/lite_ops_cpp_generator.py +191 -0
  396. mindspore/ops_generate/op_def/ops_def_cc_generator.py +296 -0
  397. mindspore/ops_generate/op_def/ops_def_h_generator.py +74 -0
  398. mindspore/ops_generate/op_def/ops_name_h_generator.py +83 -0
  399. mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
  400. mindspore/ops_generate/op_def_py/__init__.py +0 -0
  401. mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
  402. mindspore/ops_generate/op_def_py/op_def_py_generator.py +132 -0
  403. mindspore/ops_generate/op_def_py/op_prim_py_generator.py +489 -0
  404. mindspore/ops_generate/pyboost/__init__.py +0 -0
  405. mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +139 -0
  406. mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +93 -0
  407. mindspore/ops_generate/pyboost/gen_pyboost_func.py +175 -0
  408. mindspore/ops_generate/pyboost/op_template_parser.py +517 -0
  409. mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +407 -0
  410. mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +100 -0
  411. mindspore/ops_generate/pyboost/pyboost_functions_py_generator.py +148 -0
  412. mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +155 -0
  413. mindspore/ops_generate/pyboost/pyboost_inner_prim_generator.py +132 -0
  414. mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +272 -0
  415. mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +938 -0
  416. mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +357 -0
  417. mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +179 -36
  418. mindspore/ops_generate/resources/__init__.py +0 -0
  419. mindspore/ops_generate/resources/resource_list.py +30 -0
  420. mindspore/ops_generate/resources/resource_loader.py +36 -0
  421. mindspore/ops_generate/resources/resource_manager.py +64 -0
  422. mindspore/ops_generate/resources/yaml_loader.py +88 -0
  423. mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
  424. mindspore/parallel/__init__.py +7 -3
  425. mindspore/parallel/_auto_parallel_context.py +159 -40
  426. mindspore/parallel/_cell_wrapper.py +132 -15
  427. mindspore/parallel/_parallel_serialization.py +107 -5
  428. mindspore/parallel/_ps_context.py +1 -1
  429. mindspore/parallel/_recovery_context.py +7 -2
  430. mindspore/parallel/_tensor.py +142 -18
  431. mindspore/parallel/_utils.py +199 -23
  432. mindspore/parallel/algo_parameter_config.py +4 -4
  433. mindspore/parallel/auto_parallel.py +732 -0
  434. mindspore/parallel/checkpoint_convert.py +159 -0
  435. mindspore/parallel/checkpoint_transform.py +700 -35
  436. mindspore/parallel/cluster/process_entity/_api.py +276 -50
  437. mindspore/parallel/cluster/process_entity/_utils.py +41 -6
  438. mindspore/parallel/cluster/run.py +21 -4
  439. mindspore/parallel/function/__init__.py +24 -0
  440. mindspore/parallel/function/reshard_func.py +258 -0
  441. mindspore/parallel/nn/__init__.py +25 -0
  442. mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
  443. mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
  444. mindspore/parallel/parameter_broadcast.py +25 -14
  445. mindspore/parallel/shard.py +137 -59
  446. mindspore/parallel/transform_safetensors.py +364 -305
  447. mindspore/pgodb140.dll +0 -0
  448. mindspore/pgort140.dll +0 -0
  449. mindspore/profiler/__init__.py +22 -5
  450. mindspore/profiler/analysis/__init__.py +0 -0
  451. mindspore/profiler/analysis/parser/__init__.py +0 -0
  452. mindspore/profiler/analysis/parser/ascend_cann_parser.py +170 -0
  453. mindspore/profiler/analysis/parser/base_parser.py +158 -0
  454. mindspore/profiler/analysis/parser/framework_cann_relation_parser.py +45 -0
  455. mindspore/profiler/analysis/parser/ms_framework_parser.py +142 -0
  456. mindspore/profiler/analysis/parser/ms_minddata_parser.py +145 -0
  457. mindspore/profiler/analysis/parser/timeline_assembly_factory/__init__.py +0 -0
  458. mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +264 -0
  459. mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +40 -0
  460. mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +109 -0
  461. mindspore/profiler/analysis/parser/timeline_creator/__init__.py +0 -0
  462. mindspore/profiler/analysis/parser/timeline_creator/base_timeline_creator.py +44 -0
  463. mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +90 -0
  464. mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +76 -0
  465. mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +103 -0
  466. mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +134 -0
  467. mindspore/profiler/analysis/parser/timeline_event/__init__.py +0 -0
  468. mindspore/profiler/analysis/parser/timeline_event/base_event.py +233 -0
  469. mindspore/profiler/analysis/parser/timeline_event/cpu_op_event.py +47 -0
  470. mindspore/profiler/analysis/parser/timeline_event/flow_event.py +36 -0
  471. mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +415 -0
  472. mindspore/profiler/analysis/parser/timeline_event/msprof_event.py +73 -0
  473. mindspore/profiler/analysis/parser/timeline_event/scope_layer_event.py +53 -0
  474. mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +146 -0
  475. mindspore/profiler/analysis/task_manager.py +131 -0
  476. mindspore/profiler/analysis/time_converter.py +84 -0
  477. mindspore/profiler/analysis/viewer/__init__.py +0 -0
  478. mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +372 -0
  479. mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +87 -0
  480. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +250 -0
  481. mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +320 -0
  482. mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +327 -0
  483. mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +376 -0
  484. mindspore/profiler/analysis/viewer/ascend_timeline_viewer.py +58 -0
  485. mindspore/profiler/analysis/viewer/base_viewer.py +26 -0
  486. mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +96 -0
  487. mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +581 -0
  488. mindspore/profiler/analysis/work_flow.py +73 -0
  489. mindspore/profiler/common/ascend_msprof_exporter.py +139 -0
  490. mindspore/profiler/common/command_executor.py +90 -0
  491. mindspore/profiler/common/constant.py +186 -3
  492. mindspore/profiler/common/file_manager.py +208 -0
  493. mindspore/profiler/common/log.py +130 -0
  494. mindspore/profiler/common/msprof_cmd_tool.py +221 -0
  495. mindspore/profiler/common/path_manager.py +395 -0
  496. mindspore/profiler/common/process_bar.py +168 -0
  497. mindspore/profiler/common/process_pool.py +9 -3
  498. mindspore/profiler/common/profiler_context.py +500 -0
  499. mindspore/profiler/common/profiler_info.py +304 -0
  500. mindspore/profiler/common/profiler_meta_data.py +74 -0
  501. mindspore/profiler/common/profiler_output_path.py +284 -0
  502. mindspore/profiler/common/profiler_parameters.py +251 -0
  503. mindspore/profiler/common/profiler_path_manager.py +179 -0
  504. mindspore/profiler/common/record_function.py +76 -0
  505. mindspore/profiler/common/tlv_decoder.py +76 -0
  506. mindspore/profiler/common/util.py +75 -2
  507. mindspore/profiler/dynamic_profiler.py +341 -75
  508. mindspore/profiler/envprofiler.py +163 -0
  509. mindspore/profiler/experimental_config.py +197 -0
  510. mindspore/profiler/mstx.py +242 -0
  511. mindspore/profiler/platform/__init__.py +21 -0
  512. mindspore/profiler/platform/base_profiler.py +40 -0
  513. mindspore/profiler/platform/cpu_profiler.py +124 -0
  514. mindspore/profiler/platform/gpu_profiler.py +74 -0
  515. mindspore/profiler/platform/npu_profiler.py +335 -0
  516. mindspore/profiler/profiler.py +1073 -90
  517. mindspore/profiler/profiler_action_controller.py +187 -0
  518. mindspore/profiler/profiler_interface.py +118 -0
  519. mindspore/profiler/schedule.py +243 -0
  520. mindspore/rewrite/api/node.py +15 -13
  521. mindspore/rewrite/api/symbol_tree.py +2 -3
  522. mindspore/run_check/_check_version.py +27 -20
  523. mindspore/run_check/run_check.py +1 -1
  524. mindspore/runtime/__init__.py +37 -0
  525. mindspore/runtime/device.py +27 -0
  526. mindspore/runtime/event.py +209 -0
  527. mindspore/runtime/executor.py +177 -0
  528. mindspore/runtime/memory.py +416 -0
  529. mindspore/runtime/stream.py +460 -0
  530. mindspore/runtime/thread_bind_core.py +401 -0
  531. mindspore/safeguard/rewrite_obfuscation.py +12 -9
  532. mindspore/swresample-4.dll +0 -0
  533. mindspore/swscale-6.dll +0 -0
  534. mindspore/tbbmalloc.dll +0 -0
  535. mindspore/tinyxml2.dll +0 -0
  536. mindspore/train/__init__.py +8 -8
  537. mindspore/train/_utils.py +96 -27
  538. mindspore/train/amp.py +9 -5
  539. mindspore/train/callback/__init__.py +2 -2
  540. mindspore/train/callback/_callback.py +2 -16
  541. mindspore/train/callback/_checkpoint.py +53 -55
  542. mindspore/train/callback/_cluster_monitor.py +14 -18
  543. mindspore/train/callback/_early_stop.py +1 -1
  544. mindspore/train/callback/_flops_collector.py +103 -68
  545. mindspore/train/callback/_history.py +8 -5
  546. mindspore/train/callback/_lambda_callback.py +2 -2
  547. mindspore/train/callback/_landscape.py +0 -3
  548. mindspore/train/callback/_loss_monitor.py +2 -1
  549. mindspore/train/callback/_on_request_exit.py +6 -5
  550. mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
  551. mindspore/train/callback/_summary_collector.py +52 -19
  552. mindspore/train/callback/_time_monitor.py +2 -1
  553. mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +228 -108
  554. mindspore/train/data_sink.py +25 -2
  555. mindspore/train/dataset_helper.py +15 -16
  556. mindspore/train/loss_scale_manager.py +8 -7
  557. mindspore/train/metrics/accuracy.py +3 -3
  558. mindspore/train/metrics/confusion_matrix.py +9 -9
  559. mindspore/train/metrics/error.py +3 -3
  560. mindspore/train/metrics/hausdorff_distance.py +4 -4
  561. mindspore/train/metrics/mean_surface_distance.py +3 -3
  562. mindspore/train/metrics/metric.py +0 -12
  563. mindspore/train/metrics/occlusion_sensitivity.py +4 -2
  564. mindspore/train/metrics/precision.py +11 -10
  565. mindspore/train/metrics/recall.py +9 -9
  566. mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
  567. mindspore/train/mind_ir_pb2.py +174 -46
  568. mindspore/train/model.py +269 -136
  569. mindspore/train/serialization.py +622 -978
  570. mindspore/train/summary/_summary_adapter.py +2 -2
  571. mindspore/train/summary/summary_record.py +2 -3
  572. mindspore/train/train_thor/model_thor.py +1 -1
  573. mindspore/turbojpeg.dll +0 -0
  574. mindspore/utils/__init__.py +6 -3
  575. mindspore/utils/dryrun.py +140 -0
  576. mindspore/utils/hooks.py +81 -0
  577. mindspore/utils/runtime_execution_order_check.py +552 -0
  578. mindspore/utils/utils.py +138 -4
  579. mindspore/vcmeta.dll +0 -0
  580. mindspore/vcruntime140.dll +0 -0
  581. mindspore/vcruntime140_1.dll +0 -0
  582. mindspore/version.py +1 -1
  583. {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/METADATA +3 -3
  584. {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/RECORD +587 -418
  585. {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/entry_points.txt +1 -1
  586. mindspore/_install_custom.py +0 -43
  587. mindspore/common/_register_for_adapter.py +0 -74
  588. mindspore/common/_tensor_overload.py +0 -139
  589. mindspore/mindspore_np_dtype.dll +0 -0
  590. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
  591. mindspore/ops/auto_generate/gen_arg_handler.py +0 -197
  592. mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
  593. mindspore/ops_generate/gen_aclnn_implement.py +0 -263
  594. mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
  595. mindspore/ops_generate/gen_pyboost_func.py +0 -1052
  596. mindspore/ops_generate/gen_utils.py +0 -209
  597. mindspore/ops_generate/op_proto.py +0 -145
  598. mindspore/ops_generate/template.py +0 -261
  599. mindspore/profiler/envprofiling.py +0 -254
  600. mindspore/profiler/profiling.py +0 -1926
  601. {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/WHEEL +0 -0
  602. {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/top_level.txt +0 -0
mindspore/common/api.py CHANGED
@@ -17,6 +17,7 @@
17
17
  """Providing interface methods."""
18
18
  from __future__ import absolute_import
19
19
 
20
+ import gc
20
21
  import types
21
22
  import sys
22
23
  import os
@@ -24,11 +25,11 @@ import time
24
25
  import ast
25
26
  import inspect
26
27
  import importlib
27
- import hashlib
28
28
  import contextlib
29
+ import json
29
30
  from collections import OrderedDict, namedtuple
30
31
  from functools import wraps
31
- import numpy as np
32
+ from typing import Optional, Callable
32
33
  import mindspore as ms
33
34
  from mindspore import context
34
35
  from mindspore import log as logger
@@ -39,21 +40,23 @@ from mindspore.common.sparse_tensor import CSRTensor as PythonCSRTensor
39
40
  from mindspore.common.sparse_tensor import COOTensor as PythonCOOTensor
40
41
  from mindspore.common.sparse_tensor import RowTensor as PythonRowTensor
41
42
  from mindspore._c_expression.amp import get_curr_amp_strategy
42
- from mindspore._c_expression import GraphExecutor_, Tensor, CSRTensor, RowTensor, COOTensor, \
43
+ from mindspore._c_expression import GraphExecutor_, JitExecutor_, CSRTensor, RowTensor, COOTensor, \
43
44
  PyNativeExecutor_, verify_inputs_signature, init_exec_dataset, _set_dataset_mode_config, init_pipeline, \
44
- _ms_memory_recycle, _bind_device_ctx
45
+ _run_jit_pipeline, _ms_memory_recycle, _bind_device_ctx, StubNode, MSContext, TensorPy as Tensor
45
46
  from mindspore.parallel._ps_context import _is_role_sched
46
- from mindspore.parallel._utils import _check_full_batch, _get_parameter_broadcast, _is_pynative_parallel, \
47
- _is_in_auto_parallel_mode, _is_parallel_mode
47
+ from mindspore.parallel._utils import _check_full_batch, _get_parameter_broadcast, _is_in_auto_parallel_mode, \
48
+ _is_parallel_mode
48
49
  from mindspore import _checkparam as Validator
49
50
  from mindspore._checkparam import is_stub_tensor
50
51
  from mindspore.common._utils import is_shape_unknown
51
- from mindspore.common.mutable import mutable
52
- from mindspore.common._register_for_adapter import ms_adapter_registry
52
+ from mindspore.common.mutable import mutable, _check_element_type
53
53
  from mindspore.common.auto_dynamic_shape import get_auto_dynamic_shape_args, update_auto_dynamic_shape_phase, \
54
54
  get_auto_dynamic_shape_args_with_check_input_signature, update_auto_dynamic_shape_phase_with_check_input_signature
55
55
  from mindspore.common._pijit_context import PIJitCaptureContext
56
- from mindspore.common.parameter import Parameter
56
+ from mindspore.common.parameter import Parameter, set_parameter_hook_updated, parameter_hook_updated
57
+ from mindspore.common.jit_context import jit_context
58
+ from mindspore.common.jit_trace import _jit_trace
59
+ from mindspore.parallel._utils import _init_auto_parallel_context, _clear_auto_parallel_context
57
60
 
58
61
  # Store ms_function class compiled pipeline cache.
59
62
  ms_compile_cache = set()
@@ -107,8 +110,7 @@ def _check_recompile(obj, compile_args, kwargs, full_function_name, create_time,
107
110
  logger.info(f"The {echo_function_name} has been compiled again. "
108
111
  f"{tips} ")
109
112
  else:
110
- tips = "Try to decorate the function with @jit(hash_args=...) " \
111
- "or @jit(compile_once=True) to reduce the compile time. " \
113
+ tips = "Try to reuse the function object decorated by @jit to reduce the compile time. " \
112
114
  "For more details, get instructions about `jit` at " \
113
115
  "https://www.mindspore.cn/search?inputValue=jit."
114
116
  logger.warning(f"The {echo_function_name} has been compiled again. "
@@ -120,14 +122,6 @@ def _check_recompile(obj, compile_args, kwargs, full_function_name, create_time,
120
122
  function_phases[full_function_name].add(create_time)
121
123
 
122
124
 
123
- def _ms_adapter_tensor_as_parameter_output(data):
124
- """Check whether the data is an output from a parameter which is a ms_adapter tensor.
125
- Pylint: disable=unidiomatic-typecheck.
126
- """
127
- return ms_adapter_registry.is_registered and isinstance(data, ms_adapter_registry.tensor) \
128
- and hasattr(data, "__ms_parameter_output__") and getattr(data, "__ms_parameter_output__")
129
-
130
-
131
125
  def _convert_python_data(data):
132
126
  """
133
127
  Convert C++ data to python.
@@ -138,18 +132,10 @@ def _convert_python_data(data):
138
132
  Returns:
139
133
  data, a data convert C++ to python
140
134
  """
141
- if isinstance(data, (Tensor, PythonTensor)) and data.adapter_flag:
142
- return ms_adapter_registry.tensor(data)
143
- if _ms_adapter_tensor_as_parameter_output(data) and hasattr(data, "tensor"):
144
- return data.tensor
145
- if isinstance(data, Tensor) and not isinstance(data, PythonTensor):
146
- return PythonTensor(data, internal=True)
147
- if isinstance(data, CSRTensor) and not isinstance(data, PythonCSRTensor):
148
- return PythonCSRTensor(csr_tensor=data)
149
- if isinstance(data, COOTensor) and not isinstance(data, PythonCOOTensor):
150
- return PythonCOOTensor(coo_tensor=data)
151
- if isinstance(data, RowTensor) and not isinstance(data, PythonRowTensor):
152
- return PythonRowTensor(row_tensor=data)
135
+ if isinstance(data, PythonTensor):
136
+ return data
137
+ if isinstance(data, StubNode):
138
+ return ms.common._stub_tensor._convert_stub(data)
153
139
  if data.__class__ is tuple:
154
140
  # Handle namedtuple since its type is tuple.
155
141
  if hasattr(data, "_fields"):
@@ -158,6 +144,12 @@ def _convert_python_data(data):
158
144
  fields = data_dict.keys()
159
145
  return namedtuple(type_name, fields)(**_convert_python_data(data_dict))
160
146
  return tuple(_convert_python_data(x) for x in data)
147
+ if isinstance(data, CSRTensor) and not isinstance(data, PythonCSRTensor):
148
+ return PythonCSRTensor(csr_tensor=data)
149
+ if isinstance(data, COOTensor) and not isinstance(data, PythonCOOTensor):
150
+ return PythonCOOTensor(coo_tensor=data)
151
+ if isinstance(data, RowTensor) and not isinstance(data, PythonRowTensor):
152
+ return PythonRowTensor(row_tensor=data)
161
153
  if data.__class__ is list:
162
154
  # Keep list object not change for inplace operation.
163
155
  for i in range(len(data)):
@@ -273,7 +265,9 @@ def __get_compile_cache_dep_files(file_path, compile_cache_dep_files, pkg):
273
265
  else:
274
266
  whole_module = module_name
275
267
  if n.name is not None:
276
- whole_module += "." + n.name
268
+ if not whole_module.endswith("."):
269
+ whole_module += "."
270
+ whole_module += n.name
277
271
  try:
278
272
  module_spec = importlib.util.find_spec(whole_module, pkg)
279
273
  except (ModuleNotFoundError, ValueError):
@@ -305,7 +299,22 @@ def _get_compile_cache_dep_files():
305
299
  return compile_cache_dep_files
306
300
 
307
301
 
308
- def _restore_mutable_attr(args_list, compile_args):
302
+ def _contains_auto_grad_tensor(obj):
303
+ """Check object is or contains auto grad tensor element"""
304
+ if isinstance(obj, PythonTensor):
305
+ return obj._has_auto_grad()
306
+ if isinstance(obj, (tuple, list)):
307
+ for element in obj:
308
+ if _contains_auto_grad_tensor(element):
309
+ return True
310
+ if isinstance(obj, dict):
311
+ for key in obj:
312
+ if _contains_auto_grad_tensor(obj[key]):
313
+ return True
314
+ return False
315
+
316
+
317
+ def _add_mutable_attr(args_list, compile_args, is_grad):
309
318
  """Restore the mutable attr for every arg."""
310
319
  new_compile_args = ()
311
320
  for idx, arg in enumerate(args_list):
@@ -316,7 +325,12 @@ def _restore_mutable_attr(args_list, compile_args):
316
325
  else:
317
326
  new_compile_args += (mutable(compile_args[idx], False),)
318
327
  else:
319
- new_compile_args += (compile_args[idx],)
328
+ if is_grad and _contains_auto_grad_tensor(arg):
329
+ if not _check_element_type(arg):
330
+ raise RuntimeError("Input \"%s\" contains tensor with gradient but can not mutable." % (str(arg)))
331
+ new_compile_args += (mutable(compile_args[idx], False),)
332
+ else:
333
+ new_compile_args += (compile_args[idx],)
320
334
  return new_compile_args
321
335
 
322
336
 
@@ -330,6 +344,7 @@ def _get_parameter_layout():
330
344
 
331
345
  def _handle_arg(obj, arg, compile_arg):
332
346
  """Handle arg for runtime .If need handle the arg, return True"""
347
+ from mindspore._extends.parse import compile_config
333
348
  if isinstance(arg, PythonTensor):
334
349
  if arg.has_init:
335
350
  arg.init_data()
@@ -342,7 +357,8 @@ def _handle_arg(obj, arg, compile_arg):
342
357
  if isinstance(arg, list) and not arg:
343
358
  return None
344
359
  return arg
345
- elif context.get_context("grad_for_scalar") and isinstance(arg, (int, float)):
360
+ elif (context.get_context("grad_for_scalar") or str(compile_config.GRAD_FOR_SCALAR) == '1') and \
361
+ isinstance(arg, (int, float)):
346
362
  return arg
347
363
  elif hasattr(obj, "enable_tuple_broaden") and obj.enable_tuple_broaden and isinstance(arg, tuple) and \
348
364
  _check_all_tensor(arg):
@@ -528,12 +544,35 @@ def _get_parameter_ids(args, kwargs):
528
544
  parameter_ids += str(id(value))
529
545
  return parameter_ids
530
546
 
547
+ def _get_tensor_hook_key(tensor):
548
+ """Get the hook key of Tensor/Parameter"""
549
+ return ".".join(map(str, map(id, tensor.hooks())))
550
+
551
+ def _get_hook_key(*args, **kwargs):
552
+ """Get the hook key of Tensors/Parameters"""
553
+ hook_key = ""
554
+ for idx, arg in enumerate(args):
555
+ if idx != 0:
556
+ hook_key += "."
557
+ # Only arg of the type Tensor or Parameter is supported now
558
+ if isinstance(arg, (Tensor, Parameter)):
559
+ hook_key += _get_tensor_hook_key(arg)
560
+
561
+ for idx, value in enumerate(kwargs.values()):
562
+ if idx != 0:
563
+ hook_key += "."
564
+ # Only kwarg of the type Tensor or Parameter is supported now
565
+ if isinstance(value, (Tensor, Parameter)):
566
+ hook_key += _get_tensor_hook_key(value)
531
567
 
532
- class _MindsporeFunctionExecutor:
568
+ return hook_key
569
+
570
+
571
+ class _JitExecutor:
533
572
  """
534
573
  Represents a function compiled by graph compiler.
535
574
 
536
- _MindsporeFunctionExecutor will compile the original function for every combination
575
+ _JitExecutor will compile the original function for every combination
537
576
  of argument types and shapes it is given (as well as their values, optionally).
538
577
 
539
578
  Args:
@@ -547,7 +586,7 @@ class _MindsporeFunctionExecutor:
547
586
  The result of pipeline running in graph mode.
548
587
  """
549
588
 
550
- def __init__(self, fn, ms_create_time, input_signature=None, obj=None, jit_config=None):
589
+ def __init__(self, fn, ms_create_time, input_signature=None, obj=None, jit_config=None, dynamic=0):
551
590
  init_pipeline()
552
591
  if not isinstance(fn, (types.FunctionType, types.MethodType)):
553
592
  raise RuntimeError('fn {} is not function or method'.format(fn))
@@ -559,13 +598,61 @@ class _MindsporeFunctionExecutor:
559
598
  self.obj = obj
560
599
  self.shard_parent_obj = obj
561
600
  self.enable_tuple_broaden = False
562
- self._graph_executor = GraphExecutor_.get_instance()
601
+ if _run_jit_pipeline():
602
+ self._graph_executor = JitExecutor_.get_instance()
603
+ else:
604
+ self._graph_executor = GraphExecutor_.get_instance()
563
605
  self._create_time = ms_create_time
564
606
  self._compile_args = None
607
+ self._enable_auto_dynamic = dynamic == 1
565
608
  self.jit_config_dict = jit_config.jit_config_dict if jit_config else None
566
609
 
610
+ def _predict(self, *args, **kwargs):
611
+ """Dedicated routine for predict."""
612
+ if not hasattr(self.obj, "phase"):
613
+ return False, None
614
+
615
+ predict_vailid_phase = {"prefill", 'increment'}
616
+ predict_phase = self.obj.phase
617
+ if predict_phase not in predict_vailid_phase:
618
+ return False, None
619
+
620
+ args_list = args
621
+ if self.obj is not None:
622
+ args_list = args_list[1:]
623
+
624
+ if predict_phase not in self.obj.phase_cache:
625
+ try:
626
+ predict_phase = self.compile(self.fn.__name__, *args_list, **kwargs)
627
+ except Exception as err:
628
+ _pynative_executor.clear_res()
629
+ raise err
630
+ else: # get compiled args to generate run args by _generate_run_args
631
+ compile_args = self._generate_compile_args(args_list)
632
+ key_id = self._get_key_id()
633
+ compile_args = get_auto_dynamic_shape_args_with_check_input_signature(
634
+ compile_args,
635
+ key_id,
636
+ self.input_signature,
637
+ self._enable_auto_dynamic
638
+ )
639
+ self._compile_args = compile_args
640
+
641
+ new_inputs = self._generate_run_args(args_list, kwargs)
642
+ output = self._graph_executor(
643
+ tuple(new_inputs),
644
+ self.obj.phase_cache[self.obj.phase]
645
+ )
646
+ res = _convert_python_data(output)
647
+ return True, res
648
+
567
649
  @_wrap_func
568
650
  def __call__(self, *args, **kwargs):
651
+ predict, res = self._predict(*args, **kwargs)
652
+ if predict:
653
+ return res
654
+ if jit_context() and jit_context().is_nested():
655
+ return jit_context().run_graph("", None, *())
569
656
  args_list = args
570
657
  if self.obj is not None:
571
658
  args_list = args_list[1:]
@@ -581,13 +668,18 @@ class _MindsporeFunctionExecutor:
581
668
  _pynative_executor.clear_res()
582
669
  raise err
583
670
 
584
- if context.get_context("precompile_only"):
671
+ if context.get_context("precompile_only") or os.getenv('MS_DEV_PRECOMPILE_ONLY') == '1':
585
672
  return None
586
673
 
587
674
  new_inputs = self._generate_run_args(args_list, kwargs)
588
- output = self._graph_executor(tuple(new_inputs), phase)
589
- if context.get_context("mode") == context.PYNATIVE_MODE:
590
- output = _pynative_executor.grad_jit(output, *new_inputs)
675
+ if context.get_context("mode") == context.PYNATIVE_MODE and not jit_context():
676
+ output = _pynative_executor.grad_jit(*new_inputs)
677
+ else:
678
+ output = self._graph_executor(tuple(new_inputs), phase)
679
+ if jit_context():
680
+ if is_stub_tensor(output):
681
+ output = output.stub_sync()
682
+ return jit_context().run_graph(phase, output, *tuple(new_inputs))
591
683
 
592
684
  return output
593
685
 
@@ -603,10 +695,13 @@ class _MindsporeFunctionExecutor:
603
695
  compile_args = self._generate_compile_args(args)
604
696
  key_id = self._get_key_id()
605
697
  compile_args = get_auto_dynamic_shape_args_with_check_input_signature(compile_args, key_id,
606
- self.input_signature)
698
+ self.input_signature,
699
+ self._enable_auto_dynamic)
607
700
 
608
- # Restore the mutable attr for every arg.
609
- compile_args = _restore_mutable_attr(args, compile_args)
701
+ # Add mutable for compile_args for two scene:
702
+ # 1) Origin args is mutable.
703
+ # 2) Args contains sequence with gradient tensor.
704
+ compile_args = _add_mutable_attr(args, compile_args, _pynative_executor.requires_grad())
610
705
  self._compile_args = compile_args
611
706
  generate_name, echo_function_name = self._get_generate_name()
612
707
  # The full Function name
@@ -621,7 +716,7 @@ class _MindsporeFunctionExecutor:
621
716
  f'`{self.fn.__module__}`')
622
717
  self.obj.__parse_method__ = method_name
623
718
  if isinstance(self.obj, ms.nn.Cell):
624
- generate_name = generate_name + '.' + str(self.obj.create_time)
719
+ generate_name = generate_name + '.' + str(self.obj.create_time) + self.obj.phase
625
720
  create_time = str(self.obj.create_time)
626
721
  else:
627
722
  generate_name = generate_name + '.' + str(self._create_time)
@@ -645,11 +740,14 @@ class _MindsporeFunctionExecutor:
645
740
  parameter_ids = _get_parameter_ids(args, kwargs)
646
741
  if parameter_ids != "":
647
742
  key = str(key) + '.' + parameter_ids
743
+
744
+ key = str(key) + "." + _get_hook_key(*args, **kwargs)
745
+
648
746
  phase = generate_name + '.' + str(key)
649
747
 
650
748
  update_auto_dynamic_shape_phase_with_check_input_signature(compile_args, key_id, phase, self.input_signature)
651
749
 
652
- if phase in ms_compile_cache:
750
+ if phase in ms_compile_cache and self._graph_executor.has_compiled(phase) and not parameter_hook_updated():
653
751
  # Release resource should be released when CompileInner won't be executed, such as cur_convert_input_
654
752
  # generated in generate_arguments_key.
655
753
  self._graph_executor.clear_compile_arguments_resource()
@@ -671,7 +769,7 @@ class _MindsporeFunctionExecutor:
671
769
  setattr(self.fn.__func__, "__jit_function__", True)
672
770
  else:
673
771
  setattr(self.fn, "__jit_function__", True)
674
- is_compile = self._graph_executor.compile(self.fn, compile_args, kwargs, phase, True)
772
+ is_compile = self._graph_executor.compile(self.fn, compile_args, kwargs, phase)
675
773
  if isinstance(self.fn, types.MethodType):
676
774
  delattr(self.fn.__func__, "__jit_function__")
677
775
  else:
@@ -679,11 +777,14 @@ class _MindsporeFunctionExecutor:
679
777
  else:
680
778
  if isinstance(self.obj, ms.nn.Cell):
681
779
  self._graph_executor.set_weights_values(self.obj.parameters_dict())
682
- is_compile = self._graph_executor.compile(self.obj, compile_args, kwargs, phase, True)
780
+ is_compile = self._graph_executor.compile(self.obj, compile_args, kwargs, phase)
683
781
 
684
782
  if not is_compile:
685
783
  raise RuntimeError("Executor compile failed.")
784
+ set_parameter_hook_updated(False)
686
785
  ms_compile_cache.add(phase)
786
+ if hasattr(self.obj, "phase"):
787
+ self.obj.phase_cache[self.obj.phase] = phase
687
788
 
688
789
  return phase
689
790
 
@@ -704,7 +805,7 @@ class _MindsporeFunctionExecutor:
704
805
  else:
705
806
  key_id = str(id(self.obj)) + str(self._create_time)
706
807
 
707
- if _pynative_executor.grad_flag():
808
+ if _pynative_executor.requires_grad():
708
809
  key_id = key_id + ".grad"
709
810
  return key_id
710
811
 
@@ -714,9 +815,9 @@ class _MindsporeFunctionExecutor:
714
815
  self.fn.__code__.co_firstlineno)
715
816
  echo_function_name = "function \"" + self.fn.__name__ + "\" at the file \"" + self.fn.__code__.co_filename \
716
817
  + "\", line " + str(self.fn.__code__.co_firstlineno)
717
- if _pynative_executor.grad_flag():
818
+ if _pynative_executor.requires_grad():
718
819
  generate_name = generate_name + ".grad"
719
- if _is_pynative_parallel():
820
+ if self.fn.__name__ == _PYNATIVE_PARALLEL_FUNC_NAME:
720
821
  generate_name = generate_name[:generate_name.rfind(str(id(self.fn)))] + str(id(self.shard_parent_obj))
721
822
  return generate_name, echo_function_name
722
823
 
@@ -777,6 +878,14 @@ class _MindsporeFunctionExecutor:
777
878
  """
778
879
  return _get_args_for_run(self, args_list, kwargs, self._compile_args)
779
880
 
881
+ def _get_func_graph_proto(self, obj, exec_id, ir_type="onnx_ir", use_prefix=False, incremental=False):
882
+ """Get graph proto from pipeline."""
883
+ if use_prefix:
884
+ exec_id = exec_id + '.' + obj.arguments_key
885
+ if self._graph_executor.has_compiled(exec_id) is False:
886
+ return None
887
+ return self._graph_executor.get_func_graph_proto(exec_id, ir_type, incremental)
888
+
780
889
 
781
890
  # The attributes used to identify a given object.
782
891
  attr_op = {"__str__": lambda x: x.__str__(),
@@ -789,6 +898,13 @@ attr_op = {"__str__": lambda x: x.__str__(),
789
898
  }
790
899
 
791
900
 
901
+ def _is_inner_func(func):
902
+ """Check whether the func is an inner func which needs hash_args parameter."""
903
+ # This is a workaround for inner api, should fix it later.
904
+ inner_func = ["after_shard", "_wrap_container"]
905
+ return func.__name__ in inner_func
906
+
907
+
792
908
  def _get_obj_id(input_obj):
793
909
  """Get hash id of single object."""
794
910
  obj_id = ".".join(
@@ -803,50 +919,227 @@ def _get_jit_hash(hash_input):
803
919
  return _get_obj_id(hash_input)
804
920
 
805
921
 
806
- def jit(fn=None, mode="PSJit", input_signature=None, hash_args=None, jit_config=None, compile_once=False):
922
+ def _get_hash_obj(options):
923
+ hash_obj = None
924
+ if "hash_args" in options:
925
+ hash_obj = _get_jit_hash(options["hash_args"])
926
+ del options["hash_args"]
927
+ return hash_obj
928
+
929
+
930
+ def _check_option_device(option, device):
931
+ """Check jit options wiwh device"""
932
+ option_device_cfgs = {
933
+ 'disable_format_transform': ['GPU'],
934
+ 'exec_order': ['Ascend'],
935
+ 'ge_options': ['Ascend'],
936
+ 'infer_boost': ['Ascend'],
937
+ }
938
+ if option in option_device_cfgs and device not in option_device_cfgs[option]:
939
+ logger.warning(f"For 'jit(options)', the option '{option}' is only support device in "
940
+ f"'{option_device_cfgs[option]}', but got '{device}', ignore it.")
941
+
942
+
943
+ def _check_option_backend(option, backend):
944
+ """Check jit options wiwh backend"""
945
+ option_backend_cfgs = {
946
+ 'disable_format_transform': ['ms_backend'],
947
+ 'exec_order': ['ms_backend'],
948
+ 'ge_options': ['GE'],
949
+ 'infer_boost': ['ms_backend'],
950
+ }
951
+ if option in option_backend_cfgs and backend not in option_backend_cfgs[option]:
952
+ logger.warning(f"For 'jit(options)', the option '{option}' is only support backend in "
953
+ f"'{option_backend_cfgs[option]}', but got '{backend}', ignore it.")
954
+
955
+
956
+ def _check_disable_format_transform_value(option, disable_format_transform):
957
+ """check disable_format_transform option value"""
958
+ if not isinstance(disable_format_transform, bool):
959
+ raise TypeError(f"For 'jit(options)', the type of '{option}' must be bool, "
960
+ f"but got {type(disable_format_transform)}.")
961
+
962
+
963
+ def _check_exec_order_value(option, exec_order):
964
+ """check exec_order option value"""
965
+ if not isinstance(exec_order, str):
966
+ raise TypeError(f"For 'jit(options)', the type of '{option}' must be str, but got {type(exec_order)}.")
967
+
968
+ if exec_order not in ['bfs', 'dfs']:
969
+ raise ValueError(f"For '{option}', the value of '{option}' must be one of "
970
+ f"['bfs', 'dfs'], but got '{exec_order}'.")
971
+
972
+
973
+ def _check_ge_options_value(option, ge_options):
974
+ """check ge_options option value"""
975
+ if not isinstance(ge_options, dict):
976
+ raise TypeError(f"For 'jit(options)', the type of '{option}' must be dict, but got {type(ge_options)}.")
977
+
978
+ for level, options in ge_options.items():
979
+ if level not in ['global', 'session']:
980
+ raise ValueError(f"For '{option}', the key of '{option}' must be one of "
981
+ f"['global', 'session'], but got '{level}'.")
982
+
983
+ if not isinstance(options, dict):
984
+ raise TypeError(f"For '{option}', the type of {level} options must be dict, "
985
+ f"but got {type(options)}. The error options: {options}.")
986
+
987
+ for key, value in options.items():
988
+ if not isinstance(key, str):
989
+ raise TypeError(f"For '{option}', the type of key and value must be str, "
990
+ f"but got {type(key)}. The error key is {key}.")
991
+ if not isinstance(value, str):
992
+ raise TypeError(f"For '{option}', the type of key and value must be str, "
993
+ f"but got {type(value)}. The error value is {value}")
994
+
995
+
996
+ def _check_infer_boost_value(option, value):
997
+ """check infer_boost option value"""
998
+ if not isinstance(value, str):
999
+ raise TypeError(f"For 'jit(options)', the type of '{option}' must be str, but got {type(value)}.")
1000
+
1001
+ if value not in ['on', 'off']:
1002
+ raise ValueError(f"For '{option}', the value of '{option}' must be one of ['on', 'off'], but got '{value}'.")
1003
+
1004
+
1005
+ def _check_option_value(option, value):
1006
+ """check jit options wiwh value"""
1007
+ option_valuecheck_funcs = {
1008
+ 'disable_format_transform': _check_disable_format_transform_value,
1009
+ 'exec_order': _check_exec_order_value,
1010
+ 'ge_options': _check_ge_options_value,
1011
+ 'infer_boost': _check_infer_boost_value,
1012
+ }
1013
+ if option in option_valuecheck_funcs:
1014
+ option_valuecheck_funcs[option](option, value)
1015
+ else:
1016
+ logger.warning(f"For 'jit(options)', the option argument '{option}' is not recognized, please check!"
1017
+ f"For detailed usage of 'jit(options)', please refer to the Mindspore official website.")
1018
+
1019
+
1020
+ def _check_options(options, backend):
1021
+ """Check jit options"""
1022
+ # check whether there are deprecated parameters in the dict `options`.
1023
+ deprecated_args = {'mode': 'capture_mode', 'input_signature': 'dynamic', 'hash_args: ': '',
1024
+ 'jit_config': 'jit_level, fullgraph or options', 'compile_once': ''}
1025
+ for key, value in deprecated_args.items():
1026
+ if key in options:
1027
+ log = f"For 'jit', the parameter '{key}' has been deprecated."
1028
+ if value != '':
1029
+ log += f" Please use the parameter '{value}' instead. For more details, please refer to " \
1030
+ f"https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.jit.html."
1031
+ logger.warning(log)
1032
+ del options[key]
1033
+
1034
+ # check options' device, backend and value
1035
+ for option, value in options.items():
1036
+ _check_option_backend(option, backend)
1037
+ _check_option_value(option, value)
1038
+
1039
+
1040
+ def jit(
1041
+ function: Optional[Callable] = None,
1042
+ *,
1043
+ capture_mode: str = "ast",
1044
+ jit_level: str = "O0",
1045
+ dynamic: int = 0,
1046
+ fullgraph: bool = False,
1047
+ backend: str = "",
1048
+ **options):
807
1049
  """
808
1050
  Create a callable MindSpore graph from a Python function.
809
1051
 
810
1052
  This allows the MindSpore runtime to apply optimizations based on graph.
811
1053
 
812
1054
  Note:
813
- - If `input_signature` is specified, each input of `fn` must be a Tensor. And the input arguments for `fn`
814
- will not accept `**kwargs`.
815
- - It is not supported to run a function with decoration @jit(mode=“PIJit”)
816
- in static graph mode, in which case the decoration @jit(mode=“PIJit”) is considered invalid.
817
- - Calls to functions with decorated @jit(mode=“PIJit”) inside functions
818
- decorated with @jit(mode=“PIJit”) are not supported,
819
- and the decoration @jit(mode=“PIJit”) is considered invalid.
1055
+ - It is not supported to run a function with decoration @jit(capture_mode=“bytecode”)
1056
+ in static graph mode, in which case the decoration @jit(capture_mode=“bytecode”) is considered invalid.
1057
+ - Calls to functions with decorated @jit(capture_mode=“bytecode”) inside functions
1058
+ decorated with @jit(capture_mode=“ast”) are not supported,
1059
+ and the decoration @jit(capture_mode=“bytecode”) is considered invalid.
820
1060
 
821
1061
  Args:
822
- fn (Function): The Python function that will be run as a graph. Default: ``None`` .
823
- mode (str): The type of jit used, the value of mode should be ``PIJit`` or ``PSJit``. Default: ``PSJit`` .
824
-
825
- - `PSJit <https://www.mindspore.cn/docs/en/master/model_train/program_form/static_graph.html>`_ :
826
- Parse python ast to build graph.
827
- - `PIJit <https://www.mindspore.cn/docs/en/master/model_train/program_form/pynative.html#pijit>`_ :
828
- Parse python bytecode to build graph at runtime.
829
-
830
- input_signature (Union[Tuple, List, Dict, Tensor]): The Tensor which describes the input arguments. The
831
- shape and dtype of the Tensor will be supplied to this function. If `input_signature` is specified, the
832
- input parameters of `fn` cannot accept `**kwargs`, and the shape and dtype of actual inputs should keep the
833
- same as `input_signature`. Otherwise, TypeError will be raised. There are two mode for `input_signature`:
834
-
835
- - Full mode: Arguments is a Tuple, List or a Tensor, and they will be used as all compile inputs
836
- for graph-compiling.
837
- - Incremental mode: Argument is a Dict, and they will set to some of the graph inputs, which will be
838
- substituted into the input at the corresponding position for graph-compiling.
839
-
840
- Default: ``None`` .
841
-
842
- hash_args (Union[Object, List or Tuple of Objects]): The local free variables used inside `fn`,
843
- like functions or objects of class defined outside `fn`. Calling `fn` again with change of `hash_args`
844
- will trigger recompilation. Default: ``None`` .
845
- jit_config (JitConfig): Jit config for compile. Default: ``None`` .
846
- compile_once(bool): ``True``: The function would be compiled once when it was created many times.
847
- But it may be wrong if the free variables were changed. ``False`` : It would be recompiled when
848
- it was created again.
849
- Default: ``False`` .
1062
+ function (Function, optional): The Python function that will be run as a graph. Default: ``None``.
1063
+
1064
+ Keyword Args:
1065
+ capture_mode (str, optional): The method to create a callable MindSpore graph. The value of capture_mode
1066
+ should be ``ast`` , ``bytecode`` or ``trace`` . Default: ``ast`` .
1067
+
1068
+ - `ast <https://www.mindspore.cn/docs/en/r2.5.0/model_train/program_form/static_graph.html>`_ :
1069
+ Parse Python ast to build graph.
1070
+ - `bytecode <https://www.mindspore.cn/docs/en/r2.5.0/model_train/program_form/pynative.html#pijit>`_ :
1071
+ Parse Python bytecode to build graph at runtime. This is an experimental prototype that is subject to
1072
+ change and/or deletion.
1073
+ - `trace` : Trace the execution of Python code to build graph. This is an experimental prototype that is
1074
+ subject to change and/or deletion.
1075
+
1076
+ jit_level (str, optional): Used to control the compilation optimization level. Currently is only effective
1077
+ with default backend. The value of jit_level should be ``O0`` or ``O1`` . Default: ``O0`` .
1078
+
1079
+ - `O0`: Except for optimizations that may affect functionality, all other optimizations are turned off.
1080
+ - `O1`: Using commonly used optimizations and automatic operator fusion optimizations. This optimization
1081
+ level is experimental and is being improved.
1082
+
1083
+ dynamic (int, optional): Whether dynamic shape compilation should be performed. Default: ``0``. The value range
1084
+ is as follows:
1085
+
1086
+ - `0`: Do not perform dynamic shape compilation.
1087
+ - `1`: Enable dynamic shape compilation and automatically detect shape changes.
1088
+
1089
+ fullgraph (bool, optional): Whether to capture the entire function into graph. If False, jit attempts to
1090
+ be compatible with all Python syntax in the function as much as possible. If True, we require that the
1091
+ entire function can be captured into graph. If this is not possible (that is, if there is Python syntax
1092
+ not supported), then it will raise an exception. This currently only applies when capture_mode is ast.
1093
+ Default: ``False``.
1094
+ backend (str, optional): The compilation backend to be used. If this parameter is not set, the framework will
1095
+ use ``GE`` backend for Atlas training series products and ``ms_backend`` backend for others including Atlas
1096
+ A2 training series products by default.
1097
+
1098
+ - `ms_backend`: Adopt KernelByKernel execution mode.
1099
+ - `GE`: Adopt Sink execution mode. The whole model will be sinked to device to execute, only applicable to
1100
+ the top cell of model. And only can be used in Ascend platform.
1101
+
1102
+ **options (dict): A dictionary of options to pass to the compilation backend.
1103
+
1104
+ Some options are device specific, see the below table for details:
1105
+
1106
+ +---------------------------+---------------------------+-------------------------+
1107
+ | Option Parameters | Hardware Platform Support | Backend Support |
1108
+ +===========================+===========================+=========================+
1109
+ | disable_format_transform | GPU | ms_backend |
1110
+ +---------------------------+---------------------------+-------------------------+
1111
+ | exec_order | Ascend | ms_backend |
1112
+ +---------------------------+---------------------------+-------------------------+
1113
+ | ge_options | Ascend | GE |
1114
+ +---------------------------+---------------------------+-------------------------+
1115
+ | infer_boost | Ascend | ms_backend |
1116
+ +---------------------------+---------------------------+-------------------------+
1117
+
1118
+ - disable_format_transform (bool, optional): Whether to disable the automatic format transform function
1119
+ from NCHW to NHWC. When the network training performance of fp16 is worse than fp32,
1120
+ `disable_format_transform` can be set to ``True`` to try to improve training performance.
1121
+ Default: ``False`` .
1122
+ - exec_order (str, optional): Set the sorting method for operator execution, currently only two sorting
1123
+ methods are supported: ``bfs`` and ``dfs`` . Default: ``bfs`` .
1124
+
1125
+ - `bfs`: The default sorting method, breadth priority, good communication masking, relatively good
1126
+ performance.
1127
+ - `dfs`: An optional sorting method, depth-first sorting. The performance is relatively worse than that
1128
+ of bfs execution order, but it occupies less memory. It is recommended to try dfs in scenarios where
1129
+ other execution orders run out of memory (OOM).
1130
+
1131
+ - ge_options (dict): Set options for ge backend. The options are divided into two categories: global,
1132
+ and session. This is an experimental prototype that is subject to change and/or deletion.
1133
+ For detailed information, please refer to `Ascend community <https://www.hiascend.com/document/detail/zh/canncommercial/80RC3/apiref/ascendgraphapi/atlasgeapi_07_0146.html>`_ .
1134
+
1135
+ - global (dict): Set global options.
1136
+ - session (dict): Set session options.
1137
+
1138
+ - infer_boost (str, optional): Used to control the inference mode. Default: ``off``, which means
1139
+ the inference mode is disabled. The range is as follows:
1140
+
1141
+ - `on`: Enable inference mode, get better infer performance.
1142
+ - `off`: Disable inference mode, use forward for inference. The performance is poor.
850
1143
 
851
1144
  Returns:
852
1145
  Function, if `fn` is not None, returns a callable function that will execute the compiled function; If `fn` is
@@ -865,12 +1158,12 @@ def jit(fn=None, mode="PSJit", input_signature=None, hash_args=None, jit_config=
865
1158
  >>> x = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
866
1159
  >>> y = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
867
1160
  ...
868
- >>> # create a callable MindSpore graph by calling decorator @jit
1161
+ >>> # create a callable MindSpore graph by calling jit
869
1162
  >>> def tensor_add(x, y):
870
1163
  ... z = x + y
871
1164
  ... return z
872
1165
  ...
873
- >>> tensor_add_graph = jit(fn=tensor_add)
1166
+ >>> tensor_add_graph = jit(function=tensor_add)
874
1167
  >>> out = tensor_add_graph(x, y)
875
1168
  ...
876
1169
  >>> # create a callable MindSpore graph through decorator @jit
@@ -881,180 +1174,70 @@ def jit(fn=None, mode="PSJit", input_signature=None, hash_args=None, jit_config=
881
1174
  ...
882
1175
  >>> out = tensor_add_with_dec(x, y)
883
1176
  ...
884
- >>> # create a callable MindSpore graph through decorator @jit with input_signature parameter
885
- >>> @jit(input_signature=(Tensor(np.ones([1, 1, 3, 3]).astype(np.float32)),
886
- ... Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))))
887
- ... def tensor_add_with_sig(x, y):
1177
+ >>> # create a callable MindSpore graph and capture the entire function into the graph
1178
+ >>> @jit(fullgraph=True)
1179
+ ... def tensor_add_fullgraph(x, y):
888
1180
  ... z = x + y
889
1181
  ... return z
890
1182
  ...
891
- >>> out = tensor_add_with_sig(x, y)
892
- ...
893
- >>> @jit(input_signature={"y": Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))})
894
- ... def tensor_add_with_sig_1(x, y):
895
- ... z = x + y
896
- ... return z
897
- ...
898
- >>> out1 = tensor_add_with_sig_1(x, y)
899
- ...
900
- ... # Set hash_args as fn, otherwise cache of compiled closure_fn will not be reused.
901
- ... # While fn differs during calling again, recompilation will be triggered.
902
- >>> def func(x):
903
- ... return ops.exp(x)
904
- ...
905
- >>> def closure_fn(x, fn):
906
- ... @jit(hash_args=fn)
907
- ... def inner_fn(a):
908
- ... return fn(a)
909
- ... return inner_fn(x)
910
- ...
911
- >>> inputs = Tensor(np.ones([10, 10, 10]).astype(np.float32))
912
- >>> for i in range(10):
913
- ... closure_fn(inputs, func)
914
- ...
915
- ... # Set compile_once = True, otherwise the train_step will be compiled again.
916
- >>> def train(x):
917
- ... @jit(compile_once = True)
918
- ... def train_step(x):
919
- ... return ops.exp(x)
920
- ... for i in range(10):
921
- ... train_step(x)
922
- ...
923
- >>> inputs = Tensor(np.ones([10, 10, 10]).astype(np.float32))
924
- >>> for i in range(10):
925
- ... train(inputs)
1183
+ >>> out = tensor_add_fullgraph(x, y)
926
1184
  """
927
1185
 
928
- def wrap_mindspore(func):
929
- if not isinstance(compile_once, bool):
930
- logger.warning(f"The parameter `compile_once` of jit should be a bool, "
931
- f"but got {type(compile_once)}.")
932
- if hash_args:
933
- hash_obj = _get_jit_hash(hash_args)
934
- elif compile_once:
935
- hash_obj = 0
936
- else:
1186
+ capture_mode = Validator.check_string(capture_mode, ["ast", "bytecode", "trace"], "capture_mode", "jit")
1187
+ jit_level = Validator.check_string(jit_level, ["O0", "O1"], "jit_level", "jit")
1188
+ dynamic = Validator.check_int_range(dynamic, 0, 1, Validator.INC_BOTH, "dynamic", "jit")
1189
+ fullgraph = Validator.check_bool(fullgraph, "fullgraph", "jit")
1190
+ if backend == "":
1191
+ backend = "GE" if MSContext.get_instance().get_ascend_soc_version() == "ascend910" else "ms_backend"
1192
+ backend = Validator.check_string(backend, ["ms_backend", "GE"], "backend", "jit")
1193
+ jit_syntax_level = "LAX" if fullgraph is False else "STRICT"
1194
+ hash_obj = _get_hash_obj(options)
1195
+ _check_options(options, backend)
1196
+ options_str = json.dumps(options)
1197
+ infer_boost = options['infer_boost'] if 'infer_boost' in options else "off"
1198
+ exc_mode = options['exc_mode'] if 'exc_mode' in options else "auto"
1199
+ jit_config = JitConfig(jit_level=jit_level, exc_mode=exc_mode, jit_syntax_level=jit_syntax_level,
1200
+ infer_boost=infer_boost, backend=backend, options=options_str)
1201
+
1202
+ def wrap_func(func):
1203
+ nonlocal hash_obj
1204
+ if hash_obj is None or not _is_inner_func(func):
937
1205
  hash_obj = int(time.time() * 1e9)
938
1206
 
939
- dyn_args = _process_dyn_args(func, input_signature)
940
-
941
1207
  @wraps(func)
942
1208
  def staging_specialize(*args, **kwargs):
943
1209
  if os.getenv("MS_JIT") == '0':
944
1210
  return func(*args, **kwargs)
945
1211
 
946
1212
  args, kwargs = _handle_func_args(func, *args, **kwargs)
947
-
948
1213
  process_obj = None
949
1214
  if args and not isinstance(args[0], PythonTensor) and hasattr(args[0], func.__name__):
950
1215
  process_obj = args[0]
951
- # only the function or cell instance wrapped by shard will fall into this branch
952
- if _is_pynative_parallel() and func.__name__ == _PYNATIVE_PARALLEL_FUNC_NAME:
953
- process_obj = hash_args
954
1216
  # Handle auto mixed precision strategy.
955
1217
  if not hasattr(func, "amp_strategy"):
956
1218
  if isinstance(func, types.MethodType):
957
1219
  setattr(func.__func__, "amp_strategy", get_curr_amp_strategy())
958
1220
  else:
959
1221
  setattr(func, "amp_strategy", get_curr_amp_strategy())
960
- out = _MindsporeFunctionExecutor(func, hash_obj, dyn_args, process_obj, jit_config)(*args, **kwargs)
1222
+
1223
+ ms_function_executor = _JitExecutor(func, hash_obj, None, process_obj, jit_config, dynamic)
1224
+ out = ms_function_executor(*args, **kwargs)
961
1225
  return out
962
1226
 
963
1227
  return staging_specialize
964
1228
 
965
- wrap_func = wrap_mindspore
966
- if mode == "PIJit":
967
- wrap_func = PIJitCaptureContext(jit_config, input_signature)
1229
+ if capture_mode == "bytecode":
1230
+ wrap_func = PIJitCaptureContext(jit_config)
1231
+ elif capture_mode == "trace":
1232
+ if function is not None:
1233
+ return _jit_trace(function)
1234
+ return _jit_trace
968
1235
 
969
- if fn is not None:
970
- return wrap_func(fn)
1236
+ if function is not None:
1237
+ return wrap_func(function)
971
1238
  return wrap_func
972
1239
 
973
1240
 
974
- def ms_function(fn=None, input_signature=None, hash_args=None, jit_config=None):
975
- """
976
- Create a callable MindSpore graph from a Python function.
977
-
978
- This allows the MindSpore runtime to apply optimizations based on graph.
979
-
980
- Note:
981
- - `ms_function` will be deprecated and removed in a future version. Please use :func:`mindspore.jit` instead.
982
- - If `input_signature` is specified, each input of `fn` must be a Tensor. And the input arguments for `fn`
983
- will not accept `**kwargs`.
984
-
985
- Args:
986
- fn (Function): The Python function that will be run as a graph. Default: ``None`` .
987
- input_signature (Tensor): The Tensor which describes the input arguments. The shape and dtype of the Tensor
988
- will be supplied to this function. The shape and dtype of actual inputs of `fn` should
989
- keep the same as input_signature. Otherwise, TypeError will be raised. Default: ``None`` .
990
- hash_args (Union[Object, List or Tuple of Objects]): The local free variables used inside `fn`,
991
- like functions or objects of class defined outside `fn`. Calling `fn` again with change of `hash_args`
992
- will trigger recompilation. Default: ``None`` .
993
- jit_config (JitConfig): Jit config for compile. Default: ``None`` .
994
-
995
- Returns:
996
- Function, if `fn` is not None, returns a callable function that will execute the compiled function; If `fn` is
997
- None, returns a decorator and when this decorator invokes with a single `fn` argument, the callable function is
998
- equal to the case when `fn` is not None.
999
-
1000
- Supported Platforms:
1001
- ``Ascend`` ``GPU`` ``CPU``
1002
-
1003
- Examples:
1004
- >>> import numpy as np
1005
- >>> from mindspore import Tensor
1006
- >>> from mindspore import ops
1007
- >>> from mindspore import ms_function
1008
- ...
1009
- >>> x = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
1010
- >>> y = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
1011
- ...
1012
- >>> # create a callable MindSpore graph by calling ms_function
1013
- >>> def tensor_add(x, y):
1014
- ... z = x + y
1015
- ... return z
1016
- ...
1017
- >>> tensor_add_graph = ms_function(fn=tensor_add)
1018
- >>> out = tensor_add_graph(x, y)
1019
- ...
1020
- >>> # create a callable MindSpore graph through decorator @ms_function
1021
- >>> @ms_function
1022
- ... def tensor_add_with_dec(x, y):
1023
- ... z = x + y
1024
- ... return z
1025
- ...
1026
- >>> out = tensor_add_with_dec(x, y)
1027
- ...
1028
- >>> # create a callable MindSpore graph through decorator @ms_function with input_signature parameter
1029
- >>> @ms_function(input_signature=(Tensor(np.ones([1, 1, 3, 3]).astype(np.float32)),
1030
- ... Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))))
1031
- ... def tensor_add_with_sig(x, y):
1032
- ... z = x + y
1033
- ... return z
1034
- ...
1035
- >>> out = tensor_add_with_sig(x, y)
1036
- ...
1037
- ... # Set hash_args as fn, otherwise cache of compiled `closure_fn` will not be reused.
1038
- ... # While fn differs during calling again, recompilation will be triggered.
1039
- >>> def func(x):
1040
- ... return ops.exp(x)
1041
- ...
1042
- >>> def closure_fn(x, fn):
1043
- ... @ms_function(hash_args=fn)
1044
- ... def inner_fn(a):
1045
- ... return fn(a)
1046
- ... return inner_fn(x)
1047
- ...
1048
- >>> inputs = Tensor(np.ones([10, 10, 10]).astype(np.float32))
1049
- >>> for i in range(10):
1050
- ... closure_fn(inputs, func)
1051
- """
1052
-
1053
- logger.warning("'mindspore.ms_function' will be deprecated and removed in a future version. "
1054
- "Please use 'mindspore.jit' instead.")
1055
- return jit(fn=fn, input_signature=input_signature, hash_args=hash_args, jit_config=jit_config)
1056
-
1057
-
1058
1241
  def _core(fn=None, **flags):
1059
1242
  """
1060
1243
  A decorator that adds a flag to the function.
@@ -1147,69 +1330,6 @@ def _no_recursive(callable_obj):
1147
1330
  return callable_obj
1148
1331
 
1149
1332
 
1150
- def ms_class(cls):
1151
- """
1152
- Class decorator for user-defined classes.
1153
-
1154
- This allows MindSpore to identify user-defined classes and thus obtain their attributes and methods.
1155
-
1156
- Note:
1157
- `ms_class` will be deprecated and removed in a future version. Please use :func:`mindspore.jit_class` instead.
1158
-
1159
- Args:
1160
- cls (Class): User-defined class.
1161
-
1162
- Returns:
1163
- Class.
1164
-
1165
- Raises:
1166
- TypeError: If ms_class is used for non-class types or nn.Cell.
1167
- AttributeError: If the private attributes or magic methods of the class decorated with ms_class is called.
1168
-
1169
- Supported Platforms:
1170
- ``Ascend`` ``GPU`` ``CPU``
1171
-
1172
- Examples:
1173
- >>> import mindspore.nn as nn
1174
- >>> from mindspore import ms_class
1175
- ...
1176
- >>> @ms_class
1177
- ... class UserDefinedNet:
1178
- ... def __init__(self):
1179
- ... self.value = 10
1180
- ...
1181
- ... def func(self, x):
1182
- ... return 2 * x
1183
- ...
1184
- >>> class Net(nn.Cell):
1185
- ... def __init__(self):
1186
- ... super(Net, self).__init__()
1187
- ... self.net = UserDefinedNet()
1188
- ...
1189
- ... def construct(self, x):
1190
- ... out = self.net.value + self.net.func(x)
1191
- ... return out
1192
- ...
1193
- >>> net = Net()
1194
- >>> out = net(5)
1195
- >>> print(out)
1196
- 20
1197
- """
1198
-
1199
- logger.warning("'mindspore.ms_class' will be deprecated and removed in a future version. "
1200
- "Please use 'mindspore.jit_class' instead.")
1201
-
1202
- # Check if cls is of type class.
1203
- if not inspect.isclass(cls):
1204
- raise TypeError(f'Decorator ms_class can only be used for class type, but got {cls}.')
1205
- # Check if cls is nn.Cell.
1206
- if issubclass(cls, ms.nn.Cell):
1207
- raise TypeError(f"Decorator ms_class is used for user-defined classes and cannot be used for nn.Cell: {cls}.")
1208
- logger.info(f'Found ms_class: {cls}.')
1209
- setattr(cls, '__ms_class__', True)
1210
- return cls
1211
-
1212
-
1213
1333
  def jit_class(cls):
1214
1334
  """
1215
1335
  Class decorator for user-defined classes.
@@ -1266,28 +1386,6 @@ def jit_class(cls):
1266
1386
  return cls
1267
1387
 
1268
1388
 
1269
- def set_adapter_config(config):
1270
- """
1271
- Register configuration information for MSAdapter.
1272
-
1273
- Args:
1274
- config (dict): Configuration information.
1275
- """
1276
- if not isinstance(config, dict):
1277
- raise TypeError(f"The input argument of 'set_adapter_config' should be a dict, but got {config}.")
1278
- for key, value in config.items():
1279
- if key == "Tensor":
1280
- ms_adapter_registry.register_tensor(value)
1281
- elif key == "Parameter":
1282
- ms_adapter_registry.register_parameter(value)
1283
- elif key == "convert_object_map":
1284
- ms_adapter_registry.register_convert_map(value)
1285
- elif key == "convert_adapter_tensor_map":
1286
- ms_adapter_registry.register_convert_adapter_tensor_map(value)
1287
- else:
1288
- raise ValueError(f"Unsupported key in adapter config: {key}")
1289
-
1290
-
1291
1389
  def _function_forbid_reuse(func):
1292
1390
  if not inspect.isfunction(func):
1293
1391
  raise TypeError(f'Decorator _function_forbid_reuse can only be used for function type, but got {func}.')
@@ -1351,8 +1449,6 @@ class _no_grad(contextlib.ContextDecorator):
1351
1449
  self.prev_state = False
1352
1450
 
1353
1451
  def __enter__(self):
1354
- if context.get_context("mode") == context.GRAPH_MODE:
1355
- raise RuntimeError("For no_grad feature, currently only support Pynative mode, but got Graph mode.")
1356
1452
  self.prev_state = _pynative_executor.enable_grad()
1357
1453
  _pynative_executor.set_enable_grad(False)
1358
1454
 
@@ -1481,7 +1577,24 @@ class _PyNativeExecutor:
1481
1577
  Return:
1482
1578
  None.
1483
1579
  """
1484
- return self._executor.grad(grad, obj, weights, grad_position, *args)
1580
+ return self._executor.grad(grad, obj, weights, grad_position, False, *args)
1581
+
1582
+ def grad_aux(self, obj, grad, weights, grad_position, *args):
1583
+ """
1584
+ Run grad graph with aux
1585
+
1586
+ Args:
1587
+ obj (Function/Cell): The function or cell instance.
1588
+ grad (GradOperation): The gradoperation object.
1589
+ weights (ParameterTuple): The weights of cell instance.
1590
+ grad_position (Union(int, tuple[int])): If int, get the gradient with respect to single input.
1591
+ If tuple, get the gradients with respect to selected inputs. 'grad_position' begins with 0. Default: 0.
1592
+ args (tuple): Function or cell input arguments.
1593
+
1594
+ Return:
1595
+ None.
1596
+ """
1597
+ return self._executor.grad(grad, obj, weights, grad_position, True, *args)
1485
1598
 
1486
1599
  def clear_res(self):
1487
1600
  """
@@ -1501,18 +1614,18 @@ class _PyNativeExecutor:
1501
1614
  """
1502
1615
  self._executor.sync()
1503
1616
 
1504
- def grad_jit(self, output, *args):
1617
+ def grad_jit(self, *args):
1505
1618
  """
1506
1619
  Building grad graph decorated by jit.
1507
1620
 
1508
1621
  Args:
1509
- output (tuple): The function or cell decorated by jit output object.
1510
1622
  args (tuple): Function or cell decorated by jit input arguments.
1511
1623
 
1512
1624
  Return:
1513
- None.
1625
+ output: The output object of function or cell decorated by jit.
1514
1626
  """
1515
- return self._executor.grad_jit(output, *args)
1627
+ output = self._executor.grad_jit(*args)
1628
+ return output
1516
1629
 
1517
1630
  def call_custom_bprop(self, obj, output, *args, **kwargs):
1518
1631
  """
@@ -1617,6 +1730,15 @@ class _PyNativeExecutor:
1617
1730
  """
1618
1731
  self._executor.set_is_run_recompute(status)
1619
1732
 
1733
+ def high_order(self):
1734
+ """
1735
+ Is high order of current scene, this is a inner interface.
1736
+
1737
+ Return:
1738
+ Bool.
1739
+ """
1740
+ return self._executor.high_order()
1741
+
1620
1742
  def set_cell_use_dynamic_shape_process(self, flag):
1621
1743
  """
1622
1744
  Set the dynamic shape flag of eval process.
@@ -1699,7 +1821,6 @@ class _CellGraphExecutor:
1699
1821
  # create needed graph by lazy mode
1700
1822
  self.is_init = False
1701
1823
  self.enable_tuple_broaden = False
1702
- self.obfuscate_config = None # used for model's dynamic obfuscation
1703
1824
  self._graph_executor = GraphExecutor_.get_instance()
1704
1825
  self._graph_executor.set_py_exe_path(sys.executable)
1705
1826
  self._graph_executor.set_kernel_build_server_dir(os.path.split(kernel_build_server.__file__)[0] + os.sep)
@@ -1791,6 +1912,7 @@ class _CellGraphExecutor:
1791
1912
  Str, the full phase of the cell.
1792
1913
  Bool, if the graph has been compiled before, return False, else return True.
1793
1914
  """
1915
+ _init_auto_parallel_context(obj)
1794
1916
  obj.__parse_method__ = 'construct'
1795
1917
  if not hasattr(obj, obj.__parse_method__):
1796
1918
  raise AttributeError(
@@ -1803,8 +1925,12 @@ class _CellGraphExecutor:
1803
1925
  self.enable_tuple_broaden = obj.enable_tuple_broaden
1804
1926
  logger.debug(f"Convert the network: {do_convert}.")
1805
1927
  self._graph_executor.set_enable_tuple_broaden(self.enable_tuple_broaden)
1928
+
1806
1929
  key = self._graph_executor.generate_arguments_key(obj, args, kwargs, self.enable_tuple_broaden)
1807
1930
  obj.arguments_key = str(key)
1931
+
1932
+ obj.arguments_key = obj.arguments_key + "." + _get_hook_key(*args, **kwargs)
1933
+
1808
1934
  # When exist parameter in the top graph inputs, need check if the parameter object has changed.
1809
1935
  parameter_ids = _get_parameter_ids(args, kwargs)
1810
1936
  if parameter_ids != "":
@@ -1814,11 +1940,12 @@ class _CellGraphExecutor:
1814
1940
  obj.phase_cache[raw_phase] = phase
1815
1941
  update_auto_dynamic_shape_phase(args, key_id, phase)
1816
1942
  obj.current_phase = phase
1817
- if phase in obj.compile_cache and self.has_compiled(phase):
1943
+ if phase in obj.compile_cache and self.has_compiled(phase) and not parameter_hook_updated():
1818
1944
  logger.debug("%r graph has existed.", phase)
1819
1945
  # Release resource should be released when CompileInner won't be executed, such as cur_convert_input_
1820
1946
  # generated in generate_arguments_key.
1821
1947
  self._graph_executor.clear_compile_arguments_resource()
1948
+ _clear_auto_parallel_context(obj)
1822
1949
  return phase, False
1823
1950
 
1824
1951
  full_function_name = obj.__class__.__name__ + '.' + str(obj.instance_count) + '.' + str(id(type(obj)))
@@ -1836,10 +1963,12 @@ class _CellGraphExecutor:
1836
1963
  else:
1837
1964
  jit_config_dict = JitConfig().jit_config_dict
1838
1965
  self._graph_executor.set_jit_config(jit_config_dict)
1839
- result = self._graph_executor.compile(obj, args, kwargs, phase, self._use_vm_mode())
1966
+ gc.collect()
1967
+ result = self._graph_executor.compile(obj, args, kwargs, phase)
1840
1968
  obj.compile_cache.add(phase)
1841
1969
  if not result:
1842
1970
  raise RuntimeError("Executor compile failed.")
1971
+ set_parameter_hook_updated(False)
1843
1972
  graph = self._graph_executor.get_func_graph(phase)
1844
1973
 
1845
1974
  if graph is None:
@@ -1856,6 +1985,7 @@ class _CellGraphExecutor:
1856
1985
  self._build_data_graph(obj, phase)
1857
1986
  elif BROADCAST_PHASE not in phase and _get_parameter_broadcast():
1858
1987
  _parameter_broadcast(obj)
1988
+ _clear_auto_parallel_context(obj)
1859
1989
  return phase, True
1860
1990
 
1861
1991
  def _update_param_node_default_input(self, phase, replace):
@@ -1875,7 +2005,7 @@ class _CellGraphExecutor:
1875
2005
  return self._graph_executor.get_allreduce_fusion(real_phase)
1876
2006
 
1877
2007
  def __call__(self, obj, *args, phase='predict'):
1878
- if context.get_context("precompile_only") or _is_role_sched():
2008
+ if context.get_context("precompile_only") or os.getenv('MS_DEV_PRECOMPILE_ONLY') == '1' or _is_role_sched():
1879
2009
  return None
1880
2010
  return self.run(obj, *args, phase=phase)
1881
2011
 
@@ -1935,25 +2065,12 @@ class _CellGraphExecutor:
1935
2065
  """Clear the memory resource of a network."""
1936
2066
  self._graph_executor.del_net_res(obj, net_id)
1937
2067
 
1938
- def _get_branch_control_input(self):
1939
- if ('obf_ratio' not in self.obfuscate_config.keys()) or (
1940
- 'obf_random_seed' not in self.obfuscate_config.keys()):
1941
- raise ValueError("'obf_ratio' and 'obf_random_seed' must be in obfuscate_config.")
1942
- obf_random_seed = self.obfuscate_config.get('obf_random_seed')
1943
- if obf_random_seed == 0:
1944
- branch_control_input = 0
1945
- else:
1946
- branch_control_input = _generate_branch_control_input(obf_random_seed)
1947
- return branch_control_input
1948
-
1949
2068
  def _get_func_graph(self, obj, exec_id, use_prefix=False):
1950
2069
  """Get func graph from pipeline."""
1951
2070
  if use_prefix:
1952
2071
  exec_id = exec_id + '.' + obj.arguments_key
1953
2072
  if self._graph_executor.has_compiled(exec_id) is False:
1954
2073
  return None
1955
- if self.obfuscate_config is not None:
1956
- raise ValueError('For get func graph, obfuscate_config is currently not supported now.')
1957
2074
  return self._graph_executor.get_func_graph(exec_id)
1958
2075
 
1959
2076
  def _get_func_graph_proto(self, obj, exec_id, ir_type="onnx_ir", use_prefix=False, incremental=False):
@@ -1962,11 +2079,6 @@ class _CellGraphExecutor:
1962
2079
  exec_id = exec_id + '.' + obj.arguments_key
1963
2080
  if self._graph_executor.has_compiled(exec_id) is False:
1964
2081
  return None
1965
- if self.obfuscate_config is not None:
1966
- branch_control_input = self._get_branch_control_input()
1967
- return self._graph_executor.get_obfuscate_func_graph_proto(exec_id, incremental,
1968
- self.obfuscate_config['obf_ratio'],
1969
- branch_control_input)
1970
2082
  return self._graph_executor.get_func_graph_proto(exec_id, ir_type, incremental)
1971
2083
 
1972
2084
  def get_optimize_graph_proto(self, obj):
@@ -2004,6 +2116,8 @@ def ms_memory_recycle():
2004
2116
  """
2005
2117
  if ms_compile_cache:
2006
2118
  _cell_graph_executor.del_net_res(None, ms_compile_cache)
2119
+ if os.getenv('MS_DEV_JIT_PIPELINE') != '0':
2120
+ JitExecutor_.get_instance().del_net_res(None, ms_compile_cache)
2007
2121
  ms_compile_cache.clear()
2008
2122
  for cell_cache in cells_compile_cache.values():
2009
2123
  if cell_cache:
@@ -2012,28 +2126,22 @@ def ms_memory_recycle():
2012
2126
  _ms_memory_recycle()
2013
2127
 
2014
2128
 
2015
- def _generate_branch_control_input(obf_random_seed):
2016
- """Generate append network input for dynamic obfuscation in random seed mode."""
2017
- seed_max = 2 ** 32 - 1
2018
- int_max = 2 ** 31 - 1
2019
- np.random.seed(obf_random_seed % seed_max)
2020
- # generate a string as hash function inputs
2021
- word_repo = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + "abcdefghigklmnopqrstuvwxyz" + "0123456789"
2022
- repo_len = len(word_repo)
2023
- sha_string = ''
2024
- string_len = 1024 * 1024
2025
- for _ in range(string_len):
2026
- rand_index = np.random.randint(0, repo_len)
2027
- sha_string += word_repo[rand_index]
2028
- # get hash result
2029
- sha_result = hashlib.sha256(sha_string.encode('utf-8')).hexdigest() # len is 64
2030
- branch_control_input = 1
2031
- hex_base = 16
2032
- for item in sha_result:
2033
- if int(item, hex_base) > 0:
2034
- branch_control_input *= int(item, hex_base)
2035
- branch_control_input %= int_max
2036
- return branch_control_input
2129
+ def set_recursion_limit(recursion_limit=1000):
2130
+ """
2131
+ Specify the recursion depth limit of function call before compiling graph.
2132
+ It needs to be call when the nested function call is too deep or the number of sub graphs is too large.
2133
+ If recursion_limit is set larger than before, the system max stack depth should be set larger too,
2134
+ otherwise a `core dumped` exception may be raised because of system stack overflow.
2135
+
2136
+ Args:
2137
+ recursion_limit (int, optional): The recursion depth limit. Must be a positive integer. Default: ``1000`` .
2138
+
2139
+ Examples:
2140
+ >>> import mindspore as ms
2141
+ >>> ms.set_recursion_limit(10000)
2142
+ """
2143
+ recursion_limit = Validator.check_positive_int(recursion_limit)
2144
+ GraphExecutor_.get_instance().set_max_call_depth(recursion_limit)
2037
2145
 
2038
2146
 
2039
2147
  def _bind_device_context():
@@ -2058,4 +2166,4 @@ def flops_collection(phase='train'):
2058
2166
  _cell_graph_executor = _CellGraphExecutor()
2059
2167
  _pynative_executor = _PyNativeExecutor()
2060
2168
 
2061
- __all__ = ['ms_function', 'ms_memory_recycle', 'ms_class', 'jit', 'jit_class', 'flops_collection']
2169
+ __all__ = ['ms_memory_recycle', 'jit', 'jit_class', 'flops_collection']