mindspore 2.5.0__cp39-cp39-win_amd64.whl → 2.6.0rc1__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (491) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  3. mindspore/Newtonsoft.Json.dll +0 -0
  4. mindspore/__init__.py +6 -4
  5. mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
  6. mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
  7. mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
  8. mindspore/_check_jit_forbidden_api.py +3 -0
  9. mindspore/_checkparam.py +3 -33
  10. mindspore/_deprecated/__init__.py +17 -0
  11. mindspore/_deprecated/jit.py +198 -0
  12. mindspore/_extends/builtin_operations.py +1 -1
  13. mindspore/_extends/parse/__init__.py +6 -7
  14. mindspore/_extends/parse/compile_config.py +19 -0
  15. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +22 -3
  16. mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
  17. mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
  18. mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
  19. mindspore/_extends/parse/parser.py +24 -193
  20. mindspore/_extends/parse/resources.py +1 -5
  21. mindspore/_extends/parse/standard_method.py +97 -74
  22. mindspore/_extends/pijit/__init__.py +2 -2
  23. mindspore/_extends/pijit/pijit_func_white_list.py +16 -11
  24. mindspore/_extends/pijit/tensor_func_list.py +27 -0
  25. mindspore/_extends/utils.py +1 -1
  26. mindspore/amp.py +4 -4
  27. mindspore/atlprov.dll +0 -0
  28. mindspore/avcodec-59.dll +0 -0
  29. mindspore/avdevice-59.dll +0 -0
  30. mindspore/avfilter-8.dll +0 -0
  31. mindspore/avformat-59.dll +0 -0
  32. mindspore/avutil-57.dll +0 -0
  33. mindspore/boost/__init__.py +2 -2
  34. mindspore/boost/base.py +3 -7
  35. mindspore/boost/boost_cell_wrapper.py +2 -2
  36. mindspore/c1.dll +0 -0
  37. mindspore/c1xx.dll +0 -0
  38. mindspore/c2.dll +0 -0
  39. mindspore/common/__init__.py +4 -3
  40. mindspore/common/_grad_function.py +56 -0
  41. mindspore/common/_pijit_context.py +14 -5
  42. mindspore/common/_register_for_tensor.py +1 -1
  43. mindspore/common/_stub_tensor.py +5 -10
  44. mindspore/common/_tensor_cpp_method.py +1 -1
  45. mindspore/common/_tensor_docs.py +1915 -3287
  46. mindspore/common/api.py +341 -354
  47. mindspore/common/auto_dynamic_shape.py +41 -44
  48. mindspore/common/dtype.py +5 -2
  49. mindspore/common/dump.py +7 -5
  50. mindspore/common/file_system.py +3 -0
  51. mindspore/common/hook_handle.py +5 -3
  52. mindspore/common/initializer.py +10 -6
  53. mindspore/common/jit_begin_end.py +94 -0
  54. mindspore/common/jit_config.py +6 -1
  55. mindspore/common/jit_context.py +76 -0
  56. mindspore/common/jit_trace.py +378 -0
  57. mindspore/common/lazy_inline.py +2 -2
  58. mindspore/common/mutable.py +5 -4
  59. mindspore/common/parameter.py +106 -39
  60. mindspore/common/seed.py +2 -2
  61. mindspore/common/sparse_tensor.py +23 -17
  62. mindspore/common/tensor.py +297 -714
  63. mindspore/communication/__init__.py +7 -5
  64. mindspore/communication/_comm_helper.py +47 -2
  65. mindspore/communication/comm_func.py +70 -53
  66. mindspore/communication/management.py +83 -17
  67. mindspore/context.py +214 -560
  68. mindspore/dataset/__init__.py +44 -20
  69. mindspore/dataset/audio/__init__.py +2 -8
  70. mindspore/dataset/audio/transforms.py +3 -17
  71. mindspore/dataset/core/config.py +3 -3
  72. mindspore/dataset/engine/cache_client.py +1 -1
  73. mindspore/dataset/engine/datasets.py +102 -120
  74. mindspore/dataset/engine/datasets_audio.py +22 -22
  75. mindspore/dataset/engine/datasets_standard_format.py +43 -24
  76. mindspore/dataset/engine/datasets_text.py +78 -85
  77. mindspore/dataset/engine/datasets_user_defined.py +108 -76
  78. mindspore/dataset/engine/datasets_vision.py +111 -108
  79. mindspore/dataset/engine/iterators.py +5 -3
  80. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
  81. mindspore/dataset/engine/samplers.py +279 -57
  82. mindspore/dataset/engine/serializer_deserializer.py +2 -1
  83. mindspore/dataset/engine/validators.py +10 -0
  84. mindspore/dataset/text/__init__.py +7 -6
  85. mindspore/dataset/text/transforms.py +6 -5
  86. mindspore/dataset/text/utils.py +3 -3
  87. mindspore/dataset/transforms/__init__.py +0 -9
  88. mindspore/dataset/transforms/transforms.py +3 -3
  89. mindspore/dataset/utils/browse_dataset.py +1 -1
  90. mindspore/dataset/vision/__init__.py +2 -9
  91. mindspore/dataset/vision/transforms.py +202 -158
  92. mindspore/dataset/vision/utils.py +7 -5
  93. mindspore/device_context/ascend/op_debug.py +60 -1
  94. mindspore/device_context/ascend/op_tuning.py +0 -4
  95. mindspore/device_manager.py +39 -3
  96. mindspore/dnnl.dll +0 -0
  97. mindspore/dpcmi.dll +0 -0
  98. mindspore/experimental/es/embedding_service.py +35 -27
  99. mindspore/experimental/map_parameter.py +4 -4
  100. mindspore/experimental/optim/adadelta.py +22 -26
  101. mindspore/experimental/optim/adagrad.py +4 -4
  102. mindspore/experimental/optim/adam.py +4 -0
  103. mindspore/experimental/optim/adamax.py +4 -4
  104. mindspore/experimental/optim/adamw.py +4 -0
  105. mindspore/experimental/optim/asgd.py +1 -1
  106. mindspore/experimental/optim/lr_scheduler.py +40 -22
  107. mindspore/experimental/optim/radam.py +5 -5
  108. mindspore/experimental/optim/rprop.py +1 -1
  109. mindspore/experimental/optim/sgd.py +1 -1
  110. mindspore/hal/contiguous_tensors_handle.py +6 -10
  111. mindspore/hal/device.py +55 -81
  112. mindspore/hal/event.py +38 -55
  113. mindspore/hal/memory.py +93 -144
  114. mindspore/hal/stream.py +81 -125
  115. mindspore/include/dataset/constants.h +7 -4
  116. mindspore/include/dataset/execute.h +2 -2
  117. mindspore/jpeg62.dll +0 -0
  118. mindspore/log.py +40 -2
  119. mindspore/mindrecord/__init__.py +20 -7
  120. mindspore/mindspore_backend_common.dll +0 -0
  121. mindspore/mindspore_backend_manager.dll +0 -0
  122. mindspore/mindspore_common.dll +0 -0
  123. mindspore/mindspore_core.dll +0 -0
  124. mindspore/mindspore_dump.dll +0 -0
  125. mindspore/mindspore_frontend.dll +0 -0
  126. mindspore/mindspore_glog.dll +0 -0
  127. mindspore/mindspore_memory_pool.dll +0 -0
  128. mindspore/mindspore_ms_backend.dll +0 -0
  129. mindspore/mindspore_ops.dll +0 -0
  130. mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
  131. mindspore/mindspore_ops_kernel_common.dll +0 -0
  132. mindspore/mindspore_profiler.dll +0 -0
  133. mindspore/mindspore_pyboost.dll +0 -0
  134. mindspore/mindspore_pynative.dll +0 -0
  135. mindspore/mindspore_res_manager.dll +0 -0
  136. mindspore/mindspore_runtime_pipeline.dll +0 -0
  137. mindspore/mint/__init__.py +131 -700
  138. mindspore/mint/distributed/__init__.py +5 -1
  139. mindspore/mint/distributed/distributed.py +194 -109
  140. mindspore/mint/linalg/__init__.py +2 -0
  141. mindspore/mint/nn/__init__.py +280 -18
  142. mindspore/mint/nn/functional.py +282 -64
  143. mindspore/mint/nn/layer/__init__.py +4 -0
  144. mindspore/mint/nn/layer/_functions.py +7 -3
  145. mindspore/mint/nn/layer/activation.py +120 -13
  146. mindspore/mint/nn/layer/conv.py +218 -24
  147. mindspore/mint/nn/layer/normalization.py +15 -16
  148. mindspore/mint/nn/layer/padding.py +1 -1
  149. mindspore/mint/nn/layer/pooling.py +66 -1
  150. mindspore/mint/optim/__init__.py +2 -1
  151. mindspore/mint/optim/sgd.py +171 -0
  152. mindspore/msobj140.dll +0 -0
  153. mindspore/mspdb140.dll +0 -0
  154. mindspore/mspdbcore.dll +0 -0
  155. mindspore/mspdbst.dll +0 -0
  156. mindspore/mspft140.dll +0 -0
  157. mindspore/msvcdis140.dll +0 -0
  158. mindspore/msvcp140_1.dll +0 -0
  159. mindspore/msvcp140_2.dll +0 -0
  160. mindspore/msvcp140_atomic_wait.dll +0 -0
  161. mindspore/msvcp140_codecvt_ids.dll +0 -0
  162. mindspore/nn/__init__.py +4 -1
  163. mindspore/nn/cell.py +1250 -176
  164. mindspore/nn/layer/activation.py +23 -21
  165. mindspore/nn/layer/basic.py +22 -16
  166. mindspore/nn/layer/container.py +1 -1
  167. mindspore/nn/layer/conv.py +22 -17
  168. mindspore/nn/layer/embedding.py +9 -8
  169. mindspore/nn/layer/normalization.py +48 -42
  170. mindspore/nn/layer/pooling.py +75 -31
  171. mindspore/nn/layer/transformer.py +11 -10
  172. mindspore/nn/learning_rate_schedule.py +4 -2
  173. mindspore/nn/loss/loss.py +27 -19
  174. mindspore/nn/optim/ada_grad.py +6 -5
  175. mindspore/nn/optim/adadelta.py +9 -7
  176. mindspore/nn/optim/adafactor.py +1 -1
  177. mindspore/nn/optim/adam.py +16 -12
  178. mindspore/nn/optim/adamax.py +8 -7
  179. mindspore/nn/optim/adasum.py +5 -5
  180. mindspore/nn/optim/asgd.py +1 -1
  181. mindspore/nn/optim/ftrl.py +11 -9
  182. mindspore/nn/optim/lamb.py +1 -1
  183. mindspore/nn/optim/lazyadam.py +12 -10
  184. mindspore/nn/optim/momentum.py +7 -6
  185. mindspore/nn/optim/optimizer.py +2 -2
  186. mindspore/nn/optim/proximal_ada_grad.py +12 -10
  187. mindspore/nn/optim/rmsprop.py +13 -12
  188. mindspore/nn/optim/rprop.py +9 -7
  189. mindspore/nn/optim/sgd.py +9 -6
  190. mindspore/nn/optim/tft_wrapper.py +5 -2
  191. mindspore/nn/probability/bijector/bijector.py +17 -11
  192. mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
  193. mindspore/nn/probability/bijector/invert.py +2 -2
  194. mindspore/nn/probability/bijector/scalar_affine.py +3 -3
  195. mindspore/nn/probability/bijector/softplus.py +3 -2
  196. mindspore/nn/probability/distribution/beta.py +3 -3
  197. mindspore/nn/probability/distribution/categorical.py +1 -1
  198. mindspore/nn/probability/distribution/cauchy.py +4 -2
  199. mindspore/nn/probability/distribution/exponential.py +6 -7
  200. mindspore/nn/probability/distribution/gamma.py +2 -2
  201. mindspore/nn/probability/distribution/gumbel.py +2 -2
  202. mindspore/nn/probability/distribution/half_normal.py +5 -3
  203. mindspore/nn/probability/distribution/logistic.py +5 -3
  204. mindspore/nn/probability/distribution/poisson.py +1 -1
  205. mindspore/nn/probability/distribution/uniform.py +5 -3
  206. mindspore/nn/reinforcement/_tensors_queue.py +1 -1
  207. mindspore/nn/reinforcement/tensor_array.py +1 -1
  208. mindspore/nn/wrap/__init__.py +6 -6
  209. mindspore/nn/wrap/cell_wrapper.py +178 -117
  210. mindspore/nn/wrap/grad_reducer.py +45 -36
  211. mindspore/nn/wrap/loss_scale.py +3 -3
  212. mindspore/numpy/array_creations.py +3 -3
  213. mindspore/numpy/array_ops.py +1 -1
  214. mindspore/numpy/math_ops.py +4 -4
  215. mindspore/numpy/utils.py +1 -2
  216. mindspore/numpy/utils_const.py +1 -2
  217. mindspore/opencv_core452.dll +0 -0
  218. mindspore/opencv_imgcodecs452.dll +0 -0
  219. mindspore/opencv_imgproc452.dll +0 -0
  220. mindspore/ops/__init__.py +3 -2
  221. mindspore/ops/_grad_experimental/grad_comm_ops.py +18 -3
  222. mindspore/ops/_grad_experimental/grad_debug_ops.py +8 -1
  223. mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
  224. mindspore/ops/_register_for_op.py +0 -11
  225. mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
  226. mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -4
  227. mindspore/ops/_vmap/vmap_array_ops.py +7 -6
  228. mindspore/ops/_vmap/vmap_grad_nn_ops.py +2 -1
  229. mindspore/ops/_vmap/vmap_math_ops.py +4 -7
  230. mindspore/ops/_vmap/vmap_nn_ops.py +9 -8
  231. mindspore/ops/auto_generate/__init__.py +4 -3
  232. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +102 -49
  233. mindspore/ops/auto_generate/gen_extend_func.py +281 -135
  234. mindspore/ops/auto_generate/gen_ops_def.py +2574 -2326
  235. mindspore/ops/auto_generate/gen_ops_prim.py +8566 -2755
  236. mindspore/ops/auto_generate/pyboost_inner_prim.py +106 -76
  237. mindspore/ops/composite/__init__.py +2 -1
  238. mindspore/ops/composite/base.py +19 -24
  239. mindspore/ops/composite/math_ops.py +6 -16
  240. mindspore/ops/composite/multitype_ops/__init__.py +5 -2
  241. mindspore/ops/composite/multitype_ops/_compile_utils.py +2 -3
  242. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
  243. mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
  244. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
  245. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
  246. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
  247. mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
  248. mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
  249. mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
  250. mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
  251. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
  252. mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
  253. mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
  254. mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
  255. mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
  256. mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
  257. mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
  258. mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
  259. mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
  260. mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
  261. mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
  262. mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
  263. mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
  264. mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
  265. mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
  266. mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
  267. mindspore/ops/composite/multitype_ops/pow_impl.py +2 -1
  268. mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
  269. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  270. mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
  271. mindspore/ops/function/__init__.py +28 -2
  272. mindspore/ops/function/_add_attr_func.py +58 -0
  273. mindspore/ops/function/array_func.py +1629 -2345
  274. mindspore/ops/function/clip_func.py +38 -45
  275. mindspore/ops/function/debug_func.py +36 -44
  276. mindspore/ops/function/grad/__init__.py +1 -0
  277. mindspore/ops/function/grad/grad_func.py +104 -71
  278. mindspore/ops/function/image_func.py +1 -1
  279. mindspore/ops/function/linalg_func.py +46 -78
  280. mindspore/ops/function/math_func.py +3035 -3705
  281. mindspore/ops/function/nn_func.py +676 -241
  282. mindspore/ops/function/other_func.py +159 -1
  283. mindspore/ops/function/parameter_func.py +17 -30
  284. mindspore/ops/function/random_func.py +204 -361
  285. mindspore/ops/function/reshard_func.py +4 -70
  286. mindspore/ops/function/sparse_func.py +3 -3
  287. mindspore/ops/function/sparse_unary_func.py +5 -5
  288. mindspore/ops/function/spectral_func.py +25 -58
  289. mindspore/ops/function/vmap_func.py +24 -17
  290. mindspore/ops/functional.py +6 -4
  291. mindspore/ops/functional_overload.py +547 -4
  292. mindspore/ops/op_info_register.py +32 -244
  293. mindspore/ops/operations/__init__.py +10 -5
  294. mindspore/ops/operations/_custom_ops_utils.py +247 -0
  295. mindspore/ops/operations/_grad_ops.py +1 -10
  296. mindspore/ops/operations/_inner_ops.py +5 -76
  297. mindspore/ops/operations/_ms_kernel.py +4 -10
  298. mindspore/ops/operations/_rl_inner_ops.py +1 -1
  299. mindspore/ops/operations/_scalar_ops.py +3 -2
  300. mindspore/ops/operations/_sequence_ops.py +1 -1
  301. mindspore/ops/operations/_tensor_array.py +1 -1
  302. mindspore/ops/operations/array_ops.py +37 -22
  303. mindspore/ops/operations/comm_ops.py +150 -107
  304. mindspore/ops/operations/custom_ops.py +221 -23
  305. mindspore/ops/operations/debug_ops.py +115 -16
  306. mindspore/ops/operations/inner_ops.py +1 -1
  307. mindspore/ops/operations/linalg_ops.py +1 -58
  308. mindspore/ops/operations/manually_defined/_inner.py +1 -1
  309. mindspore/ops/operations/manually_defined/ops_def.py +746 -79
  310. mindspore/ops/operations/math_ops.py +21 -18
  311. mindspore/ops/operations/nn_ops.py +65 -191
  312. mindspore/ops/operations/other_ops.py +62 -9
  313. mindspore/ops/operations/random_ops.py +13 -7
  314. mindspore/ops/operations/reshard_ops.py +1 -1
  315. mindspore/ops/operations/sparse_ops.py +2 -2
  316. mindspore/ops/primitive.py +43 -32
  317. mindspore/ops/tensor_method.py +232 -13
  318. mindspore/ops_generate/__init__.py +0 -5
  319. mindspore/ops_generate/aclnn/__init__.py +0 -0
  320. mindspore/ops_generate/{aclnn_kernel_register_auto_cc_generator.py → aclnn/aclnn_kernel_register_auto_cc_generator.py} +43 -18
  321. mindspore/ops_generate/{gen_aclnn_implement.py → aclnn/gen_aclnn_implement.py} +49 -51
  322. mindspore/ops_generate/api/__init__.py +0 -0
  323. mindspore/ops_generate/{add_tensor_docs_generator.py → api/add_tensor_docs_generator.py} +9 -7
  324. mindspore/ops_generate/{cpp_create_prim_instance_helper_generator.py → api/cpp_create_prim_instance_helper_generator.py} +6 -9
  325. mindspore/ops_generate/{functional_map_cpp_generator.py → api/functional_map_cpp_generator.py} +25 -12
  326. mindspore/ops_generate/{functional_overload_py_generator.py → api/functional_overload_py_generator.py} +8 -6
  327. mindspore/ops_generate/{functions_cc_generator.py → api/functions_cc_generator.py} +14 -10
  328. mindspore/ops_generate/api/gen_api.py +103 -0
  329. mindspore/ops_generate/{op_api_proto.py → api/op_api_proto.py} +98 -69
  330. mindspore/ops_generate/{tensor_func_reg_cpp_generator.py → api/tensor_func_reg_cpp_generator.py} +82 -43
  331. mindspore/ops_generate/common/__init__.py +0 -0
  332. mindspore/ops_generate/common/gen_constants.py +91 -0
  333. mindspore/ops_generate/{gen_utils.py → common/gen_utils.py} +72 -19
  334. mindspore/ops_generate/{op_proto.py → common/op_proto.py} +64 -1
  335. mindspore/ops_generate/{template.py → common/template.py} +96 -84
  336. mindspore/ops_generate/gen_ops.py +23 -325
  337. mindspore/ops_generate/op_def/__init__.py +0 -0
  338. mindspore/ops_generate/op_def/gen_op_def.py +90 -0
  339. mindspore/ops_generate/{lite_ops_cpp_generator.py → op_def/lite_ops_cpp_generator.py} +47 -11
  340. mindspore/ops_generate/{ops_def_cc_generator.py → op_def/ops_def_cc_generator.py} +18 -7
  341. mindspore/ops_generate/{ops_def_h_generator.py → op_def/ops_def_h_generator.py} +5 -5
  342. mindspore/ops_generate/{ops_name_h_generator.py → op_def/ops_name_h_generator.py} +30 -15
  343. mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
  344. mindspore/ops_generate/op_def_py/__init__.py +0 -0
  345. mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
  346. mindspore/ops_generate/{op_def_py_generator.py → op_def_py/op_def_py_generator.py} +6 -5
  347. mindspore/ops_generate/{op_prim_py_generator.py → op_def_py/op_prim_py_generator.py} +24 -15
  348. mindspore/ops_generate/pyboost/__init__.py +0 -0
  349. mindspore/ops_generate/{auto_grad_impl_cc_generator.py → pyboost/auto_grad_impl_cc_generator.py} +11 -7
  350. mindspore/ops_generate/{auto_grad_reg_cc_generator.py → pyboost/auto_grad_reg_cc_generator.py} +7 -7
  351. mindspore/ops_generate/{gen_pyboost_func.py → pyboost/gen_pyboost_func.py} +40 -16
  352. mindspore/ops_generate/{op_template_parser.py → pyboost/op_template_parser.py} +105 -24
  353. mindspore/ops_generate/{pyboost_functions_cpp_generator.py → pyboost/pyboost_functions_cpp_generator.py} +55 -18
  354. mindspore/ops_generate/{pyboost_functions_h_generator.py → pyboost/pyboost_functions_h_generator.py} +42 -10
  355. mindspore/ops_generate/{pyboost_functions_py_generator.py → pyboost/pyboost_functions_py_generator.py} +6 -6
  356. mindspore/ops_generate/{pyboost_grad_function_cpp_generator.py → pyboost/pyboost_grad_function_cpp_generator.py} +11 -10
  357. mindspore/ops_generate/{pyboost_inner_prim_generator.py → pyboost/pyboost_inner_prim_generator.py} +8 -7
  358. mindspore/ops_generate/{pyboost_native_grad_functions_generator.py → pyboost/pyboost_native_grad_functions_generator.py} +14 -10
  359. mindspore/ops_generate/{pyboost_op_cpp_code_generator.py → pyboost/pyboost_op_cpp_code_generator.py} +140 -53
  360. mindspore/ops_generate/{pyboost_overload_functions_cpp_generator.py → pyboost/pyboost_overload_functions_cpp_generator.py} +28 -15
  361. mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +88 -4
  362. mindspore/ops_generate/resources/__init__.py +0 -0
  363. mindspore/ops_generate/resources/resource_list.py +30 -0
  364. mindspore/ops_generate/resources/resource_loader.py +36 -0
  365. mindspore/ops_generate/resources/resource_manager.py +64 -0
  366. mindspore/ops_generate/resources/yaml_loader.py +88 -0
  367. mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
  368. mindspore/parallel/__init__.py +6 -2
  369. mindspore/parallel/_auto_parallel_context.py +133 -6
  370. mindspore/parallel/_cell_wrapper.py +130 -15
  371. mindspore/parallel/_parallel_serialization.py +95 -4
  372. mindspore/parallel/_ps_context.py +1 -1
  373. mindspore/parallel/_recovery_context.py +7 -2
  374. mindspore/parallel/_tensor.py +142 -18
  375. mindspore/parallel/_utils.py +198 -25
  376. mindspore/parallel/algo_parameter_config.py +3 -3
  377. mindspore/parallel/auto_parallel.py +732 -0
  378. mindspore/parallel/checkpoint_convert.py +159 -0
  379. mindspore/parallel/checkpoint_transform.py +656 -37
  380. mindspore/parallel/cluster/process_entity/_api.py +151 -19
  381. mindspore/parallel/cluster/run.py +1 -1
  382. mindspore/parallel/function/__init__.py +24 -0
  383. mindspore/parallel/function/reshard_func.py +259 -0
  384. mindspore/parallel/nn/__init__.py +25 -0
  385. mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
  386. mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
  387. mindspore/parallel/parameter_broadcast.py +24 -13
  388. mindspore/parallel/shard.py +137 -61
  389. mindspore/parallel/transform_safetensors.py +287 -95
  390. mindspore/pgodb140.dll +0 -0
  391. mindspore/pgort140.dll +0 -0
  392. mindspore/profiler/__init__.py +9 -5
  393. mindspore/profiler/analysis/parser/ascend_cann_parser.py +6 -2
  394. mindspore/profiler/analysis/parser/ms_framework_parser.py +4 -4
  395. mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +7 -4
  396. mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +22 -0
  397. mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +3 -3
  398. mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +241 -86
  399. mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +41 -2
  400. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +33 -35
  401. mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +7 -0
  402. mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +8 -3
  403. mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +141 -30
  404. mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +5 -6
  405. mindspore/profiler/common/ascend_msprof_exporter.py +5 -4
  406. mindspore/profiler/common/constant.py +12 -0
  407. mindspore/profiler/common/msprof_cmd_tool.py +42 -23
  408. mindspore/profiler/common/path_manager.py +24 -0
  409. mindspore/profiler/common/profiler_context.py +26 -2
  410. mindspore/profiler/common/profiler_meta_data.py +74 -0
  411. mindspore/profiler/common/profiler_parameters.py +59 -18
  412. mindspore/profiler/common/profiler_path_manager.py +66 -7
  413. mindspore/profiler/dynamic_profiler.py +112 -79
  414. mindspore/profiler/envprofiler.py +26 -1
  415. mindspore/profiler/experimental_config.py +197 -0
  416. mindspore/profiler/mstx.py +57 -14
  417. mindspore/profiler/platform/npu_profiler.py +33 -7
  418. mindspore/profiler/profiler.py +541 -45
  419. mindspore/profiler/profiler_action_controller.py +1 -1
  420. mindspore/profiler/profiler_interface.py +4 -0
  421. mindspore/profiler/schedule.py +57 -22
  422. mindspore/rewrite/api/node.py +15 -13
  423. mindspore/rewrite/api/symbol_tree.py +1 -1
  424. mindspore/run_check/_check_version.py +25 -14
  425. mindspore/run_check/run_check.py +1 -1
  426. mindspore/runtime/__init__.py +2 -2
  427. mindspore/runtime/executor.py +40 -11
  428. mindspore/runtime/memory.py +25 -8
  429. mindspore/safeguard/rewrite_obfuscation.py +12 -9
  430. mindspore/swresample-4.dll +0 -0
  431. mindspore/swscale-6.dll +0 -0
  432. mindspore/tbbmalloc.dll +0 -0
  433. mindspore/tinyxml2.dll +0 -0
  434. mindspore/train/__init__.py +8 -8
  435. mindspore/train/_utils.py +35 -7
  436. mindspore/train/amp.py +1 -1
  437. mindspore/train/callback/__init__.py +2 -2
  438. mindspore/train/callback/_callback.py +2 -16
  439. mindspore/train/callback/_checkpoint.py +24 -40
  440. mindspore/train/callback/_cluster_monitor.py +14 -18
  441. mindspore/train/callback/_flops_collector.py +2 -3
  442. mindspore/train/callback/_history.py +7 -4
  443. mindspore/train/callback/_lambda_callback.py +2 -2
  444. mindspore/train/callback/_landscape.py +0 -3
  445. mindspore/train/callback/_loss_monitor.py +2 -1
  446. mindspore/train/callback/_on_request_exit.py +6 -5
  447. mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
  448. mindspore/train/callback/_summary_collector.py +8 -13
  449. mindspore/train/callback/_time_monitor.py +2 -1
  450. mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +179 -103
  451. mindspore/train/data_sink.py +25 -2
  452. mindspore/train/dataset_helper.py +4 -5
  453. mindspore/train/loss_scale_manager.py +8 -7
  454. mindspore/train/metrics/accuracy.py +3 -3
  455. mindspore/train/metrics/confusion_matrix.py +9 -9
  456. mindspore/train/metrics/error.py +3 -3
  457. mindspore/train/metrics/hausdorff_distance.py +4 -4
  458. mindspore/train/metrics/mean_surface_distance.py +3 -3
  459. mindspore/train/metrics/metric.py +0 -12
  460. mindspore/train/metrics/occlusion_sensitivity.py +4 -2
  461. mindspore/train/metrics/precision.py +8 -6
  462. mindspore/train/metrics/recall.py +9 -9
  463. mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
  464. mindspore/train/mind_ir_pb2.py +19 -12
  465. mindspore/train/model.py +176 -103
  466. mindspore/train/serialization.py +246 -988
  467. mindspore/train/summary/_summary_adapter.py +2 -2
  468. mindspore/train/summary/summary_record.py +1 -1
  469. mindspore/turbojpeg.dll +0 -0
  470. mindspore/utils/__init__.py +3 -2
  471. mindspore/utils/dryrun.py +4 -2
  472. mindspore/utils/hooks.py +81 -0
  473. mindspore/utils/utils.py +138 -4
  474. mindspore/vcmeta.dll +0 -0
  475. mindspore/vcruntime140.dll +0 -0
  476. mindspore/vcruntime140_1.dll +0 -0
  477. mindspore/version.py +1 -1
  478. {mindspore-2.5.0.dist-info → mindspore-2.6.0rc1.dist-info}/METADATA +2 -1
  479. {mindspore-2.5.0.dist-info → mindspore-2.6.0rc1.dist-info}/RECORD +483 -438
  480. mindspore/_install_custom.py +0 -43
  481. mindspore/common/_register_for_adapter.py +0 -74
  482. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
  483. mindspore/ops/auto_generate/gen_arg_handler.py +0 -136
  484. mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
  485. mindspore/ops_generate/gen_constants.py +0 -190
  486. mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
  487. mindspore/ops_generate/ops_primitive_h_generator.py +0 -81
  488. /mindspore/ops_generate/{base_generator.py → common/base_generator.py} +0 -0
  489. {mindspore-2.5.0.dist-info → mindspore-2.6.0rc1.dist-info}/WHEEL +0 -0
  490. {mindspore-2.5.0.dist-info → mindspore-2.6.0rc1.dist-info}/entry_points.txt +0 -0
  491. {mindspore-2.5.0.dist-info → mindspore-2.6.0rc1.dist-info}/top_level.txt +0 -0
