mindspore 2.4.10__cp311-cp311-win_amd64.whl → 2.6.0rc1__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (602) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  3. mindspore/Newtonsoft.Json.dll +0 -0
  4. mindspore/__init__.py +13 -6
  5. mindspore/_c_dataengine.cp311-win_amd64.pyd +0 -0
  6. mindspore/_c_expression.cp311-win_amd64.pyd +0 -0
  7. mindspore/_c_mindrecord.cp311-win_amd64.pyd +0 -0
  8. mindspore/_check_jit_forbidden_api.py +3 -0
  9. mindspore/_checkparam.py +3 -38
  10. mindspore/_deprecated/__init__.py +17 -0
  11. mindspore/_deprecated/jit.py +198 -0
  12. mindspore/_extends/builtin_operations.py +1 -1
  13. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
  14. mindspore/_extends/parse/__init__.py +6 -7
  15. mindspore/_extends/parse/compile_config.py +83 -0
  16. mindspore/_extends/parse/deprecated/__init__.py +0 -0
  17. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +394 -0
  18. mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
  19. mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
  20. mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
  21. mindspore/_extends/parse/parser.py +46 -197
  22. mindspore/_extends/parse/resources.py +1 -5
  23. mindspore/_extends/parse/standard_method.py +217 -98
  24. mindspore/_extends/pijit/__init__.py +2 -2
  25. mindspore/_extends/pijit/pijit_func_white_list.py +17 -12
  26. mindspore/_extends/pijit/tensor_func_list.py +27 -0
  27. mindspore/_extends/utils.py +1 -1
  28. mindspore/amp.py +11 -5
  29. mindspore/atlprov.dll +0 -0
  30. mindspore/avcodec-59.dll +0 -0
  31. mindspore/avdevice-59.dll +0 -0
  32. mindspore/avfilter-8.dll +0 -0
  33. mindspore/avformat-59.dll +0 -0
  34. mindspore/avutil-57.dll +0 -0
  35. mindspore/boost/__init__.py +2 -2
  36. mindspore/boost/base.py +3 -7
  37. mindspore/boost/boost_cell_wrapper.py +138 -43
  38. mindspore/c1.dll +0 -0
  39. mindspore/c1xx.dll +0 -0
  40. mindspore/c2.dll +0 -0
  41. mindspore/common/__init__.py +6 -3
  42. mindspore/common/_grad_function.py +56 -0
  43. mindspore/common/_pijit_context.py +14 -5
  44. mindspore/common/_register_for_tensor.py +1 -2
  45. mindspore/common/_stub_tensor.py +30 -14
  46. mindspore/common/_tensor_cpp_method.py +17 -0
  47. mindspore/common/_tensor_docs.py +4760 -0
  48. mindspore/common/api.py +435 -371
  49. mindspore/common/auto_dynamic_shape.py +41 -44
  50. mindspore/common/dtype.py +39 -36
  51. mindspore/common/dump.py +9 -6
  52. mindspore/common/file_system.py +9 -1
  53. mindspore/common/generator.py +2 -0
  54. mindspore/common/hook_handle.py +6 -2
  55. mindspore/common/initializer.py +13 -10
  56. mindspore/common/jit_begin_end.py +94 -0
  57. mindspore/common/jit_config.py +6 -1
  58. mindspore/common/jit_context.py +76 -0
  59. mindspore/common/jit_trace.py +378 -0
  60. mindspore/common/lazy_inline.py +9 -3
  61. mindspore/common/mindir_util.py +10 -2
  62. mindspore/common/mutable.py +5 -4
  63. mindspore/common/parameter.py +135 -52
  64. mindspore/common/seed.py +2 -2
  65. mindspore/common/sparse_tensor.py +23 -17
  66. mindspore/common/tensor.py +951 -1992
  67. mindspore/communication/__init__.py +7 -5
  68. mindspore/communication/_comm_helper.py +52 -2
  69. mindspore/communication/comm_func.py +240 -181
  70. mindspore/communication/management.py +95 -26
  71. mindspore/context.py +314 -566
  72. mindspore/dataset/__init__.py +65 -37
  73. mindspore/dataset/audio/__init__.py +2 -8
  74. mindspore/dataset/audio/transforms.py +3 -17
  75. mindspore/dataset/callback/ds_callback.py +2 -1
  76. mindspore/dataset/core/config.py +87 -6
  77. mindspore/dataset/engine/cache_admin.py +3 -3
  78. mindspore/dataset/engine/cache_client.py +6 -5
  79. mindspore/dataset/engine/datasets.py +292 -267
  80. mindspore/dataset/engine/datasets_audio.py +22 -8
  81. mindspore/dataset/engine/datasets_standard_format.py +46 -27
  82. mindspore/dataset/engine/datasets_text.py +78 -48
  83. mindspore/dataset/engine/datasets_user_defined.py +182 -116
  84. mindspore/dataset/engine/datasets_vision.py +120 -44
  85. mindspore/dataset/engine/iterators.py +283 -63
  86. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
  87. mindspore/dataset/engine/obs/util.py +8 -0
  88. mindspore/dataset/engine/queue.py +40 -0
  89. mindspore/dataset/engine/samplers.py +289 -43
  90. mindspore/dataset/engine/serializer_deserializer.py +3 -2
  91. mindspore/dataset/engine/validators.py +53 -11
  92. mindspore/dataset/text/__init__.py +7 -6
  93. mindspore/dataset/text/transforms.py +6 -5
  94. mindspore/dataset/text/utils.py +3 -3
  95. mindspore/dataset/transforms/__init__.py +0 -9
  96. mindspore/dataset/transforms/py_transforms_util.py +17 -0
  97. mindspore/dataset/transforms/transforms.py +31 -14
  98. mindspore/dataset/utils/browse_dataset.py +1 -1
  99. mindspore/dataset/vision/__init__.py +2 -9
  100. mindspore/dataset/vision/transforms.py +202 -158
  101. mindspore/dataset/vision/utils.py +7 -5
  102. mindspore/dataset/vision/validators.py +1 -2
  103. mindspore/device_context/__init__.py +21 -0
  104. mindspore/device_context/ascend/__init__.py +25 -0
  105. mindspore/device_context/ascend/device.py +72 -0
  106. mindspore/device_context/ascend/op_debug.py +153 -0
  107. mindspore/device_context/ascend/op_precision.py +193 -0
  108. mindspore/device_context/ascend/op_tuning.py +123 -0
  109. mindspore/{ops_generate/gen_constants.py → device_context/cpu/__init__.py} +6 -17
  110. mindspore/device_context/cpu/device.py +62 -0
  111. mindspore/device_context/cpu/op_tuning.py +43 -0
  112. mindspore/device_context/gpu/__init__.py +21 -0
  113. mindspore/device_context/gpu/device.py +70 -0
  114. mindspore/device_context/gpu/op_precision.py +67 -0
  115. mindspore/device_context/gpu/op_tuning.py +175 -0
  116. mindspore/device_manager.py +170 -0
  117. mindspore/dnnl.dll +0 -0
  118. mindspore/dpcmi.dll +0 -0
  119. mindspore/experimental/es/embedding_service.py +35 -27
  120. mindspore/experimental/llm_boost/__init__.py +1 -0
  121. mindspore/experimental/llm_boost/ascend_native/__init__.py +22 -0
  122. mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +211 -0
  123. mindspore/experimental/llm_boost/ascend_native/llm_boost.py +52 -0
  124. mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
  125. mindspore/experimental/llm_boost/atb/llama_boost.py +6 -1
  126. mindspore/experimental/llm_boost/register.py +1 -0
  127. mindspore/experimental/map_parameter.py +4 -4
  128. mindspore/experimental/optim/adadelta.py +6 -6
  129. mindspore/experimental/optim/adagrad.py +4 -4
  130. mindspore/experimental/optim/adam.py +7 -0
  131. mindspore/experimental/optim/adamax.py +4 -4
  132. mindspore/experimental/optim/adamw.py +4 -0
  133. mindspore/experimental/optim/asgd.py +1 -1
  134. mindspore/experimental/optim/lr_scheduler.py +73 -46
  135. mindspore/experimental/optim/radam.py +34 -31
  136. mindspore/experimental/optim/rprop.py +1 -1
  137. mindspore/experimental/optim/sgd.py +1 -1
  138. mindspore/hal/contiguous_tensors_handle.py +6 -10
  139. mindspore/hal/device.py +55 -53
  140. mindspore/hal/event.py +52 -52
  141. mindspore/hal/memory.py +157 -117
  142. mindspore/hal/stream.py +150 -109
  143. mindspore/include/api/context.h +0 -1
  144. mindspore/include/dataset/constants.h +7 -4
  145. mindspore/include/dataset/execute.h +2 -2
  146. mindspore/jpeg62.dll +0 -0
  147. mindspore/log.py +50 -0
  148. mindspore/mindrecord/__init__.py +21 -8
  149. mindspore/mindrecord/config.py +17 -316
  150. mindspore/mindrecord/filereader.py +1 -9
  151. mindspore/mindrecord/filewriter.py +5 -15
  152. mindspore/mindrecord/mindpage.py +1 -9
  153. mindspore/mindspore_backend_common.dll +0 -0
  154. mindspore/mindspore_backend_manager.dll +0 -0
  155. mindspore/mindspore_common.dll +0 -0
  156. mindspore/mindspore_core.dll +0 -0
  157. mindspore/mindspore_dump.dll +0 -0
  158. mindspore/mindspore_frontend.dll +0 -0
  159. mindspore/mindspore_glog.dll +0 -0
  160. mindspore/mindspore_memory_pool.dll +0 -0
  161. mindspore/mindspore_ms_backend.dll +0 -0
  162. mindspore/mindspore_ops.dll +0 -0
  163. mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
  164. mindspore/mindspore_ops_kernel_common.dll +0 -0
  165. mindspore/mindspore_profiler.dll +0 -0
  166. mindspore/mindspore_pyboost.dll +0 -0
  167. mindspore/mindspore_pynative.dll +0 -0
  168. mindspore/mindspore_res_manager.dll +0 -0
  169. mindspore/mindspore_runtime_pipeline.dll +0 -0
  170. mindspore/mint/__init__.py +796 -759
  171. mindspore/mint/distributed/__init__.py +70 -4
  172. mindspore/mint/distributed/distributed.py +2679 -44
  173. mindspore/mint/linalg/__init__.py +8 -0
  174. mindspore/mint/nn/__init__.py +743 -22
  175. mindspore/mint/nn/functional.py +716 -23
  176. mindspore/mint/nn/layer/__init__.py +21 -4
  177. mindspore/mint/nn/layer/_functions.py +334 -0
  178. mindspore/mint/nn/layer/activation.py +276 -1
  179. mindspore/mint/nn/layer/basic.py +123 -0
  180. mindspore/mint/nn/layer/conv.py +921 -0
  181. mindspore/mint/nn/layer/normalization.py +223 -28
  182. mindspore/mint/nn/layer/padding.py +797 -0
  183. mindspore/mint/nn/layer/pooling.py +235 -0
  184. mindspore/mint/optim/__init__.py +3 -1
  185. mindspore/mint/optim/adam.py +223 -0
  186. mindspore/mint/optim/adamw.py +26 -19
  187. mindspore/mint/optim/sgd.py +171 -0
  188. mindspore/mint/special/__init__.py +2 -1
  189. mindspore/msobj140.dll +0 -0
  190. mindspore/mspdb140.dll +0 -0
  191. mindspore/mspdbcore.dll +0 -0
  192. mindspore/mspdbst.dll +0 -0
  193. mindspore/mspft140.dll +0 -0
  194. mindspore/msvcdis140.dll +0 -0
  195. mindspore/msvcp140_1.dll +0 -0
  196. mindspore/msvcp140_2.dll +0 -0
  197. mindspore/msvcp140_atomic_wait.dll +0 -0
  198. mindspore/msvcp140_codecvt_ids.dll +0 -0
  199. mindspore/multiprocessing/__init__.py +5 -0
  200. mindspore/nn/__init__.py +4 -1
  201. mindspore/nn/cell.py +1370 -189
  202. mindspore/nn/dynamic_lr.py +2 -1
  203. mindspore/nn/layer/activation.py +29 -27
  204. mindspore/nn/layer/basic.py +51 -35
  205. mindspore/nn/layer/channel_shuffle.py +3 -3
  206. mindspore/nn/layer/container.py +1 -1
  207. mindspore/nn/layer/conv.py +22 -17
  208. mindspore/nn/layer/embedding.py +12 -11
  209. mindspore/nn/layer/normalization.py +56 -49
  210. mindspore/nn/layer/padding.py +4 -3
  211. mindspore/nn/layer/pooling.py +120 -42
  212. mindspore/nn/layer/rnn_cells.py +1 -1
  213. mindspore/nn/layer/rnns.py +2 -1
  214. mindspore/nn/layer/timedistributed.py +5 -5
  215. mindspore/nn/layer/transformer.py +59 -36
  216. mindspore/nn/learning_rate_schedule.py +8 -4
  217. mindspore/nn/loss/loss.py +58 -55
  218. mindspore/nn/optim/ada_grad.py +7 -5
  219. mindspore/nn/optim/adadelta.py +11 -9
  220. mindspore/nn/optim/adafactor.py +1 -1
  221. mindspore/nn/optim/adam.py +17 -13
  222. mindspore/nn/optim/adamax.py +8 -7
  223. mindspore/nn/optim/adasum.py +5 -5
  224. mindspore/nn/optim/asgd.py +1 -1
  225. mindspore/nn/optim/ftrl.py +11 -9
  226. mindspore/nn/optim/lamb.py +1 -1
  227. mindspore/nn/optim/lars.py +1 -4
  228. mindspore/nn/optim/lazyadam.py +12 -10
  229. mindspore/nn/optim/momentum.py +7 -6
  230. mindspore/nn/optim/optimizer.py +3 -3
  231. mindspore/nn/optim/proximal_ada_grad.py +12 -10
  232. mindspore/nn/optim/rmsprop.py +13 -12
  233. mindspore/nn/optim/rprop.py +11 -9
  234. mindspore/nn/optim/sgd.py +9 -6
  235. mindspore/nn/optim/tft_wrapper.py +5 -2
  236. mindspore/nn/optim/thor.py +2 -1
  237. mindspore/nn/probability/bijector/bijector.py +17 -11
  238. mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
  239. mindspore/nn/probability/bijector/invert.py +2 -2
  240. mindspore/nn/probability/bijector/scalar_affine.py +3 -3
  241. mindspore/nn/probability/bijector/softplus.py +3 -2
  242. mindspore/nn/probability/distribution/beta.py +3 -3
  243. mindspore/nn/probability/distribution/categorical.py +1 -1
  244. mindspore/nn/probability/distribution/cauchy.py +4 -2
  245. mindspore/nn/probability/distribution/exponential.py +6 -7
  246. mindspore/nn/probability/distribution/gamma.py +2 -2
  247. mindspore/nn/probability/distribution/gumbel.py +2 -2
  248. mindspore/nn/probability/distribution/half_normal.py +5 -3
  249. mindspore/nn/probability/distribution/logistic.py +5 -3
  250. mindspore/nn/probability/distribution/poisson.py +1 -1
  251. mindspore/nn/probability/distribution/uniform.py +5 -3
  252. mindspore/nn/reinforcement/_tensors_queue.py +1 -1
  253. mindspore/nn/reinforcement/tensor_array.py +1 -1
  254. mindspore/nn/utils/init.py +13 -11
  255. mindspore/nn/wrap/__init__.py +6 -6
  256. mindspore/nn/wrap/cell_wrapper.py +181 -122
  257. mindspore/nn/wrap/grad_reducer.py +45 -36
  258. mindspore/nn/wrap/loss_scale.py +6 -7
  259. mindspore/numpy/array_creations.py +63 -65
  260. mindspore/numpy/array_ops.py +149 -144
  261. mindspore/numpy/logic_ops.py +41 -42
  262. mindspore/numpy/math_ops.py +365 -363
  263. mindspore/numpy/utils.py +17 -18
  264. mindspore/numpy/utils_const.py +5 -6
  265. mindspore/opencv_core452.dll +0 -0
  266. mindspore/opencv_imgcodecs452.dll +0 -0
  267. mindspore/opencv_imgproc452.dll +0 -0
  268. mindspore/ops/__init__.py +5 -3
  269. mindspore/ops/_grad_experimental/grad_comm_ops.py +112 -16
  270. mindspore/ops/_grad_experimental/grad_debug_ops.py +14 -2
  271. mindspore/ops/_grad_experimental/grad_inner_ops.py +9 -0
  272. mindspore/ops/_grad_experimental/grad_math_ops.py +2 -1
  273. mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
  274. mindspore/ops/_op_impl/cpu/__init__.py +1 -0
  275. mindspore/ops/_op_impl/cpu/raise_op.py +28 -0
  276. mindspore/ops/_register_for_op.py +0 -11
  277. mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
  278. mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -65
  279. mindspore/ops/_vmap/vmap_array_ops.py +27 -25
  280. mindspore/ops/_vmap/vmap_base.py +0 -2
  281. mindspore/ops/_vmap/vmap_grad_nn_ops.py +21 -14
  282. mindspore/ops/_vmap/vmap_math_ops.py +15 -16
  283. mindspore/ops/_vmap/vmap_nn_ops.py +29 -42
  284. mindspore/ops/auto_generate/__init__.py +4 -3
  285. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +236 -46
  286. mindspore/ops/auto_generate/gen_extend_func.py +764 -124
  287. mindspore/ops/auto_generate/gen_ops_def.py +4018 -2264
  288. mindspore/ops/auto_generate/gen_ops_prim.py +15463 -5037
  289. mindspore/ops/auto_generate/pyboost_inner_prim.py +221 -87
  290. mindspore/ops/composite/__init__.py +2 -1
  291. mindspore/ops/composite/base.py +20 -25
  292. mindspore/ops/composite/math_ops.py +6 -16
  293. mindspore/ops/composite/multitype_ops/__init__.py +5 -2
  294. mindspore/ops/composite/multitype_ops/_compile_utils.py +228 -30
  295. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
  296. mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
  297. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
  298. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
  299. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
  300. mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
  301. mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
  302. mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
  303. mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
  304. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
  305. mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
  306. mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
  307. mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
  308. mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
  309. mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
  310. mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
  311. mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
  312. mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
  313. mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
  314. mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
  315. mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
  316. mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
  317. mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
  318. mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
  319. mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
  320. mindspore/ops/composite/multitype_ops/pow_impl.py +2 -30
  321. mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
  322. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  323. mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
  324. mindspore/ops/function/__init__.py +40 -2
  325. mindspore/ops/function/_add_attr_func.py +58 -0
  326. mindspore/ops/function/array_func.py +2089 -2403
  327. mindspore/ops/function/clip_func.py +80 -23
  328. mindspore/ops/function/debug_func.py +57 -57
  329. mindspore/ops/function/grad/__init__.py +1 -0
  330. mindspore/ops/function/grad/grad_func.py +104 -71
  331. mindspore/ops/function/image_func.py +2 -2
  332. mindspore/ops/function/linalg_func.py +47 -78
  333. mindspore/ops/function/math_func.py +4501 -3802
  334. mindspore/ops/function/nn_func.py +1726 -620
  335. mindspore/ops/function/other_func.py +159 -1
  336. mindspore/ops/function/parameter_func.py +18 -84
  337. mindspore/ops/function/random_func.py +440 -387
  338. mindspore/ops/function/reshard_func.py +4 -70
  339. mindspore/ops/function/sparse_func.py +3 -3
  340. mindspore/ops/function/sparse_unary_func.py +6 -6
  341. mindspore/ops/function/spectral_func.py +25 -58
  342. mindspore/ops/function/vmap_func.py +24 -17
  343. mindspore/ops/functional.py +22 -7
  344. mindspore/ops/functional_overload.py +1440 -0
  345. mindspore/ops/op_info_register.py +32 -244
  346. mindspore/ops/operations/__init__.py +13 -7
  347. mindspore/ops/operations/_custom_ops_utils.py +247 -0
  348. mindspore/ops/operations/_embedding_cache_ops.py +4 -4
  349. mindspore/ops/operations/_grad_ops.py +2 -43
  350. mindspore/ops/operations/_infer_ops.py +2 -1
  351. mindspore/ops/operations/_inner_ops.py +43 -84
  352. mindspore/ops/operations/_ms_kernel.py +4 -10
  353. mindspore/ops/operations/_rl_inner_ops.py +1 -1
  354. mindspore/ops/operations/_scalar_ops.py +3 -2
  355. mindspore/ops/operations/_sequence_ops.py +1 -1
  356. mindspore/ops/operations/_tensor_array.py +1 -1
  357. mindspore/ops/operations/array_ops.py +81 -324
  358. mindspore/ops/operations/comm_ops.py +154 -108
  359. mindspore/ops/operations/custom_ops.py +232 -78
  360. mindspore/ops/operations/debug_ops.py +153 -59
  361. mindspore/ops/operations/inner_ops.py +7 -5
  362. mindspore/ops/operations/linalg_ops.py +1 -57
  363. mindspore/ops/operations/manually_defined/_inner.py +1 -1
  364. mindspore/ops/operations/manually_defined/ops_def.py +928 -180
  365. mindspore/ops/operations/math_ops.py +32 -234
  366. mindspore/ops/operations/nn_ops.py +210 -498
  367. mindspore/ops/operations/other_ops.py +62 -9
  368. mindspore/ops/operations/random_ops.py +13 -7
  369. mindspore/ops/operations/reshard_ops.py +1 -1
  370. mindspore/ops/operations/sparse_ops.py +2 -2
  371. mindspore/ops/primitive.py +66 -53
  372. mindspore/ops/tensor_method.py +1888 -0
  373. mindspore/ops_generate/__init__.py +0 -5
  374. mindspore/ops_generate/aclnn/__init__.py +0 -0
  375. mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +135 -0
  376. mindspore/ops_generate/aclnn/gen_aclnn_implement.py +257 -0
  377. mindspore/ops_generate/api/__init__.py +0 -0
  378. mindspore/ops_generate/api/add_tensor_docs_generator.py +56 -0
  379. mindspore/ops_generate/api/cpp_create_prim_instance_helper_generator.py +105 -0
  380. mindspore/ops_generate/api/functional_map_cpp_generator.py +504 -0
  381. mindspore/ops_generate/api/functional_overload_py_generator.py +112 -0
  382. mindspore/ops_generate/api/functions_cc_generator.py +237 -0
  383. mindspore/ops_generate/api/gen_api.py +103 -0
  384. mindspore/ops_generate/api/op_api_proto.py +235 -0
  385. mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +461 -0
  386. mindspore/ops_generate/common/__init__.py +0 -0
  387. mindspore/ops_generate/common/base_generator.py +11 -0
  388. mindspore/ops_generate/common/gen_constants.py +91 -0
  389. mindspore/ops_generate/common/gen_utils.py +348 -0
  390. mindspore/ops_generate/common/op_proto.py +473 -0
  391. mindspore/ops_generate/common/template.py +523 -0
  392. mindspore/ops_generate/gen_ops.py +22 -1069
  393. mindspore/ops_generate/op_def/__init__.py +0 -0
  394. mindspore/ops_generate/op_def/gen_op_def.py +90 -0
  395. mindspore/ops_generate/op_def/lite_ops_cpp_generator.py +191 -0
  396. mindspore/ops_generate/op_def/ops_def_cc_generator.py +299 -0
  397. mindspore/ops_generate/op_def/ops_def_h_generator.py +74 -0
  398. mindspore/ops_generate/op_def/ops_name_h_generator.py +83 -0
  399. mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
  400. mindspore/ops_generate/op_def_py/__init__.py +0 -0
  401. mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
  402. mindspore/ops_generate/op_def_py/op_def_py_generator.py +132 -0
  403. mindspore/ops_generate/op_def_py/op_prim_py_generator.py +489 -0
  404. mindspore/ops_generate/pyboost/__init__.py +0 -0
  405. mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +139 -0
  406. mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +93 -0
  407. mindspore/ops_generate/pyboost/gen_pyboost_func.py +175 -0
  408. mindspore/ops_generate/pyboost/op_template_parser.py +517 -0
  409. mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +407 -0
  410. mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +100 -0
  411. mindspore/ops_generate/pyboost/pyboost_functions_py_generator.py +148 -0
  412. mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +155 -0
  413. mindspore/ops_generate/pyboost/pyboost_inner_prim_generator.py +132 -0
  414. mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +272 -0
  415. mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +938 -0
  416. mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +357 -0
  417. mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +179 -36
  418. mindspore/ops_generate/resources/__init__.py +0 -0
  419. mindspore/ops_generate/resources/resource_list.py +30 -0
  420. mindspore/ops_generate/resources/resource_loader.py +36 -0
  421. mindspore/ops_generate/resources/resource_manager.py +64 -0
  422. mindspore/ops_generate/resources/yaml_loader.py +88 -0
  423. mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
  424. mindspore/parallel/__init__.py +7 -3
  425. mindspore/parallel/_auto_parallel_context.py +152 -34
  426. mindspore/parallel/_cell_wrapper.py +130 -15
  427. mindspore/parallel/_parallel_serialization.py +107 -5
  428. mindspore/parallel/_ps_context.py +1 -1
  429. mindspore/parallel/_recovery_context.py +7 -2
  430. mindspore/parallel/_tensor.py +142 -18
  431. mindspore/parallel/_utils.py +199 -23
  432. mindspore/parallel/algo_parameter_config.py +4 -4
  433. mindspore/parallel/auto_parallel.py +732 -0
  434. mindspore/parallel/checkpoint_convert.py +159 -0
  435. mindspore/parallel/checkpoint_transform.py +698 -35
  436. mindspore/parallel/cluster/process_entity/_api.py +276 -50
  437. mindspore/parallel/cluster/process_entity/_utils.py +41 -6
  438. mindspore/parallel/cluster/run.py +21 -4
  439. mindspore/parallel/function/__init__.py +24 -0
  440. mindspore/parallel/function/reshard_func.py +259 -0
  441. mindspore/parallel/nn/__init__.py +25 -0
  442. mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
  443. mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
  444. mindspore/parallel/parameter_broadcast.py +25 -14
  445. mindspore/parallel/shard.py +137 -58
  446. mindspore/parallel/transform_safetensors.py +363 -305
  447. mindspore/pgodb140.dll +0 -0
  448. mindspore/pgort140.dll +0 -0
  449. mindspore/profiler/__init__.py +22 -5
  450. mindspore/profiler/analysis/__init__.py +0 -0
  451. mindspore/profiler/analysis/parser/__init__.py +0 -0
  452. mindspore/profiler/analysis/parser/ascend_cann_parser.py +170 -0
  453. mindspore/profiler/analysis/parser/base_parser.py +158 -0
  454. mindspore/profiler/analysis/parser/framework_cann_relation_parser.py +45 -0
  455. mindspore/profiler/analysis/parser/ms_framework_parser.py +142 -0
  456. mindspore/profiler/analysis/parser/ms_minddata_parser.py +145 -0
  457. mindspore/profiler/analysis/parser/timeline_assembly_factory/__init__.py +0 -0
  458. mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +264 -0
  459. mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +40 -0
  460. mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +106 -0
  461. mindspore/profiler/analysis/parser/timeline_creator/__init__.py +0 -0
  462. mindspore/profiler/analysis/parser/timeline_creator/base_timeline_creator.py +44 -0
  463. mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +90 -0
  464. mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +76 -0
  465. mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +103 -0
  466. mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +134 -0
  467. mindspore/profiler/analysis/parser/timeline_event/__init__.py +0 -0
  468. mindspore/profiler/analysis/parser/timeline_event/base_event.py +233 -0
  469. mindspore/profiler/analysis/parser/timeline_event/cpu_op_event.py +47 -0
  470. mindspore/profiler/analysis/parser/timeline_event/flow_event.py +36 -0
  471. mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +415 -0
  472. mindspore/profiler/analysis/parser/timeline_event/msprof_event.py +73 -0
  473. mindspore/profiler/analysis/parser/timeline_event/scope_layer_event.py +53 -0
  474. mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +146 -0
  475. mindspore/profiler/analysis/task_manager.py +131 -0
  476. mindspore/profiler/analysis/time_converter.py +84 -0
  477. mindspore/profiler/analysis/viewer/__init__.py +0 -0
  478. mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +372 -0
  479. mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +87 -0
  480. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +250 -0
  481. mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +320 -0
  482. mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +327 -0
  483. mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +376 -0
  484. mindspore/profiler/analysis/viewer/ascend_timeline_viewer.py +58 -0
  485. mindspore/profiler/analysis/viewer/base_viewer.py +26 -0
  486. mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +96 -0
  487. mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +581 -0
  488. mindspore/profiler/analysis/work_flow.py +73 -0
  489. mindspore/profiler/common/ascend_msprof_exporter.py +139 -0
  490. mindspore/profiler/common/command_executor.py +90 -0
  491. mindspore/profiler/common/constant.py +186 -3
  492. mindspore/profiler/common/file_manager.py +208 -0
  493. mindspore/profiler/common/log.py +130 -0
  494. mindspore/profiler/common/msprof_cmd_tool.py +221 -0
  495. mindspore/profiler/common/path_manager.py +395 -0
  496. mindspore/profiler/common/process_bar.py +168 -0
  497. mindspore/profiler/common/process_pool.py +9 -3
  498. mindspore/profiler/common/profiler_context.py +500 -0
  499. mindspore/profiler/common/profiler_info.py +304 -0
  500. mindspore/profiler/common/profiler_meta_data.py +74 -0
  501. mindspore/profiler/common/profiler_output_path.py +284 -0
  502. mindspore/profiler/common/profiler_parameters.py +251 -0
  503. mindspore/profiler/common/profiler_path_manager.py +179 -0
  504. mindspore/profiler/common/record_function.py +76 -0
  505. mindspore/profiler/common/tlv_decoder.py +76 -0
  506. mindspore/profiler/common/util.py +75 -2
  507. mindspore/profiler/dynamic_profiler.py +341 -75
  508. mindspore/profiler/envprofiler.py +163 -0
  509. mindspore/profiler/experimental_config.py +197 -0
  510. mindspore/profiler/mstx.py +242 -0
  511. mindspore/profiler/platform/__init__.py +21 -0
  512. mindspore/profiler/platform/base_profiler.py +40 -0
  513. mindspore/profiler/platform/cpu_profiler.py +124 -0
  514. mindspore/profiler/platform/gpu_profiler.py +74 -0
  515. mindspore/profiler/platform/npu_profiler.py +335 -0
  516. mindspore/profiler/profiler.py +1073 -90
  517. mindspore/profiler/profiler_action_controller.py +187 -0
  518. mindspore/profiler/profiler_interface.py +118 -0
  519. mindspore/profiler/schedule.py +243 -0
  520. mindspore/rewrite/api/node.py +15 -13
  521. mindspore/rewrite/api/symbol_tree.py +2 -3
  522. mindspore/run_check/_check_version.py +27 -20
  523. mindspore/run_check/run_check.py +1 -1
  524. mindspore/runtime/__init__.py +37 -0
  525. mindspore/runtime/device.py +27 -0
  526. mindspore/runtime/event.py +209 -0
  527. mindspore/runtime/executor.py +177 -0
  528. mindspore/runtime/memory.py +409 -0
  529. mindspore/runtime/stream.py +460 -0
  530. mindspore/runtime/thread_bind_core.py +401 -0
  531. mindspore/safeguard/rewrite_obfuscation.py +12 -9
  532. mindspore/swresample-4.dll +0 -0
  533. mindspore/swscale-6.dll +0 -0
  534. mindspore/tbbmalloc.dll +0 -0
  535. mindspore/tinyxml2.dll +0 -0
  536. mindspore/train/__init__.py +8 -8
  537. mindspore/train/_utils.py +88 -25
  538. mindspore/train/amp.py +9 -5
  539. mindspore/train/callback/__init__.py +2 -2
  540. mindspore/train/callback/_callback.py +2 -16
  541. mindspore/train/callback/_checkpoint.py +53 -55
  542. mindspore/train/callback/_cluster_monitor.py +14 -18
  543. mindspore/train/callback/_early_stop.py +1 -1
  544. mindspore/train/callback/_flops_collector.py +103 -68
  545. mindspore/train/callback/_history.py +8 -5
  546. mindspore/train/callback/_lambda_callback.py +2 -2
  547. mindspore/train/callback/_landscape.py +0 -3
  548. mindspore/train/callback/_loss_monitor.py +2 -1
  549. mindspore/train/callback/_on_request_exit.py +6 -5
  550. mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
  551. mindspore/train/callback/_summary_collector.py +52 -19
  552. mindspore/train/callback/_time_monitor.py +2 -1
  553. mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +204 -107
  554. mindspore/train/data_sink.py +25 -2
  555. mindspore/train/dataset_helper.py +15 -16
  556. mindspore/train/loss_scale_manager.py +8 -7
  557. mindspore/train/metrics/accuracy.py +3 -3
  558. mindspore/train/metrics/confusion_matrix.py +9 -9
  559. mindspore/train/metrics/error.py +3 -3
  560. mindspore/train/metrics/hausdorff_distance.py +4 -4
  561. mindspore/train/metrics/mean_surface_distance.py +3 -3
  562. mindspore/train/metrics/metric.py +0 -12
  563. mindspore/train/metrics/occlusion_sensitivity.py +4 -2
  564. mindspore/train/metrics/precision.py +11 -10
  565. mindspore/train/metrics/recall.py +9 -9
  566. mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
  567. mindspore/train/mind_ir_pb2.py +174 -46
  568. mindspore/train/model.py +184 -113
  569. mindspore/train/serialization.py +622 -978
  570. mindspore/train/summary/_summary_adapter.py +2 -2
  571. mindspore/train/summary/summary_record.py +2 -3
  572. mindspore/train/train_thor/model_thor.py +1 -1
  573. mindspore/turbojpeg.dll +0 -0
  574. mindspore/utils/__init__.py +6 -3
  575. mindspore/utils/dryrun.py +140 -0
  576. mindspore/utils/hooks.py +81 -0
  577. mindspore/utils/runtime_execution_order_check.py +550 -0
  578. mindspore/utils/utils.py +138 -4
  579. mindspore/vcmeta.dll +0 -0
  580. mindspore/vcruntime140.dll +0 -0
  581. mindspore/vcruntime140_1.dll +0 -0
  582. mindspore/version.py +1 -1
  583. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/METADATA +3 -3
  584. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/RECORD +587 -418
  585. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/entry_points.txt +1 -1
  586. mindspore/_install_custom.py +0 -43
  587. mindspore/common/_register_for_adapter.py +0 -74
  588. mindspore/common/_tensor_overload.py +0 -139
  589. mindspore/mindspore_np_dtype.dll +0 -0
  590. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
  591. mindspore/ops/auto_generate/gen_arg_handler.py +0 -197
  592. mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
  593. mindspore/ops_generate/gen_aclnn_implement.py +0 -263
  594. mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
  595. mindspore/ops_generate/gen_pyboost_func.py +0 -1052
  596. mindspore/ops_generate/gen_utils.py +0 -209
  597. mindspore/ops_generate/op_proto.py +0 -145
  598. mindspore/ops_generate/template.py +0 -261
  599. mindspore/profiler/envprofiling.py +0 -254
  600. mindspore/profiler/profiling.py +0 -1926
  601. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/WHEEL +0 -0
  602. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/top_level.txt +0 -0
