mindspore 2.1.0__cp38-none-any.whl → 2.2.10__cp38-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (569) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +4 -1
  3. mindspore/_akg/akg/build_module.py +5 -6
  4. mindspore/_akg/akg/composite/build_module.py +46 -19
  5. mindspore/_akg/akg/composite/split_stitch.py +10 -11
  6. mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
  7. mindspore/_akg/akg/tvm/api.py +4 -3
  8. mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
  9. mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
  10. mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
  11. mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
  12. mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
  13. mindspore/_akg/akg/tvm/build_module.py +16 -1
  14. mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
  15. mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
  16. mindspore/_akg/akg/tvm/ir_builder.py +1 -1
  17. mindspore/_akg/akg/tvm/module.py +1 -2
  18. mindspore/_akg/akg/tvm/stmt.py +2 -2
  19. mindspore/_akg/akg/utils/ascend_profilier/__init__.py +0 -0
  20. mindspore/_akg/akg/utils/ascend_profilier/cann_file_parser.py +76 -0
  21. mindspore/_akg/akg/utils/ascend_profilier/file_manager.py +56 -0
  22. mindspore/_akg/akg/utils/ascend_profilier/op_summary_bean.py +23 -0
  23. mindspore/_akg/akg/utils/ascend_profilier/op_summary_headers.py +8 -0
  24. mindspore/_akg/akg/utils/ascend_profilier/op_summary_parser.py +42 -0
  25. mindspore/_akg/akg/utils/ascend_profilier/path_manager.py +65 -0
  26. mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
  27. mindspore/_akg/akg/utils/kernel_exec.py +98 -274
  28. mindspore/_akg/akg/utils/result_analysis.py +4 -24
  29. mindspore/_akg/akg/utils/tbe_codegen_utils.py +219 -0
  30. mindspore/_akg/akg/utils/util.py +38 -0
  31. mindspore/_c_dataengine.cpython-38-aarch64-linux-gnu.so +0 -0
  32. mindspore/_c_expression.cpython-38-aarch64-linux-gnu.so +0 -0
  33. mindspore/_c_mindrecord.cpython-38-aarch64-linux-gnu.so +0 -0
  34. mindspore/_check_jit_forbidden_api.py +3 -1
  35. mindspore/_checkparam.py +23 -29
  36. mindspore/_extends/graph_kernel/__init__.py +0 -1
  37. mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
  38. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  39. mindspore/_extends/graph_kernel/splitter.py +4 -11
  40. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
  41. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +84 -67
  42. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  43. mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
  44. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
  45. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +6 -5
  46. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  47. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  48. mindspore/_extends/parse/__init__.py +12 -15
  49. mindspore/_extends/parse/namespace.py +7 -33
  50. mindspore/_extends/parse/parser.py +61 -71
  51. mindspore/_extends/parse/resources.py +1 -1
  52. mindspore/_extends/parse/standard_method.py +74 -104
  53. mindspore/_extends/parse/trope.py +1 -1
  54. mindspore/_extends/remote/kernel_build_server.py +25 -7
  55. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  56. mindspore/_install_custom.py +43 -0
  57. mindspore/_mindspore_offline_debug.cpython-38-aarch64-linux-gnu.so +0 -0
  58. mindspore/amp.py +47 -11
  59. mindspore/bin/cache_admin +0 -0
  60. mindspore/bin/cache_server +0 -0
  61. mindspore/boost/boost.py +1 -8
  62. mindspore/boost/boost_cell_wrapper.py +3 -2
  63. mindspore/boost/grad_accumulation.py +1 -1
  64. mindspore/boost/group_loss_scale_manager.py +8 -7
  65. mindspore/common/__init__.py +5 -3
  66. mindspore/common/_jit_fallback_utils.py +6 -0
  67. mindspore/common/_register_for_adapter.py +2 -0
  68. mindspore/common/_register_for_tensor.py +2 -2
  69. mindspore/common/_stub_tensor.py +13 -0
  70. mindspore/common/_utils.py +13 -0
  71. mindspore/common/api.py +174 -259
  72. mindspore/common/auto_dynamic_shape.py +494 -0
  73. mindspore/common/dtype.py +18 -11
  74. mindspore/common/dump.py +6 -4
  75. mindspore/common/initializer.py +14 -14
  76. mindspore/common/jit_config.py +33 -15
  77. mindspore/common/lazy_inline.py +126 -7
  78. mindspore/common/mindir_util.py +101 -0
  79. mindspore/common/parameter.py +51 -41
  80. mindspore/common/seed.py +4 -4
  81. mindspore/common/sparse_tensor.py +13 -14
  82. mindspore/common/tensor.py +243 -165
  83. mindspore/communication/__init__.py +7 -4
  84. mindspore/communication/_comm_helper.py +83 -4
  85. mindspore/communication/management.py +152 -84
  86. mindspore/config/op_info.config +14 -3
  87. mindspore/config/super_bar_config.json +4 -2
  88. mindspore/context.py +152 -61
  89. mindspore/dataset/__init__.py +5 -5
  90. mindspore/dataset/audio/__init__.py +2 -2
  91. mindspore/dataset/audio/transforms.py +52 -52
  92. mindspore/dataset/callback/ds_callback.py +16 -2
  93. mindspore/dataset/core/config.py +68 -51
  94. mindspore/dataset/engine/cache_client.py +28 -5
  95. mindspore/dataset/engine/datasets.py +250 -112
  96. mindspore/dataset/engine/datasets_audio.py +43 -211
  97. mindspore/dataset/engine/datasets_standard_format.py +16 -35
  98. mindspore/dataset/engine/datasets_text.py +43 -67
  99. mindspore/dataset/engine/datasets_user_defined.py +86 -100
  100. mindspore/dataset/engine/datasets_vision.py +219 -1029
  101. mindspore/dataset/engine/iterators.py +11 -4
  102. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
  103. mindspore/dataset/engine/obs/util.py +3 -0
  104. mindspore/dataset/engine/samplers.py +1 -1
  105. mindspore/dataset/engine/validators.py +19 -5
  106. mindspore/dataset/text/__init__.py +3 -3
  107. mindspore/dataset/text/transforms.py +101 -127
  108. mindspore/dataset/text/utils.py +205 -138
  109. mindspore/dataset/transforms/__init__.py +1 -1
  110. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  111. mindspore/dataset/transforms/transforms.py +95 -40
  112. mindspore/dataset/utils/browse_dataset.py +8 -2
  113. mindspore/dataset/utils/line_reader.py +17 -19
  114. mindspore/dataset/vision/__init__.py +3 -3
  115. mindspore/dataset/vision/c_transforms.py +6 -3
  116. mindspore/dataset/vision/transforms.py +409 -287
  117. mindspore/dataset/vision/utils.py +13 -14
  118. mindspore/dataset/vision/validators.py +11 -1
  119. mindspore/experimental/map_parameter.py +14 -0
  120. mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
  121. mindspore/{nn/optim_ex → experimental/optim}/adam.py +60 -67
  122. mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
  123. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  124. mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
  125. mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
  126. mindspore/gen_ops.py +273 -0
  127. mindspore/include/OWNERS +0 -1
  128. mindspore/include/api/data_type.h +2 -1
  129. mindspore/include/api/graph.h +0 -15
  130. mindspore/include/api/kernel.h +2 -0
  131. mindspore/include/api/kernel_api.h +37 -12
  132. mindspore/include/api/model.h +17 -14
  133. mindspore/include/api/status.h +8 -3
  134. mindspore/include/api/types.h +37 -4
  135. mindspore/include/c_api/ms/abstract.h +67 -0
  136. mindspore/include/c_api/ms/attribute.h +197 -0
  137. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  138. mindspore/include/c_api/ms/base/macros.h +32 -0
  139. mindspore/include/c_api/ms/base/status.h +33 -0
  140. mindspore/include/c_api/ms/base/types.h +282 -0
  141. mindspore/include/c_api/ms/context.h +102 -0
  142. mindspore/include/c_api/ms/graph.h +160 -0
  143. mindspore/include/c_api/ms/node.h +606 -0
  144. mindspore/include/c_api/ms/tensor.h +161 -0
  145. mindspore/include/c_api/ms/value.h +84 -0
  146. mindspore/include/dataset/constants.h +6 -5
  147. mindspore/include/dataset/execute.h +23 -13
  148. mindspore/include/dataset/text.h +26 -26
  149. mindspore/include/dataset/transforms.h +13 -13
  150. mindspore/include/dataset/vision.h +60 -60
  151. mindspore/include/dataset/vision_ascend.h +5 -6
  152. mindspore/include/dataset/vision_lite.h +17 -17
  153. mindspore/include/mindapi/base/type_id.h +1 -0
  154. mindspore/include/mindapi/base/types.h +1 -0
  155. mindspore/lib/libdnnl.so.2 +0 -0
  156. mindspore/lib/libjemalloc.so.2 +0 -0
  157. mindspore/lib/libmindspore.so +0 -0
  158. mindspore/lib/libmindspore_backend.so +0 -0
  159. mindspore/lib/libmindspore_common.so +0 -0
  160. mindspore/lib/libmindspore_core.so +0 -0
  161. mindspore/lib/libmindspore_glog.so.0 +0 -0
  162. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  163. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  164. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  165. mindspore/lib/libmindspore_shared_lib.so +0 -0
  166. mindspore/lib/libnnacl.so +0 -0
  167. mindspore/lib/libopencv_core.so.4.5 +0 -0
  168. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  169. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  170. mindspore/lib/libps_cache.so +0 -0
  171. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310/aic-ascend310-ops-info.json +123 -0
  172. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310p/aic-ascend310p-ops-info.json +123 -0
  173. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +158 -0
  174. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +37 -0
  175. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
  176. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
  177. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
  178. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
  179. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
  180. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
  181. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
  182. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
  183. mindspore/lib/plugin/ascend/custom_aicore_ops/op_proto/libop_proto.so +0 -0
  184. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  185. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  186. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +8928 -0
  187. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  188. mindspore/lib/plugin/ascend/libakg.so +0 -0
  189. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  190. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  191. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  192. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  193. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  194. mindspore/lib/plugin/cpu/libakg.so +0 -0
  195. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  196. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  197. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  198. mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
  199. mindspore/nn/__init__.py +0 -2
  200. mindspore/nn/cell.py +313 -74
  201. mindspore/nn/dynamic_lr.py +21 -21
  202. mindspore/nn/layer/activation.py +22 -30
  203. mindspore/nn/layer/basic.py +15 -13
  204. mindspore/nn/layer/channel_shuffle.py +1 -1
  205. mindspore/nn/layer/container.py +271 -9
  206. mindspore/nn/layer/conv.py +323 -204
  207. mindspore/nn/layer/dense.py +8 -5
  208. mindspore/nn/layer/embedding.py +33 -27
  209. mindspore/nn/layer/flash_attention.py +141 -88
  210. mindspore/nn/layer/image.py +8 -6
  211. mindspore/nn/layer/math.py +16 -25
  212. mindspore/nn/layer/normalization.py +107 -66
  213. mindspore/nn/layer/padding.py +1 -1
  214. mindspore/nn/layer/pooling.py +131 -109
  215. mindspore/nn/layer/rnn_cells.py +27 -22
  216. mindspore/nn/layer/rnns.py +13 -16
  217. mindspore/nn/layer/thor_layer.py +1 -1
  218. mindspore/nn/layer/transformer.py +221 -154
  219. mindspore/nn/learning_rate_schedule.py +9 -1
  220. mindspore/nn/loss/loss.py +235 -174
  221. mindspore/nn/optim/ada_grad.py +2 -1
  222. mindspore/nn/optim/adadelta.py +1 -0
  223. mindspore/nn/optim/adafactor.py +2 -1
  224. mindspore/nn/optim/adam.py +7 -4
  225. mindspore/nn/optim/adamax.py +3 -2
  226. mindspore/nn/optim/adasum.py +2 -2
  227. mindspore/nn/optim/asgd.py +2 -3
  228. mindspore/nn/optim/ftrl.py +6 -5
  229. mindspore/nn/optim/lamb.py +7 -4
  230. mindspore/nn/optim/lars.py +1 -1
  231. mindspore/nn/optim/lazyadam.py +5 -3
  232. mindspore/nn/optim/momentum.py +2 -1
  233. mindspore/nn/optim/optimizer.py +53 -4
  234. mindspore/nn/optim/proximal_ada_grad.py +3 -4
  235. mindspore/nn/optim/rmsprop.py +4 -3
  236. mindspore/nn/optim/rprop.py +23 -12
  237. mindspore/nn/optim/sgd.py +26 -11
  238. mindspore/nn/optim/thor.py +9 -7
  239. mindspore/nn/probability/bijector/bijector.py +5 -5
  240. mindspore/nn/probability/bijector/power_transform.py +27 -27
  241. mindspore/nn/probability/bijector/softplus.py +3 -3
  242. mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
  243. mindspore/nn/probability/distribution/bernoulli.py +5 -5
  244. mindspore/nn/probability/distribution/beta.py +3 -3
  245. mindspore/nn/probability/distribution/categorical.py +7 -7
  246. mindspore/nn/probability/distribution/cauchy.py +0 -1
  247. mindspore/nn/probability/distribution/distribution.py +3 -3
  248. mindspore/nn/probability/distribution/gamma.py +3 -3
  249. mindspore/nn/probability/distribution/geometric.py +4 -4
  250. mindspore/nn/probability/distribution/gumbel.py +4 -4
  251. mindspore/nn/probability/distribution/log_normal.py +2 -2
  252. mindspore/nn/probability/distribution/logistic.py +2 -2
  253. mindspore/nn/probability/distribution/poisson.py +4 -4
  254. mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
  255. mindspore/nn/probability/distribution/uniform.py +6 -6
  256. mindspore/nn/wrap/cell_wrapper.py +84 -34
  257. mindspore/nn/wrap/grad_reducer.py +8 -5
  258. mindspore/nn/wrap/loss_scale.py +105 -42
  259. mindspore/numpy/array_creations.py +1 -2
  260. mindspore/numpy/array_ops.py +3 -2
  261. mindspore/numpy/utils_const.py +5 -5
  262. mindspore/offline_debug/convert_async.py +2 -2
  263. mindspore/ops/_grad_experimental/__init__.py +0 -5
  264. mindspore/ops/_grad_experimental/grad_array_ops.py +2 -3
  265. mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
  266. mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
  267. mindspore/ops/_grad_experimental/grad_implementations.py +11 -1
  268. mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
  269. mindspore/ops/_grad_experimental/grad_math_ops.py +19 -199
  270. mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
  271. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
  272. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  273. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +165 -109
  274. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +144 -86
  275. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +172 -187
  276. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +51 -57
  277. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +6 -17
  278. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +1 -1
  279. mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
  280. mindspore/ops/_op_impl/aicpu/add.py +3 -3
  281. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  282. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  283. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  284. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  285. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  286. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  287. mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
  288. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  289. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  290. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  291. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  292. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  293. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  294. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  295. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
  296. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
  297. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  298. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  299. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  300. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  301. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  302. mindspore/ops/_op_impl/tbe/__init__.py +4 -4
  303. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  304. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  305. mindspore/ops/_primitive_cache.py +1 -1
  306. mindspore/ops/_tracefunc.py +45 -13
  307. mindspore/ops/_utils/utils.py +6 -1
  308. mindspore/ops/_vmap/vmap_array_ops.py +3 -3
  309. mindspore/ops/_vmap/vmap_base.py +3 -3
  310. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  311. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  312. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  313. mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
  314. mindspore/ops/arg_dtype_cast.py +54 -0
  315. mindspore/ops/composite/base.py +37 -10
  316. mindspore/ops/composite/math_ops.py +5 -4
  317. mindspore/ops/composite/multitype_ops/_compile_utils.py +275 -73
  318. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
  319. mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
  320. mindspore/ops/composite/multitype_ops/getitem_impl.py +42 -4
  321. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  322. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  323. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  324. mindspore/ops/deprecated.py +304 -0
  325. mindspore/ops/function/__init__.py +4 -1
  326. mindspore/ops/function/array_func.py +174 -193
  327. mindspore/ops/function/clip_func.py +81 -13
  328. mindspore/ops/function/debug_func.py +1 -1
  329. mindspore/ops/function/grad/grad_func.py +18 -9
  330. mindspore/ops/function/image_func.py +10 -4
  331. mindspore/ops/function/linalg_func.py +5 -5
  332. mindspore/ops/function/math_func.py +575 -386
  333. mindspore/ops/function/nn_func.py +568 -260
  334. mindspore/ops/function/random_func.py +88 -57
  335. mindspore/ops/function/sparse_func.py +1 -1
  336. mindspore/ops/function/sparse_unary_func.py +14 -12
  337. mindspore/ops/function/vmap_func.py +6 -5
  338. mindspore/ops/functional.py +15 -10
  339. mindspore/ops/op_info_register.py +244 -25
  340. mindspore/ops/operations/__init__.py +28 -19
  341. mindspore/ops/operations/_grad_ops.py +72 -7
  342. mindspore/ops/operations/_inner_ops.py +350 -17
  343. mindspore/ops/operations/_quant_ops.py +4 -8
  344. mindspore/ops/operations/_sequence_ops.py +42 -0
  345. mindspore/ops/operations/array_ops.py +68 -282
  346. mindspore/ops/operations/comm_ops.py +107 -59
  347. mindspore/ops/operations/custom_ops.py +94 -70
  348. mindspore/ops/operations/debug_ops.py +8 -4
  349. mindspore/ops/operations/image_ops.py +18 -12
  350. mindspore/ops/operations/inner_ops.py +26 -3
  351. mindspore/ops/operations/math_ops.py +189 -141
  352. mindspore/ops/operations/nn_ops.py +794 -489
  353. mindspore/ops/operations/other_ops.py +0 -22
  354. mindspore/ops/operations/random_ops.py +53 -111
  355. mindspore/ops/operations/sparse_ops.py +3 -1
  356. mindspore/ops/primitive.py +24 -18
  357. mindspore/parallel/_auto_parallel_context.py +68 -8
  358. mindspore/parallel/_cost_model_context.py +2 -2
  359. mindspore/parallel/_offload_context.py +17 -3
  360. mindspore/parallel/_parallel_serialization.py +12 -5
  361. mindspore/parallel/_ps_context.py +12 -0
  362. mindspore/parallel/_tensor.py +18 -13
  363. mindspore/parallel/_transformer/layers.py +5 -3
  364. mindspore/parallel/_transformer/loss.py +1 -0
  365. mindspore/parallel/_transformer/moe.py +2 -2
  366. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  367. mindspore/parallel/_transformer/transformer.py +23 -3
  368. mindspore/parallel/_utils.py +11 -7
  369. mindspore/parallel/algo_parameter_config.py +85 -5
  370. mindspore/parallel/checkpoint_transform.py +19 -12
  371. mindspore/parallel/shard.py +21 -14
  372. mindspore/profiler/common/struct_type.py +3 -3
  373. mindspore/profiler/common/util.py +4 -2
  374. mindspore/profiler/envprofiling.py +1 -1
  375. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  376. mindspore/profiler/parser/ascend_flops_generator.py +2 -2
  377. mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
  378. mindspore/profiler/parser/ascend_hccl_generator.py +249 -12
  379. mindspore/profiler/parser/ascend_msprof_exporter.py +150 -255
  380. mindspore/profiler/parser/ascend_msprof_generator.py +204 -17
  381. mindspore/profiler/parser/ascend_op_generator.py +6 -6
  382. mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
  383. mindspore/profiler/parser/ascend_timeline_generator.py +14 -187
  384. mindspore/profiler/parser/base_timeline_generator.py +10 -8
  385. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +16 -12
  386. mindspore/profiler/parser/flops_parser.py +15 -11
  387. mindspore/profiler/parser/framework_parser.py +38 -22
  388. mindspore/profiler/parser/hccl_parser.py +16 -12
  389. mindspore/profiler/parser/integrator.py +22 -11
  390. mindspore/profiler/parser/memory_usage_parser.py +2 -2
  391. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  392. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  393. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  394. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  395. mindspore/profiler/parser/optime_parser.py +1 -1
  396. mindspore/profiler/parser/profiler_info.py +21 -2
  397. mindspore/profiler/parser/step_trace_parser.py +11 -14
  398. mindspore/profiler/profiling.py +179 -89
  399. mindspore/rewrite/api/node.py +102 -19
  400. mindspore/rewrite/api/node_type.py +5 -1
  401. mindspore/rewrite/api/pattern_engine.py +1 -1
  402. mindspore/rewrite/api/scoped_value.py +9 -17
  403. mindspore/rewrite/api/symbol_tree.py +131 -47
  404. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  405. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  406. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  407. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
  408. mindspore/rewrite/common/rewrite_elog.py +5 -1
  409. mindspore/rewrite/namer.py +33 -24
  410. mindspore/rewrite/namespace.py +14 -5
  411. mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
  412. mindspore/rewrite/node/call_function.py +79 -0
  413. mindspore/rewrite/node/cell_container.py +135 -0
  414. mindspore/rewrite/node/control_flow.py +88 -0
  415. mindspore/rewrite/{node.py → node/node.py} +273 -234
  416. mindspore/rewrite/node/node_manager.py +254 -0
  417. mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
  418. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  419. mindspore/rewrite/parsers/assign_parser.py +216 -221
  420. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  421. mindspore/rewrite/parsers/class_def_parser.py +174 -113
  422. mindspore/rewrite/parsers/constant_parser.py +9 -6
  423. mindspore/rewrite/parsers/container_parser.py +9 -7
  424. mindspore/rewrite/parsers/for_parser.py +36 -15
  425. mindspore/rewrite/parsers/function_def_parser.py +24 -16
  426. mindspore/rewrite/parsers/if_parser.py +28 -24
  427. mindspore/rewrite/parsers/module_parser.py +196 -25
  428. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  429. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  430. mindspore/rewrite/parsers/return_parser.py +6 -6
  431. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  432. mindspore/rewrite/sparsify/utils.py +1 -1
  433. mindspore/rewrite/symbol_tree.py +523 -578
  434. mindspore/rewrite/symbol_tree_builder.py +9 -193
  435. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  436. mindspore/run_check/_check_version.py +6 -4
  437. mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
  438. mindspore/safeguard/rewrite_obfuscation.py +541 -0
  439. mindspore/scipy/linalg.py +1 -1
  440. mindspore/scipy/optimize/minimize.py +7 -3
  441. mindspore/train/_utils.py +7 -3
  442. mindspore/train/amp.py +323 -123
  443. mindspore/train/anf_ir_pb2.py +14 -2
  444. mindspore/train/callback/_backup_and_restore.py +2 -12
  445. mindspore/train/callback/_callback.py +29 -4
  446. mindspore/train/callback/_checkpoint.py +23 -8
  447. mindspore/train/callback/_early_stop.py +2 -2
  448. mindspore/train/callback/_landscape.py +4 -4
  449. mindspore/train/callback/_loss_monitor.py +2 -2
  450. mindspore/train/callback/_on_request_exit.py +2 -2
  451. mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
  452. mindspore/train/callback/_summary_collector.py +15 -8
  453. mindspore/train/callback/_time_monitor.py +58 -5
  454. mindspore/train/data_sink.py +5 -11
  455. mindspore/train/dataset_helper.py +84 -57
  456. mindspore/train/loss_scale_manager.py +2 -2
  457. mindspore/train/metrics/__init__.py +3 -3
  458. mindspore/train/metrics/cosine_similarity.py +1 -1
  459. mindspore/train/metrics/hausdorff_distance.py +3 -2
  460. mindspore/train/metrics/mean_surface_distance.py +3 -2
  461. mindspore/train/metrics/metric.py +39 -19
  462. mindspore/train/metrics/roc.py +2 -2
  463. mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
  464. mindspore/train/mind_ir_pb2.py +85 -36
  465. mindspore/train/model.py +187 -47
  466. mindspore/train/serialization.py +487 -161
  467. mindspore/train/summary/_summary_adapter.py +1 -1
  468. mindspore/train/summary/_writer_pool.py +3 -2
  469. mindspore/train/summary/summary_record.py +37 -17
  470. mindspore/train/train_thor/convert_utils.py +3 -3
  471. mindspore/train/train_thor/dataset_helper.py +1 -1
  472. mindspore/version.py +1 -1
  473. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/METADATA +6 -7
  474. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/RECORD +477 -517
  475. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/entry_points.txt +0 -1
  476. mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
  477. mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
  478. mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
  479. mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
  480. mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
  481. mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
  482. mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
  483. mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
  484. mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
  485. mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
  486. mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
  487. mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
  488. mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
  489. mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
  490. mindspore/_akg/akg/tvm/rpc/base.py +0 -182
  491. mindspore/_akg/akg/tvm/rpc/client.py +0 -436
  492. mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
  493. mindspore/_akg/akg/tvm/rpc/server.py +0 -413
  494. mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
  495. mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
  496. mindspore/_extends/graph_kernel/expander.py +0 -80
  497. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
  498. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  499. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  500. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  501. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  502. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  503. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  504. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  505. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  506. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  507. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  508. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  509. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  510. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  511. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  512. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  513. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  514. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  515. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  516. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  517. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  518. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  519. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  520. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  521. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  522. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  523. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  524. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  525. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  526. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  527. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  528. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  529. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  530. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  531. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  532. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  533. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  534. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  535. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  536. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  537. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  538. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  539. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  540. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  541. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  542. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  543. mindspore/dataset/datapreprocess/__init__.py +0 -20
  544. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  545. mindspore/include/api/net.h +0 -142
  546. mindspore/nn/lr_scheduler.py +0 -262
  547. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
  548. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
  549. mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
  550. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  551. mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
  552. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
  553. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
  554. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
  555. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  556. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  557. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
  558. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
  559. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  560. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  561. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
  562. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
  563. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
  564. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
  565. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
  566. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  567. mindspore/rewrite/node_visitor.py +0 -44
  568. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/WHEEL +0 -0
  569. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/top_level.txt +0 -0