mindspore/common/api.py CHANGED
@@ -17,6 +17,7 @@
17
17
  """Providing interface methods."""
18
18
  from __future__ import absolute_import
19
19
 
20
+ import gc
20
21
  import types
21
22
  import sys
22
23
  import os
@@ -24,11 +25,11 @@ import time
24
25
  import ast
25
26
  import inspect
26
27
  import importlib
27
- import hashlib
28
28
  import contextlib
29
+ import json
29
30
  from collections import OrderedDict, namedtuple
30
31
  from functools import wraps
31
- import numpy as np
32
+ from typing import Optional, Callable
32
33
  import mindspore as ms
33
34
  from mindspore import context
34
35
  from mindspore import log as logger
@@ -39,21 +40,23 @@ from mindspore.common.sparse_tensor import CSRTensor as PythonCSRTensor
39
40
  from mindspore.common.sparse_tensor import COOTensor as PythonCOOTensor
40
41
  from mindspore.common.sparse_tensor import RowTensor as PythonRowTensor
41
42
  from mindspore._c_expression.amp import get_curr_amp_strategy
42
- from mindspore._c_expression import GraphExecutor_, Tensor, CSRTensor, RowTensor, COOTensor, \
43
+ from mindspore._c_expression import GraphExecutor_, JitExecutor_, CSRTensor, RowTensor, COOTensor, \
43
44
  PyNativeExecutor_, verify_inputs_signature, init_exec_dataset, _set_dataset_mode_config, init_pipeline, \