mindspore/common/api.py CHANGED
@@ -17,6 +17,7 @@
17
17
  """Providing interface methods."""
18
18
  from __future__ import absolute_import
19
19
 
20
+ import gc
20
21
  import types
21
22
  import sys
22
23
  import os
@@ -24,11 +25,11 @@ import time
24
25
  import ast
25
26
  import inspect
26
27
  import importlib
27
- import hashlib
28
28
  import contextlib
29
+ import json
29
30
  from collections import OrderedDict, namedtuple
30
31
  from functools import wraps
31
- import numpy as np
32
+ from typing import Optional, Callable
32
33
  import mindspore as ms
33
34
  from mindspore import context
34
35
  from mindspore import log as logger
@@ -39,21 +40,23 @@ from mindspore.common.sparse_tensor import CSRTensor as PythonCSRTensor
39
40
  from mindspore.common.sparse_tensor import COOTensor as PythonCOOTensor
40
41
  from mindspore.common.sparse_tensor import RowTensor as PythonRowTensor
41
42
  from mindspore._c_expression.amp import get_curr_amp_strategy
42
- from mindspore._c_expression import GraphExecutor_, Tensor, CSRTensor, RowTensor, COOTensor, \
43
+ from mindspore._c_expression import GraphExecutor_, JitExecutor_, CSRTensor, RowTensor, COOTensor, \
43
44
  PyNativeExecutor_, verify_inputs_signature, init_exec_dataset, _set_dataset_mode_config, init_pipeline, \