mindspore/common/api.py CHANGED
@@ -26,13 +26,14 @@ import inspect
26
26
  import importlib
27
27
  import hashlib
28
28
  import contextlib
29
- from collections import OrderedDict
29
+ from collections import OrderedDict, namedtuple
30
30
  from functools import wraps
31
31
  import numpy as np
32
32
  import mindspore as ms
33
33
  from mindspore import context
34
34
  from mindspore import log as logger
35
35
  from mindspore._extends.remote import kernel_build_server
36
+ from mindspore.common.jit_config import JitConfig
36
37
  from mindspore.common.tensor import Tensor as PythonTensor
37
38
  from mindspore.common.sparse_tensor import CSRTensor as PythonCSRTensor
38
39
  from mindspore.common.sparse_tensor import COOTensor as PythonCOOTensor
@@ -48,11 +49,15 @@ from mindspore._checkparam import is_stub_tensor
48
49
  from mindspore.common._utils import is_shape_unknown
49
50
  from mindspore.common.mutable import mutable
50
51
  from mindspore.common._register_for_adapter import ms_adapter_registry
52
+ from mindspore.common.auto_dynamic_shape import get_auto_dynamic_shape_args, update_auto_dynamic_shape_phase, \
53
+ get_auto_dynamic_shape_args_with_check_input_signature, update_auto_dynamic_shape_phase_with_check_input_signature
51
54
 