44
- _ms_memory_recycle, _bind_device_ctx, StubNode
45
+ _run_jit_pipeline, _ms_memory_recycle, _bind_device_ctx, StubNode, MSContext, TensorPy as Tensor
45
46
  from mindspore.parallel._ps_context import _is_role_sched
46
- from mindspore.parallel._utils import _check_full_batch, _get_parameter_broadcast, _is_pynative_parallel, \
47
- _is_in_auto_parallel_mode, _is_parallel_mode
47
+ from mindspore.parallel._utils import _check_full_batch, _get_parameter_broadcast, _is_in_auto_parallel_mode, \
48
+ _is_parallel_mode
48
49
  from mindspore import _checkparam as Validator
49
50
  from mindspore._checkparam import is_stub_tensor
50
51
  from mindspore.common._utils import is_shape_unknown
51
52
  from mindspore.common.mutable import mutable, _check_element_type
52
- from mindspore.common._register_for_adapter import ms_adapter_registry
53
53
  from mindspore.common.auto_dynamic_shape import get_auto_dynamic_shape_args, update_auto_dynamic_shape_phase, \
54
54
  get_auto_dynamic_shape_args_with_check_input_signature, update_auto_dynamic_shape_phase_with_check_input_signature
55
55
  from mindspore.common._pijit_context import PIJitCaptureContext