44
- _ms_memory_recycle, _bind_device_ctx
45
+ _run_jit_pipeline, _ms_memory_recycle, _bind_device_ctx, StubNode, MSContext, TensorPy as Tensor
45
46
  from mindspore.parallel._ps_context import _is_role_sched
46
- from mindspore.parallel._utils import _check_full_batch, _get_parameter_broadcast, _is_pynative_parallel, \
47
- _is_in_auto_parallel_mode, _is_parallel_mode
47
+ from mindspore.parallel._utils import _check_full_batch, _get_parameter_broadcast, _is_in_auto_parallel_mode, \
48
+ _is_parallel_mode
48
49
  from mindspore import _checkparam as Validator
49
50
  from mindspore._checkparam import is_stub_tensor
50
51
  from mindspore.common._utils import is_shape_unknown
51
- from mindspore.common.mutable import mutable
52
- from mindspore.common._register_for_adapter import ms_adapter_registry
52
+ from mindspore.common.mutable import mutable, _check_element_type
53
53
  from mindspore.common.auto_dynamic_shape import get_auto_dynamic_shape_args, update_auto_dynamic_shape_phase, \
54
54
  get_auto_dynamic_shape_args_with_check_input_signature, update_auto_dynamic_shape_phase_with_check_input_signature
55
55
  from mindspore.common._pijit_context import PIJitCaptureContext