52
55
  # Store ms_function class compiled pipeline cache.
53
56
  ms_compile_cache = set()
54
57
  # Store cell compiled pipeline cache.
55
58
  cells_compile_cache = {}
59
+ # Store function compiled times information.
60
+ function_phases = dict()
56
61
 
57
62
  BROADCAST_PHASE = "_broadcast_"
58
63
  _PYNATIVE_PARALLEL_FUNC_NAME = "after_shard"
@@ -79,6 +84,12 @@ def _convert_python_data(data):
79
84
  if isinstance(data, RowTensor) and not isinstance(data, PythonRowTensor):
80
85
  return PythonRowTensor(row_tensor=data)
81
86
  if isinstance(data, tuple):
87
+ # Handle namedtuple since its type is tuple.
88
+ if hasattr(data, "_fields"):
89
+ type_name = data.__class__.__name__
90
+ data_dict = data._asdict()
91
+ fields = data_dict.keys()
92
+ return namedtuple(type_name, fields)(**_convert_python_data(data_dict))
82
93
  return tuple(_convert_python_data(x) for x in data)
83
94
  if isinstance(data, list):
84
95
  # Keep list object not change for inplace operation.
@@ -86,7 +97,11 @@ def _convert_python_data(data):
86
97
  data[i] = _convert_python_data(data[i])