56
56
  from mindspore.common.parameter import Parameter, set_parameter_hook_updated, parameter_hook_updated
57
+ from mindspore.common.jit_context import jit_context
58
+ from mindspore.common.jit_trace import _jit_trace
59
+ from mindspore.parallel._utils import _init_auto_parallel_context, _clear_auto_parallel_context
57
60
 
58
61
  # Store ms_function class compiled pipeline cache.
59
62
  ms_compile_cache = set()
@@ -107,8 +110,7 @@ def _check_recompile(obj, compile_args, kwargs, full_function_name, create_time,
107
110
  logger.info(f"The {echo_function_name} has been compiled again. "
108
111
  f"{tips} ")
109
112
  else:
110
- tips = "Try to decorate the function with @jit(hash_args=...) " \
111
- "or @jit(compile_once=True) to reduce the compile time. " \
113
+ tips = "Try to reuse the function object decorated by @jit to reduce the compile time. " \
112
114
  "For more details, get instructions about `jit` at " \
113
115
  "https://www.mindspore.cn/search?inputValue=jit."
114
116
  logger.warning(f"The {echo_function_name} has been compiled again. "
@@ -120,14 +122,6 @@ def _check_recompile(obj, compile_args, kwargs, full_function_name, create_time,
120
122
  function_phases[full_function_name].add(create_time)
121
123
 
122
124
 
123
- def _ms_adapter_tensor_as_parameter_output(data):
124
- """Check whether the data is an output from a parameter which is a ms_adapter tensor.
125
- Pylint: disable=unidiomatic-typecheck.
126
- """
127
- return ms_adapter_registry.is_registered and isinstance(data, ms_adapter_registry.tensor) \
128
- and hasattr(data, "__ms_parameter_output__") and getattr(data, "__ms_parameter_output__")
129
-
130
-
131
125
  def _convert_python_data(data):
132
126
  """
133
127
  Convert C++ data to python.
@@ -138,18 +132,8 @@ def _convert_python_data(data):
138
132
  Returns:
139
133
  data, a data convert C++ to python
140
134
  """
141
- if isinstance(data, (Tensor, PythonTensor)) and data.adapter_flag:
142
- return ms_adapter_registry.tensor(data)
143
- if _ms_adapter_tensor_as_parameter_output(data) and hasattr(data, "tensor"):
144
- return data.tensor
145
- if isinstance(data, Tensor) and not isinstance(data, PythonTensor):
146
- return PythonTensor(data, internal=True)
147
- if isinstance(data, CSRTensor) and not isinstance(data, PythonCSRTensor):
148
- return PythonCSRTensor(csr_tensor=data)
149
- if isinstance(data, COOTensor) and not isinstance(data, PythonCOOTensor):
150
- return PythonCOOTensor(coo_tensor=data)
151
- if isinstance(data, RowTensor) and not isinstance(data, PythonRowTensor):
152
- return PythonRowTensor(row_tensor=data)
135
+ if isinstance(data, PythonTensor):
136
+ return data
153
137
  if isinstance(data, StubNode):
154
138
  return ms.common._stub_tensor._convert_stub(data)
155
139
  if data.__class__ is tuple:
@@ -160,6 +144,12 @@ def _convert_python_data(data):
160
144
  fields = data_dict.keys()
161
145
  return namedtuple(type_name, fields)(**_convert_python_data(data_dict))
162
146
  return tuple(_convert_python_data(x) for x in data)
147
+ if isinstance(data, CSRTensor) and not isinstance(data, PythonCSRTensor):
148
+ return PythonCSRTensor(csr_tensor=data)
149
+ if isinstance(data, COOTensor) and not isinstance(data, PythonCOOTensor):
150
+ return PythonCOOTensor(coo_tensor=data)
151
+ if isinstance(data, RowTensor) and not isinstance(data, PythonRowTensor):
152
+ return PythonRowTensor(row_tensor=data)
163
153
  if data.__class__ is list:
164
154
  # Keep list object not change for inplace operation.
165
155
  for i in range(len(data)):
@@ -578,11 +568,11 @@ def _get_hook_key(*args, **kwargs):
578
568
  return hook_key
579
569
 
580
570
 
581
- class _MindsporeFunctionExecutor:
571
+ class _JitExecutor:
582
572
  """
583
573
  Represents a function compiled by graph compiler.
584
574
 
585
- _MindsporeFunctionExecutor will compile the original function for every combination
575
+ _JitExecutor will compile the original function for every combination
586
576
  of argument types and shapes it is given (as well as their values, optionally).
587
577
 
588
578
  Args:
@@ -596,7 +586,7 @@ class _MindsporeFunctionExecutor:
596
586
  The result of pipeline running in graph mode.
597
587
  """
598
588
 
599
- def __init__(self, fn, ms_create_time, input_signature=None, obj=None, jit_config=None):
589
+ def __init__(self, fn, ms_create_time, input_signature=None, obj=None, jit_config=None, dynamic=0):
600
590
  init_pipeline()
601
591
  if not isinstance(fn, (types.FunctionType, types.MethodType)):
602
592
  raise RuntimeError('fn {} is not function or method'.format(fn))
@@ -608,13 +598,19 @@ class _MindsporeFunctionExecutor:
608
598
  self.obj = obj
609
599
  self.shard_parent_obj = obj
610
600
  self.enable_tuple_broaden = False
611
- self._graph_executor = GraphExecutor_.get_instance()
601
+ if _run_jit_pipeline():
602
+ self._graph_executor = JitExecutor_.get_instance()
603
+ else:
604
+ self._graph_executor = GraphExecutor_.get_instance()
612
605
  self._create_time = ms_create_time
613
606
  self._compile_args = None
607
+ self._enable_auto_dynamic = dynamic == 1
614
608
  self.jit_config_dict = jit_config.jit_config_dict if jit_config else None
615
609
 
616
610
  @_wrap_func
617
611
  def __call__(self, *args, **kwargs):
612
+ if jit_context() and jit_context().is_nested():
613
+ return jit_context().run_graph("", None, *())
618
614
  args_list = args
619
615
  if self.obj is not None:
620
616
  args_list = args_list[1:]
@@ -634,10 +630,14 @@ class _MindsporeFunctionExecutor:
634
630
  return None
635
631
 
636
632
  new_inputs = self._generate_run_args(args_list, kwargs)
637
- if context.get_context("mode") == context.PYNATIVE_MODE:
633
+ if context.get_context("mode") == context.PYNATIVE_MODE and not jit_context():
638
634
  output = _pynative_executor.grad_jit(*new_inputs)
639
635
  else:
640
636
  output = self._graph_executor(tuple(new_inputs), phase)
637
+ if jit_context():
638
+ if is_stub_tensor(output):
639
+ output = output.stub_sync()
640
+ return jit_context().run_graph(phase, output, *tuple(new_inputs))
641
641
 
642
642
  return output
643
643
 
@@ -653,7 +653,8 @@ class _MindsporeFunctionExecutor:
653
653
  compile_args = self._generate_compile_args(args)
654
654
  key_id = self._get_key_id()
655
655
  compile_args = get_auto_dynamic_shape_args_with_check_input_signature(compile_args, key_id,
656
- self.input_signature)
656
+ self.input_signature,
657
+ self._enable_auto_dynamic)
657
658
 
658
659
  # Add mutable for compile_args for two scene:
659
660
  # 1) Origin args is mutable.
@@ -704,7 +705,7 @@ class _MindsporeFunctionExecutor:
704
705
 
705
706
  update_auto_dynamic_shape_phase_with_check_input_signature(compile_args, key_id, phase, self.input_signature)
706
707
 
707
- if phase in ms_compile_cache and not parameter_hook_updated():
708
+ if phase in ms_compile_cache and self._graph_executor.has_compiled(phase) and not parameter_hook_updated():
708
709
  # Release resource should be released when CompileInner won't be executed, such as cur_convert_input_
709
710
  # generated in generate_arguments_key.
710
711
  self._graph_executor.clear_compile_arguments_resource()
@@ -726,7 +727,7 @@ class _MindsporeFunctionExecutor:
726
727
  setattr(self.fn.__func__, "__jit_function__", True)
727
728
  else:
728
729
  setattr(self.fn, "__jit_function__", True)
729
- is_compile = self._graph_executor.compile(self.fn, compile_args, kwargs, phase, True)
730
+ is_compile = self._graph_executor.compile(self.fn, compile_args, kwargs, phase)
730
731
  if isinstance(self.fn, types.MethodType):
731
732
  delattr(self.fn.__func__, "__jit_function__")
732
733
  else:
@@ -734,7 +735,7 @@ class _MindsporeFunctionExecutor:
734
735
  else:
735
736
  if isinstance(self.obj, ms.nn.Cell):
736
737
  self._graph_executor.set_weights_values(self.obj.parameters_dict())
737
- is_compile = self._graph_executor.compile(self.obj, compile_args, kwargs, phase, True)
738
+ is_compile = self._graph_executor.compile(self.obj, compile_args, kwargs, phase)
738
739
 
739
740
  if not is_compile:
740
741
  raise RuntimeError("Executor compile failed.")
@@ -760,7 +761,7 @@ class _MindsporeFunctionExecutor:
760
761
  else:
761
762
  key_id = str(id(self.obj)) + str(self._create_time)
762
763
 
763
- if _pynative_executor.grad_flag():
764
+ if _pynative_executor.requires_grad():
764
765
  key_id = key_id + ".grad"
765
766
  return key_id
766
767
 
@@ -770,9 +771,9 @@ class _MindsporeFunctionExecutor:
770
771
  self.fn.__code__.co_firstlineno)
771
772
  echo_function_name = "function \"" + self.fn.__name__ + "\" at the file \"" + self.fn.__code__.co_filename \
772
773
  + "\", line " + str(self.fn.__code__.co_firstlineno)
773
- if _pynative_executor.grad_flag():
774
+ if _pynative_executor.requires_grad():
774
775
  generate_name = generate_name + ".grad"
775
- if _is_pynative_parallel():
776
+ if self.fn.__name__ == _PYNATIVE_PARALLEL_FUNC_NAME:
776
777
  generate_name = generate_name[:generate_name.rfind(str(id(self.fn)))] + str(id(self.shard_parent_obj))
777
778
  return generate_name, echo_function_name
778
779
 
@@ -833,6 +834,14 @@ class _MindsporeFunctionExecutor:
833
834
  """
834
835
  return _get_args_for_run(self, args_list, kwargs, self._compile_args)
835
836
 
837
+ def _get_func_graph_proto(self, obj, exec_id, ir_type="onnx_ir", use_prefix=False, incremental=False):
838
+ """Get graph proto from pipeline."""
839
+ if use_prefix:
840
+ exec_id = exec_id + '.' + obj.arguments_key
841
+ if self._graph_executor.has_compiled(exec_id) is False:
842
+ return None
843
+ return self._graph_executor.get_func_graph_proto(exec_id, ir_type, incremental)
844
+
836
845
 
837
846
  # The attributes used to identify a given object.
838
847
  attr_op = {"__str__": lambda x: x.__str__(),
@@ -845,6 +854,13 @@ attr_op = {"__str__": lambda x: x.__str__(),
845
854
  }
846
855
 
847
856
 
857
+ def _is_inner_func(func):
858
+ """Check whether the func is an inner func which needs hash_args parameter."""
859
+ # This is a workaround for inner api, should fix it later.
860
+ inner_func = ["after_shard", "_wrap_container"]
861
+ return func.__name__ in inner_func
862
+
863
+
848
864
  def _get_obj_id(input_obj):
849
865
  """Get hash id of single object."""
850
866
  obj_id = ".".join(
@@ -859,50 +875,227 @@ def _get_jit_hash(hash_input):
859
875
  return _get_obj_id(hash_input)
860
876
 
861
877
 
862
- def jit(fn=None, mode="PSJit", input_signature=None, hash_args=None, jit_config=None, compile_once=False):
878
+ def _get_hash_obj(options):
879
+ hash_obj = None
880
+ if "hash_args" in options:
881
+ hash_obj = _get_jit_hash(options["hash_args"])
882
+ del options["hash_args"]
883
+ return hash_obj
884
+
885
+
886
+ def _check_option_device(option, device):
887
+ """Check jit options wiwh device"""
888
+ option_device_cfgs = {
889
+ 'disable_format_transform': ['GPU'],
890
+ 'exec_order': ['Ascend'],
891
+ 'ge_options': ['Ascend'],
892
+ 'infer_boost': ['Ascend'],
893
+ }
894
+ if option in option_device_cfgs and device not in option_device_cfgs[option]:
895
+ logger.warning(f"For 'jit(options)', the option '{option}' is only support device in "
896
+ f"'{option_device_cfgs[option]}', but got '{device}', ignore it.")
897
+
898
+
899
+ def _check_option_backend(option, backend):
900
+ """Check jit options wiwh backend"""
901
+ option_backend_cfgs = {
902
+ 'disable_format_transform': ['ms_backend'],
903
+ 'exec_order': ['ms_backend'],
904
+ 'ge_options': ['GE'],
905
+ 'infer_boost': ['ms_backend'],
906
+ }
907
+ if option in option_backend_cfgs and backend not in option_backend_cfgs[option]:
908
+ logger.warning(f"For 'jit(options)', the option '{option}' is only support backend in "
909
+ f"'{option_backend_cfgs[option]}', but got '{backend}', ignore it.")
910
+
911
+
912
+ def _check_disable_format_transform_value(option, disable_format_transform):
913
+ """check disable_format_transform option value"""
914
+ if not isinstance(disable_format_transform, bool):
915
+ raise TypeError(f"For 'jit(options)', the type of '{option}' must be bool, "
916
+ f"but got {type(disable_format_transform)}.")
917
+
918
+
919
+ def _check_exec_order_value(option, exec_order):
920
+ """check exec_order option value"""
921
+ if not isinstance(exec_order, str):
922
+ raise TypeError(f"For 'jit(options)', the type of '{option}' must be str, but got {type(exec_order)}.")
923
+
924
+ if exec_order not in ['bfs', 'dfs']:
925
+ raise ValueError(f"For '{option}', the value of '{option}' must be one of "
926
+ f"['bfs', 'dfs'], but got '{exec_order}'.")
927
+
928
+
929
+ def _check_ge_options_value(option, ge_options):
930
+ """check ge_options option value"""
931
+ if not isinstance(ge_options, dict):
932
+ raise TypeError(f"For 'jit(options)', the type of '{option}' must be dict, but got {type(ge_options)}.")
933
+
934
+ for level, options in ge_options.items():
935
+ if level not in ['global', 'session']:
936
+ raise ValueError(f"For '{option}', the key of '{option}' must be one of "
937
+ f"['global', 'session'], but got '{level}'.")
938
+
939
+ if not isinstance(options, dict):
940
+ raise TypeError(f"For '{option}', the type of {level} options must be dict, "
941
+ f"but got {type(options)}. The error options: {options}.")
942
+
943
+ for key, value in options.items():
944
+ if not isinstance(key, str):
945
+ raise TypeError(f"For '{option}', the type of key and value must be str, "
946
+ f"but got {type(key)}. The error key is {key}.")
947
+ if not isinstance(value, str):
948
+ raise TypeError(f"For '{option}', the type of key and value must be str, "
949
+ f"but got {type(value)}. The error value is {value}")
950
+
951
+
952
+ def _check_infer_boost_value(option, value):
953
+ """check infer_boost option value"""
954
+ if not isinstance(value, str):
955
+ raise TypeError(f"For 'jit(options)', the type of '{option}' must be str, but got {type(value)}.")
956
+
957
+ if value not in ['on', 'off']:
958
+ raise ValueError(f"For '{option}', the value of '{option}' must be one of ['on', 'off'], but got '{value}'.")
959
+
960
+
961
+ def _check_option_value(option, value):
962
+ """check jit options wiwh value"""
963
+ option_valuecheck_funcs = {
964
+ 'disable_format_transform': _check_disable_format_transform_value,
965
+ 'exec_order': _check_exec_order_value,
966
+ 'ge_options': _check_ge_options_value,
967
+ 'infer_boost': _check_infer_boost_value,
968
+ }
969
+ if option in option_valuecheck_funcs:
970
+ option_valuecheck_funcs[option](option, value)
971
+ else:
972
+ logger.warning(f"For 'jit(options)', the option argument '{option}' is not recognized, please check!"
973
+ f"For detailed usage of 'jit(options)', please refer to the Mindspore official website.")
974
+
975
+
976
+ def _check_options(options, backend):
977
+ """Check jit options"""
978
+ # check whether there are deprecated parameters in the dict `options`.
979
+ deprecated_args = {'mode': 'capture_mode', 'input_signature': 'dynamic', 'hash_args: ': '',
980
+ 'jit_config': 'jit_level, fullgraph or options', 'compile_once': ''}
981
+ for key, value in deprecated_args.items():
982
+ if key in options:
983
+ log = f"For 'jit', the parameter '{key}' has been deprecated."
984
+ if value != '':
985
+ log += f" Please use the parameter '{value}' instead. For more details, please refer to " \
986
+ f"https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.jit.html."
987
+ logger.warning(log)
988
+ del options[key]
989
+
990
+ # check options' device, backend and value
991
+ for option, value in options.items():
992
+ _check_option_backend(option, backend)
993
+ _check_option_value(option, value)
994
+
995
+
996
+ def jit(
997
+ function: Optional[Callable] = None,
998
+ *,
999
+ capture_mode: str = "ast",
1000
+ jit_level: str = "O0",
1001
+ dynamic: int = 0,
1002
+ fullgraph: bool = False,
1003
+ backend: str = "",
1004
+ **options):
863
1005
  """
864
1006
  Create a callable MindSpore graph from a Python function.
865
1007
 
866
1008
  This allows the MindSpore runtime to apply optimizations based on graph.
867
1009
 
868
1010
  Note:
869
- - If `input_signature` is specified, each input of `fn` must be a Tensor. And the input arguments for `fn`
870
- will not accept `**kwargs`.
871
- - It is not supported to run a function with decoration @jit(mode=“PIJit”)
872
- in static graph mode, in which case the decoration @jit(mode=“PIJit”) is considered invalid.
873
- - Calls to functions with decorated @jit(mode=“PIJit”) inside functions
874
- decorated with @jit(mode=“PIJit”) are not supported,
875
- and the decoration @jit(mode=“PIJit”) is considered invalid.
1011
+ - It is not supported to run a function with decoration @jit(capture_mode=“bytecode”)
1012
+ in static graph mode, in which case the decoration @jit(capture_mode=“bytecode”) is considered invalid.
1013
+ - Calls to functions with decorated @jit(capture_mode=“bytecode”) inside functions
1014
+ decorated with @jit(capture_mode=“ast”) are not supported,
1015
+ and the decoration @jit(capture_mode=“bytecode”) is considered invalid.
876
1016
 
877
1017
  Args:
878
- fn (Function): The Python function that will be run as a graph. Default: ``None`` .
879
- mode (str): The type of jit used, the value of mode should be ``PIJit`` or ``PSJit``. Default: ``PSJit`` .
880
-
881
- - `PSJit <https://www.mindspore.cn/docs/en/master/model_train/program_form/static_graph.html>`_ :
882
- Parse python ast to build graph.
883
- - `PIJit <https://www.mindspore.cn/docs/en/master/model_train/program_form/pynative.html#pijit>`_ :
884
- Parse python bytecode to build graph at runtime.
885
-
886
- input_signature (Union[Tuple, List, Dict, Tensor]): The Tensor which describes the input arguments. The
887
- shape and dtype of the Tensor will be supplied to this function. If `input_signature` is specified, the
888
- input parameters of `fn` cannot accept `**kwargs`, and the shape and dtype of actual inputs should keep the
889
- same as `input_signature`. Otherwise, TypeError will be raised. There are two mode for `input_signature`:
890
-
891
- - Full mode: Arguments is a Tuple, List or a Tensor, and they will be used as all compile inputs
892
- for graph-compiling.
893
- - Incremental mode: Argument is a Dict, and they will set to some of the graph inputs, which will be
894
- substituted into the input at the corresponding position for graph-compiling.
895
-
896
- Default: ``None`` .
897
-
898
- hash_args (Union[Object, List or Tuple of Objects]): The local free variables used inside `fn`,
899
- like functions or objects of class defined outside `fn`. Calling `fn` again with change of `hash_args`
900
- will trigger recompilation. Default: ``None`` .
901
- jit_config (JitConfig): Jit config for compile. Default: ``None`` .
902
- compile_once(bool): ``True``: The function would be compiled once when it was created many times.
903
- But it may be wrong if the free variables were changed. ``False`` : It would be recompiled when
904
- it was created again.
905
- Default: ``False`` .
1018
+ function (Function, optional): The Python function that will be run as a graph. Default: ``None``.
1019
+
1020
+ Keyword Args:
1021
+ capture_mode (str, optional): The method to create a callable MindSpore graph. The value of capture_mode
1022
+ should be ``ast`` , ``bytecode`` or ``trace`` . Default: ``ast`` .
1023
+
1024
+ - `ast <https://www.mindspore.cn/tutorials/en/master/compile/static_graph.html>`_ :
1025
+ Parse Python ast to build graph.
1026
+ - `bytecode` :
1027
+ Parse Python bytecode to build graph at runtime. This is an experimental prototype that is subject to
1028
+ change and/or deletion.
1029
+ - `trace` : Trace the execution of Python code to build graph. This is an experimental prototype that is
1030
+ subject to change and/or deletion.
1031
+
1032
+ jit_level (str, optional): Used to control the compilation optimization level. Currently is only effective
1033
+ with default backend. The value of jit_level should be ``O0`` or ``O1`` . Default: ``O0`` .
1034
+
1035
+ - `O0`: Except for optimizations that may affect functionality, all other optimizations are turned off.
1036
+ - `O1`: Using commonly used optimizations and automatic operator fusion optimizations. This optimization
1037
+ level is experimental and is being improved.
1038
+
1039
+ dynamic (int, optional): Whether dynamic shape compilation should be performed. Default: ``0``. The value range
1040
+ is as follows:
1041
+
1042
+ - `0`: Do not perform dynamic shape compilation.
1043
+ - `1`: Enable dynamic shape compilation and automatically detect shape changes.
1044
+
1045
+ fullgraph (bool, optional): Whether to capture the entire function into graph. If False, jit attempts to
1046
+ be compatible with all Python syntax in the function as much as possible. If True, we require that the
1047
+ entire function can be captured into graph. If this is not possible (that is, if there is Python syntax
1048
+ not supported), then it will raise an exception. This currently only applies when capture_mode is ast.
1049
+ Default: ``False``.
1050
+ backend (str, optional): The compilation backend to be used. If this parameter is not set, the framework will
1051
+ use ``GE`` backend for Atlas training series products and ``ms_backend`` backend for others including Atlas
1052
+ A2 training series products by default.
1053
+
1054
+ - `ms_backend`: Adopt KernelByKernel execution mode.
1055
+ - `GE`: Adopt Sink execution mode. The whole model will be sinked to device to execute, only applicable to
1056
+ the top cell of model. And only can be used in Ascend platform.
1057
+
1058
+ **options (dict): A dictionary of options to pass to the compilation backend.
1059
+
1060
+ Some options are device specific, see the below table for details:
1061
+
1062
+ +---------------------------+---------------------------+-------------------------+
1063
+ | Option Parameters | Hardware Platform Support | Backend Support |
1064
+ +===========================+===========================+=========================+
1065
+ | disable_format_transform | GPU | ms_backend |
1066
+ +---------------------------+---------------------------+-------------------------+
1067
+ | exec_order | Ascend | ms_backend |
1068
+ +---------------------------+---------------------------+-------------------------+
1069
+ | ge_options | Ascend | GE |
1070
+ +---------------------------+---------------------------+-------------------------+
1071
+ | infer_boost | Ascend | ms_backend |
1072
+ +---------------------------+---------------------------+-------------------------+
1073
+
1074
+ - disable_format_transform (bool, optional): Whether to disable the automatic format transform function
1075
+ from NCHW to NHWC. When the network training performance of fp16 is worse than fp32,
1076
+ `disable_format_transform` can be set to ``True`` to try to improve training performance.
1077
+ Default: ``False`` .
1078
+ - exec_order (str, optional): Set the sorting method for operator execution, currently only two sorting
1079
+ methods are supported: ``bfs`` and ``dfs`` . Default: ``bfs`` .
1080
+
1081
+ - `bfs`: The default sorting method, breadth priority, good communication masking, relatively good
1082
+ performance.
1083
+ - `dfs`: An optional sorting method, depth-first sorting. The performance is relatively worse than that
1084
+ of bfs execution order, but it occupies less memory. It is recommended to try dfs in scenarios where
1085
+ other execution orders run out of memory (OOM).
1086
+
1087
+ - ge_options (dict): Set options for ge backend. The options are divided into two categories: global,
1088
+ and session. This is an experimental prototype that is subject to change and/or deletion.
1089
+ For detailed information, please refer to `Ascend community <https://www.hiascend.com/document/detail/zh/canncommercial/80RC3/apiref/ascendgraphapi/atlasgeapi_07_0146.html>`_ .
1090
+
1091
+ - global (dict): Set global options.
1092
+ - session (dict): Set session options.
1093
+
1094
+ - infer_boost (str, optional): Used to control the inference mode. Default: ``off``, which means
1095
+ the inference mode is disabled. The range is as follows:
1096
+
1097
+ - `on`: Enable inference mode, get better infer performance.
1098
+ - `off`: Disable inference mode, use forward for inference. The performance is poor.
906
1099
 
907
1100
  Returns:
908
1101
  Function, if `fn` is not None, returns a callable function that will execute the compiled function; If `fn` is
@@ -921,12 +1114,12 @@ def jit(fn=None, mode="PSJit", input_signature=None, hash_args=None, jit_config=
921
1114
  >>> x = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
922
1115
  >>> y = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
923
1116
  ...
924
- >>> # create a callable MindSpore graph by calling decorator @jit
1117
+ >>> # create a callable MindSpore graph by calling jit
925
1118
  >>> def tensor_add(x, y):
926
1119
  ... z = x + y
927
1120
  ... return z
928
1121
  ...
929
- >>> tensor_add_graph = jit(fn=tensor_add)
1122
+ >>> tensor_add_graph = jit(function=tensor_add)
930
1123
  >>> out = tensor_add_graph(x, y)
931
1124
  ...
932
1125
  >>> # create a callable MindSpore graph through decorator @jit
@@ -937,180 +1130,70 @@ def jit(fn=None, mode="PSJit", input_signature=None, hash_args=None, jit_config=
937
1130
  ...
938
1131
  >>> out = tensor_add_with_dec(x, y)
939
1132
  ...
940
- >>> # create a callable MindSpore graph through decorator @jit with input_signature parameter
941
- >>> @jit(input_signature=(Tensor(np.ones([1, 1, 3, 3]).astype(np.float32)),
942
- ... Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))))
943
- ... def tensor_add_with_sig(x, y):
944
- ... z = x + y
945
- ... return z
946
- ...
947
- >>> out = tensor_add_with_sig(x, y)
948
- ...
949
- >>> @jit(input_signature={"y": Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))})
950
- ... def tensor_add_with_sig_1(x, y):
1133
+ >>> # create a callable MindSpore graph and capture the entire function into the graph
1134
+ >>> @jit(fullgraph=True)
1135
+ ... def tensor_add_fullgraph(x, y):
951
1136
  ... z = x + y