56
- from mindspore.common.parameter import Parameter
56
+ from mindspore.common.parameter import Parameter, set_parameter_hook_updated, parameter_hook_updated
57
+ from mindspore.common.jit_context import jit_context
58
+ from mindspore.common.jit_trace import _jit_trace
59
+ from mindspore.parallel._utils import _init_auto_parallel_context, _clear_auto_parallel_context
57
60
 
58
61
  # Store ms_function class compiled pipeline cache.
59
62
  ms_compile_cache = set()
@@ -107,8 +110,7 @@ def _check_recompile(obj, compile_args, kwargs, full_function_name, create_time,
107
110
  logger.info(f"The {echo_function_name} has been compiled again. "
108
111
  f"{tips} ")
109
112
  else:
110
- tips = "Try to decorate the function with @jit(hash_args=...) " \
111
- "or @jit(compile_once=True) to reduce the compile time. " \
113
+ tips = "Try to reuse the function object decorated by @jit to reduce the compile time. " \
112
114
  "For more details, get instructions about `jit` at " \
113
115
  "https://www.mindspore.cn/search?inputValue=jit."
114
116
  logger.warning(f"The {echo_function_name} has been compiled again. "
@@ -120,14 +122,6 @@ def _check_recompile(obj, compile_args, kwargs, full_function_name, create_time,
120
122
  function_phases[full_function_name].add(create_time)
121
123
 
122
124
 
123
- def _ms_adapter_tensor_as_parameter_output(data):
124
- """Check whether the data is an output from a parameter which is a ms_adapter tensor.
125
- Pylint: disable=unidiomatic-typecheck.
126
- """
127
- return ms_adapter_registry.is_registered and isinstance(data, ms_adapter_registry.tensor) \
128
- and hasattr(data, "__ms_parameter_output__") and getattr(data, "__ms_parameter_output__")
129
-
130
-
131
125
  def _convert_python_data(data):
132
126
  """
133
127
  Convert C++ data to python.
@@ -138,18 +132,10 @@ def _convert_python_data(data):
138
132
  Returns:
139
133
  data, a data convert C++ to python
140
134
  """
141
- if isinstance(data, (Tensor, PythonTensor)) and data.adapter_flag:
142
- return ms_adapter_registry.tensor(data)
143
- if _ms_adapter_tensor_as_parameter_output(data) and hasattr(data, "tensor"):
144
- return data.tensor
145
- if isinstance(data, Tensor) and not isinstance(data, PythonTensor):
146
- return PythonTensor(data, internal=True)
147
- if isinstance(data, CSRTensor) and not isinstance(data, PythonCSRTensor):
148
- return PythonCSRTensor(csr_tensor=data)
149
- if isinstance(data, COOTensor) and not isinstance(data, PythonCOOTensor):
150
- return PythonCOOTensor(coo_tensor=data)
151
- if isinstance(data, RowTensor) and not isinstance(data, PythonRowTensor):
152
- return PythonRowTensor(row_tensor=data)
135
+ if isinstance(data, PythonTensor):
136
+ return data
137
+ if isinstance(data, StubNode):
138
+ return ms.common._stub_tensor._convert_stub(data)
153
139
  if data.__class__ is tuple:
154
140
  # Handle namedtuple since its type is tuple.
155
141
  if hasattr(data, "_fields"):
@@ -158,6 +144,12 @@ def _convert_python_data(data):
158
144
  fields = data_dict.keys()
159
145
  return namedtuple(type_name, fields)(**_convert_python_data(data_dict))
160
146
  return tuple(_convert_python_data(x) for x in data)
147
+ if isinstance(data, CSRTensor) and not isinstance(data, PythonCSRTensor):
148
+ return PythonCSRTensor(csr_tensor=data)
149
+ if isinstance(data, COOTensor) and not isinstance(data, PythonCOOTensor):
150
+ return PythonCOOTensor(coo_tensor=data)
151
+ if isinstance(data, RowTensor) and not isinstance(data, PythonRowTensor):
152
+ return PythonRowTensor(row_tensor=data)
161
153
  if data.__class__ is list:
162
154
  # Keep list object not change for inplace operation.
163
155
  for i in range(len(data)):
@@ -273,7 +265,9 @@ def __get_compile_cache_dep_files(file_path, compile_cache_dep_files, pkg):
273
265
  else:
274
266
  whole_module = module_name
275
267
  if n.name is not None:
276
- whole_module += "." + n.name
268
+ if not whole_module.endswith("."):
269
+ whole_module += "."
270
+ whole_module += n.name
277
271
  try:
278
272
  module_spec = importlib.util.find_spec(whole_module, pkg)
279
273
  except (ModuleNotFoundError, ValueError):
@@ -305,7 +299,22 @@ def _get_compile_cache_dep_files():
305
299
  return compile_cache_dep_files
306
300
 
307
301
 
308
- def _restore_mutable_attr(args_list, compile_args):
302
+ def _contains_auto_grad_tensor(obj):
303
+ """Check object is or contains auto grad tensor element"""
304
+ if isinstance(obj, PythonTensor):
305
+ return obj._has_auto_grad()
306
+ if isinstance(obj, (tuple, list)):
307
+ for element in obj:
308
+ if _contains_auto_grad_tensor(element):
309
+ return True
310
+ if isinstance(obj, dict):
311
+ for key in obj:
312
+ if _contains_auto_grad_tensor(obj[key]):
313
+ return True
314
+ return False
315
+
316
+
317
+ def _add_mutable_attr(args_list, compile_args, is_grad):
309
318
  """Restore the mutable attr for every arg."""
310
319
  new_compile_args = ()
311
320
  for idx, arg in enumerate(args_list):
@@ -316,7 +325,12 @@ def _restore_mutable_attr(args_list, compile_args):
316
325
  else:
317
326
  new_compile_args += (mutable(compile_args[idx], False),)
318
327
  else:
319
- new_compile_args += (compile_args[idx],)
328
+ if is_grad and _contains_auto_grad_tensor(arg):
329
+ if not _check_element_type(arg):
330
+ raise RuntimeError("Input \"%s\" contains tensor with gradient but can not mutable." % (str(arg)))
331
+ new_compile_args += (mutable(compile_args[idx], False),)
332
+ else:
333
+ new_compile_args += (compile_args[idx],)
320
334
  return new_compile_args
321
335
 
322
336
 
@@ -330,6 +344,7 @@ def _get_parameter_layout():
330
344
 
331
345
  def _handle_arg(obj, arg, compile_arg):
332
346
  """Handle arg for runtime .If need handle the arg, return True"""
347
+ from mindspore._extends.parse import compile_config
333
348
  if isinstance(arg, PythonTensor):
334
349
  if arg.has_init:
335
350
  arg.init_data()
@@ -342,7 +357,8 @@ def _handle_arg(obj, arg, compile_arg):
342
357
  if isinstance(arg, list) and not arg:
343
358
  return None
344
359
  return arg
345
- elif context.get_context("grad_for_scalar") and isinstance(arg, (int, float)):
360
+ elif (context.get_context("grad_for_scalar") or str(compile_config.GRAD_FOR_SCALAR) == '1') and \
361
+ isinstance(arg, (int, float)):
346
362
  return arg
347
363
  elif hasattr(obj, "enable_tuple_broaden") and obj.enable_tuple_broaden and isinstance(arg, tuple) and \
348
364
  _check_all_tensor(arg):
@@ -528,12 +544,35 @@ def _get_parameter_ids(args, kwargs):
528
544
  parameter_ids += str(id(value))
529
545
  return parameter_ids
530
546
 
547
+ def _get_tensor_hook_key(tensor):
548
+ """Get the hook key of Tensor/Parameter"""
549
+ return ".".join(map(str, map(id, tensor.hooks())))
550
+
551
+ def _get_hook_key(*args, **kwargs):
552
+ """Get the hook key of Tensors/Parameters"""
553
+ hook_key = ""
554
+ for idx, arg in enumerate(args):
555
+ if idx != 0:
556
+ hook_key += "."
557
+ # Only arg of the type Tensor or Parameter is supported now
558
+ if isinstance(arg, (Tensor, Parameter)):
559
+ hook_key += _get_tensor_hook_key(arg)
560
+
561
+ for idx, value in enumerate(kwargs.values()):
562
+ if idx != 0:
563
+ hook_key += "."
564
+ # Only kwarg of the type Tensor or Parameter is supported now
565
+ if isinstance(value, (Tensor, Parameter)):
566
+ hook_key += _get_tensor_hook_key(value)
567
+
568
+ return hook_key
569
+
531
570
 
532
- class _MindsporeFunctionExecutor:
571
+ class _JitExecutor:
533
572
  """
534
573
  Represents a function compiled by graph compiler.
535
574
 
536
- _MindsporeFunctionExecutor will compile the original function for every combination
575
+ _JitExecutor will compile the original function for every combination
537
576
  of argument types and shapes it is given (as well as their values, optionally).
538
577
 
539
578
  Args:
@@ -547,7 +586,7 @@ class _MindsporeFunctionExecutor:
547
586
  The result of pipeline running in graph mode.
548
587
  """
549
588
 
550
- def __init__(self, fn, ms_create_time, input_signature=None, obj=None, jit_config=None):
589
+ def __init__(self, fn, ms_create_time, input_signature=None, obj=None, jit_config=None, dynamic=0):
551
590
  init_pipeline()
552
591
  if not isinstance(fn, (types.FunctionType, types.MethodType)):
553
592
  raise RuntimeError('fn {} is not function or method'.format(fn))
@@ -559,13 +598,19 @@ class _MindsporeFunctionExecutor:
559
598
  self.obj = obj
560
599
  self.shard_parent_obj = obj
561
600
  self.enable_tuple_broaden = False
562
- self._graph_executor = GraphExecutor_.get_instance()
601
+ if _run_jit_pipeline():
602
+ self._graph_executor = JitExecutor_.get_instance()
603
+ else:
604
+ self._graph_executor = GraphExecutor_.get_instance()
563
605
  self._create_time = ms_create_time
564
606
  self._compile_args = None
607
+ self._enable_auto_dynamic = dynamic == 1
565
608
  self.jit_config_dict = jit_config.jit_config_dict if jit_config else None
566
609
 
567
610
  @_wrap_func
568
611
  def __call__(self, *args, **kwargs):
612
+ if jit_context() and jit_context().is_nested():
613
+ return jit_context().run_graph("", None, *())
569
614
  args_list = args
570
615
  if self.obj is not None:
571
616
  args_list = args_list[1:]
@@ -581,13 +626,18 @@ class _MindsporeFunctionExecutor:
581
626
  _pynative_executor.clear_res()
582
627
  raise err
583
628
 
584
- if context.get_context("precompile_only"):
629
+ if context.get_context("precompile_only") or os.getenv('MS_DEV_PRECOMPILE_ONLY') == '1':
585
630
  return None
586
631
 
587
632
  new_inputs = self._generate_run_args(args_list, kwargs)
588
- output = self._graph_executor(tuple(new_inputs), phase)
589
- if context.get_context("mode") == context.PYNATIVE_MODE:
590
- output = _pynative_executor.grad_jit(output, *new_inputs)
633
+ if context.get_context("mode") == context.PYNATIVE_MODE and not jit_context():
634
+ output = _pynative_executor.grad_jit(*new_inputs)
635
+ else:
636
+ output = self._graph_executor(tuple(new_inputs), phase)
637
+ if jit_context():
638
+ if is_stub_tensor(output):
639
+ output = output.stub_sync()
640
+ return jit_context().run_graph(phase, output, *tuple(new_inputs))
591
641
 
592
642
  return output
593
643
 
@@ -603,10 +653,13 @@ class _MindsporeFunctionExecutor:
603
653
  compile_args = self._generate_compile_args(args)
604
654
  key_id = self._get_key_id()
605
655
  compile_args = get_auto_dynamic_shape_args_with_check_input_signature(compile_args, key_id,
606
- self.input_signature)
656
+ self.input_signature,
657
+ self._enable_auto_dynamic)
607
658
 
608
- # Restore the mutable attr for every arg.
609
- compile_args = _restore_mutable_attr(args, compile_args)
659
+ # Add mutable for compile_args for two scene:
660
+ # 1) Origin args is mutable.
661
+ # 2) Args contains sequence with gradient tensor.
662
+ compile_args = _add_mutable_attr(args, compile_args, _pynative_executor.requires_grad())
610
663
  self._compile_args = compile_args
611
664
  generate_name, echo_function_name = self._get_generate_name()
612
665
  # The full Function name
@@ -645,11 +698,14 @@ class _MindsporeFunctionExecutor:
645
698
  parameter_ids = _get_parameter_ids(args, kwargs)
646
699
  if parameter_ids != "":
647
700
  key = str(key) + '.' + parameter_ids
701
+
702
+ key = str(key) + "." + _get_hook_key(*args, **kwargs)
703
+
648
704
  phase = generate_name + '.' + str(key)
649
705
 
650
706
  update_auto_dynamic_shape_phase_with_check_input_signature(compile_args, key_id, phase, self.input_signature)
651
707
 
652
- if phase in ms_compile_cache:
708
+ if phase in ms_compile_cache and self._graph_executor.has_compiled(phase) and not parameter_hook_updated():
653
709
  # Release resource should be released when CompileInner won't be executed, such as cur_convert_input_
654
710
  # generated in generate_arguments_key.
655
711
  self._graph_executor.clear_compile_arguments_resource()
@@ -671,7 +727,7 @@ class _MindsporeFunctionExecutor:
671
727
  setattr(self.fn.__func__, "__jit_function__", True)
672
728
  else:
673
729
  setattr(self.fn, "__jit_function__", True)
674
- is_compile = self._graph_executor.compile(self.fn, compile_args, kwargs, phase, True)
730
+ is_compile = self._graph_executor.compile(self.fn, compile_args, kwargs, phase)
675
731
  if isinstance(self.fn, types.MethodType):
676
732
  delattr(self.fn.__func__, "__jit_function__")
677
733
  else:
@@ -679,10 +735,11 @@ class _MindsporeFunctionExecutor:
679
735
  else:
680
736
  if isinstance(self.obj, ms.nn.Cell):
681
737
  self._graph_executor.set_weights_values(self.obj.parameters_dict())
682
- is_compile = self._graph_executor.compile(self.obj, compile_args, kwargs, phase, True)
738
+ is_compile = self._graph_executor.compile(self.obj, compile_args, kwargs, phase)
683
739
 
684
740
  if not is_compile:
685
741
  raise RuntimeError("Executor compile failed.")
742
+ set_parameter_hook_updated(False)
686
743
  ms_compile_cache.add(phase)
687
744
 
688
745
  return phase
@@ -704,7 +761,7 @@ class _MindsporeFunctionExecutor:
704
761
  else:
705
762
  key_id = str(id(self.obj)) + str(self._create_time)
706
763
 
707
- if _pynative_executor.grad_flag():
764
+ if _pynative_executor.requires_grad():
708
765
  key_id = key_id + ".grad"
709
766
  return key_id
710
767
 
@@ -714,9 +771,9 @@ class _MindsporeFunctionExecutor:
714
771
  self.fn.__code__.co_firstlineno)
715
772
  echo_function_name = "function \"" + self.fn.__name__ + "\" at the file \"" + self.fn.__code__.co_filename \
716
773
  + "\", line " + str(self.fn.__code__.co_firstlineno)
717
- if _pynative_executor.grad_flag():
774
+ if _pynative_executor.requires_grad():
718
775
  generate_name = generate_name + ".grad"
719
- if _is_pynative_parallel():
776
+ if self.fn.__name__ == _PYNATIVE_PARALLEL_FUNC_NAME:
720
777
  generate_name = generate_name[:generate_name.rfind(str(id(self.fn)))] + str(id(self.shard_parent_obj))
721
778
  return generate_name, echo_function_name
722
779
 
@@ -777,6 +834,14 @@ class _MindsporeFunctionExecutor:
777
834
  """
778
835
  return _get_args_for_run(self, args_list, kwargs, self._compile_args)
779
836
 
837
+ def _get_func_graph_proto(self, obj, exec_id, ir_type="onnx_ir", use_prefix=False, incremental=False):
838
+ """Get graph proto from pipeline."""
839
+ if use_prefix:
840
+ exec_id = exec_id + '.' + obj.arguments_key
841
+ if self._graph_executor.has_compiled(exec_id) is False:
842
+ return None
843
+ return self._graph_executor.get_func_graph_proto(exec_id, ir_type, incremental)
844
+
780
845
 
781
846
  # The attributes used to identify a given object.
782
847
  attr_op = {"__str__": lambda x: x.__str__(),
@@ -789,6 +854,13 @@ attr_op = {"__str__": lambda x: x.__str__(),
789
854
  }
790
855
 
791
856
 
857
+ def _is_inner_func(func):
858
+ """Check whether the func is an inner func which needs hash_args parameter."""
859
+ # This is a workaround for inner api, should fix it later.
860
+ inner_func = ["after_shard", "_wrap_container"]
861
+ return func.__name__ in inner_func
862
+
863
+
792
864
  def _get_obj_id(input_obj):
793
865
  """Get hash id of single object."""
794
866
  obj_id = ".".join(
@@ -803,50 +875,227 @@ def _get_jit_hash(hash_input):
803
875
  return _get_obj_id(hash_input)
804
876
 
805
877
 
806
- def jit(fn=None, mode="PSJit", input_signature=None, hash_args=None, jit_config=None, compile_once=False):
878
+ def _get_hash_obj(options):
879
+ hash_obj = None
880
+ if "hash_args" in options:
881
+ hash_obj = _get_jit_hash(options["hash_args"])
882
+ del options["hash_args"]
883
+ return hash_obj
884
+
885
+
886
+ def _check_option_device(option, device):
887
+ """Check jit options wiwh device"""
888
+ option_device_cfgs = {
889
+ 'disable_format_transform': ['GPU'],
890
+ 'exec_order': ['Ascend'],
891
+ 'ge_options': ['Ascend'],
892
+ 'infer_boost': ['Ascend'],
893
+ }
894
+ if option in option_device_cfgs and device not in option_device_cfgs[option]:
895
+ logger.warning(f"For 'jit(options)', the option '{option}' is only support device in "
896
+ f"'{option_device_cfgs[option]}', but got '{device}', ignore it.")
897
+
898
+
899
+ def _check_option_backend(option, backend):
900
+ """Check jit options wiwh backend"""
901
+ option_backend_cfgs = {
902
+ 'disable_format_transform': ['ms_backend'],
903
+ 'exec_order': ['ms_backend'],
904
+ 'ge_options': ['GE'],
905
+ 'infer_boost': ['ms_backend'],
906
+ }
907
+ if option in option_backend_cfgs and backend not in option_backend_cfgs[option]:
908
+ logger.warning(f"For 'jit(options)', the option '{option}' is only support backend in "
909
+ f"'{option_backend_cfgs[option]}', but got '{backend}', ignore it.")
910
+
911
+
912
+ def _check_disable_format_transform_value(option, disable_format_transform):
913
+ """check disable_format_transform option value"""
914
+ if not isinstance(disable_format_transform, bool):
915
+ raise TypeError(f"For 'jit(options)', the type of '{option}' must be bool, "
916
+ f"but got {type(disable_format_transform)}.")
917
+
918
+
919
+ def _check_exec_order_value(option, exec_order):
920
+ """check exec_order option value"""
921
+ if not isinstance(exec_order, str):
922
+ raise TypeError(f"For 'jit(options)', the type of '{option}' must be str, but got {type(exec_order)}.")
923
+
924
+ if exec_order not in ['bfs', 'dfs']:
925
+ raise ValueError(f"For '{option}', the value of '{option}' must be one of "
926
+ f"['bfs', 'dfs'], but got '{exec_order}'.")
927
+
928
+
929
+ def _check_ge_options_value(option, ge_options):
930
+ """check ge_options option value"""
931
+ if not isinstance(ge_options, dict):
932
+ raise TypeError(f"For 'jit(options)', the type of '{option}' must be dict, but got {type(ge_options)}.")
933
+
934
+ for level, options in ge_options.items():
935
+ if level not in ['global', 'session']:
936
+ raise ValueError(f"For '{option}', the key of '{option}' must be one of "
937
+ f"['global', 'session'], but got '{level}'.")
938
+
939
+ if not isinstance(options, dict):
940
+ raise TypeError(f"For '{option}', the type of {level} options must be dict, "
941
+ f"but got {type(options)}. The error options: {options}.")
942
+
943
+ for key, value in options.items():
944
+ if not isinstance(key, str):
945
+ raise TypeError(f"For '{option}', the type of key and value must be str, "
946
+ f"but got {type(key)}. The error key is {key}.")
947
+ if not isinstance(value, str):
948
+ raise TypeError(f"For '{option}', the type of key and value must be str, "
949
+ f"but got {type(value)}. The error value is {value}")
950
+
951
+
952
+ def _check_infer_boost_value(option, value):
953
+ """check infer_boost option value"""
954
+ if not isinstance(value, str):
955
+ raise TypeError(f"For 'jit(options)', the type of '{option}' must be str, but got {type(value)}.")
956
+
957
+ if value not in ['on', 'off']:
958
+ raise ValueError(f"For '{option}', the value of '{option}' must be one of ['on', 'off'], but got '{value}'.")
959
+
960
+
961
+ def _check_option_value(option, value):
962
+ """check jit options wiwh value"""
963
+ option_valuecheck_funcs = {
964
+ 'disable_format_transform': _check_disable_format_transform_value,
965
+ 'exec_order': _check_exec_order_value,
966
+ 'ge_options': _check_ge_options_value,
967
+ 'infer_boost': _check_infer_boost_value,
968
+ }
969
+ if option in option_valuecheck_funcs:
970
+ option_valuecheck_funcs[option](option, value)
971
+ else:
972
+ logger.warning(f"For 'jit(options)', the option argument '{option}' is not recognized, please check!"
973
+ f"For detailed usage of 'jit(options)', please refer to the Mindspore official website.")
974
+
975
+
976
+ def _check_options(options, backend):
977
+ """Check jit options"""
978
+ # check whether there are deprecated parameters in the dict `options`.
979
+ deprecated_args = {'mode': 'capture_mode', 'input_signature': 'dynamic', 'hash_args: ': '',
980
+ 'jit_config': 'jit_level, fullgraph or options', 'compile_once': ''}
981
+ for key, value in deprecated_args.items():
982
+ if key in options:
983
+ log = f"For 'jit', the parameter '{key}' has been deprecated."
984
+ if value != '':
985
+ log += f" Please use the parameter '{value}' instead. For more details, please refer to " \
986
+ f"https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.jit.html."
987
+ logger.warning(log)
988
+ del options[key]
989
+
990
+ # check options' device, backend and value
991
+ for option, value in options.items():
992
+ _check_option_backend(option, backend)
993
+ _check_option_value(option, value)
994
+
995
+
996
+ def jit(
997
+ function: Optional[Callable] = None,
998
+ *,
999
+ capture_mode: str = "ast",
1000
+ jit_level: str = "O0",
1001
+ dynamic: int = 0,
1002
+ fullgraph: bool = False,
1003
+ backend: str = "",
1004
+ **options):
807
1005
  """
808
1006
  Create a callable MindSpore graph from a Python function.
809
1007
 
810
1008
  This allows the MindSpore runtime to apply optimizations based on graph.
811
1009
 
812
1010
  Note:
813
- - If `input_signature` is specified, each input of `fn` must be a Tensor. And the input arguments for `fn`
814
- will not accept `**kwargs`.
815
- - It is not supported to run a function with decoration @jit(mode=“PIJit”)
816
- in static graph mode, in which case the decoration @jit(mode=“PIJit”) is considered invalid.
817
- - Calls to functions with decorated @jit(mode=“PIJit”) inside functions
818
- decorated with @jit(mode=“PIJit”) are not supported,
819
- and the decoration @jit(mode=“PIJit”) is considered invalid.
1011
+ - It is not supported to run a function with decoration @jit(capture_mode=“bytecode”)
1012
+ in static graph mode, in which case the decoration @jit(capture_mode=“bytecode”) is considered invalid.
1013
+ - Calls to functions with decorated @jit(capture_mode=“bytecode”) inside functions
1014
+ decorated with @jit(capture_mode=“ast”) are not supported,
1015
+ and the decoration @jit(capture_mode=“bytecode”) is considered invalid.
820
1016
 
821
1017
  Args:
822
- fn (Function): The Python function that will be run as a graph. Default: ``None`` .
823
- mode (str): The type of jit used, the value of mode should be ``PIJit`` or ``PSJit``. Default: ``PSJit`` .
824
-
825
- - `PSJit <https://www.mindspore.cn/docs/en/master/model_train/program_form/static_graph.html>`_ :
826
- Parse python ast to build graph.
827
- - `PIJit <https://www.mindspore.cn/docs/en/master/model_train/program_form/pynative.html#pijit>`_ :
828
- Parse python bytecode to build graph at runtime.
829
-
830
- input_signature (Union[Tuple, List, Dict, Tensor]): The Tensor which describes the input arguments. The
831
- shape and dtype of the Tensor will be supplied to this function. If `input_signature` is specified, the
832
- input parameters of `fn` cannot accept `**kwargs`, and the shape and dtype of actual inputs should keep the
833
- same as `input_signature`. Otherwise, TypeError will be raised. There are two mode for `input_signature`:
834
-
835
- - Full mode: Arguments is a Tuple, List or a Tensor, and they will be used as all compile inputs
836
- for graph-compiling.
837
- - Incremental mode: Argument is a Dict, and they will set to some of the graph inputs, which will be
838
- substituted into the input at the corresponding position for graph-compiling.
839
-
840
- Default: ``None`` .
841
-
842
- hash_args (Union[Object, List or Tuple of Objects]): The local free variables used inside `fn`,
843
- like functions or objects of class defined outside `fn`. Calling `fn` again with change of `hash_args`
844
- will trigger recompilation. Default: ``None`` .
845
- jit_config (JitConfig): Jit config for compile. Default: ``None`` .
846
- compile_once(bool): ``True``: The function would be compiled once when it was created many times.
847
- But it may be wrong if the free variables were changed. ``False`` : It would be recompiled when
848
- it was created again.
849
- Default: ``False`` .
1018
+ function (Function, optional): The Python function that will be run as a graph. Default: ``None``.
1019
+
1020
+ Keyword Args:
1021
+ capture_mode (str, optional): The method to create a callable MindSpore graph. The value of capture_mode
1022
+ should be ``ast`` , ``bytecode`` or ``trace`` . Default: ``ast`` .
1023
+
1024
+ - `ast <https://www.mindspore.cn/tutorials/en/master/compile/static_graph.html>`_ :
1025
+ Parse Python ast to build graph.
1026
+ - `bytecode` :
1027
+ Parse Python bytecode to build graph at runtime. This is an experimental prototype that is subject to
1028
+ change and/or deletion.
1029
+ - `trace` : Trace the execution of Python code to build graph. This is an experimental prototype that is
1030
+ subject to change and/or deletion.
1031
+
1032
+ jit_level (str, optional): Used to control the compilation optimization level. Currently is only effective
1033
+ with default backend. The value of jit_level should be ``O0`` or ``O1`` . Default: ``O0`` .
1034
+
1035
+ - `O0`: Except for optimizations that may affect functionality, all other optimizations are turned off.
1036
+ - `O1`: Using commonly used optimizations and automatic operator fusion optimizations. This optimization
1037
+ level is experimental and is being improved.
1038
+
1039
+ dynamic (int, optional): Whether dynamic shape compilation should be performed. Default: ``0``. The value range
1040
+ is as follows:
1041
+
1042
+ - `0`: Do not perform dynamic shape compilation.
1043
+ - `1`: Enable dynamic shape compilation and automatically detect shape changes.
1044
+
1045
+ fullgraph (bool, optional): Whether to capture the entire function into graph. If False, jit attempts to
1046
+ be compatible with all Python syntax in the function as much as possible. If True, we require that the
1047
+ entire function can be captured into graph. If this is not possible (that is, if there is Python syntax
1048
+ not supported), then it will raise an exception. This currently only applies when capture_mode is ast.
1049
+ Default: ``False``.
1050
+ backend (str, optional): The compilation backend to be used. If this parameter is not set, the framework will
1051
+ use ``GE`` backend for Atlas training series products and ``ms_backend`` backend for others including Atlas
1052
+ A2 training series products by default.
1053
+
1054
+ - `ms_backend`: Adopt KernelByKernel execution mode.
1055
+ - `GE`: Adopt Sink execution mode. The whole model will be sinked to device to execute, only applicable to
1056
+ the top cell of model. And only can be used in Ascend platform.
1057
+
1058
+ **options (dict): A dictionary of options to pass to the compilation backend.
1059
+
1060
+ Some options are device specific, see the below table for details:
1061
+
1062
+ +---------------------------+---------------------------+-------------------------+
1063
+ | Option Parameters | Hardware Platform Support | Backend Support |
1064
+ +===========================+===========================+=========================+
1065
+ | disable_format_transform | GPU | ms_backend |
1066
+ +---------------------------+---------------------------+-------------------------+
1067
+ | exec_order | Ascend | ms_backend |
1068
+ +---------------------------+---------------------------+-------------------------+
1069
+ | ge_options | Ascend | GE |
1070
+ +---------------------------+---------------------------+-------------------------+
1071
+ | infer_boost | Ascend | ms_backend |
1072
+ +---------------------------+---------------------------+-------------------------+
1073
+
1074
+ - disable_format_transform (bool, optional): Whether to disable the automatic format transform function
1075
+ from NCHW to NHWC. When the network training performance of fp16 is worse than fp32,
1076
+ `disable_format_transform` can be set to ``True`` to try to improve training performance.
1077
+ Default: ``False`` .
1078
+ - exec_order (str, optional): Set the sorting method for operator execution, currently only two sorting
1079
+ methods are supported: ``bfs`` and ``dfs`` . Default: ``bfs`` .
1080
+
1081
+ - `bfs`: The default sorting method, breadth priority, good communication masking, relatively good
1082
+ performance.
1083
+ - `dfs`: An optional sorting method, depth-first sorting. The performance is relatively worse than that
1084
+ of bfs execution order, but it occupies less memory. It is recommended to try dfs in scenarios where
1085
+ other execution orders run out of memory (OOM).
1086
+
1087
+ - ge_options (dict): Set options for ge backend. The options are divided into two categories: global,
1088
+ and session. This is an experimental prototype that is subject to change and/or deletion.
1089
+ For detailed information, please refer to `Ascend community <https://www.hiascend.com/document/detail/zh/canncommercial/80RC3/apiref/ascendgraphapi/atlasgeapi_07_0146.html>`_ .
1090
+
1091
+ - global (dict): Set global options.
1092
+ - session (dict): Set session options.
1093
+
1094
+ - infer_boost (str, optional): Used to control the inference mode. Default: ``off``, which means
1095
+ the inference mode is disabled. The range is as follows:
1096
+
1097
+ - `on`: Enable inference mode, get better infer performance.
1098
+ - `off`: Disable inference mode, use forward for inference. The performance is poor.
850
1099
 
851
1100
  Returns:
852
1101
  Function, if `fn` is not None, returns a callable function that will execute the compiled function; If `fn` is
@@ -865,12 +1114,12 @@ def jit(fn=None, mode="PSJit", input_signature=None, hash_args=None, jit_config=
865
1114
  >>> x = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
866
1115
  >>> y = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
867
1116
  ...
868
- >>> # create a callable MindSpore graph by calling decorator @jit
1117
+ >>> # create a callable MindSpore graph by calling jit
869
1118
  >>> def tensor_add(x, y):
870
1119
  ... z = x + y
871
1120
  ... return z
872
1121
  ...
873
- >>> tensor_add_graph = jit(fn=tensor_add)
1122
+ >>> tensor_add_graph = jit(function=tensor_add)
874
1123
  >>> out = tensor_add_graph(x, y)
875
1124
  ...
876
1125
  >>> # create a callable MindSpore graph through decorator @jit
@@ -881,180 +1130,70 @@ def jit(fn=None, mode="PSJit", input_signature=None, hash_args=None, jit_config=
881
1130
  ...
882
1131
  >>> out = tensor_add_with_dec(x, y)
883
1132
  ...
884
- >>> # create a callable MindSpore graph through decorator @jit with input_signature parameter
885
- >>> @jit(input_signature=(Tensor(np.ones([1, 1, 3, 3]).astype(np.float32)),
886
- ... Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))))
887
- ... def tensor_add_with_sig(x, y):
888
- ... z = x + y
889
- ... return z
890
- ...
891
- >>> out = tensor_add_with_sig(x, y)
892
- ...
893
- >>> @jit(input_signature={"y": Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))})
894
- ... def tensor_add_with_sig_1(x, y):
1133
+ >>> # create a callable MindSpore graph and capture the entire function into the graph
1134
+ >>> @jit(fullgraph=True)
1135
+ ... def tensor_add_fullgraph(x, y):
895
1136
  ... z = x + y