87
98
  return data
88
99
  if isinstance(data, dict):
89
- return dict((_convert_python_data(key), _convert_python_data(value)) for key, value in data.items())
100
+ # Keep the dict object not change.
101
+ keys = tuple(data.keys())
102
+ for key in keys:
103
+ data[_convert_python_data(key)] = _convert_python_data(data.pop(key))
104
+ return data
90
105
  return data
91
106
 
92
107
 
@@ -175,8 +190,7 @@ def __get_compile_cache_dep_files(file_path, compile_cache_dep_files, pkg):
175
190
  if isinstance(node, ast.ImportFrom):
176
191
  if node.module is not None:
177
192
  module_name = node.module
178
- if node.level == 1:
179
- module_name = "." + module_name
193
+ module_name = "." * node.level + module_name
180
194
  elif not isinstance(node, ast.Import):
181
195
  continue
182
196
  # Do not care the files in mindspore package
@@ -284,195 +298,6 @@ def _get_args_for_run(obj, args, kwargs):
284
298
  return new_args
285
299
 
286
300
 
287
- class _AutoIdentifyDynamicShape:
288
-
289
- """
290
- Represents a function auto identify dynamic shape.
291
- """
292
- def __init__(self):
293
- self.all_shape_cache = {}
294
-
295
-
296
- @staticmethod
297
- def get_input_tensor_shape(args_list):
298
- """get input tensor shape and type save as tensor, and make it value to 1"""
299
- tensor_list = []
300
- for arg in args_list:
301
- if isinstance(arg, Tensor):
302
- tmp_shape = arg.shape
303
- tmp_type = arg.dtype
304
- tensor_list.append(PythonTensor(np.ones(tmp_shape), dtype=tmp_type))
305
- else:
306
- tensor_list.append(arg)
307
-
308
- return tuple(tensor_list)
309
-
310
-
311
- @staticmethod
312
- def check_input_args(args_list):
313
- """check input args"""
314
- if not args_list:
315
- return False
316
-
317
- for elem in args_list:
318
- if elem is None:
319
- return False
320
-
321
- if isinstance(elem, ms.Parameter):
322
- return False
323
-
324
- if not isinstance(elem, Tensor):
325
- return False
326
-
327
- if elem.const_arg:
328
- return False
329
-
330
- return True
331
-
332
-
333
- @staticmethod
334
- def _is_tensor_equal(input_tensor, cache_tensor):
335
- """check two tensor is equal"""
336
- if input_tensor.dtype != cache_tensor.dtype:
337
- return False
338
-
339
- if input_tensor.shape != cache_tensor.shape:
340
- return False
341
-
342
- if len(input_tensor.shape) != len(cache_tensor.shape):
343
- return False
344
-
345
- return True
346
-
347
-
348
- @staticmethod
349
- def _is_all_input_shape_generalize(input_shape_tuple):
350
- """check all input shapes need generalize"""
351
- for elem in input_shape_tuple:
352
- if not is_shape_unknown(elem.shape):
353
- return False
354
- return True
355
-
356
-
357
- def auto_dynamic_generate_compile_args(self, args_list):
358
- """generate compile args in auto dynamic shape"""
359
- if not self._is_enable_auto_identify_shape(args_list):
360
- return args_list
361
-
362
- args_len = len(args_list)
363
- tensor_tuple = self.get_input_tensor_shape(args_list)
364
- shape_cache_list = self.all_shape_cache.get(args_len)
365
- # step1: init real_shape_cache, part_generalize_shape_cache, all_generalize_shape_cache.
366
- if shape_cache_list is None:
367
- shape_cache_list = []
368
- real_shape_cache = set()
369
- real_shape_cache.add(tensor_tuple)
370
- shape_cache_list.append(real_shape_cache)
371
- part_generalize_shape_cache = set()
372
- shape_cache_list.append(part_generalize_shape_cache)
373
- all_generalize_shape_cache = set()
374
- shape_cache_list.append(all_generalize_shape_cache)
375
- self.all_shape_cache[args_len] = shape_cache_list
376
- logger.info((f'The real shape cache is empty, add it into real_shape_cache.'))
377
- return tensor_tuple
378
-
379
- # step2: find cache in real_shape_cache.
380
- real_shape_cache = shape_cache_list[0]
381
- is_real_shape_exist, real_shape_input = self._find_compile_args_in_shape_cache(real_shape_cache, tensor_tuple,
382
- "real")
383
- if is_real_shape_exist and real_shape_input is not None:
384
- return real_shape_input
385
-
386
- # step3: if can not find cache in real_shape_cache, then generate it
387
- is_generalize_shape, compile_args = self._do_generalize_shape(real_shape_cache, tensor_tuple)
388
-
389
- # step4: if input type change or rank change, save shape into real_shape_cache and then return
390
- if not is_generalize_shape and compile_args is None:
391
- real_shape_cache.add(tensor_tuple)
392
- return tensor_tuple
393
-
394
- # step5: check whether all input tensor need generalize
395
- all_generalize_shape_cache = shape_cache_list[2]
396
- if self._is_all_input_shape_generalize(compile_args):
397
- if not all_generalize_shape_cache:
398
- all_generalize_shape_cache.add(compile_args)
399
- logger.info((f'return all generalize shape cache.'))
400
- return compile_args
401
-
402
- # step6: find compile_args in part_generalize_shape_cache
403
- part_generalize_shape_cache = shape_cache_list[1]
404
- if not part_generalize_shape_cache:
405
- part_generalize_shape_cache.add(compile_args)
406
- else:
407
- is_generalize_shape_exist, _ = self._find_compile_args_in_shape_cache(part_generalize_shape_cache,
408
- compile_args, "part generalize")
409
- if not is_generalize_shape_exist:
410
- logger.info((f'Can not find cache in part_generalize_shape_cache, add it into'
411
- ' part_generalize_shape_cache.'))
412
- part_generalize_shape_cache.add(compile_args)
413
-
414
- return compile_args
415
-
416
-
417
- def _is_all_tensor_equal(self, input_shape_tuple, cache_shape_tuple):
418
- """check two tuple is equal"""
419
- for i, elem in enumerate(cache_shape_tuple):
420
- res = self._is_tensor_equal(input_shape_tuple[i], elem)
421
- if not res:
422
- return False
423
- return True
424
-
425
-
426
- def _is_enable_auto_identify_shape(self, args_list):
427
- """is enable auto identify shape"""
428
- enable_auto_identify = os.getenv('MS_AUTO_DYNAMIC_SHAPE_ENABLE')
429
- if not enable_auto_identify:
430
- enable_auto_identify = False
431
- if ((enable_auto_identify is False or enable_auto_identify == "0")) or not self.check_input_args(args_list):
432
- return False
433
- return True
434
-
435
-
436
- def _find_compile_args_in_shape_cache(self, shape_cache, compile_args, cache_type):
437
- """find compile args in real or part generalize shape cache"""
438
- is_exist = False
439
- for shapes in shape_cache:
440
- is_exist = self._is_all_tensor_equal(compile_args, shapes)
441
- if is_exist:
442
- logger.info((f'Find cache in {cache_type} shape cache.'))
443
- return is_exist, shapes
444
- logger.info((f'Can not find cache in {cache_type} shape cache.'))
445
- return is_exist, None
446
-
447
-
448
- def _do_generalize_shape(self, real_shape_cache, tensor_tuple):
449
- """do generalize shape"""
450
- is_generalize_shape = False
451
- for real_shape in real_shape_cache:
452
- generalize_shape = []
453
- for i, elem in enumerate(real_shape):
454
- if len(elem.shape) != len(tensor_tuple[i].shape) or elem.dtype != tensor_tuple[i].dtype:
455
- generalize_shape.clear()
456
- break
457
- if (not is_shape_unknown(elem.shape)) and self._is_tensor_equal(tensor_tuple[i], elem):
458
- generalize_shape.append(tensor_tuple[i])
459
- else:
460
- shape_value = []
461
- for _ in range(len(elem.shape)):
462
- shape_value.append(-1)
463
- shape_tuple = tuple(shape_value)
464
- generalize_shape.append(PythonTensor(Tensor(shape=shape_tuple, dtype=tensor_tuple[i].dtype)))
465
- logger.info((f'The {i} input tensor shape is {tensor_tuple[i].shape}, type is '
466
- f'{tensor_tuple[i].dtype}; in real cache shape is {elem.shape}, type is '
467
- f'{elem.dtype}, the {i} input shape not equal, may generalize to {shape_tuple}.'))
468
-
469
- if len(generalize_shape) == len(real_shape):
470
- is_generalize_shape = True
471
- return is_generalize_shape, tuple(generalize_shape)
472
-
473
- return is_generalize_shape, None
474
-
475
-
476
301
  class _MindsporeFunctionExecutor:
477
302
  """
478
303
  Represents a function compiled by graph compiler.
@@ -490,7 +315,6 @@ class _MindsporeFunctionExecutor:
490
315
  Returns:
491
316
  The result of pipeline running in graph mode.
492
317
  """
493
-
494
318
  def __init__(self, fn, ms_create_time, input_signature=None, obj=None, jit_config=None):
495
319
  init_pipeline()
496
320
  if not isinstance(fn, (types.FunctionType, types.MethodType)):
@@ -506,7 +330,7 @@ class _MindsporeFunctionExecutor:
506
330
  self._graph_executor = GraphExecutor_.get_instance()
507
331
  self._create_time = ms_create_time
508
332
  self.jit_config_dict = jit_config.jit_config_dict if jit_config else None
509
- self.auto_identify_dynamic_shape = _AutoIdentifyDynamicShape()
333
+
510
334
 
511
335
  @_wrap_func
512
336
  def __call__(self, *args, **kwargs):
@@ -516,9 +340,9 @@ class _MindsporeFunctionExecutor:
516
340
  phase = ""
517
341
  try:
518
342
  if context.get_context("mode") == context.PYNATIVE_MODE:
519
- _pynative_executor.set_ms_function_compile_status(True, phase)
343
+ _pynative_executor.set_jit_compile_status(True, phase)
520
344
  phase = self.compile(self.fn.__name__, *args_list, **kwargs)