952
1137
  ... return z
953
1138
  ...
954
- >>> out1 = tensor_add_with_sig_1(x, y)
955
- ...
956
- ... # Set hash_args as fn, otherwise cache of compiled closure_fn will not be reused.
957
- ... # While fn differs during calling again, recompilation will be triggered.
958
- >>> def func(x):
959
- ... return ops.exp(x)
960
- ...
961
- >>> def closure_fn(x, fn):
962
- ... @jit(hash_args=fn)
963
- ... def inner_fn(a):
964
- ... return fn(a)
965
- ... return inner_fn(x)
966
- ...
967
- >>> inputs = Tensor(np.ones([10, 10, 10]).astype(np.float32))
968
- >>> for i in range(10):
969
- ... closure_fn(inputs, func)
970
- ...
971
- ... # Set compile_once = True, otherwise the train_step will be compiled again.
972
- >>> def train(x):
973
- ... @jit(compile_once = True)
974
- ... def train_step(x):
975
- ... return ops.exp(x)
976
- ... for i in range(10):
977
- ... train_step(x)
978
- ...
979
- >>> inputs = Tensor(np.ones([10, 10, 10]).astype(np.float32))
980
- >>> for i in range(10):
981
- ... train(inputs)
1139
+ >>> out = tensor_add_fullgraph(x, y)
982
1140
  """
983
1141
 
984
- def wrap_mindspore(func):
985
- if not isinstance(compile_once, bool):
986
- logger.warning(f"The parameter `compile_once` of jit should be a bool, "
987
- f"but got {type(compile_once)}.")
988
- if hash_args:
989
- hash_obj = _get_jit_hash(hash_args)
990
- elif compile_once:
991
- hash_obj = 0
992
- else:
1142
+ capture_mode = Validator.check_string(capture_mode, ["ast", "bytecode", "trace"], "capture_mode", "jit")
1143
+ jit_level = Validator.check_string(jit_level, ["O0", "O1"], "jit_level", "jit")
1144
+ dynamic = Validator.check_int_range(dynamic, 0, 1, Validator.INC_BOTH, "dynamic", "jit")
1145
+ fullgraph = Validator.check_bool(fullgraph, "fullgraph", "jit")
1146
+ if backend == "":
1147
+ backend = "GE" if MSContext.get_instance().get_ascend_soc_version() == "ascend910" else "ms_backend"
1148
+ backend = Validator.check_string(backend, ["ms_backend", "GE"], "backend", "jit")
1149
+ jit_syntax_level = "LAX" if fullgraph is False else "STRICT"
1150
+ hash_obj = _get_hash_obj(options)
1151
+ _check_options(options, backend)
1152
+ options_str = json.dumps(options)
1153
+ infer_boost = options['infer_boost'] if 'infer_boost' in options else "off"
1154
+ exc_mode = options['exc_mode'] if 'exc_mode' in options else "auto"
1155
+ jit_config = JitConfig(jit_level=jit_level, exc_mode=exc_mode, jit_syntax_level=jit_syntax_level,
1156
+ infer_boost=infer_boost, backend=backend, options=options_str)
1157
+
1158
+ def wrap_func(func):
1159
+ nonlocal hash_obj
1160
+ if hash_obj is None or not _is_inner_func(func):
993
1161
  hash_obj = int(time.time() * 1e9)
994
1162
 
995
- dyn_args = _process_dyn_args(func, input_signature)
996
-
997
1163
  @wraps(func)
998
1164
  def staging_specialize(*args, **kwargs):
999
1165
  if os.getenv("MS_JIT") == '0':
1000
1166
  return func(*args, **kwargs)
1001
1167
 
1002
1168
  args, kwargs = _handle_func_args(func, *args, **kwargs)
1003
-
1004
1169
  process_obj = None
1005
1170
  if args and not isinstance(args[0], PythonTensor) and hasattr(args[0], func.__name__):
1006
1171
  process_obj = args[0]
1007
- # only the function or cell instance wrapped by shard will fall into this branch
1008
- if _is_pynative_parallel() and func.__name__ == _PYNATIVE_PARALLEL_FUNC_NAME:
1009
- process_obj = hash_args
1010
1172
  # Handle auto mixed precision strategy.
1011
1173
  if not hasattr(func, "amp_strategy"):
1012
1174
  if isinstance(func, types.MethodType):
1013
1175
  setattr(func.__func__, "amp_strategy", get_curr_amp_strategy())
1014
1176
  else:
1015
1177
  setattr(func, "amp_strategy", get_curr_amp_strategy())
1016
- out = _MindsporeFunctionExecutor(func, hash_obj, dyn_args, process_obj, jit_config)(*args, **kwargs)
1178
+
1179
+ ms_function_executor = _JitExecutor(func, hash_obj, None, process_obj, jit_config, dynamic)
1180
+ out = ms_function_executor(*args, **kwargs)
1017
1181
  return out
1018
1182
 
1019
1183
  return staging_specialize
1020
1184
 
1021
- wrap_func = wrap_mindspore
1022
- if mode == "PIJit":
1023
- wrap_func = PIJitCaptureContext(jit_config, input_signature)
1185
+ if capture_mode == "bytecode":
1186
+ wrap_func = PIJitCaptureContext(jit_config)
1187
+ elif capture_mode == "trace":
1188
+ if function is not None:
1189
+ return _jit_trace(function)
1190
+ return _jit_trace
1024
1191
 
1025
- if fn is not None:
1026
- return wrap_func(fn)
1192
+ if function is not None:
1193
+ return wrap_func(function)
1027
1194
  return wrap_func
1028
1195
 
1029
1196
 
1030
- def ms_function(fn=None, input_signature=None, hash_args=None, jit_config=None):
1031
- """
1032
- Create a callable MindSpore graph from a Python function.
1033
-
1034
- This allows the MindSpore runtime to apply optimizations based on graph.
1035
-
1036
- Note:
1037
- - `ms_function` will be deprecated and removed in a future version. Please use :func:`mindspore.jit` instead.
1038
- - If `input_signature` is specified, each input of `fn` must be a Tensor. And the input arguments for `fn`
1039
- will not accept `**kwargs`.
1040
-
1041
- Args:
1042
- fn (Function): The Python function that will be run as a graph. Default: ``None`` .
1043
- input_signature (Tensor): The Tensor which describes the input arguments. The shape and dtype of the Tensor
1044
- will be supplied to this function. The shape and dtype of actual inputs of `fn` should
1045
- keep the same as input_signature. Otherwise, TypeError will be raised. Default: ``None`` .
1046
- hash_args (Union[Object, List or Tuple of Objects]): The local free variables used inside `fn`,
1047
- like functions or objects of class defined outside `fn`. Calling `fn` again with change of `hash_args`
1048
- will trigger recompilation. Default: ``None`` .
1049
- jit_config (JitConfig): Jit config for compile. Default: ``None`` .
1050
-
1051
- Returns:
1052
- Function, if `fn` is not None, returns a callable function that will execute the compiled function; If `fn` is
1053
- None, returns a decorator and when this decorator invokes with a single `fn` argument, the callable function is
1054
- equal to the case when `fn` is not None.
1055
-
1056
- Supported Platforms:
1057
- ``Ascend`` ``GPU`` ``CPU``
1058
-
1059
- Examples:
1060
- >>> import numpy as np
1061
- >>> from mindspore import Tensor
1062
- >>> from mindspore import ops
1063
- >>> from mindspore import ms_function
1064
- ...
1065
- >>> x = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
1066
- >>> y = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
1067
- ...
1068
- >>> # create a callable MindSpore graph by calling ms_function
1069
- >>> def tensor_add(x, y):
1070
- ... z = x + y
1071
- ... return z
1072
- ...
1073
- >>> tensor_add_graph = ms_function(fn=tensor_add)
1074
- >>> out = tensor_add_graph(x, y)
1075
- ...
1076
- >>> # create a callable MindSpore graph through decorator @ms_function
1077
- >>> @ms_function
1078
- ... def tensor_add_with_dec(x, y):
1079
- ... z = x + y
1080
- ... return z
1081
- ...
1082
- >>> out = tensor_add_with_dec(x, y)
1083
- ...
1084
- >>> # create a callable MindSpore graph through decorator @ms_function with input_signature parameter
1085
- >>> @ms_function(input_signature=(Tensor(np.ones([1, 1, 3, 3]).astype(np.float32)),
1086
- ... Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))))
1087
- ... def tensor_add_with_sig(x, y):
1088
- ... z = x + y
1089
- ... return z
1090
- ...
1091
- >>> out = tensor_add_with_sig(x, y)
1092
- ...
1093
- ... # Set hash_args as fn, otherwise cache of compiled `closure_fn` will not be reused.
1094
- ... # While fn differs during calling again, recompilation will be triggered.
1095
- >>> def func(x):
1096
- ... return ops.exp(x)
1097
- ...
1098
- >>> def closure_fn(x, fn):
1099
- ... @ms_function(hash_args=fn)
1100
- ... def inner_fn(a):
1101
- ... return fn(a)
1102
- ... return inner_fn(x)
1103
- ...
1104
- >>> inputs = Tensor(np.ones([10, 10, 10]).astype(np.float32))
1105
- >>> for i in range(10):
1106
- ... closure_fn(inputs, func)
1107
- """
1108
-
1109
- logger.warning("'mindspore.ms_function' will be deprecated and removed in a future version. "
1110
- "Please use 'mindspore.jit' instead.")
1111
- return jit(fn=fn, input_signature=input_signature, hash_args=hash_args, jit_config=jit_config)
1112
-
1113
-
1114
1197
  def _core(fn=None, **flags):
1115
1198
  """