896
1137
  ... return z
897
1138
  ...
898
- >>> out1 = tensor_add_with_sig_1(x, y)
899
- ...
900
- ... # Set hash_args as fn, otherwise cache of compiled closure_fn will not be reused.
901
- ... # While fn differs during calling again, recompilation will be triggered.
902
- >>> def func(x):
903
- ... return ops.exp(x)
904
- ...
905
- >>> def closure_fn(x, fn):
906
- ... @jit(hash_args=fn)
907
- ... def inner_fn(a):
908
- ... return fn(a)
909
- ... return inner_fn(x)
910
- ...
911
- >>> inputs = Tensor(np.ones([10, 10, 10]).astype(np.float32))
912
- >>> for i in range(10):
913
- ... closure_fn(inputs, func)
914
- ...
915
- ... # Set compile_once = True, otherwise the train_step will be compiled again.
916
- >>> def train(x):
917
- ... @jit(compile_once = True)
918
- ... def train_step(x):
919
- ... return ops.exp(x)
920
- ... for i in range(10):
921
- ... train_step(x)
922
- ...
923
- >>> inputs = Tensor(np.ones([10, 10, 10]).astype(np.float32))
924
- >>> for i in range(10):
925
- ... train(inputs)
1139
+ >>> out = tensor_add_fullgraph(x, y)
926
1140
  """
927
1141
 
928
- def wrap_mindspore(func):
929
- if not isinstance(compile_once, bool):
930
- logger.warning(f"The parameter `compile_once` of jit should be a bool, "
931
- f"but got {type(compile_once)}.")
932
- if hash_args:
933
- hash_obj = _get_jit_hash(hash_args)
934
- elif compile_once:
935
- hash_obj = 0
936
- else:
1142
+ capture_mode = Validator.check_string(capture_mode, ["ast", "bytecode", "trace"], "capture_mode", "jit")
1143
+ jit_level = Validator.check_string(jit_level, ["O0", "O1"], "jit_level", "jit")
1144
+ dynamic = Validator.check_int_range(dynamic, 0, 1, Validator.INC_BOTH, "dynamic", "jit")
1145
+ fullgraph = Validator.check_bool(fullgraph, "fullgraph", "jit")
1146
+ if backend == "":
1147
+ backend = "GE" if MSContext.get_instance().get_ascend_soc_version() == "ascend910" else "ms_backend"
1148
+ backend = Validator.check_string(backend, ["ms_backend", "GE"], "backend", "jit")
1149
+ jit_syntax_level = "LAX" if fullgraph is False else "STRICT"
1150
+ hash_obj = _get_hash_obj(options)
1151
+ _check_options(options, backend)
1152
+ options_str = json.dumps(options)
1153
+ infer_boost = options['infer_boost'] if 'infer_boost' in options else "off"
1154
+ exc_mode = options['exc_mode'] if 'exc_mode' in options else "auto"
1155
+ jit_config = JitConfig(jit_level=jit_level, exc_mode=exc_mode, jit_syntax_level=jit_syntax_level,
1156
+ infer_boost=infer_boost, backend=backend, options=options_str)
1157
+
1158
+ def wrap_func(func):
1159
+ nonlocal hash_obj
1160
+ if hash_obj is None or not _is_inner_func(func):
937
1161
  hash_obj = int(time.time() * 1e9)
938
1162
 
939
- dyn_args = _process_dyn_args(func, input_signature)
940
-
941
1163
  @wraps(func)
942
1164
  def staging_specialize(*args, **kwargs):
943
1165
  if os.getenv("MS_JIT") == '0':
944
1166
  return func(*args, **kwargs)
945
1167
 
946
1168
  args, kwargs = _handle_func_args(func, *args, **kwargs)
947
-
948
1169
  process_obj = None
949
1170
  if args and not isinstance(args[0], PythonTensor) and hasattr(args[0], func.__name__):
950
1171
  process_obj = args[0]
951
- # only the function or cell instance wrapped by shard will fall into this branch
952
- if _is_pynative_parallel() and func.__name__ == _PYNATIVE_PARALLEL_FUNC_NAME:
953
- process_obj = hash_args
954
1172
  # Handle auto mixed precision strategy.
955
1173
  if not hasattr(func, "amp_strategy"):
956
1174
  if isinstance(func, types.MethodType):
957
1175
  setattr(func.__func__, "amp_strategy", get_curr_amp_strategy())
958
1176
  else:
959
1177
  setattr(func, "amp_strategy", get_curr_amp_strategy())
960
- out = _MindsporeFunctionExecutor(func, hash_obj, dyn_args, process_obj, jit_config)(*args, **kwargs)
1178
+
1179
+ ms_function_executor = _JitExecutor(func, hash_obj, None, process_obj, jit_config, dynamic)
1180
+ out = ms_function_executor(*args, **kwargs)
961
1181
  return out
962
1182
 
963
1183
  return staging_specialize
964
1184
 
965
- wrap_func = wrap_mindspore
966
- if mode == "PIJit":
967
- wrap_func = PIJitCaptureContext(jit_config, input_signature)
1185
+ if capture_mode == "bytecode":
1186
+ wrap_func = PIJitCaptureContext(jit_config)
1187
+ elif capture_mode == "trace":
1188
+ if function is not None:
1189
+ return _jit_trace(function)
1190
+ return _jit_trace
968
1191
 
969
- if fn is not None:
970
- return wrap_func(fn)
1192
+ if function is not None:
1193
+ return wrap_func(function)
971
1194
  return wrap_func
972
1195
 
973
1196
 
974
- def ms_function(fn=None, input_signature=None, hash_args=None, jit_config=None):
975
- """
976
- Create a callable MindSpore graph from a Python function.
977
-
978
- This allows the MindSpore runtime to apply optimizations based on graph.
979
-
980
- Note:
981
- - `ms_function` will be deprecated and removed in a future version. Please use :func:`mindspore.jit` instead.
982
- - If `input_signature` is specified, each input of `fn` must be a Tensor. And the input arguments for `fn`
983
- will not accept `**kwargs`.
984
-
985
- Args:
986
- fn (Function): The Python function that will be run as a graph. Default: ``None`` .
987
- input_signature (Tensor): The Tensor which describes the input arguments. The shape and dtype of the Tensor
988
- will be supplied to this function. The shape and dtype of actual inputs of `fn` should
989
- keep the same as input_signature. Otherwise, TypeError will be raised. Default: ``None`` .
990
- hash_args (Union[Object, List or Tuple of Objects]): The local free variables used inside `fn`,
991
- like functions or objects of class defined outside `fn`. Calling `fn` again with change of `hash_args`
992
- will trigger recompilation. Default: ``None`` .
993
- jit_config (JitConfig): Jit config for compile. Default: ``None`` .
994
-
995
- Returns:
996
- Function, if `fn` is not None, returns a callable function that will execute the compiled function; If `fn` is
997
- None, returns a decorator and when this decorator invokes with a single `fn` argument, the callable function is
998
- equal to the case when `fn` is not None.
999
-
1000
- Supported Platforms:
1001
- ``Ascend`` ``GPU`` ``CPU``
1002
-
1003
- Examples:
1004
- >>> import numpy as np
1005
- >>> from mindspore import Tensor
1006
- >>> from mindspore import ops
1007
- >>> from mindspore import ms_function
1008
- ...
1009
- >>> x = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
1010
- >>> y = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
1011
- ...
1012
- >>> # create a callable MindSpore graph by calling ms_function
1013
- >>> def tensor_add(x, y):
1014
- ... z = x + y
1015
- ... return z
1016
- ...
1017
- >>> tensor_add_graph = ms_function(fn=tensor_add)
1018
- >>> out = tensor_add_graph(x, y)
1019
- ...
1020
- >>> # create a callable MindSpore graph through decorator @ms_function
1021
- >>> @ms_function
1022
- ... def tensor_add_with_dec(x, y):
1023
- ... z = x + y
1024
- ... return z
1025
- ...
1026
- >>> out = tensor_add_with_dec(x, y)
1027
- ...
1028
- >>> # create a callable MindSpore graph through decorator @ms_function with input_signature parameter
1029
- >>> @ms_function(input_signature=(Tensor(np.ones([1, 1, 3, 3]).astype(np.float32)),
1030
- ... Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))))
1031
- ... def tensor_add_with_sig(x, y):
1032
- ... z = x + y
1033
- ... return z
1034
- ...
1035
- >>> out = tensor_add_with_sig(x, y)
1036
- ...
1037
- ... # Set hash_args as fn, otherwise cache of compiled `closure_fn` will not be reused.
1038
- ... # While fn differs during calling again, recompilation will be triggered.
1039
- >>> def func(x):
1040
- ... return ops.exp(x)
1041
- ...
1042
- >>> def closure_fn(x, fn):
1043
- ... @ms_function(hash_args=fn)
1044
- ... def inner_fn(a):
1045
- ... return fn(a)
1046
- ... return inner_fn(x)
1047
- ...
1048
- >>> inputs = Tensor(np.ones([10, 10, 10]).astype(np.float32))
1049
- >>> for i in range(10):
1050
- ... closure_fn(inputs, func)
1051
- """
1052
-
1053
- logger.warning("'mindspore.ms_function' will be deprecated and removed in a future version. "
1054
- "Please use 'mindspore.jit' instead.")
1055
- return jit(fn=fn, input_signature=input_signature, hash_args=hash_args, jit_config=jit_config)
1056
-
1057
-
1058
1197
  def _core(fn=None, **flags):
1059
1198
  """