521
- _pynative_executor.set_ms_function_compile_status(False, phase)
345
+ _pynative_executor.set_jit_compile_status(False, phase)
522
346
  else:
523
347
  phase = self.compile(self.fn.__name__, *args_list, **kwargs)
524
348
  except Exception as err:
@@ -531,19 +355,11 @@ class _MindsporeFunctionExecutor:
531
355
  new_inputs = self._generate_run_args(args_list, kwargs)
532
356
  output = self._graph_executor(tuple(new_inputs), phase)
533
357
  if context.get_context("mode") == context.PYNATIVE_MODE:
534
- output = _pynative_executor.grad_ms_function(output, *new_inputs)
535
-
536
- enable_ge = os.getenv("MS_ENABLE_GE") == "1"
537
- if enable_ge and self.jit_config_dict is None:
538
- raise RuntimeError("GE and jit_level=O3 should be used together, but jit_config is None.")
539
- if self.jit_config_dict:
540
- enable_jit_level_o3 = self.jit_config_dict.get('jit_level') == "O3"
541
- if (enable_ge and not enable_jit_level_o3) or (not enable_ge and enable_jit_level_o3):
542
- raise RuntimeError("GE and jit_level=O3 should be used together, but got MS_ENABLE_GE={}, jit_level={}".
543
- format(os.getenv("MS_ENABLE_GE"), self.jit_config_dict.get('jit_level')))
358
+ output = _pynative_executor.grad_jit(output, *new_inputs)
544
359
 
545
360
  return output
546
361
 
362
+
547
363
  def compile(self, method_name, *args, **kwargs):
548
364
  """Returns pipeline for the given args."""
549
365
  # Check whether hook function registered on Cell object.
@@ -554,14 +370,16 @@ class _MindsporeFunctionExecutor:
554
370
  f"pynative mode and remove 'jit' decorator.")
555
371
  # Chose dynamic shape tensors or actual input tensors as compile args.
556
372
  compile_args = self._generate_compile_args(args)
373
+ key_id = self._get_key_id()
374
+ compile_args = get_auto_dynamic_shape_args_with_check_input_signature(compile_args, key_id,
375
+ self.input_signature)
376
+
557
377
  # Restore the mutable attr for every arg.
558
378
  compile_args = _restore_mutable_attr(args, compile_args)
559
- generate_name = self.fn.__module__ + "." + self.fn.__name__ + "." + self.fn.__code__.co_filename + "." + \
560
- str(self.fn.__code__.co_firstlineno)
561
- if _pynative_executor.grad_flag():
562
- generate_name = generate_name + ".grad"
563
- if _is_pynative_parallel():
564
- generate_name = generate_name[:generate_name.rfind(str(id(self.fn)))] + str(id(self.shard_parent_obj))
379
+ generate_name, echo_function_name = self._get_generate_name()
380
+ # The full Function name
381
+ full_function_name = generate_name
382
+ create_time = ''
565
383
 
566
384
  # Add key with obj
567
385
  if self.obj is not None:
@@ -572,13 +390,18 @@ class _MindsporeFunctionExecutor:
572
390
  self.obj.__parse_method__ = method_name
573
391
  if isinstance(self.obj, ms.nn.Cell):
574
392
  generate_name = generate_name + '.' + str(self.obj.create_time)
393
+ create_time = str(self.obj.create_time)
575
394
  else:
576
395
  generate_name = generate_name + '.' + str(self._create_time)
396
+ create_time = str(self._create_time)
397
+
577
398
  generate_name = generate_name + '.' + str(id(self.obj))
399
+ full_function_name = generate_name
578
400
  else:
579
401
  # Different instance of same class may use same memory(means same obj_id) at diff times.
580
402
  # To avoid unexpected phase matched, add create_time to generate_name.
581
403
  generate_name = generate_name + '.' + str(self._create_time)
404
+ create_time = str(self._create_time)
582
405
 
583
406
  self.enable_tuple_broaden = False
584
407
  if hasattr(self.obj, "enable_tuple_broaden"):
@@ -587,16 +410,33 @@ class _MindsporeFunctionExecutor:
587
410
  self._graph_executor.set_enable_tuple_broaden(self.enable_tuple_broaden)
588
411
  key = self._graph_executor.generate_arguments_key(self.fn, compile_args, kwargs, self.enable_tuple_broaden)
589
412
  phase = generate_name + '.' + str(key)
413
+
414
+ update_auto_dynamic_shape_phase_with_check_input_signature(compile_args, key_id, phase, self.input_signature)
415
+
590
416
  if phase in ms_compile_cache:
591
417
  return phase
592
418
 
419
+ self._check_recompile(full_function_name, create_time, echo_function_name)
420
+
593
421
  # If enable compile cache, get the dependency files list and set to graph executor.
594
422
  self._set_compile_cache_dep_files()
595
423
  if self.jit_config_dict:
596
424
  self._graph_executor.set_jit_config(self.jit_config_dict)
425
+ else:
426
+ jit_config_dict = JitConfig().jit_config_dict
427
+ self._graph_executor.set_jit_config(jit_config_dict)
597
428
 
598
429
  if self.obj is None:
430
+ # Set an attribute to fn as an identifier.
431
+ if isinstance(self.fn, types.MethodType):
432
+ setattr(self.fn.__func__, "__jit_function__", True)
433
+ else:
434
+ setattr(self.fn, "__jit_function__", True)
599
435
  is_compile = self._graph_executor.compile(self.fn, compile_args, kwargs, phase, True)
436
+ if isinstance(self.fn, types.MethodType):
437
+ delattr(self.fn.__func__, "__jit_function__")
438
+ else:
439
+ delattr(self.fn, "__jit_function__")
600
440
  else:
601
441
  if isinstance(self.obj, ms.nn.Cell):
602
442
  self._graph_executor.set_weights_values(self.obj.parameters_dict())
@@ -605,8 +445,32 @@ class _MindsporeFunctionExecutor:
605
445
  if not is_compile:
606
446
  raise RuntimeError("Executor compile failed.")
607
447
  ms_compile_cache.add(phase)
448
+
608
449
  return phase
609
450
 
451
+ def _check_recompile(self, full_function_name, create_time, echo_function_name):
452
+ """Warning when the function has been compiled."""
453
+ ignore_dirs = ["mindspore/ops", "mindspore/nn"]
454
+ if any((lambda x: x in full_function_name)(x) for x in ignore_dirs):
455
+ return
456
+
457
+ if full_function_name in function_phases:
458
+ warning_times = 1
459
+ if len(function_phases[full_function_name]) >= warning_times \
460
+ and create_time not in function_phases[full_function_name]:
461
+ tips = "Try to decorate the function with @jit(hash_args=...) " \
462
+ "or @jit(compile_once=True) to reduce the compile time. " \
463
+ "For more details, get instructions about `jit` at " \
464
+ "https://www.mindspore.cn/search?inputValue=jit."
465
+
466
+ logger.warning(f"The {echo_function_name} has been compiled again. "
467
+ f"{tips} ")
468
+ else:
469
+ function_phases[full_function_name] = set()
470
+
471
+ function_phases[full_function_name].add(create_time)
472
+
473
+
610
474
  @staticmethod