1116
1199
  A decorator that adds a flag to the function.
@@ -1203,69 +1286,6 @@ def _no_recursive(callable_obj):
1203
1286
  return callable_obj
1204
1287
 
1205
1288
 
1206
- def ms_class(cls):
1207
- """
1208
- Class decorator for user-defined classes.
1209
-
1210
- This allows MindSpore to identify user-defined classes and thus obtain their attributes and methods.
1211
-
1212
- Note:
1213
- `ms_class` will be deprecated and removed in a future version. Please use :func:`mindspore.jit_class` instead.
1214
-
1215
- Args:
1216
- cls (Class): User-defined class.
1217
-
1218
- Returns:
1219
- Class.
1220
-
1221
- Raises:
1222
- TypeError: If ms_class is used for non-class types or nn.Cell.
1223
- AttributeError: If the private attributes or magic methods of the class decorated with ms_class is called.
1224
-
1225
- Supported Platforms:
1226
- ``Ascend`` ``GPU`` ``CPU``
1227
-
1228
- Examples:
1229
- >>> import mindspore.nn as nn
1230
- >>> from mindspore import ms_class
1231
- ...
1232
- >>> @ms_class
1233
- ... class UserDefinedNet:
1234
- ... def __init__(self):
1235
- ... self.value = 10
1236
- ...
1237
- ... def func(self, x):
1238
- ... return 2 * x
1239
- ...
1240
- >>> class Net(nn.Cell):
1241
- ... def __init__(self):
1242
- ... super(Net, self).__init__()
1243
- ... self.net = UserDefinedNet()
1244
- ...
1245
- ... def construct(self, x):
1246
- ... out = self.net.value + self.net.func(x)
1247
- ... return out
1248
- ...
1249
- >>> net = Net()
1250
- >>> out = net(5)
1251
- >>> print(out)
1252
- 20
1253
- """
1254
-
1255
- logger.warning("'mindspore.ms_class' will be deprecated and removed in a future version. "
1256
- "Please use 'mindspore.jit_class' instead.")
1257
-
1258
- # Check if cls is of type class.
1259
- if not inspect.isclass(cls):
1260
- raise TypeError(f'Decorator ms_class can only be used for class type, but got {cls}.')
1261
- # Check if cls is nn.Cell.
1262
- if issubclass(cls, ms.nn.Cell):
1263
- raise TypeError(f"Decorator ms_class is used for user-defined classes and cannot be used for nn.Cell: {cls}.")
1264
- logger.info(f'Found ms_class: {cls}.')
1265
- setattr(cls, '__ms_class__', True)
1266
- return cls
1267
-
1268
-
1269
1289
  def jit_class(cls):