1060
1199
  A decorator that adds a flag to the function.
@@ -1147,69 +1286,6 @@ def _no_recursive(callable_obj):
1147
1286
  return callable_obj
1148
1287
 
1149
1288
 
1150
- def ms_class(cls):
1151
- """
1152
- Class decorator for user-defined classes.
1153
-
1154
- This allows MindSpore to identify user-defined classes and thus obtain their attributes and methods.
1155
-
1156
- Note:
1157
- `ms_class` will be deprecated and removed in a future version. Please use :func:`mindspore.jit_class` instead.
1158
-
1159
- Args:
1160
- cls (Class): User-defined class.
1161
-
1162
- Returns:
1163
- Class.
1164
-
1165
- Raises:
1166
- TypeError: If ms_class is used for non-class types or nn.Cell.
1167
- AttributeError: If the private attributes or magic methods of the class decorated with ms_class is called.
1168
-
1169
- Supported Platforms:
1170
- ``Ascend`` ``GPU`` ``CPU``
1171
-
1172
- Examples:
1173
- >>> import mindspore.nn as nn
1174
- >>> from mindspore import ms_class
1175
- ...
1176
- >>> @ms_class
1177
- ... class UserDefinedNet:
1178
- ... def __init__(self):
1179
- ... self.value = 10
1180
- ...
1181
- ... def func(self, x):
1182
- ... return 2 * x
1183
- ...
1184
- >>> class Net(nn.Cell):
1185
- ... def __init__(self):
1186
- ... super(Net, self).__init__()
1187
- ... self.net = UserDefinedNet()
1188
- ...
1189
- ... def construct(self, x):
1190
- ... out = self.net.value + self.net.func(x)
1191
- ... return out
1192
- ...
1193
- >>> net = Net()
1194
- >>> out = net(5)
1195
- >>> print(out)
1196
- 20
1197
- """
1198
-
1199
- logger.warning("'mindspore.ms_class' will be deprecated and removed in a future version. "
1200
- "Please use 'mindspore.jit_class' instead.")
1201
-
1202
- # Check if cls is of type class.
1203
- if not inspect.isclass(cls):
1204
- raise TypeError(f'Decorator ms_class can only be used for class type, but got {cls}.')
1205
- # Check if cls is nn.Cell.
1206
- if issubclass(cls, ms.nn.Cell):
1207
- raise TypeError(f"Decorator ms_class is used for user-defined classes and cannot be used for nn.Cell: {cls}.")
1208
- logger.info(f'Found ms_class: {cls}.')
1209
- setattr(cls, '__ms_class__', True)
1210
- return cls
1211
-
1212
-
1213
1289
  def jit_class(cls):