611
475
  def _optimizer_state_init(opt_states):
612
476
  """set data for all optimizer states in case it is executed in graph mode"""
@@ -618,6 +482,31 @@ class _MindsporeFunctionExecutor:
618
482
  opt_param.init_data()
619
483
 
620
484
 
485
+ def _get_key_id(self):
486
+ """get key id."""
487
+ if isinstance(self.obj, ms.nn.Cell):
488
+ key_id = str(id(self.obj)) + str(self.obj.create_time)
489
+ else:
490
+ key_id = str(id(self.obj)) + str(self._create_time)
491
+
492
+ if _pynative_executor.grad_flag():
493
+ key_id = key_id + ".grad"
494
+ return key_id
495
+
496
+
497
+ def _get_generate_name(self):
498
+ """get generate name."""
499
+ generate_name = self.fn.__module__ + "." + self.fn.__name__ + "." + self.fn.__code__.co_filename + "." + str(
500
+ self.fn.__code__.co_firstlineno)
501
+ echo_function_name = "function \"" + self.fn.__name__ + "\" at the file \"" + self.fn.__code__.co_filename \
502
+ + "\", line " + str(self.fn.__code__.co_firstlineno)
503
+ if _pynative_executor.grad_flag():
504
+ generate_name = generate_name + ".grad"
505
+ if _is_pynative_parallel():
506
+ generate_name = generate_name[:generate_name.rfind(str(id(self.fn)))] + str(id(self.shard_parent_obj))
507
+ return generate_name, echo_function_name
508
+
509
+
621
510
  def _set_compile_cache_dep_files(self):
622
511
  # If enable compile cache, get the dependency files list
623
512
  enable_compile_cache = context.get_context("enable_compile_cache")
@@ -630,7 +519,7 @@ class _MindsporeFunctionExecutor:
630
519
  def _generate_compile_args(self, args_list):
631
520
  """Chose dynamic shape tensors or actual input tensors as compile args."""
632
521
  # Case: If the shape of input args is dynamic, get dynamic shape tensor from context and use it to compile.
633
- compile_args = args_list
522
+ compile_args = _pynative_executor.get_dynamic_input(args_list)
634
523
  # Case: The `set_inputs()` of Cell object has been set, using these dynamic shape args as compile args.
635
524
  if self.fn.__name__ == 'construct' and isinstance(self.obj, ms.nn.Cell) and self.obj.get_inputs():
636
525
  compile_args = self.obj.get_inputs()
@@ -659,12 +548,13 @@ class _MindsporeFunctionExecutor:
659
548
  f"be 'sens' and added it to compile args.")
660
549
  self.input_signature.append(args_list[-1])
661
550
  compile_args = tuple(self.input_signature)
662
- _pynative_executor.set_dynamic_input(self.obj)
551
+ if self.obj is not None:
552
+ _pynative_executor.set_dynamic_input(self.obj, *compile_args)
553
+ else:
554
+ _pynative_executor.set_dynamic_input(self.fn, *compile_args)
663
555
  else:
664
556
  if not verify_inputs_signature(self.input_signature, args_list):
665
557
  raise ValueError("The input args is incompatible with the args in `input_signature`!")
666
- else:
667
- compile_args = self.auto_identify_dynamic_shape.auto_dynamic_generate_compile_args(args_list)
668
558
  return compile_args
669
559
 
670
560
  def _generate_run_args(self, args_list, kwargs):
@@ -699,14 +589,14 @@ def _get_obj_id(input_obj):
699
589
  return obj_id + str(id(input_obj))
700
590
 
701
591
 
702
- def _get_ms_function_hash(hash_input):
592
+ def _get_jit_hash(hash_input):
703
593
  """Get hash value of single object or list of objects."""
704
594
  if isinstance(list, tuple):
705
595
  return ".".join(map(_get_obj_id, hash_input))
706
596
  return _get_obj_id(hash_input)
707
597
 
708
598
 
709
- def jit(fn=None, input_signature=None, hash_args=None, jit_config=None):
599
+ def jit(fn=None, input_signature=None, hash_args=None, jit_config=None, compile_once=False):
710
600
  """
711
601
  Create a callable MindSpore graph from a Python function.
712
602
 
@@ -726,6 +616,10 @@ def jit(fn=None, input_signature=None, hash_args=None, jit_config=None):
726
616
  like functions or objects of class defined outside `fn`. Calling `fn` again with change of `hash_args`
727
617
  will trigger recompilation. Default: ``None`` .
728
618
  jit_config (JitConfig): Jit config for compile. Default: ``None`` .
619
+ compile_once(bool): ``True``: The function would be compiled once when it was created many times.
620
+ But it may be wrong if the free variables were changed. ``False`` : It would be recompiled when
621
+ it was created again
622
+ Default: ``False`` .
729
623
 
730
624
  Returns:
731
625
  Function, if `fn` is not None, returns a callable function that will execute the compiled function; If `fn` is
@@ -769,7 +663,7 @@ def jit(fn=None, input_signature=None, hash_args=None, jit_config=None):
769
663
  ...
770
664
  >>> out = tensor_add_with_sig(x, y)
771
665
  ...
772
- ... # Set hash_args as fn, otherwise cache of compiled `closure_fn` will not be reused.
666
+ ... # Set hash_args as fn, otherwise cache of compiled closure_fn will not be reused.
773
667
  ... # While fn differs during calling again, recompilation will be triggered.
774
668
  >>> def func(x):
775
669
  ... return ops.exp(x)
@@ -783,11 +677,28 @@ def jit(fn=None, input_signature=None, hash_args=None, jit_config=None):
783
677
  >>> inputs = Tensor(np.ones([10, 10, 10]).astype(np.float32))
784
678
  >>> for i in range(10):
785
679
  ... closure_fn(inputs, func)
680
+ ...
681
+ ... # Set compile_once = True, otherwise the train_step will be compiled again.
682
+ >>> def train(x):
683
+ ... @jit(compile_once = True)
684
+ ... def train_step(x):
685
+ ... return ops.exp(x)
686
+ ... for i in range(10):
687
+ ... train_step(x)
688
+ ...
689
+ >>> inputs = Tensor(np.ones([10, 10, 10]).astype(np.float32))
690
+ >>> for i in range(10):
691
+ ... train(inputs)
786
692
  """
787
693
 
788
694
  def wrap_mindspore(func):
695
+ if not isinstance(compile_once, bool):
696
+ logger.warning(f"The parameter `compile_once` of jit should be a bool, "
697
+ f"but got {type(compile_once)}.")
789
698
  if hash_args:
790
- hash_obj = _get_ms_function_hash(hash_args)
699
+ hash_obj = _get_jit_hash(hash_args)
700
+ elif compile_once:
701
+ hash_obj = 0
791
702
  else:
792
703
  hash_obj = int(time.time() * 1e9)
793
704
 
@@ -984,8 +895,8 @@ def _no_recursive(callable_obj):
984
895
  Supported Platforms:
985
896
  ``Ascend`` ``GPU`` ``CPU``