1270
1290
  """
1271
1291
  Class decorator for user-defined classes.
@@ -1322,28 +1342,6 @@ def jit_class(cls):
1322
1342
  return cls
1323
1343
 
1324
1344
 
1325
- def set_adapter_config(config):
1326
- """
1327
- Register configuration information for MSAdapter.
1328
-
1329
- Args:
1330
- config (dict): Configuration information.
1331
- """
1332
- if not isinstance(config, dict):
1333
- raise TypeError(f"The input argument of 'set_adapter_config' should be a dict, but got {config}.")
1334
- for key, value in config.items():
1335
- if key == "Tensor":
1336
- ms_adapter_registry.register_tensor(value)
1337
- elif key == "Parameter":
1338
- ms_adapter_registry.register_parameter(value)
1339
- elif key == "convert_object_map":
1340
- ms_adapter_registry.register_convert_map(value)
1341
- elif key == "convert_adapter_tensor_map":
1342
- ms_adapter_registry.register_convert_adapter_tensor_map(value)
1343
- else:
1344
- raise ValueError(f"Unsupported key in adapter config: {key}")
1345
-
1346
-
1347
1345
  def _function_forbid_reuse(func):
1348
1346
  if not inspect.isfunction(func):
1349
1347
  raise TypeError(f'Decorator _function_forbid_reuse can only be used for function type, but got {func}.')
@@ -1535,7 +1533,24 @@ class _PyNativeExecutor:
1535
1533
  Return:
1536
1534
  None.
1537
1535
  """