1214
1290
  """
1215
1291
  Class decorator for user-defined classes.
@@ -1266,28 +1342,6 @@ def jit_class(cls):
1266
1342
  return cls
1267
1343
 
1268
1344
 
1269
- def set_adapter_config(config):
1270
- """
1271
- Register configuration information for MSAdapter.
1272
-
1273
- Args:
1274
- config (dict): Configuration information.
1275
- """
1276
- if not isinstance(config, dict):
1277
- raise TypeError(f"The input argument of 'set_adapter_config' should be a dict, but got {config}.")
1278
- for key, value in config.items():
1279
- if key == "Tensor":
1280
- ms_adapter_registry.register_tensor(value)
1281
- elif key == "Parameter":
1282
- ms_adapter_registry.register_parameter(value)
1283
- elif key == "convert_object_map":
1284
- ms_adapter_registry.register_convert_map(value)
1285
- elif key == "convert_adapter_tensor_map":
1286
- ms_adapter_registry.register_convert_adapter_tensor_map(value)
1287
- else:
1288
- raise ValueError(f"Unsupported key in adapter config: {key}")
1289
-
1290
-
1291
1345
  def _function_forbid_reuse(func):
1292
1346
  if not inspect.isfunction(func):
1293
1347
  raise TypeError(f'Decorator _function_forbid_reuse can only be used for function type, but got {func}.')
@@ -1351,8 +1405,6 @@ class _no_grad(contextlib.ContextDecorator):
1351
1405
  self.prev_state = False
1352
1406
 
1353
1407
  def __enter__(self):
1354
- if context.get_context("mode") == context.GRAPH_MODE:
1355
- raise RuntimeError("For no_grad feature, currently only support Pynative mode, but got Graph mode.")
1356
1408
  self.prev_state = _pynative_executor.enable_grad()
1357
1409
  _pynative_executor.set_enable_grad(False)
1358
1410
 
@@ -1481,7 +1533,24 @@ class _PyNativeExecutor:
1481
1533
  Return:
1482
1534
  None.
1483
1535
  """