986
897
  """
987
- isCellSubClass = inspect.isclass(callable_obj) and issubclass(callable_obj, ms.nn.Cell)
988
- if not isCellSubClass and not inspect.ismethod(callable_obj) and not inspect.isfunction(callable_obj):
898
+ is_cell_subclass = inspect.isclass(callable_obj) and issubclass(callable_obj, ms.nn.Cell)
899
+ if not is_cell_subclass and not inspect.ismethod(callable_obj) and not inspect.isfunction(callable_obj):
989
900
  raise TypeError(f"Decorator no_recursive is used for callable object, but got {callable_obj}.")
990
901
  _add_flags(callable_obj, no_recursive=True)
991
902
  return callable_obj
@@ -1149,7 +1060,7 @@ def _build_broadcast_graph(broadcast_params_dict, broadcast_phase):
1149
1060
  _broadcast_net.phase = broadcast_phase
1150
1061
  broadcasted_params = _broadcast_net()
1151
1062
  for param_name, param in zip(broadcast_params_dict.keys(), broadcasted_params):
1152
- broadcast_params_dict[param_name].set_data(param)
1063
+ broadcast_params_dict.get(param_name).set_data(param)
1153
1064
 
1154
1065
 
1155
1066
  def _get_auto_split_param_names(parameter_layout_dict):
@@ -1355,30 +1266,18 @@ class _PyNativeExecutor:
1355
1266
  """
1356
1267
  self._executor.sync()
1357
1268
 
1358
- def set_lazy_build(self, enable):
1269
+ def grad_jit(self, output, *args):
1359
1270
  """
1360
- The switch of lazy build.
1271
+ Building grad graph decorated by jit.
1361
1272
 
1362
1273
  Args:
1363
- enable (bool): Specifies whether the lazy build is enable.
1274
+ output (tuple): The function or cell decorated by jit output object.
1275
+ args (tuple): Function or cell decorated by jit input arguments.
1364
1276
 
1365
1277
  Return:
1366
1278
  None.
1367
1279
  """
1368
- self._executor.set_lazy_build(enable)
1369
-
1370
- def grad_ms_function(self, output, *args):
1371
- """
1372
- Building grad graph decorated by ms_function.
1373
-
1374
- Args:
1375
- output (tuple): The function or cell decorated by ms_function output object.
1376
- args (tuple): Function or cell decorated by ms_function input arguments.
1377
-
1378
- Return:
1379
- None.
1380
- """
1381
- return self._executor.grad_ms_function(output, *args)
1280
+ return self._executor.grad_jit(output, *args)
1382
1281
 
1383
1282
  def grad_flag(self):
1384
1283
  """
@@ -1422,29 +1321,42 @@ class _PyNativeExecutor:
1422
1321
  """
1423
1322
  self._executor.set_enable_grad(flag)
1424
1323
 
1425
- def set_ms_function_compile_status(self, status, phase):
1324
+ def set_jit_compile_status(self, status, phase):
1426
1325
  """
1427
- Set ms_function is compiling
1326
+ Set jit is compiling
1428
1327
 
1429
1328
  Args:
1430
- status(bool): ms_function compile status
1329
+ status(bool): jit compile status
1431
1330
  phase (str): The phase of cell/function instance.
1432
1331
  Return:
1433
1332
  None.
1434
1333
  """
1435
- self._executor.set_ms_function_compile_status(status, phase)
1334
+ self._executor.set_jit_compile_status(status, phase)
1436
1335
 
1437
- def set_dynamic_input(self, obj):
1336
+ def set_dynamic_input(self, obj, *args):
1438
1337
  """
1439
1338
  Set dynamic shape tensor of input arguments.
1440
1339
 
1441
1340
  Args:
1442
1341
  obj (Function/Cell): The function or cell instance.
1342
+ args (tuple): Function or cell dynamic input arguments.
1443
1343
 
1444
1344
  Return:
1445
1345
  None.
1446
1346
  """
1447
- self._executor.set_dynamic_input(obj)
1347
+ self._executor.set_dynamic_input(obj, *args)
1348
+
1349
+ def get_dynamic_input(self, *actual_args):
1350
+ """
1351
+ Get dynamic shape arguments according to actual input arguments.
1352
+
1353
+ Args:
1354
+ actual_args(tuple): Actual input arguments of Function or Cell.
1355
+
1356
+ Return:
1357
+ dynamic_shape_args(tuple): Dynamic shape arguments of Function or Cell.
1358
+ """
1359
+ return self._executor.get_dynamic_input(*actual_args)
1448
1360
 
1449
1361
  def is_first_cell(self):
1450
1362
  """
@@ -1550,6 +1462,13 @@ class _CellGraphExecutor:
1550
1462
  """
1551
1463
  self._graph_executor.set_queue_name(queue_name)
1552
1464
 
1465
+ def get_queue_name(self, dataset_phase):
1466
+ """
1467
+ Get cached queue name for the graph loaded from compile cache.
1468
+ :return: cached queue name
1469
+ """
1470
+ return self._graph_executor.get_queue_name(dataset_phase)
1471
+
1553
1472
  @staticmethod
1554
1473
  def _set_dataset_mode(obj):
1555
1474
  """set dataset mode."""
@@ -1597,15 +1516,18 @@ class _CellGraphExecutor:
1597
1516
  if not hasattr(obj, obj.__parse_method__):
1598
1517
  raise AttributeError(
1599
1518
  'The class {} dose not have method {}'.format(obj.__class__.__name__, obj.__parse_method__))
1519
+ key_id = str(id(obj)) + str(obj.create_time)
1520
+ args = get_auto_dynamic_shape_args(args, key_id)
1600
1521
 
1601
1522
  self.enable_tuple_broaden = False
1602
1523
  if hasattr(obj, "enable_tuple_broaden"):
1603
1524
  self.enable_tuple_broaden = obj.enable_tuple_broaden
1604
-
1525
+ logger.debug("Convert the network.", do_convert)
1605
1526
  self._graph_executor.set_enable_tuple_broaden(self.enable_tuple_broaden)
1606
1527
  key = self._graph_executor.generate_arguments_key(obj, args, kwargs, self.enable_tuple_broaden)
1607
1528
  obj.arguments_key = str(key)
1608
1529
  phase = phase + '.' + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key
1530
+ update_auto_dynamic_shape_phase(args, key_id, phase)
1609
1531
 
1610
1532
  if phase in obj.compile_cache and self.has_compiled(phase):
1611
1533
  logger.debug("%r graph has existed.", phase)
@@ -1616,12 +1538,12 @@ class _CellGraphExecutor:
1616
1538
  self._set_dataset_mode(obj)
1617
1539
  self._set_compile_cache_dep_files(phase)
1618
1540
 
1619
- enable_ge = context.get_context("enable_ge")
1620
- if enable_ge:
1621
- obj.add_flags(ge_init=True)
1622
1541
  self._graph_executor.set_weights_values(obj.parameters_dict())
1623
1542
  if jit_config_dict:
1624
1543
  self._graph_executor.set_jit_config(jit_config_dict)
1544
+ else:
1545
+ jit_config_dict = JitConfig().jit_config_dict
1546
+ self._graph_executor.set_jit_config(jit_config_dict)
1625
1547
  result = self._graph_executor.compile(obj, args, kwargs, phase, self._use_vm_mode())
1626
1548
  obj.compile_cache.add(phase)
1627
1549
  if not result:
@@ -1639,17 +1561,10 @@ class _CellGraphExecutor:
1639
1561
  obj.parameter_layout_dict = self._graph_executor.get_parameter_layout(phase)
1640
1562
  obj.parallel_parameter_name_list = self._graph_executor.get_parallel_parameter_name_list(phase)
1641
1563
 
1642
- if not do_convert:
1643
- return phase, True
1644
-
1645
- # the following GE init process is not needed when use vm or ms backend
1646
- if enable_ge:
1647
- pass
1648
- elif "export" in phase:
1564
+ if "export.air" in phase:
1649
1565
  self._build_data_graph(obj, phase)
1650
1566
  elif BROADCAST_PHASE not in phase and _get_parameter_broadcast():
1651
1567
  _parameter_broadcast(obj)
1652
-
1653
1568
  return phase, True
1654
1569
 
1655
1570
  def _update_param_node_default_input(self, phase, replace):