mindspore 1.10.0__cp37-cp37m-win_amd64.whl → 2.0.0rc1__cp37-cp37m-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (966) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/ConcurrencyCheck.dll +0 -0
  3. mindspore/CppBuildInsights.dll +0 -0
  4. mindspore/CppCoreCheck.dll +0 -0
  5. mindspore/EnumIndex.dll +0 -0
  6. mindspore/EspXEngine.dll +0 -0
  7. mindspore/HResultCheck.dll +0 -0
  8. mindspore/KernelTraceControl.dll +0 -0
  9. mindspore/LocalESPC.dll +0 -0
  10. mindspore/Microsoft.Diagnostics.Tracing.EventSource.dll +0 -0
  11. mindspore/Microsoft.VisualStudio.RemoteControl.dll +0 -0
  12. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  13. mindspore/Microsoft.VisualStudio.Utilities.Internal.dll +0 -0
  14. mindspore/Newtonsoft.Json.dll +0 -0
  15. mindspore/System.Runtime.CompilerServices.Unsafe.dll +0 -0
  16. mindspore/VariantClear.dll +0 -0
  17. mindspore/__init__.py +9 -4
  18. mindspore/_c_dataengine.cp37-win_amd64.pyd +0 -0
  19. mindspore/_c_expression.cp37-win_amd64.pyd +0 -0
  20. mindspore/_c_mindrecord.cp37-win_amd64.pyd +0 -0
  21. mindspore/_check_jit_forbidden_api.py +102 -0
  22. mindspore/_checkparam.py +1066 -1001
  23. mindspore/_extends/builtin_operations.py +32 -4
  24. mindspore/_extends/graph_kernel/model/graph_split.py +66 -222
  25. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +12 -9
  26. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +119 -26
  27. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +50 -50
  28. mindspore/_extends/parallel_compile/akg_compiler/util.py +9 -6
  29. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +4 -25
  30. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +9 -4
  31. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -27
  32. mindspore/_extends/parse/__init__.py +5 -3
  33. mindspore/_extends/parse/namespace.py +17 -2
  34. mindspore/_extends/parse/parser.py +193 -34
  35. mindspore/_extends/parse/resources.py +7 -8
  36. mindspore/_extends/parse/standard_method.py +1780 -435
  37. mindspore/_extends/parse/trope.py +3 -1
  38. mindspore/amp.py +53 -58
  39. mindspore/atlprov.dll +0 -0
  40. mindspore/boost/adasum.py +3 -2
  41. mindspore/boost/boost.py +2 -2
  42. mindspore/boost/boost_cell_wrapper.py +46 -26
  43. mindspore/boost/dim_reduce.py +6 -5
  44. mindspore/boost/grad_accumulation.py +2 -1
  45. mindspore/boost/group_loss_scale_manager.py +1 -1
  46. mindspore/c1.dll +0 -0
  47. mindspore/c1xx.dll +0 -0
  48. mindspore/c2.dll +0 -0
  49. mindspore/cfgpersist.dll +0 -0
  50. mindspore/clang_rt.asan_dbg_dynamic-x86_64.dll +0 -0
  51. mindspore/clang_rt.asan_dynamic-x86_64.dll +0 -0
  52. mindspore/common/__init__.py +11 -10
  53. mindspore/common/_decorator.py +2 -0
  54. mindspore/common/_register_for_adapter.py +55 -0
  55. mindspore/common/_stub_tensor.py +201 -0
  56. mindspore/common/_utils.py +57 -0
  57. mindspore/common/api.py +582 -297
  58. mindspore/common/dtype.py +66 -18
  59. mindspore/common/dump.py +2 -2
  60. mindspore/common/initializer.py +38 -1
  61. mindspore/common/jit_config.py +25 -13
  62. mindspore/common/mutable.py +53 -24
  63. mindspore/common/parameter.py +60 -37
  64. mindspore/common/seed.py +8 -24
  65. mindspore/common/sparse_tensor.py +927 -0
  66. mindspore/common/tensor.py +1627 -3900
  67. mindspore/communication/__init__.py +10 -5
  68. mindspore/communication/_comm_helper.py +78 -214
  69. mindspore/communication/_hccl_management.py +2 -1
  70. mindspore/communication/management.py +136 -47
  71. mindspore/config/op_info.config +501 -1008
  72. mindspore/context.py +291 -56
  73. mindspore/d3dcompiler_47.dll +0 -0
  74. mindspore/dataset/__init__.py +12 -8
  75. mindspore/dataset/audio/__init__.py +9 -9
  76. mindspore/dataset/audio/transforms.py +1090 -228
  77. mindspore/dataset/audio/utils.py +87 -39
  78. mindspore/dataset/audio/validators.py +223 -1
  79. mindspore/dataset/callback/ds_callback.py +17 -15
  80. mindspore/dataset/core/config.py +246 -17
  81. mindspore/dataset/core/py_util_helpers.py +4 -3
  82. mindspore/dataset/core/validator_helpers.py +10 -10
  83. mindspore/{parallel/nn/layers.py → dataset/debug/__init__.py} +7 -8
  84. mindspore/dataset/debug/debug_hook.py +65 -0
  85. mindspore/dataset/debug/pre_defined_hook.py +67 -0
  86. mindspore/dataset/engine/__init__.py +7 -3
  87. mindspore/dataset/engine/cache_client.py +9 -9
  88. mindspore/dataset/engine/datasets.py +648 -477
  89. mindspore/dataset/engine/datasets_audio.py +165 -167
  90. mindspore/dataset/engine/datasets_standard_format.py +93 -67
  91. mindspore/dataset/engine/datasets_text.py +492 -342
  92. mindspore/dataset/engine/datasets_user_defined.py +85 -50
  93. mindspore/dataset/engine/datasets_vision.py +1224 -699
  94. mindspore/dataset/engine/graphdata.py +134 -69
  95. mindspore/dataset/engine/iterators.py +50 -9
  96. mindspore/dataset/engine/offload.py +52 -31
  97. mindspore/dataset/engine/samplers.py +27 -24
  98. mindspore/dataset/engine/serializer_deserializer.py +14 -15
  99. mindspore/dataset/engine/validators.py +213 -52
  100. mindspore/dataset/text/__init__.py +10 -8
  101. mindspore/dataset/text/transforms.py +152 -57
  102. mindspore/dataset/text/utils.py +98 -49
  103. mindspore/dataset/text/validators.py +25 -0
  104. mindspore/dataset/transforms/__init__.py +4 -2
  105. mindspore/dataset/transforms/c_transforms.py +11 -13
  106. mindspore/dataset/transforms/py_transforms.py +2 -2
  107. mindspore/dataset/transforms/py_transforms_util.py +10 -0
  108. mindspore/dataset/transforms/transforms.py +13 -15
  109. mindspore/dataset/transforms/validators.py +7 -7
  110. mindspore/dataset/utils/__init__.py +2 -1
  111. mindspore/dataset/utils/browse_dataset.py +13 -13
  112. mindspore/dataset/utils/line_reader.py +121 -0
  113. mindspore/dataset/vision/__init__.py +8 -7
  114. mindspore/dataset/vision/c_transforms.py +125 -126
  115. mindspore/dataset/vision/py_transforms.py +37 -37
  116. mindspore/dataset/vision/py_transforms_util.py +23 -20
  117. mindspore/dataset/vision/transforms.py +316 -315
  118. mindspore/dataset/vision/utils.py +313 -17
  119. mindspore/dataset/vision/validators.py +6 -6
  120. mindspore/default_config.py +0 -1
  121. mindspore/dpcmi.dll +0 -0
  122. mindspore/{compression → experimental}/__init__.py +6 -5
  123. mindspore/experimental/map_parameter.py +275 -0
  124. mindspore/include/OWNERS +0 -1
  125. mindspore/include/api/callback/callback.h +9 -13
  126. mindspore/include/api/callback/ckpt_saver.h +2 -2
  127. mindspore/include/api/callback/loss_monitor.h +2 -2
  128. mindspore/include/api/callback/lr_scheduler.h +5 -5
  129. mindspore/include/api/callback/time_monitor.h +2 -2
  130. mindspore/include/api/callback/train_accuracy.h +4 -6
  131. mindspore/include/api/cfg.h +19 -6
  132. mindspore/include/api/context.h +70 -9
  133. mindspore/include/api/delegate.h +8 -1
  134. mindspore/include/api/dual_abi_helper.h +8 -24
  135. mindspore/include/api/metrics/accuracy.h +2 -2
  136. mindspore/include/api/metrics/metrics.h +4 -3
  137. mindspore/include/api/model.h +9 -4
  138. mindspore/include/api/model_group.h +68 -0
  139. mindspore/include/api/model_parallel_runner.h +17 -17
  140. mindspore/include/api/net.h +12 -11
  141. mindspore/include/api/serialization.h +20 -4
  142. mindspore/include/api/status.h +7 -1
  143. mindspore/include/api/types.h +25 -21
  144. mindspore/include/api/visible.h +4 -0
  145. mindspore/include/c_api/model_c.h +5 -0
  146. mindspore/include/c_api/status_c.h +1 -1
  147. mindspore/include/dataset/config.h +1 -1
  148. mindspore/include/dataset/constants.h +14 -0
  149. mindspore/include/dataset/text.h +59 -0
  150. mindspore/include/dataset/vision.h +56 -117
  151. mindspore/include/dataset/vision_lite.h +102 -0
  152. mindspore/jpeg62.dll +0 -0
  153. mindspore/log.py +28 -28
  154. mindspore/mindrecord/common/exceptions.py +2 -4
  155. mindspore/mindrecord/filereader.py +19 -1
  156. mindspore/mindrecord/filewriter.py +250 -88
  157. mindspore/mindrecord/mindpage.py +13 -13
  158. mindspore/mindrecord/shardheader.py +15 -15
  159. mindspore/mindrecord/shardreader.py +9 -0
  160. mindspore/mindrecord/shardwriter.py +29 -29
  161. mindspore/mindrecord/tools/cifar100_to_mr.py +9 -9
  162. mindspore/mindrecord/tools/cifar10_to_mr.py +9 -9
  163. mindspore/mindrecord/tools/csv_to_mr.py +4 -4
  164. mindspore/mindrecord/tools/imagenet_to_mr.py +70 -65
  165. mindspore/mindrecord/tools/mnist_to_mr.py +41 -41
  166. mindspore/mindrecord/tools/tfrecord_to_mr.py +6 -6
  167. mindspore/{libmindspore_backend.dll → mindspore_backend.dll} +0 -0
  168. mindspore/mindspore_common.dll +0 -0
  169. mindspore/mindspore_core.dll +0 -0
  170. mindspore/mindspore_glog.dll +0 -0
  171. mindspore/mindspore_shared_lib.dll +0 -0
  172. mindspore/msobj140.dll +0 -0
  173. mindspore/mspdb140.dll +0 -0
  174. mindspore/mspdbcore.dll +0 -0
  175. mindspore/mspdbst.dll +0 -0
  176. mindspore/mspft140.dll +0 -0
  177. mindspore/msvcdis140.dll +0 -0
  178. mindspore/msvcp140_1.dll +0 -0
  179. mindspore/msvcp140_2.dll +0 -0
  180. mindspore/msvcp140_atomic_wait.dll +0 -0
  181. mindspore/msvcp140_codecvt_ids.dll +0 -0
  182. mindspore/nn/__init__.py +1 -5
  183. mindspore/nn/cell.py +297 -234
  184. mindspore/nn/dynamic_lr.py +1 -1
  185. mindspore/nn/grad/cell_grad.py +17 -42
  186. mindspore/nn/layer/__init__.py +7 -4
  187. mindspore/nn/layer/activation.py +131 -88
  188. mindspore/nn/layer/basic.py +313 -613
  189. mindspore/nn/layer/channel_shuffle.py +103 -0
  190. mindspore/nn/layer/combined.py +1 -1
  191. mindspore/nn/layer/container.py +52 -6
  192. mindspore/nn/layer/conv.py +112 -43
  193. mindspore/nn/layer/dense.py +10 -9
  194. mindspore/nn/layer/embedding.py +36 -34
  195. mindspore/nn/layer/image.py +123 -27
  196. mindspore/nn/layer/math.py +108 -107
  197. mindspore/nn/layer/normalization.py +212 -366
  198. mindspore/nn/layer/padding.py +370 -42
  199. mindspore/nn/layer/pooling.py +1443 -219
  200. mindspore/nn/layer/rnn_cells.py +11 -16
  201. mindspore/nn/layer/rnns.py +38 -39
  202. mindspore/nn/layer/thor_layer.py +24 -25
  203. mindspore/nn/layer/timedistributed.py +5 -5
  204. mindspore/nn/layer/transformer.py +701 -0
  205. mindspore/nn/learning_rate_schedule.py +8 -8
  206. mindspore/nn/loss/__init__.py +9 -6
  207. mindspore/nn/loss/loss.py +678 -142
  208. mindspore/nn/metrics.py +53 -0
  209. mindspore/nn/optim/_dist_optimizer_registry.py +2 -2
  210. mindspore/nn/optim/ada_grad.py +8 -8
  211. mindspore/nn/optim/adadelta.py +2 -3
  212. mindspore/nn/optim/adafactor.py +18 -14
  213. mindspore/nn/optim/adam.py +429 -87
  214. mindspore/nn/optim/adamax.py +5 -6
  215. mindspore/nn/optim/adasum.py +10 -8
  216. mindspore/nn/optim/asgd.py +7 -7
  217. mindspore/nn/optim/ftrl.py +81 -11
  218. mindspore/nn/optim/lamb.py +7 -8
  219. mindspore/nn/optim/lars.py +4 -4
  220. mindspore/nn/optim/lazyadam.py +82 -7
  221. mindspore/nn/optim/momentum.py +8 -7
  222. mindspore/nn/optim/optimizer.py +19 -10
  223. mindspore/nn/optim/proximal_ada_grad.py +6 -5
  224. mindspore/nn/optim/rmsprop.py +3 -3
  225. mindspore/nn/optim/rprop.py +20 -16
  226. mindspore/nn/optim/sgd.py +21 -15
  227. mindspore/nn/optim/thor.py +23 -21
  228. mindspore/nn/probability/__init__.py +0 -2
  229. mindspore/nn/probability/bijector/bijector.py +7 -6
  230. mindspore/nn/probability/bijector/invert.py +4 -2
  231. mindspore/nn/probability/bijector/softplus.py +2 -2
  232. mindspore/nn/probability/bnn_layers/dense_variational.py +1 -1
  233. mindspore/nn/probability/bnn_layers/layer_distribution.py +2 -2
  234. mindspore/nn/probability/distribution/__init__.py +6 -0
  235. mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -2
  236. mindspore/nn/probability/distribution/_utils/utils.py +11 -17
  237. mindspore/nn/probability/distribution/bernoulli.py +6 -6
  238. mindspore/nn/probability/distribution/beta.py +1 -1
  239. mindspore/nn/probability/distribution/categorical.py +9 -9
  240. mindspore/nn/probability/distribution/cauchy.py +8 -8
  241. mindspore/nn/probability/distribution/distribution.py +12 -6
  242. mindspore/nn/probability/distribution/exponential.py +5 -5
  243. mindspore/nn/probability/distribution/gamma.py +3 -3
  244. mindspore/nn/probability/distribution/geometric.py +6 -5
  245. mindspore/nn/probability/distribution/gumbel.py +5 -5
  246. mindspore/nn/probability/distribution/half_normal.py +133 -0
  247. mindspore/nn/probability/distribution/laplace.py +128 -0
  248. mindspore/nn/probability/distribution/log_normal.py +0 -1
  249. mindspore/nn/probability/distribution/logistic.py +4 -5
  250. mindspore/nn/probability/distribution/normal.py +11 -15
  251. mindspore/nn/probability/distribution/poisson.py +6 -2
  252. mindspore/nn/probability/distribution/student_t.py +150 -0
  253. mindspore/nn/probability/distribution/transformed_distribution.py +4 -4
  254. mindspore/nn/probability/distribution/uniform.py +5 -5
  255. mindspore/nn/reinforcement/_tensors_queue.py +3 -3
  256. mindspore/nn/reinforcement/tensor_array.py +2 -2
  257. mindspore/nn/sparse/sparse.py +8 -1
  258. mindspore/nn/wrap/cell_wrapper.py +55 -27
  259. mindspore/nn/wrap/grad_reducer.py +20 -11
  260. mindspore/nn/wrap/loss_scale.py +47 -30
  261. mindspore/numpy/array_creations.py +33 -22
  262. mindspore/numpy/array_ops.py +46 -42
  263. mindspore/numpy/logic_ops.py +6 -27
  264. mindspore/numpy/math_ops.py +26 -19
  265. mindspore/numpy/utils.py +1 -8
  266. mindspore/numpy/utils_const.py +112 -62
  267. mindspore/opencv_core452.dll +0 -0
  268. mindspore/opencv_imgcodecs452.dll +0 -0
  269. mindspore/opencv_imgproc452.dll +0 -0
  270. mindspore/ops/__init__.py +6 -3
  271. mindspore/ops/_constants.py +0 -6
  272. mindspore/ops/_grad/__init__.py +2 -1
  273. mindspore/ops/_grad/grad_array_ops.py +209 -152
  274. mindspore/ops/_grad/grad_base.py +55 -17
  275. mindspore/ops/_grad/grad_clip_ops.py +11 -3
  276. mindspore/ops/_grad/grad_comm_ops.py +58 -47
  277. mindspore/ops/_grad/grad_implementations.py +21 -61
  278. mindspore/ops/_grad/grad_inner_ops.py +48 -6
  279. mindspore/ops/_grad/grad_math_ops.py +306 -161
  280. mindspore/ops/_grad/grad_nn_ops.py +192 -181
  281. mindspore/ops/_grad/grad_other_ops.py +1 -1
  282. mindspore/ops/_grad/grad_quant_ops.py +5 -5
  283. mindspore/ops/_grad/grad_sequence_ops.py +296 -0
  284. mindspore/ops/_grad/grad_sparse.py +15 -9
  285. mindspore/ops/_grad_experimental/__init__.py +1 -0
  286. mindspore/ops/_grad_experimental/grad_array_ops.py +441 -55
  287. mindspore/ops/_grad_experimental/grad_image_ops.py +25 -7
  288. mindspore/ops/_grad_experimental/grad_inner_ops.py +3 -44
  289. mindspore/ops/_grad_experimental/grad_linalg_ops.py +16 -21
  290. mindspore/ops/_grad_experimental/grad_math_ops.py +979 -49
  291. mindspore/ops/_grad_experimental/grad_nn_ops.py +78 -8
  292. mindspore/ops/_grad_experimental/grad_scalar_ops.py +112 -0
  293. mindspore/ops/_grad_experimental/grad_sparse_ops.py +197 -13
  294. mindspore/ops/_op_impl/__init__.py +3 -3
  295. mindspore/ops/_op_impl/_custom_op/__init__.py +0 -1
  296. mindspore/ops/_op_impl/_custom_op/_basic.py +0 -1
  297. mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +1 -1
  298. mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +4 -2
  299. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +2 -2
  300. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +2 -2
  301. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +5 -5
  302. mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +3 -3
  303. mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +1 -1
  304. mindspore/ops/_op_impl/_custom_op/correction_mul.py +3 -3
  305. mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +2 -2
  306. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +4 -8
  307. mindspore/ops/_op_impl/_custom_op/dsd_impl.py +1 -1
  308. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +2 -2
  309. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +2 -2
  310. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +2 -2
  311. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +2 -2
  312. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +2 -2
  313. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +2 -2
  314. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +2 -2
  315. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +2 -2
  316. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +2 -2
  317. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +2 -2
  318. mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +1 -1
  319. mindspore/ops/_op_impl/_custom_op/img2col_impl.py +1 -1
  320. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
  321. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +1 -1
  322. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +1 -1
  323. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +1 -1
  324. mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +2 -2
  325. mindspore/ops/_op_impl/_custom_op/matmul_dds_grad_impl.py +0 -1
  326. mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +0 -1
  327. mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +1 -1
  328. mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +2 -2
  329. mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +2 -2
  330. mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +1 -1
  331. mindspore/ops/_op_impl/aicpu/__init__.py +238 -3
  332. mindspore/ops/_op_impl/aicpu/abs.py +36 -0
  333. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d.py +34 -0
  334. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
  335. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d.py +39 -0
  336. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d_grad.py +39 -0
  337. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d_grad.py +37 -0
  338. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d.py +42 -0
  339. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d_grad.py +152 -0
  340. mindspore/ops/_op_impl/aicpu/add.py +43 -0
  341. mindspore/ops/_op_impl/aicpu/addcdiv.py +0 -32
  342. mindspore/ops/_op_impl/aicpu/addcmul.py +0 -84
  343. mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
  344. mindspore/ops/_op_impl/aicpu/arg_max.py +75 -0
  345. mindspore/ops/_op_impl/aicpu/arg_min.py +75 -0
  346. mindspore/ops/_op_impl/aicpu/argmin_with_value.py +43 -0
  347. mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -0
  348. mindspore/ops/_op_impl/aicpu/batch_norm_grad_grad.py +49 -0
  349. mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
  350. mindspore/ops/_op_impl/aicpu/bessel_i0.py +31 -0
  351. mindspore/ops/_op_impl/aicpu/bias_add.py +44 -0
  352. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +43 -0
  353. mindspore/ops/_op_impl/aicpu/bincount.py +33 -0
  354. mindspore/{nn/probability/infer/variational/__init__.py → ops/_op_impl/aicpu/cauchy.py} +17 -10
  355. mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
  356. mindspore/ops/_op_impl/aicpu/cholesky.py +1 -1
  357. mindspore/ops/_op_impl/{cpu/bias_add.py → aicpu/choleskygrad.py} +9 -7
  358. mindspore/ops/_op_impl/aicpu/combined_non_max_suppression.py +42 -0
  359. mindspore/ops/_op_impl/aicpu/concat_offset.py +42 -0
  360. mindspore/ops/_op_impl/aicpu/concat_offset_v1.py +31 -0
  361. mindspore/ops/_op_impl/aicpu/conj.py +11 -0
  362. mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_image.py +38 -0
  363. mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +36 -0
  364. mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
  365. mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +2 -2
  366. mindspore/ops/_op_impl/aicpu/dense_to_sparse_set_operation.py +48 -0
  367. mindspore/ops/_op_impl/aicpu/diag.py +36 -0
  368. mindspore/ops/_op_impl/aicpu/diag_part.py +36 -0
  369. mindspore/ops/_op_impl/aicpu/diagonal.py +35 -0
  370. mindspore/ops/_op_impl/{cpu/bias_add_grad.py → aicpu/digamma.py} +9 -7
  371. mindspore/ops/_op_impl/aicpu/eig.py +35 -0
  372. mindspore/ops/_op_impl/aicpu/fft_with_size.py +41 -0
  373. mindspore/ops/_op_impl/aicpu/flatten.py +1 -0
  374. mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
  375. mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
  376. mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +1 -1
  377. mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
  378. mindspore/ops/_op_impl/aicpu/glu.py +33 -0
  379. mindspore/ops/_op_impl/aicpu/glu_grad.py +34 -0
  380. mindspore/ops/_op_impl/aicpu/greater.py +41 -0
  381. mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
  382. mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
  383. mindspore/ops/_op_impl/{tbe/scatter_add_ds.py → aicpu/inplace_index_add.py} +17 -21
  384. mindspore/ops/_op_impl/aicpu/instance_norm_v2.py +41 -0
  385. mindspore/ops/_op_impl/aicpu/instance_norm_v2_grad.py +44 -0
  386. mindspore/ops/_op_impl/aicpu/layer_norm_grad_grad.py +47 -0
  387. mindspore/ops/_op_impl/aicpu/less.py +41 -0
  388. mindspore/ops/_op_impl/aicpu/less_equal.py +41 -0
  389. mindspore/ops/_op_impl/aicpu/lgamma.py +32 -0
  390. mindspore/ops/_op_impl/aicpu/log_normal_reverse.py +33 -0
  391. mindspore/ops/_op_impl/aicpu/logit.py +33 -0
  392. mindspore/ops/_op_impl/aicpu/logit_grad.py +34 -0
  393. mindspore/ops/_op_impl/aicpu/masked_fill.py +42 -0
  394. mindspore/ops/_op_impl/aicpu/masked_scatter.py +39 -0
  395. mindspore/ops/_op_impl/aicpu/matmul.py +39 -0
  396. mindspore/ops/_op_impl/aicpu/matrix_logarithm.py +31 -0
  397. mindspore/ops/_op_impl/aicpu/matrix_power.py +32 -0
  398. mindspore/ops/_op_impl/aicpu/matrix_solve_ls.py +36 -0
  399. mindspore/ops/_op_impl/aicpu/matrix_triangular_solve.py +36 -0
  400. mindspore/ops/_op_impl/aicpu/mirror_pad.py +2 -0
  401. mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +0 -4
  402. mindspore/ops/_op_impl/aicpu/mul.py +3 -1
  403. mindspore/ops/_op_impl/aicpu/multinomial.py +14 -6
  404. mindspore/ops/_op_impl/aicpu/multinomial_with_replacement.py +35 -0
  405. mindspore/ops/_op_impl/aicpu/nan_to_num.py +34 -0
  406. mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
  407. mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
  408. mindspore/ops/_op_impl/aicpu/ones_like.py +0 -2
  409. mindspore/ops/_op_impl/aicpu/polar.py +32 -0
  410. mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
  411. mindspore/ops/_op_impl/aicpu/qr.py +36 -0
  412. mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
  413. mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
  414. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
  415. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_tensor.py +74 -0
  416. mindspore/ops/_op_impl/aicpu/random_shuffle.py +3 -0
  417. mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
  418. mindspore/ops/_op_impl/aicpu/range.py +36 -0
  419. mindspore/ops/_op_impl/aicpu/reciprocal.py +34 -0
  420. mindspore/ops/_op_impl/aicpu/reciprocal_grad.py +35 -0
  421. mindspore/ops/_op_impl/aicpu/reduce_sum.py +57 -0
  422. mindspore/ops/_op_impl/aicpu/resize_bicubic.py +2 -8
  423. mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +1 -1
  424. mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
  425. mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
  426. mindspore/ops/_op_impl/aicpu/scatter_elements.py +4 -0
  427. mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +2 -0
  428. mindspore/ops/_op_impl/aicpu/search_sorted.py +12 -6
  429. mindspore/ops/_op_impl/aicpu/self_adjoint_eig.py +34 -0
  430. mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
  431. mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
  432. mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
  433. mindspore/ops/_op_impl/aicpu/slice_grad.py +76 -0
  434. mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
  435. mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
  436. mindspore/ops/_op_impl/aicpu/sort.py +39 -0
  437. mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +0 -24
  438. mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
  439. mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows.py +63 -0
  440. mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows_grad.py +45 -0
  441. mindspore/ops/_op_impl/aicpu/sparse_matrix_mat_mul.py +56 -0
  442. mindspore/ops/_op_impl/{tbe/slice_ds.py → aicpu/sparse_segment_sum.py} +16 -24
  443. mindspore/ops/_op_impl/aicpu/sparse_segment_sum_with_num_segments.py +68 -0
  444. mindspore/ops/_op_impl/aicpu/sparse_slice.py +63 -0
  445. mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +61 -0
  446. mindspore/ops/_op_impl/aicpu/squared_difference.py +2 -0
  447. mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +93 -0
  448. mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +66 -0
  449. mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
  450. mindspore/ops/_op_impl/{tbe/gather_v2.py → aicpu/tile.py} +24 -24
  451. mindspore/ops/_op_impl/aicpu/tridiagonal_solve.py +35 -0
  452. mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
  453. mindspore/ops/_op_impl/aicpu/triu_indices.py +34 -0
  454. mindspore/ops/_op_impl/aicpu/uniform.py +34 -0
  455. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +1 -0
  456. mindspore/ops/_op_impl/aicpu/unique_consecutive.py +10 -2
  457. mindspore/ops/_op_impl/cpu/__init__.py +1 -2
  458. mindspore/ops/_op_impl/cpu/dynamic_shape.py +5 -1
  459. mindspore/ops/_op_impl/cpu/maximum_grad.py +2 -0
  460. mindspore/{compression/common/__init__.py → ops/_op_impl/cpu/pyexecute.py} +13 -8
  461. mindspore/ops/_op_impl/cpu/reduce_sum.py +8 -0
  462. mindspore/ops/_op_impl/cpu/sparse_slice.py +62 -0
  463. mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +60 -0
  464. mindspore/ops/_op_impl/cpu/tensor_shape.py +5 -1
  465. mindspore/ops/_op_impl/tbe/__init__.py +27 -608
  466. mindspore/ops/_op_impl/tbe/addcdiv_ds.py +42 -0
  467. mindspore/ops/_op_impl/tbe/addcmul_ds.py +44 -0
  468. mindspore/ops/_op_impl/tbe/assign_add_ds.py +1 -0
  469. mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
  470. mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +1 -1
  471. mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad_v2.py +0 -1
  472. mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
  473. mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +1 -1
  474. mindspore/ops/_op_impl/tbe/batch_to_space_nd_v2.py +41 -0
  475. mindspore/ops/_op_impl/tbe/bce_with_logits_loss.py +1 -0
  476. mindspore/ops/_op_impl/tbe/bias_add_grad.py +2 -0
  477. mindspore/ops/_op_impl/tbe/bn_infer_grad.py +4 -2
  478. mindspore/ops/_op_impl/tbe/bn_infer_grad_ds.py +40 -0
  479. mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -1
  480. mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -1
  481. mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +6 -4
  482. mindspore/ops/_op_impl/tbe/cast.py +0 -2
  483. mindspore/ops/_op_impl/tbe/cast_ds.py +3 -3
  484. mindspore/ops/_op_impl/tbe/ctc_loss_v2.py +0 -2
  485. mindspore/ops/_op_impl/tbe/ctc_loss_v2_grad.py +0 -2
  486. mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +1 -0
  487. mindspore/ops/_op_impl/tbe/deformable_offsets.py +1 -0
  488. mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +1 -1
  489. mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +1 -1
  490. mindspore/ops/_op_impl/tbe/gather_nd.py +1 -0
  491. mindspore/ops/_op_impl/tbe/greater.py +2 -0
  492. mindspore/ops/_op_impl/tbe/{index_add.py → inplace_index_add.py} +3 -6
  493. mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2.py +0 -1
  494. mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +35 -0
  495. mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +35 -0
  496. mindspore/ops/_op_impl/tbe/one_hot_ds.py +0 -6
  497. mindspore/ops/_op_impl/tbe/{greater_ds.py → reduce_all_ds.py} +13 -16
  498. mindspore/ops/_op_impl/tbe/reduce_any_ds.py +39 -0
  499. mindspore/ops/_op_impl/tbe/roi_align_ds.py +44 -0
  500. mindspore/ops/_op_impl/tbe/roi_align_grad_ds.py +44 -0
  501. mindspore/ops/_op_impl/tbe/scatter_add.py +2 -0
  502. mindspore/ops/_op_impl/tbe/scatter_nd_add.py +2 -2
  503. mindspore/ops/_op_impl/tbe/slice.py +26 -15
  504. mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
  505. mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +1 -1
  506. mindspore/ops/_op_impl/tbe/strided_slice_grad_d.py +1 -0
  507. mindspore/ops/_op_impl/tbe/trans_data_ds.py +15 -5
  508. mindspore/ops/_op_impl/tbe/unsorted_segment_sum.py +1 -1
  509. mindspore/ops/_op_impl/tbe/unsorted_segment_sum_ds.py +2 -0
  510. mindspore/ops/_primitive_cache.py +3 -2
  511. mindspore/ops/_register_for_op.py +11 -0
  512. mindspore/ops/_utils/__init__.py +1 -1
  513. mindspore/ops/_utils/utils.py +20 -41
  514. mindspore/ops/_vmap/__init__.py +2 -2
  515. mindspore/ops/_vmap/vmap_array_ops.py +170 -78
  516. mindspore/ops/_vmap/vmap_base.py +24 -10
  517. mindspore/ops/_vmap/vmap_convolution_ops.py +7 -10
  518. mindspore/ops/_vmap/vmap_grad_math_ops.py +4 -4
  519. mindspore/ops/_vmap/vmap_grad_nn_ops.py +41 -9
  520. mindspore/ops/_vmap/vmap_image_ops.py +52 -0
  521. mindspore/ops/_vmap/vmap_math_ops.py +77 -6
  522. mindspore/ops/_vmap/vmap_nn_ops.py +78 -29
  523. mindspore/ops/_vmap/vmap_other_ops.py +3 -1
  524. mindspore/ops/_vmap/vmap_random_ops.py +55 -3
  525. mindspore/ops/_vmap/vmap_sparse_ops.py +1 -0
  526. mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
  527. mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
  528. mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +18 -19
  529. mindspore/ops/bprop_mindir/Argmax_bprop.mindir +13 -12
  530. mindspore/ops/bprop_mindir/Argmin_bprop.mindir +14 -13
  531. mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +17 -18
  532. mindspore/ops/bprop_mindir/Assign_bprop.mindir +16 -16
  533. mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +150 -0
  534. mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +66 -0
  535. mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
  536. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +13 -12
  537. mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
  538. mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +28 -0
  539. mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
  540. mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +33 -0
  541. mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +306 -0
  542. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +12 -8
  543. mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
  544. mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
  545. mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +240 -0
  546. mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +247 -0
  547. mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +247 -0
  548. mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +315 -0
  549. mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +278 -0
  550. mindspore/ops/bprop_mindir/DType_bprop.mindir +12 -12
  551. mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +58 -0
  552. mindspore/ops/bprop_mindir/Depend_bprop.mindir +12 -13
  553. mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +23 -0
  554. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +138 -0
  555. mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +15 -0
  556. mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
  557. mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
  558. mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +22 -24
  559. mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +16 -14
  560. mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +27 -0
  561. mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
  562. mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
  563. mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
  564. mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +12 -12
  565. mindspore/ops/bprop_mindir/Elu_bprop.mindir +16 -0
  566. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  567. mindspore/ops/bprop_mindir/Equal_bprop.mindir +18 -19
  568. mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +58 -0
  569. mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +16 -0
  570. mindspore/ops/bprop_mindir/Flatten_bprop.mindir +54 -0
  571. mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +18 -15
  572. mindspore/ops/bprop_mindir/GatherD_bprop.mindir +26 -0
  573. mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +57 -0
  574. mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
  575. mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +17 -18
  576. mindspore/ops/bprop_mindir/Greater_bprop.mindir +18 -19
  577. mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +16 -0
  578. mindspore/ops/bprop_mindir/HSwish_bprop.mindir +16 -0
  579. mindspore/ops/bprop_mindir/IOU_bprop.mindir +18 -19
  580. mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
  581. mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +13 -12
  582. mindspore/ops/bprop_mindir/IsInf_bprop.mindir +13 -10
  583. mindspore/ops/bprop_mindir/IsNan_bprop.mindir +14 -11
  584. mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +126 -0
  585. mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +15 -0
  586. mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +30 -0
  587. mindspore/ops/bprop_mindir/LRN_bprop.mindir +43 -0
  588. mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
  589. mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +18 -19
  590. mindspore/ops/bprop_mindir/Less_bprop.mindir +17 -18
  591. mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +22 -19
  592. mindspore/ops/bprop_mindir/Load_bprop.mindir +12 -13
  593. mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +23 -0
  594. mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +17 -18
  595. mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +14 -13
  596. mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +21 -0
  597. mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +74 -0
  598. mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +74 -0
  599. mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +75 -0
  600. mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +65 -0
  601. mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
  602. mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
  603. mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
  604. mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +27 -0
  605. mindspore/ops/bprop_mindir/Mish_bprop.mindir +35 -0
  606. mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
  607. mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
  608. mindspore/ops/bprop_mindir/NonZero_bprop.mindir +14 -0
  609. mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +18 -19
  610. mindspore/ops/bprop_mindir/OneHot_bprop.mindir +25 -23
  611. mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +13 -13
  612. mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
  613. mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
  614. mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
  615. mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +29 -0
  616. mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +82 -0
  617. mindspore/ops/bprop_mindir/Range_bprop.mindir +21 -19
  618. mindspore/ops/bprop_mindir/Rank_bprop.mindir +11 -11
  619. mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +16 -0
  620. mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
  621. mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +18 -17
  622. mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +18 -17
  623. mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +19 -23
  624. mindspore/ops/bprop_mindir/Reshape_bprop.mindir +60 -0
  625. mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +29 -0
  626. mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +89 -0
  627. mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +52 -0
  628. mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +22 -0
  629. mindspore/ops/bprop_mindir/Round_bprop.mindir +14 -13
  630. mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
  631. mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
  632. mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +22 -0
  633. mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +24 -0
  634. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +22 -0
  635. mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
  636. mindspore/ops/bprop_mindir/SeLU_bprop.mindir +21 -0
  637. mindspore/ops/bprop_mindir/Select_bprop.mindir +30 -34
  638. mindspore/ops/bprop_mindir/Shape_bprop.mindir +12 -12
  639. mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +21 -0
  640. mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
  641. mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +16 -0
  642. mindspore/ops/bprop_mindir/Sign_bprop.mindir +13 -12
  643. mindspore/ops/bprop_mindir/Slice_bprop.mindir +26 -0
  644. mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +36 -0
  645. mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  646. mindspore/ops/bprop_mindir/Softplus_bprop.mindir +16 -0
  647. mindspore/ops/bprop_mindir/Softsign_bprop.mindir +33 -0
  648. mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
  649. mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +28 -0
  650. mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +23 -0
  651. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  652. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  653. mindspore/ops/bprop_mindir/Split_bprop.mindir +22 -0
  654. mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +54 -0
  655. mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +95 -0
  656. mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +98 -0
  657. mindspore/ops/bprop_mindir/Switch_bprop.mindir +28 -32
  658. mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
  659. mindspore/ops/bprop_mindir/Tanh_bprop.mindir +66 -0
  660. mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +22 -0
  661. mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +29 -0
  662. mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +14 -0
  663. mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
  664. mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
  665. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +23 -0
  666. mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +18 -15
  667. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +11 -13
  668. mindspore/ops/bprop_mindir/Unique_bprop.mindir +16 -0
  669. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +22 -0
  670. mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +32 -0
  671. mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +38 -0
  672. mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +13 -12
  673. mindspore/ops/bprop_mindir/__init__.py +1 -4
  674. mindspore/ops/bprop_mindir/generate_mindir.py +32 -20
  675. mindspore/ops/composite/__init__.py +12 -13
  676. mindspore/ops/composite/base.py +261 -254
  677. mindspore/ops/composite/env_ops.py +41 -0
  678. mindspore/ops/composite/math_ops.py +197 -156
  679. mindspore/ops/composite/multitype_ops/_compile_utils.py +428 -176
  680. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +188 -87
  681. mindspore/ops/composite/multitype_ops/add_impl.py +23 -1
  682. mindspore/ops/composite/multitype_ops/div_impl.py +3 -3
  683. mindspore/ops/composite/multitype_ops/equal_impl.py +1 -0
  684. mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -1
  685. mindspore/ops/composite/multitype_ops/getitem_impl.py +52 -5
  686. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +31 -0
  687. mindspore/ops/composite/multitype_ops/greater_impl.py +31 -0
  688. mindspore/ops/composite/multitype_ops/in_impl.py +15 -3
  689. mindspore/ops/composite/multitype_ops/less_equal_impl.py +33 -2
  690. mindspore/ops/composite/multitype_ops/less_impl.py +33 -0
  691. mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -2
  692. mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
  693. mindspore/ops/composite/multitype_ops/mod_impl.py +1 -1
  694. mindspore/ops/composite/multitype_ops/mul_impl.py +21 -7
  695. mindspore/ops/composite/multitype_ops/not_in_impl.py +15 -3
  696. mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -4
  697. mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
  698. mindspore/ops/composite/multitype_ops/setitem_impl.py +62 -70
  699. mindspore/ops/composite/multitype_ops/sub_impl.py +3 -3
  700. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +41 -4
  701. mindspore/ops/function/__init__.py +323 -8
  702. mindspore/ops/function/array_func.py +3511 -780
  703. mindspore/ops/function/clip_func.py +329 -0
  704. mindspore/ops/function/debug_func.py +6 -6
  705. mindspore/ops/function/grad/__init__.py +5 -1
  706. mindspore/ops/function/grad/grad_func.py +736 -65
  707. mindspore/ops/function/image_func.py +270 -0
  708. mindspore/ops/function/linalg_func.py +268 -8
  709. mindspore/ops/function/math_func.py +8032 -3164
  710. mindspore/ops/function/nn_func.py +5619 -1855
  711. mindspore/ops/function/other_func.py +115 -0
  712. mindspore/ops/function/parameter_func.py +11 -10
  713. mindspore/ops/function/random_func.py +939 -77
  714. mindspore/ops/function/sparse_func.py +249 -84
  715. mindspore/ops/function/sparse_unary_func.py +2303 -0
  716. mindspore/ops/function/spectral_func.py +146 -0
  717. mindspore/ops/function/vmap_func.py +114 -0
  718. mindspore/ops/functional.py +182 -254
  719. mindspore/ops/op_info_register.py +79 -34
  720. mindspore/ops/operations/__init__.py +210 -118
  721. mindspore/ops/operations/_csr_ops.py +7 -7
  722. mindspore/ops/operations/_embedding_cache_ops.py +25 -15
  723. mindspore/ops/operations/_grad_ops.py +447 -322
  724. mindspore/ops/operations/_inner_ops.py +547 -176
  725. mindspore/ops/operations/_map_tensor_ops.py +112 -0
  726. mindspore/ops/operations/_ms_kernel.py +29 -27
  727. mindspore/ops/operations/_ocr_ops.py +11 -11
  728. mindspore/ops/operations/_opaque_predicate_registry.py +41 -0
  729. mindspore/ops/operations/_quant_ops.py +186 -101
  730. mindspore/ops/operations/_rl_inner_ops.py +122 -61
  731. mindspore/ops/operations/_scalar_ops.py +466 -0
  732. mindspore/ops/operations/_sequence_ops.py +1047 -0
  733. mindspore/ops/operations/_tensor_array.py +10 -11
  734. mindspore/ops/operations/_thor_ops.py +4 -4
  735. mindspore/ops/operations/array_ops.py +1428 -1226
  736. mindspore/ops/operations/comm_ops.py +180 -117
  737. mindspore/ops/operations/control_ops.py +4 -2
  738. mindspore/ops/operations/custom_ops.py +185 -98
  739. mindspore/ops/operations/debug_ops.py +92 -54
  740. mindspore/ops/operations/image_ops.py +406 -211
  741. mindspore/ops/operations/inner_ops.py +42 -53
  742. mindspore/ops/operations/linalg_ops.py +32 -29
  743. mindspore/ops/operations/math_ops.py +2076 -897
  744. mindspore/ops/operations/nn_ops.py +1282 -1252
  745. mindspore/ops/operations/other_ops.py +124 -278
  746. mindspore/ops/operations/random_ops.py +345 -178
  747. mindspore/ops/operations/rl_ops.py +8 -9
  748. mindspore/ops/operations/sparse_ops.py +502 -157
  749. mindspore/ops/operations/spectral_ops.py +107 -0
  750. mindspore/ops/primitive.py +192 -15
  751. mindspore/ops/vm_impl_registry.py +23 -2
  752. mindspore/parallel/__init__.py +6 -1
  753. mindspore/parallel/_auto_parallel_context.py +199 -92
  754. mindspore/parallel/_cell_wrapper.py +4 -2
  755. mindspore/parallel/_cost_model_context.py +3 -0
  756. mindspore/parallel/_dp_allreduce_fusion.py +2 -1
  757. mindspore/parallel/_offload_context.py +185 -0
  758. mindspore/parallel/_parallel_serialization.py +167 -28
  759. mindspore/parallel/_ps_context.py +9 -5
  760. mindspore/parallel/_recovery_context.py +1 -1
  761. mindspore/parallel/_tensor.py +9 -1
  762. mindspore/{nn/transformer → parallel/_transformer}/__init__.py +6 -6
  763. mindspore/{nn/transformer → parallel/_transformer}/layers.py +59 -37
  764. mindspore/{nn/transformer → parallel/_transformer}/loss.py +4 -7
  765. mindspore/{nn/transformer → parallel/_transformer}/moe.py +160 -35
  766. mindspore/{nn/transformer → parallel/_transformer}/op_parallel_config.py +3 -3
  767. mindspore/{nn/transformer → parallel/_transformer}/transformer.py +235 -196
  768. mindspore/parallel/_utils.py +47 -7
  769. mindspore/parallel/algo_parameter_config.py +5 -1
  770. mindspore/parallel/checkpoint_transform.py +329 -0
  771. mindspore/parallel/shard.py +229 -0
  772. mindspore/perf_msvcbuildinsights.dll +0 -0
  773. mindspore/pgodb140.dll +0 -0
  774. mindspore/pgort140.dll +0 -0
  775. mindspore/profiler/__init__.py +2 -1
  776. mindspore/profiler/common/util.py +4 -3
  777. mindspore/profiler/common/validator/validate_path.py +2 -2
  778. mindspore/profiler/envprofiling.py +249 -0
  779. mindspore/profiler/parser/aicpu_data_parser.py +38 -39
  780. mindspore/profiler/parser/ascend_timeline_generator.py +497 -0
  781. mindspore/profiler/parser/base_timeline_generator.py +471 -0
  782. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +684 -0
  783. mindspore/profiler/parser/framework_parser.py +42 -16
  784. mindspore/profiler/parser/hccl_parser.py +158 -158
  785. mindspore/profiler/parser/hwts_log_parser.py +7 -6
  786. mindspore/profiler/parser/integrator.py +18 -1579
  787. mindspore/profiler/parser/minddata_analyzer.py +8 -8
  788. mindspore/profiler/parser/msadvisor_analyzer.py +14 -27
  789. mindspore/profiler/parser/msadvisor_parser.py +2 -4
  790. mindspore/profiler/parser/optime_parser.py +17 -18
  791. mindspore/profiler/parser/profiler_info.py +108 -0
  792. mindspore/profiler/parser/step_trace_parser.py +1 -1
  793. mindspore/profiler/profiling.py +396 -194
  794. mindspore/rewrite/__init__.py +6 -2
  795. mindspore/rewrite/api/node.py +51 -110
  796. mindspore/rewrite/api/node_type.py +10 -6
  797. mindspore/rewrite/api/pattern_engine.py +51 -7
  798. mindspore/rewrite/api/scoped_value.py +64 -53
  799. mindspore/rewrite/api/symbol_tree.py +108 -61
  800. mindspore/rewrite/api/tree_node_helper.py +2 -3
  801. mindspore/{compression/quant/__init__.py → rewrite/ast_creator_register.py} +20 -11
  802. mindspore/rewrite/ast_helpers/__init__.py +6 -3
  803. mindspore/rewrite/ast_helpers/ast_creator.py +115 -0
  804. mindspore/rewrite/ast_helpers/ast_finder.py +99 -1
  805. mindspore/rewrite/ast_helpers/ast_modifier.py +17 -4
  806. mindspore/rewrite/ast_helpers/ast_replacer.py +1 -1
  807. mindspore/rewrite/ast_transformers/__init__.py +0 -1
  808. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +46 -5
  809. mindspore/rewrite/ast_transformers/remove_return_out_of_if.py +6 -3
  810. mindspore/rewrite/common/__init__.py +2 -0
  811. mindspore/rewrite/common/event.py +1 -1
  812. mindspore/rewrite/common/observable.py +1 -1
  813. mindspore/rewrite/common/observer.py +1 -1
  814. mindspore/rewrite/common/rewrite_elog.py +35 -0
  815. mindspore/rewrite/namer.py +2 -2
  816. mindspore/rewrite/namespace.py +14 -4
  817. mindspore/rewrite/node.py +161 -13
  818. mindspore/rewrite/parser.py +0 -1
  819. mindspore/rewrite/parser_register.py +0 -1
  820. mindspore/rewrite/parsers/arguments_parser.py +3 -2
  821. mindspore/rewrite/parsers/assign_parser.py +267 -67
  822. mindspore/rewrite/parsers/attribute_parser.py +56 -0
  823. mindspore/rewrite/parsers/class_def_parser.py +191 -108
  824. mindspore/rewrite/parsers/constant_parser.py +101 -0
  825. mindspore/rewrite/parsers/container_parser.py +88 -0
  826. mindspore/rewrite/parsers/for_parser.py +28 -15
  827. mindspore/rewrite/parsers/function_def_parser.py +21 -5
  828. mindspore/rewrite/parsers/if_parser.py +11 -28
  829. mindspore/rewrite/parsers/module_parser.py +9 -6
  830. mindspore/rewrite/parsers/return_parser.py +3 -2
  831. mindspore/rewrite/sparsify/__init__.py +0 -0
  832. mindspore/rewrite/sparsify/sparse_transformer.py +448 -0
  833. mindspore/rewrite/sparsify/sparsify.py +109 -0
  834. mindspore/rewrite/sparsify/utils.py +173 -0
  835. mindspore/rewrite/symbol_tree.py +322 -109
  836. mindspore/rewrite/symbol_tree_builder.py +45 -8
  837. mindspore/rewrite/symbol_tree_dumper.py +0 -1
  838. mindspore/rewrite/topological_manager.py +1 -2
  839. mindspore/run_check/_check_version.py +209 -112
  840. mindspore/run_check/run_check.py +2 -1
  841. mindspore/tbbmalloc.dll +0 -0
  842. mindspore/tinyxml2.dll +0 -0
  843. mindspore/train/__init__.py +6 -4
  844. mindspore/train/_utils.py +28 -5
  845. mindspore/train/amp.py +321 -50
  846. mindspore/train/callback/__init__.py +3 -1
  847. mindspore/train/callback/_backup_and_restore.py +120 -0
  848. mindspore/train/callback/_callback.py +8 -8
  849. mindspore/train/callback/_checkpoint.py +12 -9
  850. mindspore/train/callback/_early_stop.py +13 -7
  851. mindspore/train/callback/_history.py +8 -8
  852. mindspore/train/callback/_lambda_callback.py +6 -6
  853. mindspore/train/callback/_landscape.py +36 -38
  854. mindspore/train/callback/_loss_monitor.py +12 -6
  855. mindspore/train/callback/_lr_scheduler_callback.py +2 -4
  856. mindspore/train/callback/_on_request_exit.py +212 -0
  857. mindspore/train/callback/_reduce_lr_on_plateau.py +13 -7
  858. mindspore/train/callback/_summary_collector.py +27 -19
  859. mindspore/train/callback/_time_monitor.py +13 -7
  860. mindspore/train/checkpoint_pb2.py +68 -8
  861. mindspore/train/data_sink.py +122 -33
  862. mindspore/train/dataset_helper.py +28 -87
  863. mindspore/train/loss_scale_manager.py +4 -7
  864. mindspore/{nn → train}/metrics/__init__.py +20 -20
  865. mindspore/{nn → train}/metrics/accuracy.py +12 -10
  866. mindspore/{nn → train}/metrics/auc.py +4 -4
  867. mindspore/{nn → train}/metrics/bleu_score.py +4 -4
  868. mindspore/{nn → train}/metrics/confusion_matrix.py +10 -8
  869. mindspore/{nn → train}/metrics/cosine_similarity.py +4 -4
  870. mindspore/{nn → train}/metrics/dice.py +6 -5
  871. mindspore/{nn → train}/metrics/error.py +7 -5
  872. mindspore/{nn → train}/metrics/fbeta.py +9 -7
  873. mindspore/{nn → train}/metrics/hausdorff_distance.py +8 -6
  874. mindspore/{nn → train}/metrics/loss.py +4 -3
  875. mindspore/{nn → train}/metrics/mean_surface_distance.py +6 -5
  876. mindspore/{nn → train}/metrics/metric.py +6 -5
  877. mindspore/{nn → train}/metrics/occlusion_sensitivity.py +4 -3
  878. mindspore/{nn → train}/metrics/perplexity.py +5 -4
  879. mindspore/{nn → train}/metrics/precision.py +5 -4
  880. mindspore/{nn → train}/metrics/recall.py +5 -4
  881. mindspore/{nn → train}/metrics/roc.py +7 -6
  882. mindspore/{nn → train}/metrics/root_mean_square_surface_distance.py +6 -5
  883. mindspore/{nn → train}/metrics/topk.py +7 -5
  884. mindspore/train/mind_ir_pb2.py +339 -32
  885. mindspore/train/model.py +113 -84
  886. mindspore/train/serialization.py +547 -167
  887. mindspore/train/summary/_summary_adapter.py +1 -1
  888. mindspore/train/summary/summary_record.py +43 -12
  889. mindspore/train/train_thor/convert_utils.py +7 -1
  890. mindspore/train/train_thor/dataset_helper.py +3 -3
  891. mindspore/train/train_thor/model_thor.py +0 -4
  892. mindspore/turbojpeg.dll +0 -0
  893. mindspore/vcmeta.dll +0 -0
  894. mindspore/vcruntime140.dll +0 -0
  895. mindspore/vcruntime140_1.dll +0 -0
  896. mindspore/version.py +1 -1
  897. {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/METADATA +4 -3
  898. {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/RECORD +901 -660
  899. mindspore/compression/common/constant.py +0 -124
  900. mindspore/compression/export/__init__.py +0 -19
  901. mindspore/compression/export/quant_export.py +0 -514
  902. mindspore/compression/quant/qat.py +0 -636
  903. mindspore/compression/quant/quant_utils.py +0 -462
  904. mindspore/compression/quant/quantizer.py +0 -68
  905. mindspore/libatomic-1.dll +0 -0
  906. mindspore/libgcc_s_seh-1.dll +0 -0
  907. mindspore/libgfortran-4.dll +0 -0
  908. mindspore/libgomp-1.dll +0 -0
  909. mindspore/libjpeg-62.dll +0 -0
  910. mindspore/libmindspore.dll +0 -0
  911. mindspore/libmindspore_common.dll +0 -0
  912. mindspore/libmindspore_core.dll +0 -0
  913. mindspore/libmindspore_glog.dll +0 -0
  914. mindspore/libnnacl.dll +0 -0
  915. mindspore/libopencv_core452.dll +0 -0
  916. mindspore/libopencv_imgcodecs452.dll +0 -0
  917. mindspore/libopencv_imgproc452.dll +0 -0
  918. mindspore/libquadmath-0.dll +0 -0
  919. mindspore/libsqlite3.dll +0 -0
  920. mindspore/libssp-0.dll +0 -0
  921. mindspore/libstdc++-6.dll +0 -0
  922. mindspore/libtinyxml2.dll +0 -0
  923. mindspore/libturbojpeg.dll +0 -0
  924. mindspore/libwinpthread-1.dll +0 -0
  925. mindspore/nn/layer/quant.py +0 -1868
  926. mindspore/nn/layer/rnn_utils.py +0 -90
  927. mindspore/nn/probability/dpn/__init__.py +0 -22
  928. mindspore/nn/probability/dpn/vae/__init__.py +0 -25
  929. mindspore/nn/probability/dpn/vae/cvae.py +0 -138
  930. mindspore/nn/probability/dpn/vae/vae.py +0 -122
  931. mindspore/nn/probability/infer/__init__.py +0 -22
  932. mindspore/nn/probability/infer/variational/elbo.py +0 -70
  933. mindspore/nn/probability/infer/variational/svi.py +0 -84
  934. mindspore/nn/probability/toolbox/__init__.py +0 -22
  935. mindspore/nn/probability/toolbox/anomaly_detection.py +0 -99
  936. mindspore/nn/probability/toolbox/uncertainty_evaluation.py +0 -363
  937. mindspore/nn/probability/transforms/__init__.py +0 -22
  938. mindspore/nn/probability/transforms/transform_bnn.py +0 -262
  939. mindspore/nn/probability/zhusuan/__init__.py +0 -18
  940. mindspore/nn/probability/zhusuan/framework/__init__.py +0 -18
  941. mindspore/nn/probability/zhusuan/framework/bn.py +0 -95
  942. mindspore/nn/probability/zhusuan/variational/__init__.py +0 -18
  943. mindspore/nn/probability/zhusuan/variational/elbo.py +0 -46
  944. mindspore/ops/_op_impl/tbe/bias_add_grad_ds.py +0 -52
  945. mindspore/ops/_op_impl/tbe/scatter_nd_add_ds.py +0 -43
  946. mindspore/ops/bprop_mindir/AssignAdd_bprop.mindir +0 -20
  947. mindspore/ops/bprop_mindir/Identity_bprop.mindir +0 -9
  948. mindspore/ops/bprop_mindir/LogicalOr_bprop.mindir +0 -20
  949. mindspore/ops/bprop_mindir/ReLU_bprop.mindir +0 -16
  950. mindspore/ops/bprop_mindir/UpdateState_bprop.mindir +0 -17
  951. mindspore/ops/bprop_mindir/stop_gradient_bprop.mindir +0 -12
  952. mindspore/ops/composite/array_ops.py +0 -210
  953. mindspore/ops/composite/clip_ops.py +0 -238
  954. mindspore/ops/composite/random_ops.py +0 -426
  955. mindspore/ops/composite/vmap_ops.py +0 -38
  956. mindspore/ops/operations/sponge_ops.py +0 -3531
  957. mindspore/ops/operations/sponge_update_ops.py +0 -2546
  958. mindspore/parallel/nn/__init__.py +0 -42
  959. mindspore/parallel/nn/loss.py +0 -22
  960. mindspore/parallel/nn/moe.py +0 -21
  961. mindspore/parallel/nn/op_parallel_config.py +0 -22
  962. mindspore/parallel/nn/transformer.py +0 -31
  963. mindspore/run_check/_check_deps_version.py +0 -84
  964. {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/WHEEL +0 -0
  965. {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/entry_points.txt +0 -0
  966. {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/top_level.txt +0 -0
mindspore/common/api.py CHANGED
@@ -1,6 +1,6 @@
1
1
  # This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
2
2
  #
3
- # Copyright 2020-2021 Huawei Technologies Co., Ltd
3
+ # Copyright 2020-2023 Huawei Technologies Co., Ltd
4
4
  #
5
5
  # Licensed under the Apache License, Version 2.0 (the "License");
6
6
  # you may not use this file except in compliance with the License.
@@ -24,27 +24,29 @@ import time
24
24
  import ast
25
25
  import inspect
26
26
  import importlib
27
+ import hashlib
27
28
  from collections import OrderedDict
28
29
  from functools import wraps
30
+ import numpy as np
29
31
  import mindspore as ms
30
32
  from mindspore import context
31
33
  from mindspore import log as logger
32
34
  from mindspore._extends.remote import kernel_build_server
33
35
  from mindspore.common.tensor import Tensor as PythonTensor
34
- from mindspore.common.tensor import CSRTensor as PythonCSRTensor
35
- from mindspore.common.tensor import COOTensor as PythonCOOTensor
36
- from mindspore.common.tensor import RowTensor as PythonRowTensor
37
- from mindspore.common.initializer import initializer
38
- from mindspore._c_expression import GraphExecutor_, Tensor, MetaTensor, CSRTensor, RowTensor, COOTensor, \
39
- PynativeExecutor_, verify_inputs_signature, init_exec_dataset, _set_dataset_mode_config, init_pipeline, \
40
- _ms_memory_recycle
41
- from mindspore.parallel._tensor import _load_tensor_by_layout
42
- from mindspore.parallel._ps_context import _is_role_pserver, _is_role_sched, _enable_distributed_mindrt
43
- from mindspore.parallel._utils import _check_full_batch, _get_parameter_broadcast, _get_pipeline_stages
44
- from mindspore._checkparam import Validator
36
+ from mindspore.common.sparse_tensor import CSRTensor as PythonCSRTensor
37
+ from mindspore.common.sparse_tensor import COOTensor as PythonCOOTensor
38
+ from mindspore.common.sparse_tensor import RowTensor as PythonRowTensor
39
+ from mindspore._c_expression import GraphExecutor_, Tensor, CSRTensor, RowTensor, COOTensor, \
40
+ PyNativeExecutor_, verify_inputs_signature, init_exec_dataset, _set_dataset_mode_config, init_pipeline, \
41
+ _ms_memory_recycle, _bind_device_ctx
42
+ from mindspore.parallel._ps_context import _is_role_sched
43
+ from mindspore.parallel._utils import _check_full_batch, _get_parameter_broadcast, _is_pynative_parallel, \
44
+ _get_pipeline_stages, _is_in_auto_parallel_mode
45
+ from mindspore import _checkparam as Validator
46
+ from mindspore._checkparam import is_stub_tensor
45
47
  from mindspore.common._utils import is_shape_unknown
46
48
  from mindspore.common.mutable import mutable
47
-
49
+ from mindspore.common._register_for_adapter import ms_adapter_registry
48
50
 
49
51
  # store ms_function class compiled pipeline cache
50
52
  ms_compile_cache = set()
@@ -52,7 +54,7 @@ ms_compile_cache = set()
52
54
  cells_compile_cache = {}
53
55
 
54
56
  BROADCAST_PHASE = "_broadcast_"
55
- _PYNATIVE_PARRALLEL_FUNC_NAME = "after_shard"
57
+ _PYNATIVE_PARALLEL_FUNC_NAME = "after_shard"
56
58
 
57
59
 
58
60
  def _convert_python_data(data):
@@ -65,6 +67,8 @@ def _convert_python_data(data):
65
67
  Returns:
66
68
  data, a data convert C++ to python
67
69
  """
70
+ if isinstance(data, Tensor) and data.adapter_flag:
71
+ return ms_adapter_registry.tensor(data)
68
72
  if isinstance(data, Tensor) and not isinstance(data, PythonTensor):
69
73
  return PythonTensor(data, internal=True)
70
74
  if isinstance(data, CSRTensor) and not isinstance(data, PythonCSRTensor):
@@ -103,7 +107,8 @@ def _wrap_func(fn):
103
107
 
104
108
  def _check_all_tensor(sequence):
105
109
  for element in sequence:
106
- if not isinstance(element, Tensor) and not (isinstance(element, tuple) and _check_all_tensor(element)):
110
+ if not isinstance(element, Tensor) and not is_stub_tensor(element) and not (isinstance(element, tuple)
111
+ and _check_all_tensor(element)):
107
112
  return False
108
113
  return True
109
114
 
@@ -117,28 +122,28 @@ def _handle_func_args(func, *args, **kwargs):
117
122
  bound_arguments.apply_defaults()
118
123
  args = bound_arguments.args
119
124
  kwargs = bound_arguments.kwargs
120
- # After apply_defaults, kwargs should be empty here.
121
- if kwargs:
122
- raise ValueError(f"Failed to handle kwargs of {func.__name__}. Maybe you pass wrong arguments, "
123
- f"or there is a key in kwargs that is not used as a function argument, "
124
- f"args: {args}, kwargs: {kwargs}")
125
125
 
126
126
  positional_args = 0
127
127
  default_args = 0
128
+ has_var = False
128
129
  for value in inspect.signature(func).parameters.values():
129
130
  if value.kind is inspect.Parameter.VAR_POSITIONAL or value.kind is inspect.Parameter.VAR_KEYWORD:
130
- return args
131
+ has_var = True
131
132
  if value.kind is inspect.Parameter.POSITIONAL_OR_KEYWORD:
132
133
  if value.default is inspect.Parameter.empty:
133
134
  positional_args += 1
134
135
  else:
135
136
  default_args += 1
137
+
138
+ if has_var:
139
+ return args, kwargs
140
+
136
141
  if len(args) < positional_args:
137
142
  raise TypeError(f"Function {func.__name__} needs {positional_args} positional argument, but got {len(args)}.")
138
143
  if len(args) > positional_args + default_args:
139
144
  raise TypeError(f"Function {func.__name__} needs {positional_args} positional argument and {default_args} "
140
145
  f"default argument, total {positional_args + default_args}, but got {len(args)}.")
141
- return args
146
+ return args, kwargs
142
147
 
143
148
 
144
149
  sys_path = list(sys.path)
@@ -164,7 +169,8 @@ def __get_compile_cache_dep_files(file_path, compile_cache_dep_files, pkg):
164
169
  for node in ast.iter_child_nodes(root):
165
170
  module_name = ""
166
171
  if isinstance(node, ast.ImportFrom):
167
- module_name = node.module
172
+ if node.module is not None:
173
+ module_name = node.module
168
174
  if node.level == 1:
169
175
  module_name = "." + module_name
170
176
  elif not isinstance(node, ast.Import):
@@ -219,31 +225,59 @@ def _restore_mutable_attr(args_list, compile_args):
219
225
  for idx, arg in enumerate(args_list):
220
226
  if hasattr(arg, "__ms_mutable__") and getattr(arg, "__ms_mutable__") and \
221
227
  not (hasattr(arg, "const_arg") and getattr(arg, "const_arg")):
222
- new_compile_args += (mutable(compile_args[idx]),)
228
+ if hasattr(arg, "__ms_dynamic_len__"):
229
+ new_compile_args += (mutable(compile_args[idx], getattr(arg, "__ms_dynamic_len__")),)
230
+ else:
231
+ new_compile_args += (mutable(compile_args[idx], False),)
223
232
  else:
224
233
  new_compile_args += (compile_args[idx],)
225
234
  return new_compile_args
226
235
 
227
236
 
228
- def _get_args_for_run(obj, args_list):
229
- """Get the actual input args for runtime."""
230
- inputs = []
231
- for i in args_list:
232
- if isinstance(i, PythonTensor):
233
- if i.has_init:
234
- i.init_data()
235
- if not i.const_arg:
236
- inputs.append(i)
237
- elif isinstance(i, (Tensor, CSRTensor, COOTensor)):
238
- inputs.append(i)
239
- elif hasattr(i, "__ms_mutable__") and getattr(i, "__ms_mutable__"):
240
- inputs.append(i)
241
- elif context.get_context("grad_for_scalar") and isinstance(i, (int, float)):
242
- inputs.append(i)
243
- elif hasattr(obj, "enable_tuple_broaden") and obj.enable_tuple_broaden and isinstance(i, tuple) and \
244
- _check_all_tensor(i):
245
- inputs.append(i)
246
- return inputs
237
+ def _get_parameter_layout():
238
+ graph_executor = GraphExecutor_.get_instance()
239
+ layout = dict()
240
+ for phase in ms_compile_cache:
241
+ layout.update(graph_executor.get_parameter_layout(phase))
242
+ return layout
243
+
244
+
245
+ def _handle_arg(obj, arg):
246
+ """Handle arg for runtime .If need handle the arg, return True"""
247
+ if isinstance(arg, PythonTensor):
248
+ if arg.has_init:
249
+ arg.init_data()
250
+ if not arg.const_arg:
251
+ return arg
252
+ elif isinstance(arg, (Tensor, CSRTensor, COOTensor)):
253
+ return arg
254
+ elif hasattr(arg, "__ms_mutable__") and getattr(arg, "__ms_mutable__"):
255
+ # mutable([]) will be eliminated by FuncGraphSpecializer, and empty list is not supported by backend.
256
+ if isinstance(arg, list) and not arg:
257
+ return None
258
+ return arg
259
+ elif context.get_context("grad_for_scalar") and isinstance(arg, (int, float)):
260
+ return arg
261
+ elif hasattr(obj, "enable_tuple_broaden") and obj.enable_tuple_broaden and isinstance(arg, tuple) and \
262
+ _check_all_tensor(arg):
263
+ return arg
264
+ return None
265
+
266
+
267
+ def _get_args_for_run(obj, args, kwargs):
268
+ """Get the actual input args and kwargs for runtime."""
269
+ new_args = []
270
+ for arg in args:
271
+ new_arg = _handle_arg(obj, arg)
272
+ if new_arg is not None:
273
+ new_args.append(new_arg)
274
+
275
+ for _, value in kwargs.items():
276
+ new_value = _handle_arg(obj, value)
277
+ if new_value is not None:
278
+ new_args.append(new_value)
279
+
280
+ return new_args
247
281
 
248
282
 
249
283
  class _MindsporeFunctionExecutor:
@@ -256,7 +290,7 @@ class _MindsporeFunctionExecutor:
256
290
  Args:
257
291
  fn (Function): The root function to compile.
258
292
  input_signature (Function): User defines signature to verify input.
259
- ms_create_time(TimeStamp): The time ms_function created
293
+ ms_create_time(TimeStamp): Time the function was created
260
294
  obj (Object): If function is a method, obj is the owner of function,
261
295
  else, obj is none.
262
296
 
@@ -280,69 +314,60 @@ class _MindsporeFunctionExecutor:
280
314
  self._create_time = ms_create_time
281
315
  self.jit_config_dict = jit_config.jit_config_dict if jit_config else None
282
316
 
283
- def _set_compile_cache_dep_files(self):
284
- # If enable compile cache, get the dependency files list
285
- enable_compile_cache = context.get_context("enable_compile_cache")
286
- if enable_compile_cache is None:
287
- enable_compile_cache = os.getenv('MS_COMPILER_CACHE_ENABLE')
288
- if enable_compile_cache is True or enable_compile_cache == "1":
289
- self._graph_executor.set_compile_cache_dep_files(_get_compile_cache_dep_files())
290
-
291
- def _parallel_process_for_ms_function(self, phase):
292
- """Set parameter and optimizer states data according to sliced shape for shard"""
293
- obj = self.shard_parent_obj if self.obj is None else self.obj
294
- obj.parameter_layout_dict = self._graph_executor.get_parameter_layout(phase)
295
- obj.parallel_parameter_name_list = self._graph_executor.get_parallel_parameter_name_list(phase)
296
- replace = obj.init_parameters_data(auto_parallel_mode=True)
297
- new_param = {x.name: replace[x] for x in replace if id(x) != id(replace[x])}
298
- self._graph_executor.updata_param_node_default_input(phase, new_param)
299
- obj.load_parameter_slice(None)
300
-
301
- if _pynative_executor.get_optimizer():
302
- params = obj.trainable_params()
303
- opt_params = _pynative_executor.get_optimizer().trainable_params()
304
- opt_states = []
305
- for opt_param in opt_params:
306
- for param in params:
307
- if opt_param.name.find(param.name) > 0:
308
- opt_states.append(opt_param)
309
- obj.parameter_layout_dict[opt_param.name] = obj.parameter_layout_dict[param.name]
310
- continue
311
-
312
- if len(opt_states) != len(params):
313
- states_tuple = (opt_states[:len(params)], opt_states[len(params):])
317
+ @_wrap_func
318
+ def __call__(self, *args, **kwargs):
319
+ args_list = args
320
+ if self.obj is not None:
321
+ args_list = args_list[1:]
322
+ phase = ""
323
+ try:
324
+ if context.get_context("mode") == context.PYNATIVE_MODE:
325
+ _pynative_executor.set_ms_function_compile_status(True, phase)
326
+ phase = self.compile(self.fn.__name__, *args_list, **kwargs)
327
+ _pynative_executor.set_ms_function_compile_status(False, phase)
314
328
  else:
315
- states_tuple = (opt_states[:len(params)],)
316
- for states in states_tuple:
317
- for param, state in zip(params, states):
318
- if param.shape != state.shape:
319
- if state.has_init:
320
- state.set_data(initializer("zeros", param.shape), True)
321
- else:
322
- layout = obj.parameter_layout_dict[param.name]
323
- new_tensor = _load_tensor_by_layout(state.data, layout)
324
- state.set_data(new_tensor, True)
325
-
326
- _pynative_executor.get_top_cell().parameter_layout_dict = obj.parameter_layout_dict
327
-
328
- def compile(self, args_list, method_name):
329
+ phase = self.compile(self.fn.__name__, *args_list, **kwargs)
330
+ except Exception as err:
331
+ _pynative_executor.clear_res()
332
+ raise err
333
+
334
+ if context.get_context("precompile_only"):
335
+ return None
336
+
337
+ new_inputs = self._generate_run_args(args_list, kwargs)
338
+ output = self._graph_executor(tuple(new_inputs), phase)
339
+ if context.get_context("mode") == context.PYNATIVE_MODE:
340
+ output = _pynative_executor.grad_ms_function(output, *new_inputs)
341
+
342
+ enable_ge = os.getenv("MS_ENABLE_GE") == "1"
343
+ if enable_ge and self.jit_config_dict is None:
344
+ raise RuntimeError("GE and jit_level=O3 should be used together, but jit_config is None.")
345
+ if self.jit_config_dict:
346
+ enable_jit_level_o3 = self.jit_config_dict.get('jit_level') == "O3"
347
+ if (enable_ge and not enable_jit_level_o3) or (not enable_ge and enable_jit_level_o3):
348
+ raise RuntimeError("GE and jit_level=O3 should be used together, but got MS_ENABLE_GE={}, jit_level={}".
349
+ format(os.getenv("MS_ENABLE_GE"), self.jit_config_dict.get('jit_level')))
350
+
351
+ return output
352
+
353
+ def compile(self, method_name, *args, **kwargs):
329
354
  """Returns pipeline for the given args."""
330
355
  # Check whether hook function registered on Cell object.
331
356
  if self.obj and hasattr(self.obj, "_hook_fn_registered"):
332
357
  if self.obj._hook_fn_registered():
333
- logger.warning(f"For 'Cell', it's not support hook function when using ms_function. If you want to "
334
- f"use hook function, please use context.set_context to set pynative mode and remove "
335
- f"`ms_function`.")
358
+ logger.warning(f"For 'Cell', it's not support hook function when using 'jit' decorator. "
359
+ f"If you want to use hook function, please use context.set_context to set "
360
+ f"pynative mode and remove 'jit' decorator.")
336
361
  # Chose dynamic shape tensors or actual input tensors as compile args.
337
- compile_args = self._generate_compile_args(args_list)
362
+ compile_args = self._generate_compile_args(args)
338
363
  # Restore the mutable attr for every arg.
339
- compile_args = _restore_mutable_attr(args_list, compile_args)
364
+ compile_args = _restore_mutable_attr(args, compile_args)
340
365
 
341
366
  generate_name = self.fn.__module__ + "." + self.fn.__name__ + "." + self.fn.__code__.co_filename + "." + \
342
- str(self.fn.__code__.co_firstlineno) + '.' + str(id(self.fn))
367
+ str(self.fn.__code__.co_firstlineno)
343
368
  if _pynative_executor.grad_flag():
344
369
  generate_name = generate_name + ".grad"
345
- if is_pynative_parallel():
370
+ if _is_pynative_parallel():
346
371
  generate_name = generate_name[:generate_name.rfind(str(id(self.fn)))] + str(id(self.shard_parent_obj))
347
372
 
348
373
  # Add key with obj
@@ -365,7 +390,7 @@ class _MindsporeFunctionExecutor:
365
390
  self.enable_tuple_broaden = self.obj.enable_tuple_broaden
366
391
 
367
392
  self._graph_executor.set_enable_tuple_broaden(self.enable_tuple_broaden)
368
- key = self._graph_executor.generate_arguments_key(compile_args, self.enable_tuple_broaden)
393
+ key = self._graph_executor.generate_arguments_key(self.fn, compile_args, kwargs, self.enable_tuple_broaden)
369
394
  phase = generate_name + '.' + str(key)
370
395
  if phase in ms_compile_cache:
371
396
  return phase
@@ -376,43 +401,17 @@ class _MindsporeFunctionExecutor:
376
401
  self._graph_executor.set_jit_config(self.jit_config_dict)
377
402
 
378
403
  if self.obj is None:
379
- is_compile = self._graph_executor.compile(self.fn, compile_args, phase, True)
404
+ is_compile = self._graph_executor.compile(self.fn, compile_args, kwargs, phase, True)
380
405
  else:
381
406
  if isinstance(self.obj, ms.nn.Cell):
382
407
  self._graph_executor.set_weights_values(self.obj.parameters_dict())
383
- is_compile = self._graph_executor.compile(self.obj, compile_args, phase, True)
384
-
385
- # init sliced parameter and optimizer state
386
- if is_pynative_parallel() and self.fn.__name__ == _PYNATIVE_PARRALLEL_FUNC_NAME:
387
- self._parallel_process_for_ms_function(phase)
388
-
389
- # init the rest optimizer states
390
- if is_pynative_parallel() and _pynative_executor.get_optimizer():
391
- opt_states = _pynative_executor.get_optimizer().trainable_params()
392
- self._optimizer_state_init(opt_states)
408
+ is_compile = self._graph_executor.compile(self.obj, compile_args, kwargs, phase, True)
393
409
 
394
410
  if not is_compile:
395
411
  raise RuntimeError("Executor compile failed.")
396
412
  ms_compile_cache.add(phase)
397
413
  return phase
398
414
 
399
- @_wrap_func
400
- def __call__(self, *args):
401
- args_list = args
402
- if self.obj is not None:
403
- args_list = args_list[1:]
404
-
405
- phase = self.compile(args_list, self.fn.__name__)
406
- if context.get_context("precompile_only"):
407
- return None
408
- new_inputs = self._generate_run_args(args_list)
409
- output = self._graph_executor(tuple(new_inputs), phase)
410
- if context.get_context("mode") == context.PYNATIVE_MODE:
411
- _pynative_executor.set_graph_phase(phase)
412
- output = _pynative_executor.grad_ms_function(output, *new_inputs)
413
-
414
- return output
415
-
416
415
  @staticmethod
417
416
  def _optimizer_state_init(opt_states):
418
417
  """set data for all optimizer states in case it is executed in graph mode"""
@@ -423,52 +422,64 @@ class _MindsporeFunctionExecutor:
423
422
  if opt_param.has_init and (prefix in prefix_list or opt_param.name == "global_step"):
424
423
  opt_param.init_data()
425
424
 
425
+ def _set_compile_cache_dep_files(self):
426
+ # If enable compile cache, get the dependency files list
427
+ enable_compile_cache = context.get_context("enable_compile_cache")
428
+ if enable_compile_cache is None:
429
+ enable_compile_cache = os.getenv('MS_COMPILER_CACHE_ENABLE')
430
+ if enable_compile_cache is True or enable_compile_cache == "1":
431
+ self._graph_executor.set_compile_cache_dep_files(_get_compile_cache_dep_files())
432
+
426
433
  def _generate_compile_args(self, args_list):
427
434
  """Chose dynamic shape tensors or actual input tensors as compile args."""
428
435
  # Case: If the shape of input args is dynamic, get dynamic shape tensor from context and use it to compile.
429
- compile_args = _pynative_executor.get_dynamic_input(args_list)
436
+ compile_args = args_list
430
437
  # Case: The `set_inputs()` of Cell object has been set, using these dynamic shape args as compile args.
431
- if isinstance(self.obj, ms.nn.Cell) and self.obj.get_inputs():
438
+ if self.fn.__name__ == 'construct' and isinstance(self.obj, ms.nn.Cell) and self.obj.get_inputs():
432
439
  compile_args = self.obj.get_inputs()
433
- for args in compile_args:
434
- Validator.check_isinstance("args set in `set_inputs()` of Cell", args, PythonTensor)
435
- Validator.check_dynamic_shape(compile_args, args_list)
440
+ if len(compile_args) != len(args_list):
441
+ raise ValueError(f"The number of actual input tensors: {len(args_list)} is not equal to the number of "
442
+ f"dynamic shape tensors: {len(compile_args)}.")
443
+ for i, elem in enumerate(compile_args):
444
+ if isinstance(elem, PythonTensor):
445
+ Validator.check_dynamic_shape(compile_args[i], args_list[i], i)
446
+
436
447
  # Case: If dynamic shape tensors have been assigned to `input_signature`, they are preferred as compile args.
437
448
  if self.input_signature is not None:
438
449
  if not isinstance(self.input_signature, (tuple, list)):
439
450
  self.input_signature = (self.input_signature,)
440
451
  self.input_signature = list(self.input_signature)
441
452
  dyn_shape = False
442
- for sig_args in self.input_signature:
443
- Validator.check_isinstance("args in `input_signature` of `ms_function`", sig_args, MetaTensor)
444
- if is_shape_unknown(sig_args.shape):
453
+ for i, elem in enumerate(self.input_signature):
454
+ if isinstance(elem, PythonTensor) and is_shape_unknown(elem.shape):
455
+ Validator.check_dynamic_shape(self.input_signature[i], args_list[i], i)
445
456
  dyn_shape = True
446
- if not dyn_shape:
447
- if not verify_inputs_signature(self.input_signature, args_list):
448
- raise ValueError("The input args is incompatible with the args in `input_signature`!")
449
- else:
457
+ if dyn_shape:
450
458
  # Checkout whether the `sens` has been added to args_list.
451
459
  if len(self.input_signature) == len(args_list) - 1:
452
- logger.warning(f"The number of actual input args `{len(args_list)}` is one more than the number "
453
- f"of input_signature args `{len(self.input_signature)}`. The last actual args may "
454
- f"be `sens` and added it to compile args.")
460
+ logger.warning(f"The number of actual input args '{len(args_list)}' is one more than the number "
461
+ f"of input_signature args '{len(self.input_signature)}'. The last actual args may "
462
+ f"be 'sens' and added it to compile args.")
455
463
  self.input_signature.append(args_list[-1])
456
- Validator.check_dynamic_shape(self.input_signature, args_list)
457
464
  compile_args = tuple(self.input_signature)
458
- _pynative_executor.set_dynamic_input(self.obj, *compile_args)
465
+ _pynative_executor.set_dynamic_input(self.obj)
466
+ else:
467
+ if not verify_inputs_signature(self.input_signature, args_list):
468
+ raise ValueError("The input args is incompatible with the args in `input_signature`!")
459
469
  return compile_args
460
470
 
461
- def _generate_run_args(self, args_list):
471
+ def _generate_run_args(self, args_list, kwargs):
462
472
  """
463
473
  Generate input args, which are required for running.
464
474
 
465
475
  Args:
466
476
  args_list (Tuple): Actual input args.
477
+ kwargs (Dict): Actual input kwargs.
467
478
 
468
479
  Returns:
469
480
  new_inputs, new input args, which are required for running.
470
481
  """
471
- return _get_args_for_run(self, args_list)
482
+ return _get_args_for_run(self, args_list, kwargs)
472
483
 
473
484
 
474
485
  # The attributes used to identify a given object.
@@ -496,7 +507,7 @@ def _get_ms_function_hash(hash_input):
496
507
  return _get_obj_id(hash_input)
497
508
 
498
509
 
499
- def ms_function(fn=None, input_signature=None, hash_args=None, jit_config=None):
510
+ def jit(fn=None, input_signature=None, hash_args=None, jit_config=None):
500
511
  """
501
512
  Create a callable MindSpore graph from a Python function.
502
513
 
@@ -529,30 +540,30 @@ def ms_function(fn=None, input_signature=None, hash_args=None, jit_config=None):
529
540
  >>> import numpy as np
530
541
  >>> from mindspore import Tensor
531
542
  >>> from mindspore import ops
532
- >>> from mindspore import ms_function
543
+ >>> from mindspore import jit
533
544
  ...
534
545
  >>> x = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
535
546
  >>> y = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
536
547
  ...
537
- >>> # create a callable MindSpore graph by calling ms_function
548
+ >>> # create a callable MindSpore graph by calling decorator @jit
538
549
  >>> def tensor_add(x, y):
539
550
  ... z = x + y
540
551
  ... return z
541
552
  ...
542
- >>> tensor_add_graph = ms_function(fn=tensor_add)
553
+ >>> tensor_add_graph = jit(fn=tensor_add)
543
554
  >>> out = tensor_add_graph(x, y)
544
555
  ...
545
- >>> # create a callable MindSpore graph through decorator @ms_function
546
- >>> @ms_function
556
+ >>> # create a callable MindSpore graph through decorator @jit
557
+ >>> @jit
547
558
  ... def tensor_add_with_dec(x, y):
548
559
  ... z = x + y
549
560
  ... return z
550
561
  ...
551
562
  >>> out = tensor_add_with_dec(x, y)
552
563
  ...
553
- >>> # create a callable MindSpore graph through decorator @ms_function with input_signature parameter
554
- >>> @ms_function(input_signature=(Tensor(np.ones([1, 1, 3, 3]).astype(np.float32)),
555
- ... Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))))
564
+ >>> # create a callable MindSpore graph through decorator @jit with input_signature parameter
565
+ >>> @jit(input_signature=(Tensor(np.ones([1, 1, 3, 3]).astype(np.float32)),
566
+ ... Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))))
556
567
  ... def tensor_add_with_sig(x, y):
557
568
  ... z = x + y
558
569
  ... return z
@@ -565,7 +576,7 @@ def ms_function(fn=None, input_signature=None, hash_args=None, jit_config=None):
565
576
  ... return ops.exp(x)
566
577
  ...
567
578
  >>> def closure_fn(x, fn):
568
- ... @ms_function(hash_args=fn)
579
+ ... @jit(hash_args=fn)
569
580
  ... def inner_fn(a):
570
581
  ... return fn(a)
571
582
  ... return inner_fn(x)
@@ -583,15 +594,18 @@ def ms_function(fn=None, input_signature=None, hash_args=None, jit_config=None):
583
594
 
584
595
  @wraps(func)
585
596
  def staging_specialize(*args, **kwargs):
586
- args = _handle_func_args(func, *args, **kwargs)
597
+ if os.getenv("MS_JIT") == '0':
598
+ return func(*args, **kwargs)
599
+
600
+ args, kwargs = _handle_func_args(func, *args, **kwargs)
601
+
587
602
  process_obj = None
588
603
  if args and not isinstance(args[0], PythonTensor) and hasattr(args[0], func.__name__):
589
604
  process_obj = args[0]
590
605
  # only the function or cell instance wrapped by shard will fall into this branch
591
- if is_pynative_parallel() and func.__name__ == _PYNATIVE_PARRALLEL_FUNC_NAME:
592
- process_obj = args[0]
593
- args = args[1:]
594
- out = _MindsporeFunctionExecutor(func, hash_obj, input_signature, process_obj, jit_config)(*args)
606
+ if _is_pynative_parallel() and func.__name__ == _PYNATIVE_PARALLEL_FUNC_NAME:
607
+ process_obj = hash_args
608
+ out = _MindsporeFunctionExecutor(func, hash_obj, input_signature, process_obj, jit_config)(*args, **kwargs)
595
609
  return out
596
610
 
597
611
  return staging_specialize
@@ -601,21 +615,201 @@ def ms_function(fn=None, input_signature=None, hash_args=None, jit_config=None):
601
615
  return wrap_mindspore
602
616
 
603
617
 
618
+ def ms_function(fn=None, input_signature=None, hash_args=None, jit_config=None):
619
+ """
620
+ Create a callable MindSpore graph from a Python function.
621
+
622
+ This allows the MindSpore runtime to apply optimizations based on graph.
623
+
624
+ Note:
625
+ `ms_function` will be deprecated and removed in a future version. Please use `jit` instead.
626
+ If `input_signature` is specified, each input of `fn` must be a Tensor. And the input arguments for `fn`
627
+ will not accept `**kwargs`.
628
+
629
+ Args:
630
+ fn (Function): The Python function that will be run as a graph. Default: None.
631
+ input_signature (Tensor): The Tensor which describes the input arguments. The shape and dtype of the Tensor
632
+ will be supplied to this function. If input_signature is specified, each input to `fn` must be a `Tensor`.
633
+ And the input parameters of `fn` cannot accept `**kwargs`. The shape and dtype of actual inputs should
634
+ keep the same as input_signature. Otherwise, TypeError will be raised. Default: None.
635
+ hash_args (Union[Object, List or Tuple of Objects]): The local free variables used inside `fn`,
636
+ like functions or objects of class defined outside `fn`. Calling `fn` again with change of `hash_args`
637
+ will trigger recompilation.
638
+ jit_config (JitConfig): Jit config for compile. Default: None.
639
+
640
+ Returns:
641
+ Function, if `fn` is not None, returns a callable function that will execute the compiled function; If `fn` is
642
+ None, returns a decorator and when this decorator invokes with a single `fn` argument, the callable function is
643
+ equal to the case when `fn` is not None.
644
+
645
+ Supported Platforms:
646
+ ``Ascend`` ``GPU`` ``CPU``
647
+
648
+ Examples:
649
+ >>> import numpy as np
650
+ >>> from mindspore import Tensor
651
+ >>> from mindspore import ops
652
+ >>> from mindspore import ms_function
653
+ ...
654
+ >>> x = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
655
+ >>> y = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
656
+ ...
657
+ >>> # create a callable MindSpore graph by calling ms_function
658
+ >>> def tensor_add(x, y):
659
+ ... z = x + y
660
+ ... return z
661
+ ...
662
+ >>> tensor_add_graph = ms_function(fn=tensor_add)
663
+ >>> out = tensor_add_graph(x, y)
664
+ ...
665
+ >>> # create a callable MindSpore graph through decorator @ms_function
666
+ >>> @ms_function
667
+ ... def tensor_add_with_dec(x, y):
668
+ ... z = x + y
669
+ ... return z
670
+ ...
671
+ >>> out = tensor_add_with_dec(x, y)
672
+ ...
673
+ >>> # create a callable MindSpore graph through decorator @ms_function with input_signature parameter
674
+ >>> @ms_function(input_signature=(Tensor(np.ones([1, 1, 3, 3]).astype(np.float32)),
675
+ ... Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))))
676
+ ... def tensor_add_with_sig(x, y):
677
+ ... z = x + y
678
+ ... return z
679
+ ...
680
+ >>> out = tensor_add_with_sig(x, y)
681
+ ...
682
+ ... # Set hash_args as fn, otherwise cache of compiled `closure_fn` will not be reused.
683
+ ... # While fn differs during calling again, recompilation will be triggered.
684
+ >>> def func(x):
685
+ ... return ops.exp(x)
686
+ ...
687
+ >>> def closure_fn(x, fn):
688
+ ... @ms_function(hash_args=fn)
689
+ ... def inner_fn(a):
690
+ ... return fn(a)
691
+ ... return inner_fn(x)
692
+ ...
693
+ >>> inputs = Tensor(np.ones([10, 10, 10]).astype(np.float32))
694
+ >>> for i in range(10):
695
+ ... closure_fn(inputs, func)
696
+ """
697
+
698
+ logger.warning("'mindspore.ms_function' will be deprecated and removed in a future version. "
699
+ "Please use 'mindspore.jit' instead.")
700
+ return jit(fn=fn, input_signature=input_signature, hash_args=hash_args, jit_config=jit_config)
701
+
702
+
703
+ def _core(fn=None, **flags):
704
+ """
705
+ A decorator that adds a flag to the function.
706
+
707
+ By default, the function is marked as True, enabling to use this decorator to
708
+ set flag to a graph.
709
+
710
+ Args:
711
+ fn (Function): Function to add flag. Default: None.
712
+ flags (dict): The following flags can be set core, which indicates that this is a core function or
713
+ other flag. Default: None.
714
+
715
+ Returns:
716
+ Function, the function with core flag.
717
+
718
+ Supported Platforms:
719
+ ``Ascend`` ``GPU`` ``CPU``
720
+ """
721
+
722
+ # need set the attr and access on c++
723
+ def deco(fn):
724
+ fn._func_graph_flags = {
725
+ 'core': True,
726
+ **flags,
727
+ }
728
+ return fn
729
+
730
+ if fn is not None:
731
+ ret = deco(fn)
732
+ else:
733
+ ret = deco
734
+ return ret
735
+
736
+
737
+ def _add_flags(fn=None, **flags):
738
+ """
739
+ A decorator that adds a flag to the function.
740
+
741
+ Note:
742
+ Only supports bool value.
743
+
744
+ Args:
745
+ fn (Function): Function or cell to add flag. Default: None.
746
+ flags (dict): Flags use kwargs. Default: None.
747
+
748
+ Returns:
749
+ Function, the function with added flags.
750
+
751
+ Supported Platforms:
752
+ ``Ascend`` ``GPU`` ``CPU``
753
+ """
754
+
755
+ def deco(fn):
756
+ # need set the attr and access on c++
757
+ if not hasattr(fn, "_func_graph_flags"):
758
+ fn._func_graph_flags = {}
759
+
760
+ fn._func_graph_flags.update({**flags})
761
+ return fn
762
+
763
+ ret = deco
764
+ if fn is not None:
765
+ ret = deco(fn)
766
+ return ret
767
+
768
+
769
+ def _no_recursive(callable_obj):
770
+ """
771
+ Method or function decorator for ignoring recursive check.
772
+
773
+ This allows MindSpore to skip the procedure of checking function or method recursive.
774
+
775
+ Args:
776
+ callable_obj (Union(method, function)): The function or method to call.
777
+
778
+ Returns:
779
+ Function or method with no_recursive flag.
780
+
781
+ Raises:
782
+ TypeError: If ms_class is used for non-class types or nn.Cell.
783
+ AttributeError: If the private attributes or magic methods of the class decorated by ms_class is called.
784
+
785
+ Supported Platforms:
786
+ ``Ascend`` ``GPU`` ``CPU``
787
+ """
788
+ isCellSubClass = inspect.isclass(callable_obj) and issubclass(callable_obj, ms.nn.Cell)
789
+ if not isCellSubClass and not inspect.ismethod(callable_obj) and not inspect.isfunction(callable_obj):
790
+ raise TypeError(f"Decorator no_recursive is used for callable object, but got {callable_obj}.")
791
+ _add_flags(callable_obj, no_recursive=True)
792
+ return callable_obj
793
+
794
+
604
795
  def ms_class(cls):
605
796
  """
606
797
  Class decorator for user-defined classes.
607
798
 
608
799
  This allows MindSpore to identify user-defined classes and thus obtain their attributes and methods.
609
800
 
801
+ Note:
802
+ `ms_class` will be deprecated and removed in a future version. Please use `jit_class` instead.
803
+
610
804
  Args:
611
805
  cls (Class): User-defined class.
612
806
 
613
807
  Returns:
614
- Class with __ms_class__ attribute.
808
+ Class.
615
809
 
616
810
  Raises:
617
811
  TypeError: If ms_class is used for non-class types or nn.Cell.
618
- AttributeError: If the private attributes or magic methods of the class decorated by ms_class is called.
812
+ AttributeError: If the private attributes or magic methods of the class decorated with ms_class is called.
619
813
 
620
814
  Supported Platforms:
621
815
  ``Ascend`` ``GPU`` ``CPU``
@@ -647,6 +841,9 @@ def ms_class(cls):
647
841
  20
648
842
  """
649
843
 
844
+ logger.warning("'mindspore.ms_class' will be deprecated and removed in a future version. "
845
+ "Please use 'mindspore.jit_class' instead.")
846
+
650
847
  # Check if cls is of type class.
651
848
  if not inspect.isclass(cls):
652
849
  raise TypeError(f'Decorator ms_class can only be used for class type, but got {cls}.')
@@ -658,6 +855,81 @@ def ms_class(cls):
658
855
  return cls
659
856
 
660
857
 
858
+ def jit_class(cls):
859
+ """
860
+ Class decorator for user-defined classes.
861
+
862
+ This allows MindSpore to identify user-defined classes and thus obtain their attributes and methods.
863
+
864
+ Args:
865
+ cls (Class): User-defined class.
866
+
867
+ Returns:
868
+ Class.
869
+
870
+ Raises:
871
+ TypeError: If `jit_class` is used for non-class types or nn.Cell.
872
+ AttributeError: If the private attributes or magic methods of the class decorated with `jit_class` is called.
873
+
874
+ Supported Platforms:
875
+ ``Ascend`` ``GPU`` ``CPU``
876
+
877
+ Examples:
878
+ >>> import mindspore.nn as nn
879
+ >>> from mindspore import jit_class
880
+ ...
881
+ >>> @jit_class
882
+ ... class UserDefinedNet:
883
+ ... def __init__(self):
884
+ ... self.value = 10
885
+ ...
886
+ ... def func(self, x):
887
+ ... return 2 * x
888
+ ...
889
+ >>> class Net(nn.Cell):
890
+ ... def __init__(self):
891
+ ... super(Net, self).__init__()
892
+ ... self.net = UserDefinedNet()
893
+ ...
894
+ ... def construct(self, x):
895
+ ... out = self.net.value + self.net.func(x)
896
+ ... return out
897
+ ...
898
+ >>> net = Net()
899
+ >>> out = net(5)
900
+ >>> print(out)
901
+ 20
902
+ """
903
+
904
+ # Check if cls is of type class.
905
+ if not inspect.isclass(cls):
906
+ raise TypeError(f'Decorator jit_class can only be used for class type, but got {cls}.')
907
+ # Check if cls is nn.Cell.
908
+ if issubclass(cls, ms.nn.Cell):
909
+ raise TypeError(f"Decorator jit_class is used for user-defined classes and cannot be used for nn.Cell: {cls}.")
910
+ setattr(cls, '__ms_class__', True)
911
+ return cls
912
+
913
+
914
+ def set_adapter_config(config):
915
+ """
916
+ Register configuration information for MSAdapter.
917
+
918
+ Args:
919
+ config (dict): Configuration information.
920
+ """
921
+ if not isinstance(config, dict):
922
+ raise TypeError(f"The input argument of 'set_adapter_config' should be a dict, but got {config}.")
923
+ for key, value in config.items():
924
+ if key == "Tensor":
925
+ setattr(value, "__adapter_tensor__", True)
926
+ ms_adapter_registry.register_tensor(value)
927
+ elif key == "convert_object_map":
928
+ ms_adapter_registry.register_convert_map(value)
929
+ else:
930
+ raise ValueError(f"Unsupported key in adapter config: {key}")
931
+
932
+
661
933
  def _function_forbid_reuse(func):
662
934
  if not inspect.isfunction(func):
663
935
  raise TypeError(f'Decorator _function_forbid_reuse can only be used for function type, but got {func}.')
@@ -665,23 +937,6 @@ def _function_forbid_reuse(func):
665
937
  return func
666
938
 
667
939
 
668
- def is_pynative_parallel():
669
- run_mode = context.get_context('mode')
670
- parallel_mode = context.get_auto_parallel_context('parallel_mode')
671
- return run_mode == context.PYNATIVE_MODE and parallel_mode in (
672
- context.ParallelMode.SEMI_AUTO_PARALLEL, context.ParallelMode.AUTO_PARALLEL)
673
-
674
-
675
- def _get_auto_split_param_names(parameter_layout_dict):
676
- auto_split_param_names = []
677
- for key, value in parameter_layout_dict.items():
678
- for dim in value[1]:
679
- if dim != -1:
680
- auto_split_param_names.append(key)
681
- break
682
- return auto_split_param_names
683
-
684
-
685
940
  def _build_broadcast_graph(broadcast_params_dict, broadcast_phase):
686
941
  """Build broadcast graph."""
687
942
  from mindspore.nn.wrap.cell_wrapper import _BroadCastCell
@@ -697,10 +952,24 @@ def _build_broadcast_graph(broadcast_params_dict, broadcast_phase):
697
952
  broadcast_params_dict[param_name].set_data(param)
698
953
 
699
954
 
700
- def _parameter_broadcast(obj, auto_parallel_mode):
701
- """Parameter broadcast."""
955
+ def _get_auto_split_param_names(parameter_layout_dict):
702
956
  auto_split_param_names = []
703
- if auto_parallel_mode:
957
+ for key, value in parameter_layout_dict.items():
958
+ for dim in value[1]:
959
+ if dim != -1:
960
+ auto_split_param_names.append(key)
961
+ break
962
+ return auto_split_param_names
963
+
964
+
965
+ def _parameter_broadcast(obj):
966
+ """
967
+ Parameter broadcast.
968
+ When the parallel mode is 'semi_auto_parallel' or 'auto_parallel', it will broadcast the parameters that have not
969
+ split.
970
+ """
971
+ auto_split_param_names = []
972
+ if _is_in_auto_parallel_mode():
704
973
  auto_split_param_names = _get_auto_split_param_names(obj.parameter_layout_dict)
705
974
 
706
975
  broadcast_params_dict = obj.parameters_broadcast_dict()
@@ -713,7 +982,7 @@ def _parameter_broadcast(obj, auto_parallel_mode):
713
982
  _build_broadcast_graph(broadcast_params_dict, broadcast_phase)
714
983
 
715
984
 
716
- class _PynativeExecutor:
985
+ class _PyNativeExecutor:
717
986
  """
718
987
  A pynative executor used to compile/manage/run single op.
719
988
 
@@ -731,27 +1000,59 @@ class _PynativeExecutor:
731
1000
  """
732
1001
 
733
1002
  def __init__(self):
734
- self._executor = PynativeExecutor_.get_instance()
1003
+ self._executor = PyNativeExecutor_.get_instance()
735
1004
  self._executor.set_py_exe_path(sys.executable)
736
1005
  self._executor.set_kernel_build_server_dir(os.path.split(kernel_build_server.__file__)[0] + os.sep)
737
- self._optimizer = None
738
1006
  self._top_cell = None
739
1007
 
1008
+ def __call__(self):
1009
+ """
1010
+ PyNative executor run grad graph.
1011
+
1012
+ Return:
1013
+ The return object after running grad graph.
1014
+ """
1015
+ return self._executor()
1016
+
740
1017
  @staticmethod
741
- def parameter_broadcast(obj, phase, auto_parallel_mode):
1018
+ def parameter_broadcast(obj, phase):
742
1019
  """
743
1020
  Run broadcast for parameter.
744
1021
 
745
1022
  Args:
746
1023
  obj (Cell): The cell instance.
747
1024
  phase (str): The phase of cell instance.
748
- auto_parallel_mode (bool): The flag of running auto parallel.
749
1025
 
750
1026
  Return:
751
1027
  None.
752
1028
  """
753
1029
  if BROADCAST_PHASE not in phase and _get_parameter_broadcast():
754
- _parameter_broadcast(obj, auto_parallel_mode)
1030
+ _parameter_broadcast(obj)
1031
+
1032
+ def real_run_op(self, *args):
1033
+ """
1034
+ Run single op.
1035
+
1036
+ Args:
1037
+ args (tuple): Op prim and input arguments.
1038
+
1039
+ Return:
1040
+ Tensor, result of run op.
1041
+ """
1042
+ return self._executor.real_run_op(*args)
1043
+
1044
+ def run_op_async(self, prim, args):
1045
+ """
1046
+ Run single op async.
1047
+
1048
+ Args:
1049
+ prim (Primitive): Op primitive
1050
+ args (tuple): input arguments.
1051
+
1052
+ Return:
1053
+ StubNode, result of run op.
1054
+ """
1055
+ return self._executor.run_op_async(prim, args)
755
1056
 
756
1057
  def new_graph(self, obj, *args, **kwargs):
757
1058
  """
@@ -782,7 +1083,7 @@ class _PynativeExecutor:
782
1083
  """
783
1084
  self._executor.end_graph(obj, output, *args, *(kwargs.values()))
784
1085
 
785
- def check_run(self, grad, obj, grad_hash_id, *args, **kwargs):
1086
+ def check_run(self, grad, obj, weights, grad_hash_id, *args, **kwargs):
786
1087
  """
787
1088
  Whether the forward graph need to construct.
788
1089
 
@@ -796,7 +1097,7 @@ class _PynativeExecutor:
796
1097
  Return:
797
1098
  bool, specifies whether the forward graph need to construct.
798
1099
  """
799
- return self._executor.check_run(grad, obj, grad_hash_id, *args, *(kwargs.values()))
1100
+ return self._executor.check_run(grad, obj, weights, grad_hash_id, *args, *(kwargs.values()))
800
1101
 
801
1102
  def grad(self, obj, grad, weights, grad_position, *args, **kwargs):
802
1103
  """
@@ -816,41 +1117,15 @@ class _PynativeExecutor:
816
1117
  """
817
1118
  self._executor.grad_net(grad, obj, weights, grad_position, *args, *(kwargs.values()))
818
1119
 
819
- def del_cell(self, obj):
820
- """
821
- Clean resource for cell.
822
-
823
- Args:
824
- obj (Function/Cell): The function or cell instance.
825
-
826
- Return:
827
- None.
828
- """
829
- self._executor.clear_cell(obj)
830
-
831
1120
  def clear_res(self):
832
1121
  """
833
- Clean resource for _PynativeExecutor.
1122
+ Clean resource for _PyNativeExecutor.
834
1123
 
835
1124
  Return:
836
1125
  None.
837
1126
  """
838
1127
  return self._executor.clear_res()
839
1128
 
840
- def clear_grad(self, obj, *args, **kwargs):
841
- """
842
- Clean resource after building grad graph.
843
-
844
- Args:
845
- obj (Function/Cell): The function or cell instance.
846
- args (tuple): Function or cell input arguments.
847
- kwargs (dict): keyword arguments.
848
-
849
- Return:
850
- None.
851
- """
852
- self._executor.clear_grad(obj, *args, *(kwargs.values()))
853
-
854
1129
  def sync(self):
855
1130
  """
856
1131
  SyncStream.
@@ -885,18 +1160,6 @@ class _PynativeExecutor:
885
1160
  """
886
1161
  return self._executor.grad_ms_function(output, *args)
887
1162
 
888
- def set_graph_phase(self, phase):
889
- """
890
- Set the phase of cell/function instance.
891
-
892
- Args:
893
- phase (str): The phase of cell/function instance.
894
-
895
- Return:
896
- None.
897
- """
898
- self._executor.set_graph_phase(phase)
899
-
900
1163
  def grad_flag(self):
901
1164
  """
902
1165
  The flag of building grad graph.
@@ -918,30 +1181,29 @@ class _PynativeExecutor:
918
1181
  """
919
1182
  self._executor.set_grad_flag(flag)
920
1183
 
921
- def set_dynamic_input(self, obj, *args):
1184
+ def set_ms_function_compile_status(self, status, phase):
922
1185
  """
923
- Set dynamic shape tensor of input arguments.
1186
+ Set ms_function is compiling
924
1187
 
925
1188
  Args:
926
- obj (Function/Cell): The function or cell instance.
927
- args (tuple): Function or cell dynamic input arguments.
928
-
1189
+ status(bool): ms_function compile status
1190
+ phase (str): The phase of cell/function instance.
929
1191
  Return:
930
1192
  None.
931
1193
  """
932
- self._executor.set_dynamic_input(obj, *args)
1194
+ self._executor.set_ms_function_compile_status(status, phase)
933
1195
 
934
- def get_dynamic_input(self, *actual_args):
1196
+ def set_dynamic_input(self, obj):
935
1197
  """
936
- Get dynamic shape arguments according to actual input arguments.
1198
+ Set dynamic shape tensor of input arguments.
937
1199
 
938
1200
  Args:
939
- actual_args(tuple): Actual input arguments of Function or Cell.
1201
+ obj (Function/Cell): The function or cell instance.
940
1202
 
941
1203
  Return:
942
- dynamic_shape_args(tuple): Dynamic shape arguments of Function or Cell.
1204
+ None.
943
1205
  """
944
- return self._executor.get_dynamic_input(*actual_args)
1206
+ self._executor.set_dynamic_input(obj)
945
1207
 
946
1208
  def is_first_cell(self):
947
1209
  """
@@ -965,14 +1227,6 @@ class _PynativeExecutor:
965
1227
  """
966
1228
  self._executor.set_hook_changed(cell)
967
1229
 
968
- def get_optimizer(self):
969
- """
970
- Get the optimizer.
971
-
972
- Return:
973
- The optimizer.
974
- """
975
- return self._optimizer
976
1230
 
977
1231
  def get_top_cell(self):
978
1232
  """
@@ -983,20 +1237,18 @@ class _PynativeExecutor:
983
1237
  """
984
1238
  return self._top_cell
985
1239
 
986
- def __call__(self, sens_param, obj, *args, **kwargs):
1240
+
1241
+ def constant_folding(self, *args):
987
1242
  """
988
- PyNative executor run grad graph.
1243
+ Get value by infer value.
989
1244
 
990
1245
  Args:
991
- obj (Function/Cell): The function or cell instance.
992
- args (tuple): Function or cell input arguments.
993
- kwargs (dict): keyword arguments.
1246
+ args (tuple): Op prim and input arguments.
994
1247
 
995
1248
  Return:
996
- The return object after running grad graph.
1249
+ Tensor, the value get by op infer.
997
1250
  """
998
- args = args + tuple(kwargs.values())
999
- return self._executor(sens_param, obj, args)
1251
+ return self._executor.constant_folding(*args)
1000
1252
 
1001
1253
 
1002
1254
  class _CellGraphExecutor:
@@ -1012,10 +1264,12 @@ class _CellGraphExecutor:
1012
1264
  Returns:
1013
1265
  Graph, return the result of pipeline running.
1014
1266
  """
1267
+
1015
1268
  def __init__(self):
1016
1269
  # create needed graph by lazy mode
1017
1270
  self.is_init = False
1018
1271
  self.enable_tuple_broaden = False
1272
+ self.obfuscate_config = None # used for model's dynamic obfuscation
1019
1273
  self._graph_executor = GraphExecutor_.get_instance()
1020
1274
  self._graph_executor.set_py_exe_path(sys.executable)
1021
1275
  self._graph_executor.set_kernel_build_server_dir(os.path.split(kernel_build_server.__file__)[0] + os.sep)
@@ -1084,17 +1338,17 @@ class _CellGraphExecutor:
1084
1338
  if "train" in phase and (enable_compile_cache is True or enable_compile_cache == "1"):
1085
1339
  self._graph_executor.set_compile_cache_dep_files(_get_compile_cache_dep_files())
1086
1340
 
1087
- def compile(self, obj, *args, phase='predict', do_convert=True, auto_parallel_mode=False, jit_config_dict=None):
1341
+ def compile(self, obj, *args, phase='predict', do_convert=True, jit_config_dict=None, **kwargs):
1088
1342
  """
1089
1343
  Compiles graph.
1090
1344
 
1091
1345
  Args:
1092
1346
  obj (Function/Cell): The function or cell instance need compile.
1093
- args (tuple): Function or cell input arguments.
1094
1347
  phase (str): The name of compile phase. Default: 'predict'.
1095
1348
  do_convert (bool): When set to True, convert ME graph to GE graph after compiling graph.
1096
- auto_parallel_mode: When set to True, use auto parallel mode to compile graph.
1097
1349
  jit_config_dict (dict): Jit config for compile. Default: None.
1350
+ args (tuple): Args of the Cell object.
1351
+ kwargs (dict): Kwargs of the Cell object.
1098
1352
 
1099
1353
  Return:
1100
1354
  Str, the full phase of the cell.
@@ -1104,14 +1358,13 @@ class _CellGraphExecutor:
1104
1358
  if not hasattr(obj, obj.__parse_method__):
1105
1359
  raise AttributeError(
1106
1360
  'The class {} dose not have method {}'.format(obj.__class__.__name__, obj.__parse_method__))
1107
- args_list = args
1108
1361
 
1109
1362
  self.enable_tuple_broaden = False
1110
1363
  if hasattr(obj, "enable_tuple_broaden"):
1111
1364
  self.enable_tuple_broaden = obj.enable_tuple_broaden
1112
1365
 
1113
1366
  self._graph_executor.set_enable_tuple_broaden(self.enable_tuple_broaden)
1114
- key = self._graph_executor.generate_arguments_key(args_list, self.enable_tuple_broaden)
1367
+ key = self._graph_executor.generate_arguments_key(obj, args, kwargs, self.enable_tuple_broaden)
1115
1368
  obj.arguments_key = str(key)
1116
1369
  phase = phase + '.' + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key
1117
1370
 
@@ -1121,14 +1374,14 @@ class _CellGraphExecutor:
1121
1374
 
1122
1375
  obj.check_names()
1123
1376
  _check_full_batch()
1124
- self._set_dataset_mode(args_list)
1377
+ self._set_dataset_mode(args)
1125
1378
  self._set_compile_cache_dep_files(phase)
1126
1379
 
1127
1380
  enable_ge = context.get_context("enable_ge")
1128
1381
  self._graph_executor.set_weights_values(obj.parameters_dict())
1129
1382
  if jit_config_dict:
1130
1383
  self._graph_executor.set_jit_config(jit_config_dict)
1131
- result = self._graph_executor.compile(obj, args_list, phase, self._use_vm_mode())
1384
+ result = self._graph_executor.compile(obj, args, kwargs, phase, self._use_vm_mode())
1132
1385
  obj.compile_cache.add(phase)
1133
1386
  if not result:
1134
1387
  raise RuntimeError("Executor compile failed.")
@@ -1137,7 +1390,15 @@ class _CellGraphExecutor:
1137
1390
  if graph is None:
1138
1391
  raise RuntimeError("Compile graph failed for phase {}.".format(phase))
1139
1392
 
1140
- self._auto_parallel_process(obj, phase, auto_parallel_mode, *args)
1393
+ auto_parallel_mode = _is_in_auto_parallel_mode()
1394
+ if not auto_parallel_mode:
1395
+ replace = obj.init_parameters_data(auto_parallel_mode=auto_parallel_mode)
1396
+ self._update_param_node_default_input(phase, replace)
1397
+ elif 'skip_auto_parallel_compile' not in obj.get_flags().keys():
1398
+ obj.parameter_layout_dict = self._graph_executor.get_parameter_layout(phase)
1399
+ obj.parallel_parameter_name_list = self._graph_executor.get_parallel_parameter_name_list(phase)
1400
+ if _get_pipeline_stages() > 1 and (not hasattr(obj, "is_first_iteration") or not obj.is_first_iteration):
1401
+ obj.remove_redundant_parameters()
1141
1402
 
1142
1403
  if not do_convert:
1143
1404
  return phase, True
@@ -1148,27 +1409,10 @@ class _CellGraphExecutor:
1148
1409
  elif "export" in phase:
1149
1410
  self._build_data_graph(obj, phase)
1150
1411
  elif BROADCAST_PHASE not in phase and _get_parameter_broadcast():
1151
- _parameter_broadcast(obj, auto_parallel_mode)
1412
+ _parameter_broadcast(obj)
1152
1413
 
1153
1414
  return phase, True
1154
1415
 
1155
- def _auto_parallel_process(self, obj, phase, auto_parallel_mode, *args):
1156
- """compile graph in auto parallel mode."""
1157
- if not auto_parallel_mode:
1158
- replace = obj.init_parameters_data(auto_parallel_mode=auto_parallel_mode)
1159
- self._update_param_node_default_input(phase, replace)
1160
- return
1161
-
1162
- obj.parameter_layout_dict = self._graph_executor.get_parameter_layout(phase)
1163
- obj.parallel_parameter_name_list = self._graph_executor.get_parallel_parameter_name_list(phase)
1164
- replace = obj.init_parameters_data(auto_parallel_mode=True)
1165
- if _get_pipeline_stages() > 1 and (not hasattr(obj, "is_first_iteration") or not obj.is_first_iteration):
1166
- obj.remove_redundant_parameters()
1167
- if not context.get_context("enable_debug_runtime") or context.get_context("enable_ge"):
1168
- obj.load_parameter_slice(None)
1169
-
1170
- self._update_param_node_default_input(phase, replace)
1171
-
1172
1416
  def _update_param_node_default_input(self, phase, replace):
1173
1417
  new_param = {x.name: replace[x] for x in replace if id(x) != id(replace[x])}
1174
1418
  return self._graph_executor.updata_param_node_default_input(phase, new_param)
@@ -1186,8 +1430,7 @@ class _CellGraphExecutor:
1186
1430
  return self._graph_executor.get_allreduce_fusion(real_phase)
1187
1431
 
1188
1432
  def __call__(self, obj, *args, phase='predict'):
1189
- if context.get_context("precompile_only") or\
1190
- (_is_role_pserver() and not _enable_distributed_mindrt()) or _is_role_sched():
1433
+ if context.get_context("precompile_only") or _is_role_sched():
1191
1434
  return None
1192
1435
  return self.run(obj, *args, phase=phase)
1193
1436
 
@@ -1215,6 +1458,8 @@ class _CellGraphExecutor:
1215
1458
  Run the specific graph.
1216
1459
 
1217
1460
  Args:
1461
+ obj (Cell): The cell object.
1462
+ args (tuple): Args of the Cell object.
1218
1463
  phase (str): The phase name. Default: 'predict'.
1219
1464
 
1220
1465
  Returns:
@@ -1229,16 +1474,33 @@ class _CellGraphExecutor:
1229
1474
  return self._exec_pip(obj, *args, phase=phase_real)
1230
1475
  raise KeyError('{} graph is not exist.'.format(phase_real))
1231
1476
 
1232
- def del_net_res(self, net_id):
1233
- self._graph_executor.del_net_res(net_id)
1477
+ def del_net_res(self, obj, net_id):
1478
+ """Clear the memory resource of a network."""
1479
+ self._graph_executor.del_net_res(obj, net_id)
1480
+
1481
+ def _get_branch_control_input(self):
1482
+ if ('obf_ratio' not in self.obfuscate_config.keys()) or (
1483
+ 'obf_random_seed' not in self.obfuscate_config.keys()):
1484
+ raise ValueError("'obf_ratio' and 'obf_random_seed' must be in obfuscate_config.")
1485
+ obf_random_seed = self.obfuscate_config.get('obf_random_seed')
1486
+ if obf_random_seed == 0:
1487
+ branch_control_input = 0
1488
+ else:
1489
+ branch_control_input = _generate_branch_control_input(obf_random_seed)
1490
+ return branch_control_input
1234
1491
 
1235
- def _get_func_graph_proto(self, obj, exec_id, ir_type="onnx_ir", use_prefix=False):
1492
+ def _get_func_graph_proto(self, obj, exec_id, ir_type="onnx_ir", use_prefix=False, incremental=False):
1236
1493
  """Get graph proto from pipeline."""
1237
1494
  if use_prefix:
1238
1495
  exec_id = exec_id + '.' + obj.arguments_key
1239
1496
  if self._graph_executor.has_compiled(exec_id) is False:
1240
1497
  return None
1241
- return self._graph_executor.get_func_graph_proto(exec_id, ir_type)
1498
+ if self.obfuscate_config is not None:
1499
+ branch_control_input = self._get_branch_control_input()
1500
+ return self._graph_executor.get_obfuscate_func_graph_proto(exec_id, incremental,
1501
+ self.obfuscate_config['obf_ratio'],
1502
+ branch_control_input)
1503
+ return self._graph_executor.get_func_graph_proto(exec_id, ir_type, incremental)
1242
1504
 
1243
1505
  def get_optimize_graph_proto(self, obj):
1244
1506
  """Return optimize graph binary proto."""
@@ -1261,12 +1523,6 @@ class _CellGraphExecutor:
1261
1523
  """
1262
1524
  self._graph_executor.export_graph(file_name, graph_id, encrypt_func, enc_key)
1263
1525
 
1264
- def fetch_info_for_quant_export(self, exec_id):
1265
- """Get graph proto from pipeline."""
1266
- if self._graph_executor.has_compiled(exec_id) is False:
1267
- return None
1268
- return self._graph_executor.fetch_info_for_quant_export(exec_id)
1269
-
1270
1526
 
1271
1527
  def ms_memory_recycle():
1272
1528
  """
@@ -1276,16 +1532,45 @@ def ms_memory_recycle():
1276
1532
  To recycle these cached memory, users can call this function after training of one model.
1277
1533
  """
1278
1534
  if ms_compile_cache:
1279
- _cell_graph_executor.del_net_res(ms_compile_cache)
1535
+ _cell_graph_executor.del_net_res(None, ms_compile_cache)
1280
1536
  ms_compile_cache.clear()
1281
1537
  for cell_cache in cells_compile_cache.values():
1282
1538
  if cell_cache:
1283
- _cell_graph_executor.del_net_res(cell_cache)
1539
+ _cell_graph_executor.del_net_res(None, cell_cache)
1284
1540
  cell_cache.clear()
1285
1541
  _ms_memory_recycle()
1286
1542
 
1287
1543
 
1544
+ def _generate_branch_control_input(obf_random_seed):
1545
+ """Generate append network input for dynamic obfuscation in random seed mode."""
1546
+ seed_max = 2 ** 32 - 1
1547
+ int_max = 2 ** 31 - 1
1548
+ np.random.seed(obf_random_seed % seed_max)
1549
+ # generate a string as hash function inputs
1550
+ word_repo = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + "abcdefghigklmnopqrstuvwxyz" + "0123456789"
1551
+ repo_len = len(word_repo)
1552
+ sha_string = ''
1553
+ string_len = 1024 * 1024
1554
+ for _ in range(string_len):
1555
+ rand_index = np.random.randint(0, repo_len)
1556
+ sha_string += word_repo[rand_index]
1557
+ # get hash result
1558
+ sha_result = hashlib.sha256(sha_string.encode('utf-8')).hexdigest() # len is 64
1559
+ branch_control_input = 1
1560
+ hex_base = 16
1561
+ for item in sha_result:
1562
+ if int(item, hex_base) > 0:
1563
+ branch_control_input *= int(item, hex_base)
1564
+ branch_control_input %= int_max
1565
+ return branch_control_input
1566
+
1567
+
1568
+ def _bind_device_context():
1569
+ """Bind device context to current thread"""
1570
+ _bind_device_ctx()
1571
+
1572
+
1288
1573
  _cell_graph_executor = _CellGraphExecutor()
1289
- _pynative_executor = _PynativeExecutor()
1574
+ _pynative_executor = _PyNativeExecutor()
1290
1575
 
1291
- __all__ = ['ms_function', 'ms_memory_recycle', 'ms_class']
1576
+ __all__ = ['ms_function', 'ms_memory_recycle', 'ms_class', 'jit', 'jit_class']