1484
- return self._executor.grad(grad, obj, weights, grad_position, *args)
1536
+ return self._executor.grad(grad, obj, weights, grad_position, False, *args)
1537
+
1538
+ def grad_aux(self, obj, grad, weights, grad_position, *args):
1539
+ """
1540
+ Run grad graph with aux
1541
+
1542
+ Args:
1543
+ obj (Function/Cell): The function or cell instance.
1544
+ grad (GradOperation): The gradoperation object.
1545
+ weights (ParameterTuple): The weights of cell instance.
1546
+ grad_position (Union(int, tuple[int])): If int, get the gradient with respect to single input.
1547
+ If tuple, get the gradients with respect to selected inputs. 'grad_position' begins with 0. Default: 0.
1548
+ args (tuple): Function or cell input arguments.
1549
+
1550
+ Return:
1551
+ None.
1552
+ """
1553
+ return self._executor.grad(grad, obj, weights, grad_position, True, *args)
1485
1554
 
1486
1555
  def clear_res(self):
1487
1556
  """
@@ -1501,18 +1570,18 @@ class _PyNativeExecutor:
1501
1570
  """
1502
1571
  self._executor.sync()
1503
1572
 
1504
- def grad_jit(self, output, *args):
1573
+ def grad_jit(self, *args):
1505
1574
  """
1506
1575
  Building grad graph decorated by jit.
1507
1576
 
1508
1577
  Args:
1509
- output (tuple): The function or cell decorated by jit output object.
1510
1578
  args (tuple): Function or cell decorated by jit input arguments.
1511
1579
 
1512
1580
  Return:
1513
- None.
1581
+ output: The output object of function or cell decorated by jit.
1514
1582
  """
1515
- return self._executor.grad_jit(output, *args)
1583
+ output = self._executor.grad_jit(*args)
1584
+ return output
1516
1585
 
1517
1586
  def call_custom_bprop(self, obj, output, *args, **kwargs):
1518
1587
  """
@@ -1617,6 +1686,15 @@ class _PyNativeExecutor:
1617
1686
  """
1618
1687
  self._executor.set_is_run_recompute(status)
1619
1688
 
1689
+ def high_order(self):
1690
+ """
1691
+ Is high order of current scene, this is a inner interface.
1692
+
1693
+ Return:
1694
+ Bool.
1695
+ """
1696
+ return self._executor.high_order()
1697
+
1620
1698
  def set_cell_use_dynamic_shape_process(self, flag):
1621
1699
  """
1622
1700
  Set the dynamic shape flag of eval process.
@@ -1699,7 +1777,6 @@ class _CellGraphExecutor:
1699
1777
  # create needed graph by lazy mode
1700
1778
  self.is_init = False
1701
1779
  self.enable_tuple_broaden = False
1702
- self.obfuscate_config = None # used for model's dynamic obfuscation
1703
1780
  self._graph_executor = GraphExecutor_.get_instance()
1704
1781
  self._graph_executor.set_py_exe_path(sys.executable)
1705
1782
  self._graph_executor.set_kernel_build_server_dir(os.path.split(kernel_build_server.__file__)[0] + os.sep)
@@ -1791,6 +1868,7 @@ class _CellGraphExecutor:
1791
1868
  Str, the full phase of the cell.
1792
1869
  Bool, if the graph has been compiled before, return False, else return True.
1793
1870
  """
1871
+ _init_auto_parallel_context(obj)
1794
1872
  obj.__parse_method__ = 'construct'
1795
1873
  if not hasattr(obj, obj.__parse_method__):
1796
1874
  raise AttributeError(
@@ -1803,8 +1881,12 @@ class _CellGraphExecutor:
1803
1881
  self.enable_tuple_broaden = obj.enable_tuple_broaden
1804
1882
  logger.debug(f"Convert the network: {do_convert}.")
1805
1883
  self._graph_executor.set_enable_tuple_broaden(self.enable_tuple_broaden)
1884
+
1806
1885
  key = self._graph_executor.generate_arguments_key(obj, args, kwargs, self.enable_tuple_broaden)
1807
1886
  obj.arguments_key = str(key)
1887
+
1888
+ obj.arguments_key = obj.arguments_key + "." + _get_hook_key(*args, **kwargs)
1889
+
1808
1890
  # When exist parameter in the top graph inputs, need check if the parameter object has changed.
1809
1891
  parameter_ids = _get_parameter_ids(args, kwargs)
1810
1892
  if parameter_ids != "":
@@ -1814,11 +1896,12 @@ class _CellGraphExecutor:
1814
1896
  obj.phase_cache[raw_phase] = phase
1815
1897
  update_auto_dynamic_shape_phase(args, key_id, phase)
1816
1898
  obj.current_phase = phase
1817
- if phase in obj.compile_cache and self.has_compiled(phase):
1899
+ if phase in obj.compile_cache and self.has_compiled(phase) and not parameter_hook_updated():
1818
1900
  logger.debug("%r graph has existed.", phase)
1819
1901
  # Release resource should be released when CompileInner won't be executed, such as cur_convert_input_
1820
1902
  # generated in generate_arguments_key.
1821
1903
  self._graph_executor.clear_compile_arguments_resource()
1904
+ _clear_auto_parallel_context(obj)
1822
1905
  return phase, False
1823
1906
 
1824
1907
  full_function_name = obj.__class__.__name__ + '.' + str(obj.instance_count) + '.' + str(id(type(obj)))
@@ -1836,10 +1919,12 @@ class _CellGraphExecutor:
1836
1919
  else:
1837
1920
  jit_config_dict = JitConfig().jit_config_dict
1838
1921
  self._graph_executor.set_jit_config(jit_config_dict)
1839
- result = self._graph_executor.compile(obj, args, kwargs, phase, self._use_vm_mode())
1922
+ gc.collect()
1923
+ result = self._graph_executor.compile(obj, args, kwargs, phase)
1840
1924
  obj.compile_cache.add(phase)
1841
1925
  if not result:
1842
1926
  raise RuntimeError("Executor compile failed.")
1927
+ set_parameter_hook_updated(False)
1843
1928
  graph = self._graph_executor.get_func_graph(phase)
1844
1929
 
1845
1930
  if graph is None:
@@ -1856,6 +1941,7 @@ class _CellGraphExecutor:
1856
1941
  self._build_data_graph(obj, phase)
1857
1942
  elif BROADCAST_PHASE not in phase and _get_parameter_broadcast():
1858
1943
  _parameter_broadcast(obj)
1944
+ _clear_auto_parallel_context(obj)
1859
1945
  return phase, True
1860
1946
 
1861
1947
  def _update_param_node_default_input(self, phase, replace):
@@ -1875,7 +1961,7 @@ class _CellGraphExecutor:
1875
1961
  return self._graph_executor.get_allreduce_fusion(real_phase)
1876
1962
 
1877
1963
  def __call__(self, obj, *args, phase='predict'):
1878
- if context.get_context("precompile_only") or _is_role_sched():
1964
+ if context.get_context("precompile_only") or os.getenv('MS_DEV_PRECOMPILE_ONLY') == '1' or _is_role_sched():
1879
1965
  return None
1880
1966
  return self.run(obj, *args, phase=phase)
1881
1967
 
@@ -1935,25 +2021,12 @@ class _CellGraphExecutor:
1935
2021
  """Clear the memory resource of a network."""
1936
2022
  self._graph_executor.del_net_res(obj, net_id)
1937
2023
 
1938
- def _get_branch_control_input(self):
1939
- if ('obf_ratio' not in self.obfuscate_config.keys()) or (
1940
- 'obf_random_seed' not in self.obfuscate_config.keys()):
1941
- raise ValueError("'obf_ratio' and 'obf_random_seed' must be in obfuscate_config.")
1942
- obf_random_seed = self.obfuscate_config.get('obf_random_seed')
1943
- if obf_random_seed == 0:
1944
- branch_control_input = 0
1945
- else:
1946
- branch_control_input = _generate_branch_control_input(obf_random_seed)
1947
- return branch_control_input
1948
-
1949
2024
  def _get_func_graph(self, obj, exec_id, use_prefix=False):
1950
2025
  """Get func graph from pipeline."""
1951
2026
  if use_prefix:
1952
2027
  exec_id = exec_id + '.' + obj.arguments_key
1953
2028
  if self._graph_executor.has_compiled(exec_id) is False:
1954
2029
  return None
1955
- if self.obfuscate_config is not None:
1956
- raise ValueError('For get func graph, obfuscate_config is currently not supported now.')
1957
2030
  return self._graph_executor.get_func_graph(exec_id)
1958
2031
 
1959
2032
  def _get_func_graph_proto(self, obj, exec_id, ir_type="onnx_ir", use_prefix=False, incremental=False):
@@ -1962,11 +2035,6 @@ class _CellGraphExecutor:
1962
2035
  exec_id = exec_id + '.' + obj.arguments_key
1963
2036
  if self._graph_executor.has_compiled(exec_id) is False:
1964
2037
  return None
1965
- if self.obfuscate_config is not None:
1966
- branch_control_input = self._get_branch_control_input()
1967
- return self._graph_executor.get_obfuscate_func_graph_proto(exec_id, incremental,
1968
- self.obfuscate_config['obf_ratio'],
1969
- branch_control_input)
1970
2038
  return self._graph_executor.get_func_graph_proto(exec_id, ir_type, incremental)
1971
2039
 
1972
2040
  def get_optimize_graph_proto(self, obj):
@@ -2004,6 +2072,8 @@ def ms_memory_recycle():
2004
2072
  """
2005
2073
  if ms_compile_cache:
2006
2074
  _cell_graph_executor.del_net_res(None, ms_compile_cache)
2075
+ if os.getenv('MS_DEV_JIT_PIPELINE') != '0':
2076
+ JitExecutor_.get_instance().del_net_res(None, ms_compile_cache)
2007
2077
  ms_compile_cache.clear()
2008
2078
  for cell_cache in cells_compile_cache.values():
2009
2079
  if cell_cache:
@@ -2012,28 +2082,22 @@ def ms_memory_recycle():
2012
2082
  _ms_memory_recycle()
2013
2083
 
2014
2084
 
2015
- def _generate_branch_control_input(obf_random_seed):
2016
- """Generate append network input for dynamic obfuscation in random seed mode."""
2017
- seed_max = 2 ** 32 - 1
2018
- int_max = 2 ** 31 - 1
2019
- np.random.seed(obf_random_seed % seed_max)
2020
- # generate a string as hash function inputs
2021
- word_repo = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + "abcdefghigklmnopqrstuvwxyz" + "0123456789"
2022
- repo_len = len(word_repo)
2023
- sha_string = ''
2024
- string_len = 1024 * 1024
2025
- for _ in range(string_len):
2026
- rand_index = np.random.randint(0, repo_len)
2027
- sha_string += word_repo[rand_index]
2028
- # get hash result
2029
- sha_result = hashlib.sha256(sha_string.encode('utf-8')).hexdigest() # len is 64
2030
- branch_control_input = 1
2031
- hex_base = 16
2032
- for item in sha_result:
2033
- if int(item, hex_base) > 0:
2034
- branch_control_input *= int(item, hex_base)
2035
- branch_control_input %= int_max
2036
- return branch_control_input
2085
+ def set_recursion_limit(recursion_limit=1000):
2086
+ """
2087
+ Specify the recursion depth limit of function call before compiling graph.
2088
+ It needs to be call when the nested function call is too deep or the number of sub graphs is too large.
2089
+ If recursion_limit is set larger than before, the system max stack depth should be set larger too,
2090
+ otherwise a `core dumped` exception may be raised because of system stack overflow.
2091
+
2092
+ Args:
2093
+ recursion_limit (int, optional): The recursion depth limit. Must be a positive integer. Default: ``1000`` .
2094
+
2095
+ Examples:
2096
+ >>> import mindspore as ms
2097
+ >>> ms.set_recursion_limit(10000)
2098
+ """
2099
+ recursion_limit = Validator.check_positive_int(recursion_limit)
2100
+ GraphExecutor_.get_instance().set_max_call_depth(recursion_limit)
2037
2101
 
2038
2102
 
2039
2103
  def _bind_device_context():
@@ -2058,4 +2122,4 @@ def flops_collection(phase='train'):
2058
2122
  _cell_graph_executor = _CellGraphExecutor()
2059
2123
  _pynative_executor = _PyNativeExecutor()
2060
2124
 
2061
- __all__ = ['ms_function', 'ms_memory_recycle', 'ms_class', 'jit', 'jit_class', 'flops_collection']
2125
+ __all__ = ['ms_memory_recycle', 'jit', 'jit_class', 'flops_collection']