1538
- return self._executor.grad(grad, obj, weights, grad_position, *args)
1536
+ return self._executor.grad(grad, obj, weights, grad_position, False, *args)
1537
+
1538
+ def grad_aux(self, obj, grad, weights, grad_position, *args):
1539
+ """
1540
+ Run grad graph with aux
1541
+
1542
+ Args:
1543
+ obj (Function/Cell): The function or cell instance.
1544
+ grad (GradOperation): The gradoperation object.
1545
+ weights (ParameterTuple): The weights of cell instance.
1546
+ grad_position (Union(int, tuple[int])): If int, get the gradient with respect to single input.
1547
+ If tuple, get the gradients with respect to selected inputs. 'grad_position' begins with 0. Default: 0.
1548
+ args (tuple): Function or cell input arguments.
1549
+
1550
+ Return:
1551
+ None.
1552
+ """
1553
+ return self._executor.grad(grad, obj, weights, grad_position, True, *args)
1539
1554
 
1540
1555
  def clear_res(self):
1541
1556
  """
@@ -1671,6 +1686,15 @@ class _PyNativeExecutor:
1671
1686
  """
1672
1687
  self._executor.set_is_run_recompute(status)
1673
1688
 
1689
+ def high_order(self):
1690
+ """
1691
+ Is high order of current scene, this is a inner interface.
1692
+
1693
+ Return:
1694
+ Bool.
1695
+ """
1696
+ return self._executor.high_order()
1697
+
1674
1698
  def set_cell_use_dynamic_shape_process(self, flag):
1675
1699
  """
1676
1700
  Set the dynamic shape flag of eval process.
@@ -1753,7 +1777,6 @@ class _CellGraphExecutor:
1753
1777
  # create needed graph by lazy mode
1754
1778
  self.is_init = False
1755
1779
  self.enable_tuple_broaden = False
1756
- self.obfuscate_config = None # used for model's dynamic obfuscation
1757
1780
  self._graph_executor = GraphExecutor_.get_instance()
1758
1781
  self._graph_executor.set_py_exe_path(sys.executable)
1759
1782
  self._graph_executor.set_kernel_build_server_dir(os.path.split(kernel_build_server.__file__)[0] + os.sep)
@@ -1845,6 +1868,7 @@ class _CellGraphExecutor:
1845
1868
  Str, the full phase of the cell.
1846
1869
  Bool, if the graph has been compiled before, return False, else return True.
1847
1870
  """
1871
+ _init_auto_parallel_context(obj)
1848
1872
  obj.__parse_method__ = 'construct'
1849
1873
  if not hasattr(obj, obj.__parse_method__):
1850
1874
  raise AttributeError(
@@ -1877,6 +1901,7 @@ class _CellGraphExecutor:
1877
1901
  # Release resource should be released when CompileInner won't be executed, such as cur_convert_input_
1878
1902
  # generated in generate_arguments_key.
1879
1903
  self._graph_executor.clear_compile_arguments_resource()
1904
+ _clear_auto_parallel_context(obj)
1880
1905
  return phase, False
1881
1906
 
1882
1907
  full_function_name = obj.__class__.__name__ + '.' + str(obj.instance_count) + '.' + str(id(type(obj)))
@@ -1894,7 +1919,8 @@ class _CellGraphExecutor:
1894
1919
  else:
1895
1920
  jit_config_dict = JitConfig().jit_config_dict
1896
1921
  self._graph_executor.set_jit_config(jit_config_dict)
1897
- result = self._graph_executor.compile(obj, args, kwargs, phase, self._use_vm_mode())
1922
+ gc.collect()
1923
+ result = self._graph_executor.compile(obj, args, kwargs, phase)
1898
1924
  obj.compile_cache.add(phase)
1899
1925
  if not result:
1900
1926
  raise RuntimeError("Executor compile failed.")
@@ -1915,6 +1941,7 @@ class _CellGraphExecutor:
1915
1941
  self._build_data_graph(obj, phase)
1916
1942
  elif BROADCAST_PHASE not in phase and _get_parameter_broadcast():
1917
1943
  _parameter_broadcast(obj)
1944
+ _clear_auto_parallel_context(obj)
1918
1945
  return phase, True
1919
1946
 
1920
1947
  def _update_param_node_default_input(self, phase, replace):
@@ -1994,25 +2021,12 @@ class _CellGraphExecutor:
1994
2021
  """Clear the memory resource of a network."""
1995
2022
  self._graph_executor.del_net_res(obj, net_id)
1996
2023
 
1997
- def _get_branch_control_input(self):
1998
- if ('obf_ratio' not in self.obfuscate_config.keys()) or (
1999
- 'obf_random_seed' not in self.obfuscate_config.keys()):
2000
- raise ValueError("'obf_ratio' and 'obf_random_seed' must be in obfuscate_config.")
2001
- obf_random_seed = self.obfuscate_config.get('obf_random_seed')
2002
- if obf_random_seed == 0:
2003
- branch_control_input = 0
2004
- else:
2005
- branch_control_input = _generate_branch_control_input(obf_random_seed)
2006
- return branch_control_input
2007
-
2008
2024
  def _get_func_graph(self, obj, exec_id, use_prefix=False):
2009
2025
  """Get func graph from pipeline."""
2010
2026
  if use_prefix:
2011
2027
  exec_id = exec_id + '.' + obj.arguments_key
2012
2028
  if self._graph_executor.has_compiled(exec_id) is False:
2013
2029
  return None
2014
- if self.obfuscate_config is not None:
2015
- raise ValueError('For get func graph, obfuscate_config is currently not supported now.')
2016
2030
  return self._graph_executor.get_func_graph(exec_id)
2017
2031
 
2018
2032
  def _get_func_graph_proto(self, obj, exec_id, ir_type="onnx_ir", use_prefix=False, incremental=False):
@@ -2021,11 +2035,6 @@ class _CellGraphExecutor:
2021
2035
  exec_id = exec_id + '.' + obj.arguments_key
2022
2036
  if self._graph_executor.has_compiled(exec_id) is False:
2023
2037
  return None
2024
- if self.obfuscate_config is not None:
2025
- branch_control_input = self._get_branch_control_input()
2026
- return self._graph_executor.get_obfuscate_func_graph_proto(exec_id, incremental,
2027
- self.obfuscate_config['obf_ratio'],
2028
- branch_control_input)
2029
2038
  return self._graph_executor.get_func_graph_proto(exec_id, ir_type, incremental)
2030
2039
 
2031
2040
  def get_optimize_graph_proto(self, obj):
@@ -2063,6 +2072,8 @@ def ms_memory_recycle():
2063
2072
  """
2064
2073
  if ms_compile_cache:
2065
2074
  _cell_graph_executor.del_net_res(None, ms_compile_cache)
2075
+ if os.getenv('MS_DEV_JIT_PIPELINE') != '0':
2076
+ JitExecutor_.get_instance().del_net_res(None, ms_compile_cache)
2066
2077
  ms_compile_cache.clear()
2067
2078
  for cell_cache in cells_compile_cache.values():
2068
2079
  if cell_cache:
@@ -2089,30 +2100,6 @@ def set_recursion_limit(recursion_limit=1000):
2089
2100
  GraphExecutor_.get_instance().set_max_call_depth(recursion_limit)
2090
2101
 
2091
2102
 
2092
- def _generate_branch_control_input(obf_random_seed):
2093
- """Generate append network input for dynamic obfuscation in random seed mode."""
2094
- seed_max = 2 ** 32 - 1
2095
- int_max = 2 ** 31 - 1
2096
- np.random.seed(obf_random_seed % seed_max)
2097
- # generate a string as hash function inputs
2098
- word_repo = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + "abcdefghigklmnopqrstuvwxyz" + "0123456789"
2099
- repo_len = len(word_repo)
2100
- sha_string = ''
2101
- string_len = 1024 * 1024
2102
- for _ in range(string_len):
2103
- rand_index = np.random.randint(0, repo_len)
2104
- sha_string += word_repo[rand_index]
2105
- # get hash result
2106
- sha_result = hashlib.sha256(sha_string.encode('utf-8')).hexdigest() # len is 64
2107
- branch_control_input = 1
2108
- hex_base = 16
2109
- for item in sha_result:
2110
- if int(item, hex_base) > 0:
2111
- branch_control_input *= int(item, hex_base)
2112
- branch_control_input %= int_max
2113
- return branch_control_input
2114
-
2115
-
2116
2103
  def _bind_device_context():
2117
2104
  """Bind device context to current thread"""
2118
2105
  _bind_device_ctx()
@@ -2135,4 +2122,4 @@ def flops_collection(phase='train'):
2135
2122
  _cell_graph_executor = _CellGraphExecutor()
2136
2123
  _pynative_executor = _PyNativeExecutor()
2137
2124
 
2138
- __all__ = ['ms_function', 'ms_memory_recycle', 'ms_class', 'jit', 'jit_class', 'flops_collection']
2125
+ __all__ = ['ms_memory_recycle', 'jit', 'jit_class', 'flops_collection']