mindspore 1.10.0__cp37-cp37m-win_amd64.whl → 2.0.0rc1__cp37-cp37m-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (966) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/ConcurrencyCheck.dll +0 -0
  3. mindspore/CppBuildInsights.dll +0 -0
  4. mindspore/CppCoreCheck.dll +0 -0
  5. mindspore/EnumIndex.dll +0 -0
  6. mindspore/EspXEngine.dll +0 -0
  7. mindspore/HResultCheck.dll +0 -0
  8. mindspore/KernelTraceControl.dll +0 -0
  9. mindspore/LocalESPC.dll +0 -0
  10. mindspore/Microsoft.Diagnostics.Tracing.EventSource.dll +0 -0
  11. mindspore/Microsoft.VisualStudio.RemoteControl.dll +0 -0
  12. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  13. mindspore/Microsoft.VisualStudio.Utilities.Internal.dll +0 -0
  14. mindspore/Newtonsoft.Json.dll +0 -0
  15. mindspore/System.Runtime.CompilerServices.Unsafe.dll +0 -0
  16. mindspore/VariantClear.dll +0 -0
  17. mindspore/__init__.py +9 -4
  18. mindspore/_c_dataengine.cp37-win_amd64.pyd +0 -0
  19. mindspore/_c_expression.cp37-win_amd64.pyd +0 -0
  20. mindspore/_c_mindrecord.cp37-win_amd64.pyd +0 -0
  21. mindspore/_check_jit_forbidden_api.py +102 -0
  22. mindspore/_checkparam.py +1066 -1001
  23. mindspore/_extends/builtin_operations.py +32 -4
  24. mindspore/_extends/graph_kernel/model/graph_split.py +66 -222
  25. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +12 -9
  26. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +119 -26
  27. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +50 -50
  28. mindspore/_extends/parallel_compile/akg_compiler/util.py +9 -6
  29. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +4 -25
  30. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +9 -4
  31. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -27
  32. mindspore/_extends/parse/__init__.py +5 -3
  33. mindspore/_extends/parse/namespace.py +17 -2
  34. mindspore/_extends/parse/parser.py +193 -34
  35. mindspore/_extends/parse/resources.py +7 -8
  36. mindspore/_extends/parse/standard_method.py +1780 -435
  37. mindspore/_extends/parse/trope.py +3 -1
  38. mindspore/amp.py +53 -58
  39. mindspore/atlprov.dll +0 -0
  40. mindspore/boost/adasum.py +3 -2
  41. mindspore/boost/boost.py +2 -2
  42. mindspore/boost/boost_cell_wrapper.py +46 -26
  43. mindspore/boost/dim_reduce.py +6 -5
  44. mindspore/boost/grad_accumulation.py +2 -1
  45. mindspore/boost/group_loss_scale_manager.py +1 -1
  46. mindspore/c1.dll +0 -0
  47. mindspore/c1xx.dll +0 -0
  48. mindspore/c2.dll +0 -0
  49. mindspore/cfgpersist.dll +0 -0
  50. mindspore/clang_rt.asan_dbg_dynamic-x86_64.dll +0 -0
  51. mindspore/clang_rt.asan_dynamic-x86_64.dll +0 -0
  52. mindspore/common/__init__.py +11 -10
  53. mindspore/common/_decorator.py +2 -0
  54. mindspore/common/_register_for_adapter.py +55 -0
  55. mindspore/common/_stub_tensor.py +201 -0
  56. mindspore/common/_utils.py +57 -0
  57. mindspore/common/api.py +582 -297
  58. mindspore/common/dtype.py +66 -18
  59. mindspore/common/dump.py +2 -2
  60. mindspore/common/initializer.py +38 -1
  61. mindspore/common/jit_config.py +25 -13
  62. mindspore/common/mutable.py +53 -24
  63. mindspore/common/parameter.py +60 -37
  64. mindspore/common/seed.py +8 -24
  65. mindspore/common/sparse_tensor.py +927 -0
  66. mindspore/common/tensor.py +1627 -3900
  67. mindspore/communication/__init__.py +10 -5
  68. mindspore/communication/_comm_helper.py +78 -214
  69. mindspore/communication/_hccl_management.py +2 -1
  70. mindspore/communication/management.py +136 -47
  71. mindspore/config/op_info.config +501 -1008
  72. mindspore/context.py +291 -56
  73. mindspore/d3dcompiler_47.dll +0 -0
  74. mindspore/dataset/__init__.py +12 -8
  75. mindspore/dataset/audio/__init__.py +9 -9
  76. mindspore/dataset/audio/transforms.py +1090 -228
  77. mindspore/dataset/audio/utils.py +87 -39
  78. mindspore/dataset/audio/validators.py +223 -1
  79. mindspore/dataset/callback/ds_callback.py +17 -15
  80. mindspore/dataset/core/config.py +246 -17
  81. mindspore/dataset/core/py_util_helpers.py +4 -3
  82. mindspore/dataset/core/validator_helpers.py +10 -10
  83. mindspore/{parallel/nn/layers.py → dataset/debug/__init__.py} +7 -8
  84. mindspore/dataset/debug/debug_hook.py +65 -0
  85. mindspore/dataset/debug/pre_defined_hook.py +67 -0
  86. mindspore/dataset/engine/__init__.py +7 -3
  87. mindspore/dataset/engine/cache_client.py +9 -9
  88. mindspore/dataset/engine/datasets.py +648 -477
  89. mindspore/dataset/engine/datasets_audio.py +165 -167
  90. mindspore/dataset/engine/datasets_standard_format.py +93 -67
  91. mindspore/dataset/engine/datasets_text.py +492 -342
  92. mindspore/dataset/engine/datasets_user_defined.py +85 -50
  93. mindspore/dataset/engine/datasets_vision.py +1224 -699
  94. mindspore/dataset/engine/graphdata.py +134 -69
  95. mindspore/dataset/engine/iterators.py +50 -9
  96. mindspore/dataset/engine/offload.py +52 -31
  97. mindspore/dataset/engine/samplers.py +27 -24
  98. mindspore/dataset/engine/serializer_deserializer.py +14 -15
  99. mindspore/dataset/engine/validators.py +213 -52
  100. mindspore/dataset/text/__init__.py +10 -8
  101. mindspore/dataset/text/transforms.py +152 -57
  102. mindspore/dataset/text/utils.py +98 -49
  103. mindspore/dataset/text/validators.py +25 -0
  104. mindspore/dataset/transforms/__init__.py +4 -2
  105. mindspore/dataset/transforms/c_transforms.py +11 -13
  106. mindspore/dataset/transforms/py_transforms.py +2 -2
  107. mindspore/dataset/transforms/py_transforms_util.py +10 -0
  108. mindspore/dataset/transforms/transforms.py +13 -15
  109. mindspore/dataset/transforms/validators.py +7 -7
  110. mindspore/dataset/utils/__init__.py +2 -1
  111. mindspore/dataset/utils/browse_dataset.py +13 -13
  112. mindspore/dataset/utils/line_reader.py +121 -0
  113. mindspore/dataset/vision/__init__.py +8 -7
  114. mindspore/dataset/vision/c_transforms.py +125 -126
  115. mindspore/dataset/vision/py_transforms.py +37 -37
  116. mindspore/dataset/vision/py_transforms_util.py +23 -20
  117. mindspore/dataset/vision/transforms.py +316 -315
  118. mindspore/dataset/vision/utils.py +313 -17
  119. mindspore/dataset/vision/validators.py +6 -6
  120. mindspore/default_config.py +0 -1
  121. mindspore/dpcmi.dll +0 -0
  122. mindspore/{compression → experimental}/__init__.py +6 -5
  123. mindspore/experimental/map_parameter.py +275 -0
  124. mindspore/include/OWNERS +0 -1
  125. mindspore/include/api/callback/callback.h +9 -13
  126. mindspore/include/api/callback/ckpt_saver.h +2 -2
  127. mindspore/include/api/callback/loss_monitor.h +2 -2
  128. mindspore/include/api/callback/lr_scheduler.h +5 -5
  129. mindspore/include/api/callback/time_monitor.h +2 -2
  130. mindspore/include/api/callback/train_accuracy.h +4 -6
  131. mindspore/include/api/cfg.h +19 -6
  132. mindspore/include/api/context.h +70 -9
  133. mindspore/include/api/delegate.h +8 -1
  134. mindspore/include/api/dual_abi_helper.h +8 -24
  135. mindspore/include/api/metrics/accuracy.h +2 -2
  136. mindspore/include/api/metrics/metrics.h +4 -3
  137. mindspore/include/api/model.h +9 -4
  138. mindspore/include/api/model_group.h +68 -0
  139. mindspore/include/api/model_parallel_runner.h +17 -17
  140. mindspore/include/api/net.h +12 -11
  141. mindspore/include/api/serialization.h +20 -4
  142. mindspore/include/api/status.h +7 -1
  143. mindspore/include/api/types.h +25 -21
  144. mindspore/include/api/visible.h +4 -0
  145. mindspore/include/c_api/model_c.h +5 -0
  146. mindspore/include/c_api/status_c.h +1 -1
  147. mindspore/include/dataset/config.h +1 -1
  148. mindspore/include/dataset/constants.h +14 -0
  149. mindspore/include/dataset/text.h +59 -0
  150. mindspore/include/dataset/vision.h +56 -117
  151. mindspore/include/dataset/vision_lite.h +102 -0
  152. mindspore/jpeg62.dll +0 -0
  153. mindspore/log.py +28 -28
  154. mindspore/mindrecord/common/exceptions.py +2 -4
  155. mindspore/mindrecord/filereader.py +19 -1
  156. mindspore/mindrecord/filewriter.py +250 -88
  157. mindspore/mindrecord/mindpage.py +13 -13
  158. mindspore/mindrecord/shardheader.py +15 -15
  159. mindspore/mindrecord/shardreader.py +9 -0
  160. mindspore/mindrecord/shardwriter.py +29 -29
  161. mindspore/mindrecord/tools/cifar100_to_mr.py +9 -9
  162. mindspore/mindrecord/tools/cifar10_to_mr.py +9 -9
  163. mindspore/mindrecord/tools/csv_to_mr.py +4 -4
  164. mindspore/mindrecord/tools/imagenet_to_mr.py +70 -65
  165. mindspore/mindrecord/tools/mnist_to_mr.py +41 -41
  166. mindspore/mindrecord/tools/tfrecord_to_mr.py +6 -6
  167. mindspore/{libmindspore_backend.dll → mindspore_backend.dll} +0 -0
  168. mindspore/mindspore_common.dll +0 -0
  169. mindspore/mindspore_core.dll +0 -0
  170. mindspore/mindspore_glog.dll +0 -0
  171. mindspore/mindspore_shared_lib.dll +0 -0
  172. mindspore/msobj140.dll +0 -0
  173. mindspore/mspdb140.dll +0 -0
  174. mindspore/mspdbcore.dll +0 -0
  175. mindspore/mspdbst.dll +0 -0
  176. mindspore/mspft140.dll +0 -0
  177. mindspore/msvcdis140.dll +0 -0
  178. mindspore/msvcp140_1.dll +0 -0
  179. mindspore/msvcp140_2.dll +0 -0
  180. mindspore/msvcp140_atomic_wait.dll +0 -0
  181. mindspore/msvcp140_codecvt_ids.dll +0 -0
  182. mindspore/nn/__init__.py +1 -5
  183. mindspore/nn/cell.py +297 -234
  184. mindspore/nn/dynamic_lr.py +1 -1
  185. mindspore/nn/grad/cell_grad.py +17 -42
  186. mindspore/nn/layer/__init__.py +7 -4
  187. mindspore/nn/layer/activation.py +131 -88
  188. mindspore/nn/layer/basic.py +313 -613
  189. mindspore/nn/layer/channel_shuffle.py +103 -0
  190. mindspore/nn/layer/combined.py +1 -1
  191. mindspore/nn/layer/container.py +52 -6
  192. mindspore/nn/layer/conv.py +112 -43
  193. mindspore/nn/layer/dense.py +10 -9
  194. mindspore/nn/layer/embedding.py +36 -34
  195. mindspore/nn/layer/image.py +123 -27
  196. mindspore/nn/layer/math.py +108 -107
  197. mindspore/nn/layer/normalization.py +212 -366
  198. mindspore/nn/layer/padding.py +370 -42
  199. mindspore/nn/layer/pooling.py +1443 -219
  200. mindspore/nn/layer/rnn_cells.py +11 -16
  201. mindspore/nn/layer/rnns.py +38 -39
  202. mindspore/nn/layer/thor_layer.py +24 -25
  203. mindspore/nn/layer/timedistributed.py +5 -5
  204. mindspore/nn/layer/transformer.py +701 -0
  205. mindspore/nn/learning_rate_schedule.py +8 -8
  206. mindspore/nn/loss/__init__.py +9 -6
  207. mindspore/nn/loss/loss.py +678 -142
  208. mindspore/nn/metrics.py +53 -0
  209. mindspore/nn/optim/_dist_optimizer_registry.py +2 -2
  210. mindspore/nn/optim/ada_grad.py +8 -8
  211. mindspore/nn/optim/adadelta.py +2 -3
  212. mindspore/nn/optim/adafactor.py +18 -14
  213. mindspore/nn/optim/adam.py +429 -87
  214. mindspore/nn/optim/adamax.py +5 -6
  215. mindspore/nn/optim/adasum.py +10 -8
  216. mindspore/nn/optim/asgd.py +7 -7
  217. mindspore/nn/optim/ftrl.py +81 -11
  218. mindspore/nn/optim/lamb.py +7 -8
  219. mindspore/nn/optim/lars.py +4 -4
  220. mindspore/nn/optim/lazyadam.py +82 -7
  221. mindspore/nn/optim/momentum.py +8 -7
  222. mindspore/nn/optim/optimizer.py +19 -10
  223. mindspore/nn/optim/proximal_ada_grad.py +6 -5
  224. mindspore/nn/optim/rmsprop.py +3 -3
  225. mindspore/nn/optim/rprop.py +20 -16
  226. mindspore/nn/optim/sgd.py +21 -15
  227. mindspore/nn/optim/thor.py +23 -21
  228. mindspore/nn/probability/__init__.py +0 -2
  229. mindspore/nn/probability/bijector/bijector.py +7 -6
  230. mindspore/nn/probability/bijector/invert.py +4 -2
  231. mindspore/nn/probability/bijector/softplus.py +2 -2
  232. mindspore/nn/probability/bnn_layers/dense_variational.py +1 -1
  233. mindspore/nn/probability/bnn_layers/layer_distribution.py +2 -2
  234. mindspore/nn/probability/distribution/__init__.py +6 -0
  235. mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -2
  236. mindspore/nn/probability/distribution/_utils/utils.py +11 -17
  237. mindspore/nn/probability/distribution/bernoulli.py +6 -6
  238. mindspore/nn/probability/distribution/beta.py +1 -1
  239. mindspore/nn/probability/distribution/categorical.py +9 -9
  240. mindspore/nn/probability/distribution/cauchy.py +8 -8
  241. mindspore/nn/probability/distribution/distribution.py +12 -6
  242. mindspore/nn/probability/distribution/exponential.py +5 -5
  243. mindspore/nn/probability/distribution/gamma.py +3 -3
  244. mindspore/nn/probability/distribution/geometric.py +6 -5
  245. mindspore/nn/probability/distribution/gumbel.py +5 -5
  246. mindspore/nn/probability/distribution/half_normal.py +133 -0
  247. mindspore/nn/probability/distribution/laplace.py +128 -0
  248. mindspore/nn/probability/distribution/log_normal.py +0 -1
  249. mindspore/nn/probability/distribution/logistic.py +4 -5
  250. mindspore/nn/probability/distribution/normal.py +11 -15
  251. mindspore/nn/probability/distribution/poisson.py +6 -2
  252. mindspore/nn/probability/distribution/student_t.py +150 -0
  253. mindspore/nn/probability/distribution/transformed_distribution.py +4 -4
  254. mindspore/nn/probability/distribution/uniform.py +5 -5
  255. mindspore/nn/reinforcement/_tensors_queue.py +3 -3
  256. mindspore/nn/reinforcement/tensor_array.py +2 -2
  257. mindspore/nn/sparse/sparse.py +8 -1
  258. mindspore/nn/wrap/cell_wrapper.py +55 -27
  259. mindspore/nn/wrap/grad_reducer.py +20 -11
  260. mindspore/nn/wrap/loss_scale.py +47 -30
  261. mindspore/numpy/array_creations.py +33 -22
  262. mindspore/numpy/array_ops.py +46 -42
  263. mindspore/numpy/logic_ops.py +6 -27
  264. mindspore/numpy/math_ops.py +26 -19
  265. mindspore/numpy/utils.py +1 -8
  266. mindspore/numpy/utils_const.py +112 -62
  267. mindspore/opencv_core452.dll +0 -0
  268. mindspore/opencv_imgcodecs452.dll +0 -0
  269. mindspore/opencv_imgproc452.dll +0 -0
  270. mindspore/ops/__init__.py +6 -3
  271. mindspore/ops/_constants.py +0 -6
  272. mindspore/ops/_grad/__init__.py +2 -1
  273. mindspore/ops/_grad/grad_array_ops.py +209 -152
  274. mindspore/ops/_grad/grad_base.py +55 -17
  275. mindspore/ops/_grad/grad_clip_ops.py +11 -3
  276. mindspore/ops/_grad/grad_comm_ops.py +58 -47
  277. mindspore/ops/_grad/grad_implementations.py +21 -61
  278. mindspore/ops/_grad/grad_inner_ops.py +48 -6
  279. mindspore/ops/_grad/grad_math_ops.py +306 -161
  280. mindspore/ops/_grad/grad_nn_ops.py +192 -181
  281. mindspore/ops/_grad/grad_other_ops.py +1 -1
  282. mindspore/ops/_grad/grad_quant_ops.py +5 -5
  283. mindspore/ops/_grad/grad_sequence_ops.py +296 -0
  284. mindspore/ops/_grad/grad_sparse.py +15 -9
  285. mindspore/ops/_grad_experimental/__init__.py +1 -0
  286. mindspore/ops/_grad_experimental/grad_array_ops.py +441 -55
  287. mindspore/ops/_grad_experimental/grad_image_ops.py +25 -7
  288. mindspore/ops/_grad_experimental/grad_inner_ops.py +3 -44
  289. mindspore/ops/_grad_experimental/grad_linalg_ops.py +16 -21
  290. mindspore/ops/_grad_experimental/grad_math_ops.py +979 -49
  291. mindspore/ops/_grad_experimental/grad_nn_ops.py +78 -8
  292. mindspore/ops/_grad_experimental/grad_scalar_ops.py +112 -0
  293. mindspore/ops/_grad_experimental/grad_sparse_ops.py +197 -13
  294. mindspore/ops/_op_impl/__init__.py +3 -3
  295. mindspore/ops/_op_impl/_custom_op/__init__.py +0 -1
  296. mindspore/ops/_op_impl/_custom_op/_basic.py +0 -1
  297. mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +1 -1
  298. mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +4 -2
  299. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +2 -2
  300. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +2 -2
  301. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +5 -5
  302. mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +3 -3
  303. mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +1 -1
  304. mindspore/ops/_op_impl/_custom_op/correction_mul.py +3 -3
  305. mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +2 -2
  306. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +4 -8
  307. mindspore/ops/_op_impl/_custom_op/dsd_impl.py +1 -1
  308. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +2 -2
  309. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +2 -2
  310. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +2 -2
  311. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +2 -2
  312. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +2 -2
  313. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +2 -2
  314. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +2 -2
  315. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +2 -2
  316. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +2 -2
  317. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +2 -2
  318. mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +1 -1
  319. mindspore/ops/_op_impl/_custom_op/img2col_impl.py +1 -1
  320. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
  321. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +1 -1
  322. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +1 -1
  323. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +1 -1
  324. mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +2 -2
  325. mindspore/ops/_op_impl/_custom_op/matmul_dds_grad_impl.py +0 -1
  326. mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +0 -1
  327. mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +1 -1
  328. mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +2 -2
  329. mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +2 -2
  330. mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +1 -1
  331. mindspore/ops/_op_impl/aicpu/__init__.py +238 -3
  332. mindspore/ops/_op_impl/aicpu/abs.py +36 -0
  333. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d.py +34 -0
  334. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
  335. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d.py +39 -0
  336. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d_grad.py +39 -0
  337. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d_grad.py +37 -0
  338. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d.py +42 -0
  339. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d_grad.py +152 -0
  340. mindspore/ops/_op_impl/aicpu/add.py +43 -0
  341. mindspore/ops/_op_impl/aicpu/addcdiv.py +0 -32
  342. mindspore/ops/_op_impl/aicpu/addcmul.py +0 -84
  343. mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
  344. mindspore/ops/_op_impl/aicpu/arg_max.py +75 -0
  345. mindspore/ops/_op_impl/aicpu/arg_min.py +75 -0
  346. mindspore/ops/_op_impl/aicpu/argmin_with_value.py +43 -0
  347. mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -0
  348. mindspore/ops/_op_impl/aicpu/batch_norm_grad_grad.py +49 -0
  349. mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
  350. mindspore/ops/_op_impl/aicpu/bessel_i0.py +31 -0
  351. mindspore/ops/_op_impl/aicpu/bias_add.py +44 -0
  352. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +43 -0
  353. mindspore/ops/_op_impl/aicpu/bincount.py +33 -0
  354. mindspore/{nn/probability/infer/variational/__init__.py → ops/_op_impl/aicpu/cauchy.py} +17 -10
  355. mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
  356. mindspore/ops/_op_impl/aicpu/cholesky.py +1 -1
  357. mindspore/ops/_op_impl/{cpu/bias_add.py → aicpu/choleskygrad.py} +9 -7
  358. mindspore/ops/_op_impl/aicpu/combined_non_max_suppression.py +42 -0
  359. mindspore/ops/_op_impl/aicpu/concat_offset.py +42 -0
  360. mindspore/ops/_op_impl/aicpu/concat_offset_v1.py +31 -0
  361. mindspore/ops/_op_impl/aicpu/conj.py +11 -0
  362. mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_image.py +38 -0
  363. mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +36 -0
  364. mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
  365. mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +2 -2
  366. mindspore/ops/_op_impl/aicpu/dense_to_sparse_set_operation.py +48 -0
  367. mindspore/ops/_op_impl/aicpu/diag.py +36 -0
  368. mindspore/ops/_op_impl/aicpu/diag_part.py +36 -0
  369. mindspore/ops/_op_impl/aicpu/diagonal.py +35 -0
  370. mindspore/ops/_op_impl/{cpu/bias_add_grad.py → aicpu/digamma.py} +9 -7
  371. mindspore/ops/_op_impl/aicpu/eig.py +35 -0
  372. mindspore/ops/_op_impl/aicpu/fft_with_size.py +41 -0
  373. mindspore/ops/_op_impl/aicpu/flatten.py +1 -0
  374. mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
  375. mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
  376. mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +1 -1
  377. mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
  378. mindspore/ops/_op_impl/aicpu/glu.py +33 -0
  379. mindspore/ops/_op_impl/aicpu/glu_grad.py +34 -0
  380. mindspore/ops/_op_impl/aicpu/greater.py +41 -0
  381. mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
  382. mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
  383. mindspore/ops/_op_impl/{tbe/scatter_add_ds.py → aicpu/inplace_index_add.py} +17 -21
  384. mindspore/ops/_op_impl/aicpu/instance_norm_v2.py +41 -0
  385. mindspore/ops/_op_impl/aicpu/instance_norm_v2_grad.py +44 -0
  386. mindspore/ops/_op_impl/aicpu/layer_norm_grad_grad.py +47 -0
  387. mindspore/ops/_op_impl/aicpu/less.py +41 -0
  388. mindspore/ops/_op_impl/aicpu/less_equal.py +41 -0
  389. mindspore/ops/_op_impl/aicpu/lgamma.py +32 -0
  390. mindspore/ops/_op_impl/aicpu/log_normal_reverse.py +33 -0
  391. mindspore/ops/_op_impl/aicpu/logit.py +33 -0
  392. mindspore/ops/_op_impl/aicpu/logit_grad.py +34 -0
  393. mindspore/ops/_op_impl/aicpu/masked_fill.py +42 -0
  394. mindspore/ops/_op_impl/aicpu/masked_scatter.py +39 -0
  395. mindspore/ops/_op_impl/aicpu/matmul.py +39 -0
  396. mindspore/ops/_op_impl/aicpu/matrix_logarithm.py +31 -0
  397. mindspore/ops/_op_impl/aicpu/matrix_power.py +32 -0
  398. mindspore/ops/_op_impl/aicpu/matrix_solve_ls.py +36 -0
  399. mindspore/ops/_op_impl/aicpu/matrix_triangular_solve.py +36 -0
  400. mindspore/ops/_op_impl/aicpu/mirror_pad.py +2 -0
  401. mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +0 -4
  402. mindspore/ops/_op_impl/aicpu/mul.py +3 -1
  403. mindspore/ops/_op_impl/aicpu/multinomial.py +14 -6
  404. mindspore/ops/_op_impl/aicpu/multinomial_with_replacement.py +35 -0
  405. mindspore/ops/_op_impl/aicpu/nan_to_num.py +34 -0
  406. mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
  407. mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
  408. mindspore/ops/_op_impl/aicpu/ones_like.py +0 -2
  409. mindspore/ops/_op_impl/aicpu/polar.py +32 -0
  410. mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
  411. mindspore/ops/_op_impl/aicpu/qr.py +36 -0
  412. mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
  413. mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
  414. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
  415. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_tensor.py +74 -0
  416. mindspore/ops/_op_impl/aicpu/random_shuffle.py +3 -0
  417. mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
  418. mindspore/ops/_op_impl/aicpu/range.py +36 -0
  419. mindspore/ops/_op_impl/aicpu/reciprocal.py +34 -0
  420. mindspore/ops/_op_impl/aicpu/reciprocal_grad.py +35 -0
  421. mindspore/ops/_op_impl/aicpu/reduce_sum.py +57 -0
  422. mindspore/ops/_op_impl/aicpu/resize_bicubic.py +2 -8
  423. mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +1 -1
  424. mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
  425. mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
  426. mindspore/ops/_op_impl/aicpu/scatter_elements.py +4 -0
  427. mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +2 -0
  428. mindspore/ops/_op_impl/aicpu/search_sorted.py +12 -6
  429. mindspore/ops/_op_impl/aicpu/self_adjoint_eig.py +34 -0
  430. mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
  431. mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
  432. mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
  433. mindspore/ops/_op_impl/aicpu/slice_grad.py +76 -0
  434. mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
  435. mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
  436. mindspore/ops/_op_impl/aicpu/sort.py +39 -0
  437. mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +0 -24
  438. mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
  439. mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows.py +63 -0
  440. mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows_grad.py +45 -0
  441. mindspore/ops/_op_impl/aicpu/sparse_matrix_mat_mul.py +56 -0
  442. mindspore/ops/_op_impl/{tbe/slice_ds.py → aicpu/sparse_segment_sum.py} +16 -24
  443. mindspore/ops/_op_impl/aicpu/sparse_segment_sum_with_num_segments.py +68 -0
  444. mindspore/ops/_op_impl/aicpu/sparse_slice.py +63 -0
  445. mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +61 -0
  446. mindspore/ops/_op_impl/aicpu/squared_difference.py +2 -0
  447. mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +93 -0
  448. mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +66 -0
  449. mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
  450. mindspore/ops/_op_impl/{tbe/gather_v2.py → aicpu/tile.py} +24 -24
  451. mindspore/ops/_op_impl/aicpu/tridiagonal_solve.py +35 -0
  452. mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
  453. mindspore/ops/_op_impl/aicpu/triu_indices.py +34 -0
  454. mindspore/ops/_op_impl/aicpu/uniform.py +34 -0
  455. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +1 -0
  456. mindspore/ops/_op_impl/aicpu/unique_consecutive.py +10 -2
  457. mindspore/ops/_op_impl/cpu/__init__.py +1 -2
  458. mindspore/ops/_op_impl/cpu/dynamic_shape.py +5 -1
  459. mindspore/ops/_op_impl/cpu/maximum_grad.py +2 -0
  460. mindspore/{compression/common/__init__.py → ops/_op_impl/cpu/pyexecute.py} +13 -8
  461. mindspore/ops/_op_impl/cpu/reduce_sum.py +8 -0
  462. mindspore/ops/_op_impl/cpu/sparse_slice.py +62 -0
  463. mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +60 -0
  464. mindspore/ops/_op_impl/cpu/tensor_shape.py +5 -1
  465. mindspore/ops/_op_impl/tbe/__init__.py +27 -608
  466. mindspore/ops/_op_impl/tbe/addcdiv_ds.py +42 -0
  467. mindspore/ops/_op_impl/tbe/addcmul_ds.py +44 -0
  468. mindspore/ops/_op_impl/tbe/assign_add_ds.py +1 -0
  469. mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
  470. mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +1 -1
  471. mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad_v2.py +0 -1
  472. mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
  473. mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +1 -1
  474. mindspore/ops/_op_impl/tbe/batch_to_space_nd_v2.py +41 -0
  475. mindspore/ops/_op_impl/tbe/bce_with_logits_loss.py +1 -0
  476. mindspore/ops/_op_impl/tbe/bias_add_grad.py +2 -0
  477. mindspore/ops/_op_impl/tbe/bn_infer_grad.py +4 -2
  478. mindspore/ops/_op_impl/tbe/bn_infer_grad_ds.py +40 -0
  479. mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -1
  480. mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -1
  481. mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +6 -4
  482. mindspore/ops/_op_impl/tbe/cast.py +0 -2
  483. mindspore/ops/_op_impl/tbe/cast_ds.py +3 -3
  484. mindspore/ops/_op_impl/tbe/ctc_loss_v2.py +0 -2
  485. mindspore/ops/_op_impl/tbe/ctc_loss_v2_grad.py +0 -2
  486. mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +1 -0
  487. mindspore/ops/_op_impl/tbe/deformable_offsets.py +1 -0
  488. mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +1 -1
  489. mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +1 -1
  490. mindspore/ops/_op_impl/tbe/gather_nd.py +1 -0
  491. mindspore/ops/_op_impl/tbe/greater.py +2 -0
  492. mindspore/ops/_op_impl/tbe/{index_add.py → inplace_index_add.py} +3 -6
  493. mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2.py +0 -1
  494. mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +35 -0
  495. mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +35 -0
  496. mindspore/ops/_op_impl/tbe/one_hot_ds.py +0 -6
  497. mindspore/ops/_op_impl/tbe/{greater_ds.py → reduce_all_ds.py} +13 -16
  498. mindspore/ops/_op_impl/tbe/reduce_any_ds.py +39 -0
  499. mindspore/ops/_op_impl/tbe/roi_align_ds.py +44 -0
  500. mindspore/ops/_op_impl/tbe/roi_align_grad_ds.py +44 -0
  501. mindspore/ops/_op_impl/tbe/scatter_add.py +2 -0
  502. mindspore/ops/_op_impl/tbe/scatter_nd_add.py +2 -2
  503. mindspore/ops/_op_impl/tbe/slice.py +26 -15
  504. mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
  505. mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +1 -1
  506. mindspore/ops/_op_impl/tbe/strided_slice_grad_d.py +1 -0
  507. mindspore/ops/_op_impl/tbe/trans_data_ds.py +15 -5
  508. mindspore/ops/_op_impl/tbe/unsorted_segment_sum.py +1 -1
  509. mindspore/ops/_op_impl/tbe/unsorted_segment_sum_ds.py +2 -0
  510. mindspore/ops/_primitive_cache.py +3 -2
  511. mindspore/ops/_register_for_op.py +11 -0
  512. mindspore/ops/_utils/__init__.py +1 -1
  513. mindspore/ops/_utils/utils.py +20 -41
  514. mindspore/ops/_vmap/__init__.py +2 -2
  515. mindspore/ops/_vmap/vmap_array_ops.py +170 -78
  516. mindspore/ops/_vmap/vmap_base.py +24 -10
  517. mindspore/ops/_vmap/vmap_convolution_ops.py +7 -10
  518. mindspore/ops/_vmap/vmap_grad_math_ops.py +4 -4
  519. mindspore/ops/_vmap/vmap_grad_nn_ops.py +41 -9
  520. mindspore/ops/_vmap/vmap_image_ops.py +52 -0
  521. mindspore/ops/_vmap/vmap_math_ops.py +77 -6
  522. mindspore/ops/_vmap/vmap_nn_ops.py +78 -29
  523. mindspore/ops/_vmap/vmap_other_ops.py +3 -1
  524. mindspore/ops/_vmap/vmap_random_ops.py +55 -3
  525. mindspore/ops/_vmap/vmap_sparse_ops.py +1 -0
  526. mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
  527. mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
  528. mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +18 -19
  529. mindspore/ops/bprop_mindir/Argmax_bprop.mindir +13 -12
  530. mindspore/ops/bprop_mindir/Argmin_bprop.mindir +14 -13
  531. mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +17 -18
  532. mindspore/ops/bprop_mindir/Assign_bprop.mindir +16 -16
  533. mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +150 -0
  534. mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +66 -0
  535. mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
  536. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +13 -12
  537. mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
  538. mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +28 -0
  539. mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
  540. mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +33 -0
  541. mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +306 -0
  542. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +12 -8
  543. mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
  544. mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
  545. mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +240 -0
  546. mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +247 -0
  547. mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +247 -0
  548. mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +315 -0
  549. mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +278 -0
  550. mindspore/ops/bprop_mindir/DType_bprop.mindir +12 -12
  551. mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +58 -0
  552. mindspore/ops/bprop_mindir/Depend_bprop.mindir +12 -13
  553. mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +23 -0
  554. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +138 -0
  555. mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +15 -0
  556. mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
  557. mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
  558. mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +22 -24
  559. mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +16 -14
  560. mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +27 -0
  561. mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
  562. mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
  563. mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
  564. mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +12 -12
  565. mindspore/ops/bprop_mindir/Elu_bprop.mindir +16 -0
  566. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  567. mindspore/ops/bprop_mindir/Equal_bprop.mindir +18 -19
  568. mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +58 -0
  569. mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +16 -0
  570. mindspore/ops/bprop_mindir/Flatten_bprop.mindir +54 -0
  571. mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +18 -15
  572. mindspore/ops/bprop_mindir/GatherD_bprop.mindir +26 -0
  573. mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +57 -0
  574. mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
  575. mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +17 -18
  576. mindspore/ops/bprop_mindir/Greater_bprop.mindir +18 -19
  577. mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +16 -0
  578. mindspore/ops/bprop_mindir/HSwish_bprop.mindir +16 -0
  579. mindspore/ops/bprop_mindir/IOU_bprop.mindir +18 -19
  580. mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
  581. mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +13 -12
  582. mindspore/ops/bprop_mindir/IsInf_bprop.mindir +13 -10
  583. mindspore/ops/bprop_mindir/IsNan_bprop.mindir +14 -11
  584. mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +126 -0
  585. mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +15 -0
  586. mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +30 -0
  587. mindspore/ops/bprop_mindir/LRN_bprop.mindir +43 -0
  588. mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
  589. mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +18 -19
  590. mindspore/ops/bprop_mindir/Less_bprop.mindir +17 -18
  591. mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +22 -19
  592. mindspore/ops/bprop_mindir/Load_bprop.mindir +12 -13
  593. mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +23 -0
  594. mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +17 -18
  595. mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +14 -13
  596. mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +21 -0
  597. mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +74 -0
  598. mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +74 -0
  599. mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +75 -0
  600. mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +65 -0
  601. mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
  602. mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
  603. mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
  604. mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +27 -0
  605. mindspore/ops/bprop_mindir/Mish_bprop.mindir +35 -0
  606. mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
  607. mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
  608. mindspore/ops/bprop_mindir/NonZero_bprop.mindir +14 -0
  609. mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +18 -19
  610. mindspore/ops/bprop_mindir/OneHot_bprop.mindir +25 -23
  611. mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +13 -13
  612. mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
  613. mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
  614. mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
  615. mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +29 -0
  616. mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +82 -0
  617. mindspore/ops/bprop_mindir/Range_bprop.mindir +21 -19
  618. mindspore/ops/bprop_mindir/Rank_bprop.mindir +11 -11
  619. mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +16 -0
  620. mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
  621. mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +18 -17
  622. mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +18 -17
  623. mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +19 -23
  624. mindspore/ops/bprop_mindir/Reshape_bprop.mindir +60 -0
  625. mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +29 -0
  626. mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +89 -0
  627. mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +52 -0
  628. mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +22 -0
  629. mindspore/ops/bprop_mindir/Round_bprop.mindir +14 -13
  630. mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
  631. mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
  632. mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +22 -0
  633. mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +24 -0
  634. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +22 -0
  635. mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
  636. mindspore/ops/bprop_mindir/SeLU_bprop.mindir +21 -0
  637. mindspore/ops/bprop_mindir/Select_bprop.mindir +30 -34
  638. mindspore/ops/bprop_mindir/Shape_bprop.mindir +12 -12
  639. mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +21 -0
  640. mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
  641. mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +16 -0
  642. mindspore/ops/bprop_mindir/Sign_bprop.mindir +13 -12
  643. mindspore/ops/bprop_mindir/Slice_bprop.mindir +26 -0
  644. mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +36 -0
  645. mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  646. mindspore/ops/bprop_mindir/Softplus_bprop.mindir +16 -0
  647. mindspore/ops/bprop_mindir/Softsign_bprop.mindir +33 -0
  648. mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
  649. mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +28 -0
  650. mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +23 -0
  651. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  652. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  653. mindspore/ops/bprop_mindir/Split_bprop.mindir +22 -0
  654. mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +54 -0
  655. mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +95 -0
  656. mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +98 -0
  657. mindspore/ops/bprop_mindir/Switch_bprop.mindir +28 -32
  658. mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
  659. mindspore/ops/bprop_mindir/Tanh_bprop.mindir +66 -0
  660. mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +22 -0
  661. mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +29 -0
  662. mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +14 -0
  663. mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
  664. mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
  665. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +23 -0
  666. mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +18 -15
  667. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +11 -13
  668. mindspore/ops/bprop_mindir/Unique_bprop.mindir +16 -0
  669. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +22 -0
  670. mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +32 -0
  671. mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +38 -0
  672. mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +13 -12
  673. mindspore/ops/bprop_mindir/__init__.py +1 -4
  674. mindspore/ops/bprop_mindir/generate_mindir.py +32 -20
  675. mindspore/ops/composite/__init__.py +12 -13
  676. mindspore/ops/composite/base.py +261 -254
  677. mindspore/ops/composite/env_ops.py +41 -0
  678. mindspore/ops/composite/math_ops.py +197 -156
  679. mindspore/ops/composite/multitype_ops/_compile_utils.py +428 -176
  680. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +188 -87
  681. mindspore/ops/composite/multitype_ops/add_impl.py +23 -1
  682. mindspore/ops/composite/multitype_ops/div_impl.py +3 -3
  683. mindspore/ops/composite/multitype_ops/equal_impl.py +1 -0
  684. mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -1
  685. mindspore/ops/composite/multitype_ops/getitem_impl.py +52 -5
  686. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +31 -0
  687. mindspore/ops/composite/multitype_ops/greater_impl.py +31 -0
  688. mindspore/ops/composite/multitype_ops/in_impl.py +15 -3
  689. mindspore/ops/composite/multitype_ops/less_equal_impl.py +33 -2
  690. mindspore/ops/composite/multitype_ops/less_impl.py +33 -0
  691. mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -2
  692. mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
  693. mindspore/ops/composite/multitype_ops/mod_impl.py +1 -1
  694. mindspore/ops/composite/multitype_ops/mul_impl.py +21 -7
  695. mindspore/ops/composite/multitype_ops/not_in_impl.py +15 -3
  696. mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -4
  697. mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
  698. mindspore/ops/composite/multitype_ops/setitem_impl.py +62 -70
  699. mindspore/ops/composite/multitype_ops/sub_impl.py +3 -3
  700. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +41 -4
  701. mindspore/ops/function/__init__.py +323 -8
  702. mindspore/ops/function/array_func.py +3511 -780
  703. mindspore/ops/function/clip_func.py +329 -0
  704. mindspore/ops/function/debug_func.py +6 -6
  705. mindspore/ops/function/grad/__init__.py +5 -1
  706. mindspore/ops/function/grad/grad_func.py +736 -65
  707. mindspore/ops/function/image_func.py +270 -0
  708. mindspore/ops/function/linalg_func.py +268 -8
  709. mindspore/ops/function/math_func.py +8032 -3164
  710. mindspore/ops/function/nn_func.py +5619 -1855
  711. mindspore/ops/function/other_func.py +115 -0
  712. mindspore/ops/function/parameter_func.py +11 -10
  713. mindspore/ops/function/random_func.py +939 -77
  714. mindspore/ops/function/sparse_func.py +249 -84
  715. mindspore/ops/function/sparse_unary_func.py +2303 -0
  716. mindspore/ops/function/spectral_func.py +146 -0
  717. mindspore/ops/function/vmap_func.py +114 -0
  718. mindspore/ops/functional.py +182 -254
  719. mindspore/ops/op_info_register.py +79 -34
  720. mindspore/ops/operations/__init__.py +210 -118
  721. mindspore/ops/operations/_csr_ops.py +7 -7
  722. mindspore/ops/operations/_embedding_cache_ops.py +25 -15
  723. mindspore/ops/operations/_grad_ops.py +447 -322
  724. mindspore/ops/operations/_inner_ops.py +547 -176
  725. mindspore/ops/operations/_map_tensor_ops.py +112 -0
  726. mindspore/ops/operations/_ms_kernel.py +29 -27
  727. mindspore/ops/operations/_ocr_ops.py +11 -11
  728. mindspore/ops/operations/_opaque_predicate_registry.py +41 -0
  729. mindspore/ops/operations/_quant_ops.py +186 -101
  730. mindspore/ops/operations/_rl_inner_ops.py +122 -61
  731. mindspore/ops/operations/_scalar_ops.py +466 -0
  732. mindspore/ops/operations/_sequence_ops.py +1047 -0
  733. mindspore/ops/operations/_tensor_array.py +10 -11
  734. mindspore/ops/operations/_thor_ops.py +4 -4
  735. mindspore/ops/operations/array_ops.py +1428 -1226
  736. mindspore/ops/operations/comm_ops.py +180 -117
  737. mindspore/ops/operations/control_ops.py +4 -2
  738. mindspore/ops/operations/custom_ops.py +185 -98
  739. mindspore/ops/operations/debug_ops.py +92 -54
  740. mindspore/ops/operations/image_ops.py +406 -211
  741. mindspore/ops/operations/inner_ops.py +42 -53
  742. mindspore/ops/operations/linalg_ops.py +32 -29
  743. mindspore/ops/operations/math_ops.py +2076 -897
  744. mindspore/ops/operations/nn_ops.py +1282 -1252
  745. mindspore/ops/operations/other_ops.py +124 -278
  746. mindspore/ops/operations/random_ops.py +345 -178
  747. mindspore/ops/operations/rl_ops.py +8 -9
  748. mindspore/ops/operations/sparse_ops.py +502 -157
  749. mindspore/ops/operations/spectral_ops.py +107 -0
  750. mindspore/ops/primitive.py +192 -15
  751. mindspore/ops/vm_impl_registry.py +23 -2
  752. mindspore/parallel/__init__.py +6 -1
  753. mindspore/parallel/_auto_parallel_context.py +199 -92
  754. mindspore/parallel/_cell_wrapper.py +4 -2
  755. mindspore/parallel/_cost_model_context.py +3 -0
  756. mindspore/parallel/_dp_allreduce_fusion.py +2 -1
  757. mindspore/parallel/_offload_context.py +185 -0
  758. mindspore/parallel/_parallel_serialization.py +167 -28
  759. mindspore/parallel/_ps_context.py +9 -5
  760. mindspore/parallel/_recovery_context.py +1 -1
  761. mindspore/parallel/_tensor.py +9 -1
  762. mindspore/{nn/transformer → parallel/_transformer}/__init__.py +6 -6
  763. mindspore/{nn/transformer → parallel/_transformer}/layers.py +59 -37
  764. mindspore/{nn/transformer → parallel/_transformer}/loss.py +4 -7
  765. mindspore/{nn/transformer → parallel/_transformer}/moe.py +160 -35
  766. mindspore/{nn/transformer → parallel/_transformer}/op_parallel_config.py +3 -3
  767. mindspore/{nn/transformer → parallel/_transformer}/transformer.py +235 -196
  768. mindspore/parallel/_utils.py +47 -7
  769. mindspore/parallel/algo_parameter_config.py +5 -1
  770. mindspore/parallel/checkpoint_transform.py +329 -0
  771. mindspore/parallel/shard.py +229 -0
  772. mindspore/perf_msvcbuildinsights.dll +0 -0
  773. mindspore/pgodb140.dll +0 -0
  774. mindspore/pgort140.dll +0 -0
  775. mindspore/profiler/__init__.py +2 -1
  776. mindspore/profiler/common/util.py +4 -3
  777. mindspore/profiler/common/validator/validate_path.py +2 -2
  778. mindspore/profiler/envprofiling.py +249 -0
  779. mindspore/profiler/parser/aicpu_data_parser.py +38 -39
  780. mindspore/profiler/parser/ascend_timeline_generator.py +497 -0
  781. mindspore/profiler/parser/base_timeline_generator.py +471 -0
  782. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +684 -0
  783. mindspore/profiler/parser/framework_parser.py +42 -16
  784. mindspore/profiler/parser/hccl_parser.py +158 -158
  785. mindspore/profiler/parser/hwts_log_parser.py +7 -6
  786. mindspore/profiler/parser/integrator.py +18 -1579
  787. mindspore/profiler/parser/minddata_analyzer.py +8 -8
  788. mindspore/profiler/parser/msadvisor_analyzer.py +14 -27
  789. mindspore/profiler/parser/msadvisor_parser.py +2 -4
  790. mindspore/profiler/parser/optime_parser.py +17 -18
  791. mindspore/profiler/parser/profiler_info.py +108 -0
  792. mindspore/profiler/parser/step_trace_parser.py +1 -1
  793. mindspore/profiler/profiling.py +396 -194
  794. mindspore/rewrite/__init__.py +6 -2
  795. mindspore/rewrite/api/node.py +51 -110
  796. mindspore/rewrite/api/node_type.py +10 -6
  797. mindspore/rewrite/api/pattern_engine.py +51 -7
  798. mindspore/rewrite/api/scoped_value.py +64 -53
  799. mindspore/rewrite/api/symbol_tree.py +108 -61
  800. mindspore/rewrite/api/tree_node_helper.py +2 -3
  801. mindspore/{compression/quant/__init__.py → rewrite/ast_creator_register.py} +20 -11
  802. mindspore/rewrite/ast_helpers/__init__.py +6 -3
  803. mindspore/rewrite/ast_helpers/ast_creator.py +115 -0
  804. mindspore/rewrite/ast_helpers/ast_finder.py +99 -1
  805. mindspore/rewrite/ast_helpers/ast_modifier.py +17 -4
  806. mindspore/rewrite/ast_helpers/ast_replacer.py +1 -1
  807. mindspore/rewrite/ast_transformers/__init__.py +0 -1
  808. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +46 -5
  809. mindspore/rewrite/ast_transformers/remove_return_out_of_if.py +6 -3
  810. mindspore/rewrite/common/__init__.py +2 -0
  811. mindspore/rewrite/common/event.py +1 -1
  812. mindspore/rewrite/common/observable.py +1 -1
  813. mindspore/rewrite/common/observer.py +1 -1
  814. mindspore/rewrite/common/rewrite_elog.py +35 -0
  815. mindspore/rewrite/namer.py +2 -2
  816. mindspore/rewrite/namespace.py +14 -4
  817. mindspore/rewrite/node.py +161 -13
  818. mindspore/rewrite/parser.py +0 -1
  819. mindspore/rewrite/parser_register.py +0 -1
  820. mindspore/rewrite/parsers/arguments_parser.py +3 -2
  821. mindspore/rewrite/parsers/assign_parser.py +267 -67
  822. mindspore/rewrite/parsers/attribute_parser.py +56 -0
  823. mindspore/rewrite/parsers/class_def_parser.py +191 -108
  824. mindspore/rewrite/parsers/constant_parser.py +101 -0
  825. mindspore/rewrite/parsers/container_parser.py +88 -0
  826. mindspore/rewrite/parsers/for_parser.py +28 -15
  827. mindspore/rewrite/parsers/function_def_parser.py +21 -5
  828. mindspore/rewrite/parsers/if_parser.py +11 -28
  829. mindspore/rewrite/parsers/module_parser.py +9 -6
  830. mindspore/rewrite/parsers/return_parser.py +3 -2
  831. mindspore/rewrite/sparsify/__init__.py +0 -0
  832. mindspore/rewrite/sparsify/sparse_transformer.py +448 -0
  833. mindspore/rewrite/sparsify/sparsify.py +109 -0
  834. mindspore/rewrite/sparsify/utils.py +173 -0
  835. mindspore/rewrite/symbol_tree.py +322 -109
  836. mindspore/rewrite/symbol_tree_builder.py +45 -8
  837. mindspore/rewrite/symbol_tree_dumper.py +0 -1
  838. mindspore/rewrite/topological_manager.py +1 -2
  839. mindspore/run_check/_check_version.py +209 -112
  840. mindspore/run_check/run_check.py +2 -1
  841. mindspore/tbbmalloc.dll +0 -0
  842. mindspore/tinyxml2.dll +0 -0
  843. mindspore/train/__init__.py +6 -4
  844. mindspore/train/_utils.py +28 -5
  845. mindspore/train/amp.py +321 -50
  846. mindspore/train/callback/__init__.py +3 -1
  847. mindspore/train/callback/_backup_and_restore.py +120 -0
  848. mindspore/train/callback/_callback.py +8 -8
  849. mindspore/train/callback/_checkpoint.py +12 -9
  850. mindspore/train/callback/_early_stop.py +13 -7
  851. mindspore/train/callback/_history.py +8 -8
  852. mindspore/train/callback/_lambda_callback.py +6 -6
  853. mindspore/train/callback/_landscape.py +36 -38
  854. mindspore/train/callback/_loss_monitor.py +12 -6
  855. mindspore/train/callback/_lr_scheduler_callback.py +2 -4
  856. mindspore/train/callback/_on_request_exit.py +212 -0
  857. mindspore/train/callback/_reduce_lr_on_plateau.py +13 -7
  858. mindspore/train/callback/_summary_collector.py +27 -19
  859. mindspore/train/callback/_time_monitor.py +13 -7
  860. mindspore/train/checkpoint_pb2.py +68 -8
  861. mindspore/train/data_sink.py +122 -33
  862. mindspore/train/dataset_helper.py +28 -87
  863. mindspore/train/loss_scale_manager.py +4 -7
  864. mindspore/{nn → train}/metrics/__init__.py +20 -20
  865. mindspore/{nn → train}/metrics/accuracy.py +12 -10
  866. mindspore/{nn → train}/metrics/auc.py +4 -4
  867. mindspore/{nn → train}/metrics/bleu_score.py +4 -4
  868. mindspore/{nn → train}/metrics/confusion_matrix.py +10 -8
  869. mindspore/{nn → train}/metrics/cosine_similarity.py +4 -4
  870. mindspore/{nn → train}/metrics/dice.py +6 -5
  871. mindspore/{nn → train}/metrics/error.py +7 -5
  872. mindspore/{nn → train}/metrics/fbeta.py +9 -7
  873. mindspore/{nn → train}/metrics/hausdorff_distance.py +8 -6
  874. mindspore/{nn → train}/metrics/loss.py +4 -3
  875. mindspore/{nn → train}/metrics/mean_surface_distance.py +6 -5
  876. mindspore/{nn → train}/metrics/metric.py +6 -5
  877. mindspore/{nn → train}/metrics/occlusion_sensitivity.py +4 -3
  878. mindspore/{nn → train}/metrics/perplexity.py +5 -4
  879. mindspore/{nn → train}/metrics/precision.py +5 -4
  880. mindspore/{nn → train}/metrics/recall.py +5 -4
  881. mindspore/{nn → train}/metrics/roc.py +7 -6
  882. mindspore/{nn → train}/metrics/root_mean_square_surface_distance.py +6 -5
  883. mindspore/{nn → train}/metrics/topk.py +7 -5
  884. mindspore/train/mind_ir_pb2.py +339 -32
  885. mindspore/train/model.py +113 -84
  886. mindspore/train/serialization.py +547 -167
  887. mindspore/train/summary/_summary_adapter.py +1 -1
  888. mindspore/train/summary/summary_record.py +43 -12
  889. mindspore/train/train_thor/convert_utils.py +7 -1
  890. mindspore/train/train_thor/dataset_helper.py +3 -3
  891. mindspore/train/train_thor/model_thor.py +0 -4
  892. mindspore/turbojpeg.dll +0 -0
  893. mindspore/vcmeta.dll +0 -0
  894. mindspore/vcruntime140.dll +0 -0
  895. mindspore/vcruntime140_1.dll +0 -0
  896. mindspore/version.py +1 -1
  897. {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/METADATA +4 -3
  898. {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/RECORD +901 -660
  899. mindspore/compression/common/constant.py +0 -124
  900. mindspore/compression/export/__init__.py +0 -19
  901. mindspore/compression/export/quant_export.py +0 -514
  902. mindspore/compression/quant/qat.py +0 -636
  903. mindspore/compression/quant/quant_utils.py +0 -462
  904. mindspore/compression/quant/quantizer.py +0 -68
  905. mindspore/libatomic-1.dll +0 -0
  906. mindspore/libgcc_s_seh-1.dll +0 -0
  907. mindspore/libgfortran-4.dll +0 -0
  908. mindspore/libgomp-1.dll +0 -0
  909. mindspore/libjpeg-62.dll +0 -0
  910. mindspore/libmindspore.dll +0 -0
  911. mindspore/libmindspore_common.dll +0 -0
  912. mindspore/libmindspore_core.dll +0 -0
  913. mindspore/libmindspore_glog.dll +0 -0
  914. mindspore/libnnacl.dll +0 -0
  915. mindspore/libopencv_core452.dll +0 -0
  916. mindspore/libopencv_imgcodecs452.dll +0 -0
  917. mindspore/libopencv_imgproc452.dll +0 -0
  918. mindspore/libquadmath-0.dll +0 -0
  919. mindspore/libsqlite3.dll +0 -0
  920. mindspore/libssp-0.dll +0 -0
  921. mindspore/libstdc++-6.dll +0 -0
  922. mindspore/libtinyxml2.dll +0 -0
  923. mindspore/libturbojpeg.dll +0 -0
  924. mindspore/libwinpthread-1.dll +0 -0
  925. mindspore/nn/layer/quant.py +0 -1868
  926. mindspore/nn/layer/rnn_utils.py +0 -90
  927. mindspore/nn/probability/dpn/__init__.py +0 -22
  928. mindspore/nn/probability/dpn/vae/__init__.py +0 -25
  929. mindspore/nn/probability/dpn/vae/cvae.py +0 -138
  930. mindspore/nn/probability/dpn/vae/vae.py +0 -122
  931. mindspore/nn/probability/infer/__init__.py +0 -22
  932. mindspore/nn/probability/infer/variational/elbo.py +0 -70
  933. mindspore/nn/probability/infer/variational/svi.py +0 -84
  934. mindspore/nn/probability/toolbox/__init__.py +0 -22
  935. mindspore/nn/probability/toolbox/anomaly_detection.py +0 -99
  936. mindspore/nn/probability/toolbox/uncertainty_evaluation.py +0 -363
  937. mindspore/nn/probability/transforms/__init__.py +0 -22
  938. mindspore/nn/probability/transforms/transform_bnn.py +0 -262
  939. mindspore/nn/probability/zhusuan/__init__.py +0 -18
  940. mindspore/nn/probability/zhusuan/framework/__init__.py +0 -18
  941. mindspore/nn/probability/zhusuan/framework/bn.py +0 -95
  942. mindspore/nn/probability/zhusuan/variational/__init__.py +0 -18
  943. mindspore/nn/probability/zhusuan/variational/elbo.py +0 -46
  944. mindspore/ops/_op_impl/tbe/bias_add_grad_ds.py +0 -52
  945. mindspore/ops/_op_impl/tbe/scatter_nd_add_ds.py +0 -43
  946. mindspore/ops/bprop_mindir/AssignAdd_bprop.mindir +0 -20
  947. mindspore/ops/bprop_mindir/Identity_bprop.mindir +0 -9
  948. mindspore/ops/bprop_mindir/LogicalOr_bprop.mindir +0 -20
  949. mindspore/ops/bprop_mindir/ReLU_bprop.mindir +0 -16
  950. mindspore/ops/bprop_mindir/UpdateState_bprop.mindir +0 -17
  951. mindspore/ops/bprop_mindir/stop_gradient_bprop.mindir +0 -12
  952. mindspore/ops/composite/array_ops.py +0 -210
  953. mindspore/ops/composite/clip_ops.py +0 -238
  954. mindspore/ops/composite/random_ops.py +0 -426
  955. mindspore/ops/composite/vmap_ops.py +0 -38
  956. mindspore/ops/operations/sponge_ops.py +0 -3531
  957. mindspore/ops/operations/sponge_update_ops.py +0 -2546
  958. mindspore/parallel/nn/__init__.py +0 -42
  959. mindspore/parallel/nn/loss.py +0 -22
  960. mindspore/parallel/nn/moe.py +0 -21
  961. mindspore/parallel/nn/op_parallel_config.py +0 -22
  962. mindspore/parallel/nn/transformer.py +0 -31
  963. mindspore/run_check/_check_deps_version.py +0 -84
  964. {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/WHEEL +0 -0
  965. {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/entry_points.txt +0 -0
  966. {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/top_level.txt +0 -0
mindspore/nn/cell.py CHANGED
@@ -29,18 +29,42 @@ from mindspore import log as logger
29
29
  from mindspore.common.parameter import PARAMETER_NAME_DEFAULT
30
30
  from mindspore.common.hook_handle import HookHandle
31
31
  from mindspore.context import ParallelMode
32
- from mindspore.ops.composite import Shard
33
32
  from mindspore import context
34
33
  from mindspore._c_expression import init_pipeline, update_func_graph_hyper_params, Cell_, FuncGraph, MixedPrecisionType
35
- from mindspore._checkparam import Validator
34
+ from mindspore import _checkparam as Validator
36
35
  from mindspore.common import dtype as mstype
37
36
  from mindspore.common.api import _cell_graph_executor, _pynative_executor, _get_args_for_run, cells_compile_cache
37
+ from mindspore.common.api import _generate_branch_control_input
38
38
  from mindspore.common.parameter import Parameter, ParameterTuple
39
39
  from mindspore.common.tensor import Tensor
40
40
  from mindspore.ops.operations import Cast
41
41
  from mindspore.ops.primitive import Primitive
42
42
  from mindspore.ops.operations import _inner_ops as inner
43
- from mindspore.parallel._tensor import _load_tensor_by_layout
43
+ from mindspore.parallel.shard import Shard
44
+ from mindspore._check_jit_forbidden_api import jit_forbidden_register
45
+
46
+
47
+ def _check_args(args):
48
+ """Check the input args's type"""
49
+ index = 1
50
+ for item in args:
51
+ if isinstance(item, Tensor) and item.has_init:
52
+ item.init_data()
53
+ elif isinstance(item, numpy.ndarray):
54
+ suffix = "th"
55
+ if index == 1:
56
+ suffix = "st"
57
+ elif index == 2:
58
+ suffix = "nd"
59
+ elif index == 3:
60
+ suffix = "rd"
61
+
62
+ input_index = str(index) + suffix
63
+ raise TypeError(f"For 'Cell', inputs should not be numpy array. Only support bool, int, float, None, "
64
+ f"Tensor, Parameter, mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint"
65
+ f"), and tuple or list containing only these types, and dict whose values are these "
66
+ f"types, but the {input_index} arg type is {type(item)}.")
67
+ index += 1
44
68
 
45
69
 
46
70
  class Cell(Cell_):
@@ -54,11 +78,14 @@ class Cell(Cell_):
54
78
  PYNATIVE_MODE (dynamic graph mode).
55
79
 
56
80
  Args:
57
- auto_prefix (bool): Whether to automatically generate NameSpace for Cell and its subcells. It will affect the
58
- name of the parameter in the network. If set to True, the network parameter
59
- name will be prefixed, otherwise it will not. Default: True.
60
- flags (dict): Network configuration information, currently it is used for the binding of network and dataset.
61
- Users can also customize network attributes by this parameter. Default: None.
81
+ auto_prefix (bool, optional): Whether to automatically generate NameSpace for Cell and its child cells. It also
82
+ affects the names of parameters in the `Cell`. If set to True, the parameter name will be
83
+ automatically prefixed, otherwise not. In general, the backbone network should be set to True,
84
+ otherwise the duplicate name problem will appear. The cell to train the backbone network, such as
85
+ optimizer and :class:`mindspore.nn.TrainOneStepCell`, should be set to False, otherwise the
86
+ parameter name in backbone will be changed by mistake. Default: True.
87
+ flags (dict, optional): Network configuration information, currently it is used for the binding of network
88
+ and dataset. Users can also customize network attributes by this parameter. Default: None.
62
89
 
63
90
  Supported Platforms:
64
91
  ``Ascend`` ``GPU`` ``CPU``
@@ -84,12 +111,11 @@ class Cell(Cell_):
84
111
  [Parameter (name=weight, shape=(240, 120, 4, 4), dtype=Float32, requires_grad=True)]
85
112
  """
86
113
 
87
- IGNORE_LIST = ['_scope', '_cell_init_args', '_auto_prefix', '_cells', '_params', '_construct_inputs_names',
88
- '_construct_inputs_num', '_create_time', '_func_graph_flags', '_parallel_inputs_run',
89
- '_parameter_layout_dict', '_params_list', '_tensor_list', '_phase', '_auto_parallel_mode',
114
+ IGNORE_LIST = ['_scope', '_cell_init_args', '_auto_prefix', '_cells', '_params', '_create_time',
115
+ '_func_graph_flags', '_parameter_layout_dict', '_params_list', '_tensor_list', '_phase',
90
116
  '_forward_pre_hook', '_forward_hook', '_enable_forward_pre_hook', '_enable_forward_hook',
91
117
  '_bprop_debug', '_enable_backward_hook', '_cell_backward_hook', '_is_run', '_param_prefix',
92
- '_attr_synced', 'pynative', 'requires_grad', '_auto_parallel_compile_and_run', 'cell_type']
118
+ '_attr_synced', 'pynative', 'requires_grad', 'cell_type']
93
119
 
94
120
  def __init__(self, auto_prefix=True, flags=None):
95
121
  Cell_.__init__(self, self._cell_tag)
@@ -123,10 +149,6 @@ class Cell(Cell_):
123
149
  if os.getenv('GC_COLLECT_IN_CELL') == '1':
124
150
  gc.collect()
125
151
 
126
- self._construct_inputs_num = 0
127
- self._construct_inputs_names = []
128
- self._auto_parallel_mode = False
129
- self._parallel_inputs_run = None
130
152
  if flags:
131
153
  self.add_flags(**flags)
132
154
  self._bprop_debug = False
@@ -136,8 +158,8 @@ class Cell(Cell_):
136
158
  self._enable_forward_hook = False
137
159
  self._enable_backward_hook = False
138
160
  self._cell_backward_hook = None
161
+ self._is_recursion_hook = False
139
162
  self.cell_type = None
140
- self._auto_parallel_compile_and_run = False
141
163
  self.cast = Cast()
142
164
  self._has_config_recompute = False
143
165
  self._user_parameters = []
@@ -145,6 +167,7 @@ class Cell(Cell_):
145
167
  self.saved_dynamic_shape = None
146
168
  self._jit_config_dict = dict()
147
169
  self.grad_ops_label = False
170
+ self.to_float_fp16 = False
148
171
 
149
172
  def __getstate__(self):
150
173
  base = Cell_.__getstate__(self)
@@ -156,6 +179,9 @@ class Cell(Cell_):
156
179
  self.__dict__ = dict_
157
180
  self._attr_synced = False
158
181
 
182
+ def __bool__(self):
183
+ return True
184
+
159
185
  @property
160
186
  def _cell_tag(self):
161
187
  # `<class 'xxxxxxx'>` to `xxxxxxx`
@@ -310,8 +336,6 @@ class Cell(Cell_):
310
336
  if '_params' in self.__dict__:
311
337
  params = self.__dict__['_params']
312
338
  if name in params:
313
- if context._get_mode() == context.PYNATIVE_MODE:
314
- return self.cast_param(params[name])
315
339
  return params[name]
316
340
  if '_cells' in self.__dict__:
317
341
  cells = self.__dict__['_cells']
@@ -320,27 +344,23 @@ class Cell(Cell_):
320
344
  if '_tensor_list' in self.__dict__:
321
345
  tensor_list = self.__dict__['_tensor_list']
322
346
  if name in tensor_list:
323
- return self.cast_param(tensor_list[name])
347
+ return tensor_list[name]
324
348
  if '_params_list' in self.__dict__:
325
349
  params_list = self.__dict__['_params_list']
326
350
  if name in params_list:
327
- para_list = params_list[name]
328
- cast_list = list()
329
- for para in para_list:
330
- cast_list.append(self.cast_param(para))
331
- para_list = ParameterTuple(cast_list)
332
- return para_list
351
+ return ParameterTuple(params_list[name])
333
352
  raise AttributeError("The '{}' object has no attribute '{}'.".format(type(self).__name__, name))
334
353
 
335
354
  def __del__(self):
336
- if context.get_context is not None and context._get_mode() == context.PYNATIVE_MODE:
337
- _pynative_executor.del_cell(self)
338
-
339
355
  # while deepcopy a cell instance, the copied cell instance can't be added to cells_compile_cache
340
356
  # here using pop(id(self), None) to avoid KeyError exception
341
357
  cells_compile_cache.pop(id(self), None)
342
- if self.compile_cache:
343
- _cell_graph_executor.del_net_res(self.compile_cache)
358
+ try:
359
+ if self.compile_cache:
360
+ _cell_graph_executor.del_net_res(self, self.compile_cache)
361
+ except AttributeError as e:
362
+ raise AttributeError(f"The '{type(self).__name__}' object does not inherit attribute from 'cell'. "
363
+ f"Please use 'super().__init__()'.") from e
344
364
 
345
365
  def __delattr__(self, name):
346
366
  if name in self._params:
@@ -391,7 +411,7 @@ class Cell(Cell_):
391
411
  def _do_parameter_broadcast(self):
392
412
  if context.get_auto_parallel_context("parallel_mode") == ParallelMode.DATA_PARALLEL:
393
413
  if not self.parameter_broadcast_done:
394
- _pynative_executor.parameter_broadcast(self, self.phase, self._auto_parallel_mode)
414
+ _pynative_executor.parameter_broadcast(self, self.phase)
395
415
  self.parameter_broadcast_done = True
396
416
 
397
417
  def run_construct(self, cast_inputs, kwargs):
@@ -427,39 +447,51 @@ class Cell(Cell_):
427
447
  output = self._run_forward_hook(cast_inputs, output)
428
448
  return output
429
449
 
430
- def _check_construct_args(self, *inputs, **kwargs):
450
+ def _check_construct_args(self, *args):
431
451
  """Check the args needed by the function construct"""
432
- if kwargs:
433
- raise ValueError(f"For 'Cell', expect no kwargs here, maybe you pass wrong arguments, "
434
- f"or there is a key in kwargs that is not used as a function argument. "
435
- f"args: {inputs}, kwargs: {kwargs}")
436
452
  positional_args = 0
437
453
  default_args = 0
454
+ has_var = False
438
455
  for value in inspect.signature(self.construct).parameters.values():
439
456
  if value.kind is inspect.Parameter.VAR_POSITIONAL or value.kind is inspect.Parameter.VAR_KEYWORD:
440
- return
457
+ has_var = True
441
458
  if value.kind is inspect.Parameter.POSITIONAL_OR_KEYWORD:
442
459
  if value.default is inspect.Parameter.empty:
443
460
  positional_args += 1
444
461
  else:
445
462
  default_args += 1
446
463
 
447
- if len(inputs) < positional_args:
464
+ if has_var:
465
+ return
466
+
467
+ if len(args) < positional_args:
448
468
  raise TypeError(f"For 'Cell', the function construct requires {positional_args} positional argument, "
449
- f"but got {len(inputs)}. When using set_inputs, please make sure that all networks "
469
+ f"but got {len(args)}. When using set_inputs, please make sure that all networks "
450
470
  f"and loss functions are configured with set_inputs.")
451
471
 
452
- if len(inputs) > positional_args + default_args:
472
+ if len(args) > positional_args + default_args:
473
+ construct_inputs_names = self.construct.__code__.co_varnames
474
+ if 'self' not in construct_inputs_names:
475
+ raise TypeError(f"For 'Cell', the method 'construct' must have parameter 'self'. ")
476
+
453
477
  raise TypeError(f"For 'Cell', the function construct requires {positional_args} positional argument and "
454
478
  f"{default_args} default argument, total {positional_args + default_args}, "
455
- f"but got {len(inputs)}.")
479
+ f"but got {len(args)}.")
456
480
 
457
481
  def _hook_fn_registered(self):
458
- if self._enable_forward_pre_hook or self._enable_forward_hook or self._enable_backward_hook:
459
- return True
460
- for cell in self.cells():
461
- if cell._hook_fn_registered():
482
+ '''Hook function in graph mode'''
483
+ # Check super().__init__() in graph mode.
484
+ try:
485
+ if self._enable_forward_pre_hook or self._enable_forward_hook or self._enable_backward_hook:
462
486
  return True
487
+ except AttributeError as e:
488
+ raise AttributeError(f"The '{type(self).__name__}' object does not inherit attribute from 'cell'. "
489
+ f"Please use 'super().__init__()'.") from e
490
+ if not self._is_recursion_hook:
491
+ self._is_recursion_hook = True
492
+ for cell in self.cells():
493
+ if cell._hook_fn_registered():
494
+ return True
463
495
  return False
464
496
 
465
497
  def _get_prims_recursively(self):
@@ -494,7 +526,7 @@ class Cell(Cell_):
494
526
  for prim in all_prims:
495
527
  prim.add_prim_attr("strategy_gen_mode", "data_parallel")
496
528
 
497
- def shard(self, in_strategy, out_strategy, parameter_plan=None, device="Ascend", level=0):
529
+ def shard(self, in_strategy, out_strategy=None, parameter_plan=None, device="Ascend", level=0):
498
530
  """
499
531
  Defining the input and output layouts of this cell and the parallel strategies of remaining ops will be
500
532
  generated by sharding propagation. In PyNative mode, use this method
@@ -508,11 +540,13 @@ class Cell(Cell_):
508
540
  Note:
509
541
  Only effective in PYNATIVE_MODE and in either ParallelMode.AUTO_PARALLEL with
510
542
  search_mode in auto_parallel_context set as sharding_propagation.
543
+ If the input contain Parameter, its strategy should be set in `in_strategy`.
511
544
 
512
545
  Args:
513
546
  in_strategy (tuple): Define the layout of inputs, each element of the tuple should be a tuple or None. Tuple
514
547
  defines the layout of the corresponding input and None represents a data parallel strategy.
515
- out_strategy (tuple): Define the layout of outputs similar with in_strategy.
548
+ out_strategy (Union[None, tuple]): Define the layout of outputs similar with in_strategy.
549
+ It is not in use right now. Default: None.
516
550
  parameter_plan (Union[dict, None]): Define the layout for the specified parameters. Each element in dict
517
551
  defines the layout of the parameter like "param_name: layout".
518
552
  The key is a parameter name of type 'str'.
@@ -552,7 +586,11 @@ class Cell(Cell_):
552
586
  ... x = self.block2(x)
553
587
  ... return x
554
588
  """
555
- # Transfer parameter_plan from dict to tuple
589
+ if context.get_context("mode") != context.PYNATIVE_MODE or \
590
+ context.get_auto_parallel_context("parallel_mode") not in ["auto_parallel"]:
591
+ raise AssertionError(f"Cell shard only supports auto parallel under PyNative mode. "
592
+ f"Please check if you call Cell.shard in the script.")
593
+
556
594
  shard_fn = Shard()
557
595
  fn = shard_fn(self, in_strategy, out_strategy, parameter_plan, device, level)
558
596
  object.__setattr__(self, "_shard_fn", fn)
@@ -568,6 +606,8 @@ class Cell(Cell_):
568
606
  Returns:
569
607
  Tuple, the inputs after data type cast.
570
608
  """
609
+ msg = f"'auto_cast_inputs' is deprecated from version 2.0 and will be removed in a future version."
610
+ logger.warning(msg)
571
611
  cast_inputs = inputs
572
612
  mixed_type = self.get_mixed_precision_type()
573
613
  if mixed_type == MixedPrecisionType.FP16:
@@ -577,32 +617,10 @@ class Cell(Cell_):
577
617
 
578
618
  return cast_inputs
579
619
 
580
- def _check_args(self, args):
581
- """Check the input args's type"""
582
- index = 1
583
- for item in args:
584
- if isinstance(item, Tensor) and item.has_init:
585
- item.init_data()
586
- elif isinstance(item, numpy.ndarray):
587
- suffix = "th"
588
- if index == 1:
589
- suffix = "st"
590
- elif index == 2:
591
- suffix = "nd"
592
- elif index == 3:
593
- suffix = "rd"
594
-
595
- input_index = str(index) + suffix
596
- raise TypeError(f"For 'Cell', inputs should not be numpy array. Only support bool, int, float, None, "
597
- f"Tensor, Parameter, mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint"
598
- f"), and tuple or list containing only these types, and dict whose values are these "
599
- f"types, but the {input_index} arg type is {type(item)}.")
600
- index += 1
601
-
602
620
  def __call__(self, *args, **kwargs):
603
621
  if self.__class__.construct is Cell.construct:
604
- logger.warning(f"The '{self.__class__}' does not override the method 'construct', "
605
- f"it will call the super class(Cell) 'construct'.")
622
+ raise AttributeError("For 'Cell', the method 'construct' is not defined.")
623
+
606
624
  if kwargs:
607
625
  bound_arguments = inspect.signature(self.construct).bind(*args, **kwargs)
608
626
  bound_arguments.apply_defaults()
@@ -610,34 +628,33 @@ class Cell(Cell_):
610
628
  kwargs = bound_arguments.kwargs
611
629
 
612
630
  # Run in Graph mode.
613
- if context._get_mode() == context.GRAPH_MODE:
614
- self._check_construct_args(*args, **kwargs)
631
+ if os.getenv("MS_JIT") != '0' and context._get_mode() == context.GRAPH_MODE:
632
+ self._check_construct_args(*args)
615
633
  if self._hook_fn_registered():
616
634
  logger.warning(f"For 'Cell', it's not support hook function in graph mode. If you want to use hook "
617
635
  f"function, please use context.set_context to set pynative mode.")
618
- out = self.compile_and_run(*args)
636
+ out = self.compile_and_run(*args, **kwargs)
619
637
  return out
620
638
 
621
639
  # Run in PyNative mode.
622
640
  if _pynative_executor.is_first_cell():
623
- _pynative_executor.set_lazy_build(True)
624
641
  _pynative_executor._optimizer = getattr(self, "optimizer", None)
625
642
  _pynative_executor._top_cell = self
626
- # There many Casts in parameter_broadcast. Enable lazy_build and build faster.
643
+ # There many Casts in parameter_broadcast. Enable build faster.
627
644
  self._do_parameter_broadcast()
628
645
 
629
- self._check_args(args)
646
+ _check_args(args)
647
+ self._check_cell_flags_in_pynative()
630
648
 
631
649
  if self.requires_grad:
632
650
  _pynative_executor.set_grad_flag(True)
633
651
 
634
652
  if self._dynamic_shape_inputs is not None:
635
- self._check_compile_dynamic_shape(*args)
653
+ self._check_compile_dynamic_shape(self._dynamic_shape_inputs, args)
636
654
 
637
655
  try:
638
656
  _pynative_executor.new_graph(self, *args, **kwargs)
639
- cast_inputs = self.auto_cast_inputs(args)
640
- output = self._run_construct(cast_inputs, kwargs)
657
+ output = self._run_construct(args, kwargs)
641
658
  _pynative_executor.end_graph(self, output, *args, **kwargs)
642
659
  except Exception as err:
643
660
  _pynative_executor.clear_res()
@@ -647,6 +664,12 @@ class Cell(Cell_):
647
664
  output = output.data
648
665
  return output
649
666
 
667
+ def _check_cell_flags_in_pynative(self):
668
+ """Check the flags added to cell in pynative mode"""
669
+ if hasattr(self, "_func_graph_flags") and self._func_graph_flags.get("output_no_recompute"):
670
+ raise TypeError("Recompute is not supported in PyNative mode currently, you can use "
671
+ "'context.set_context(mode=context.GRAPH_MODE)' or @jit to set graph mode.")
672
+
650
673
  def _add_attr(self, name, value):
651
674
  if name and name[:2] != '__' and name not in Cell.IGNORE_LIST:
652
675
  super(Cell, self)._add_attr(name, value)
@@ -829,84 +852,19 @@ class Cell(Cell_):
829
852
  """
830
853
  Replace parameters with sliced tensors by parallel strategies.
831
854
 
832
- Please refer to the usage in source code of `mindspore.common._CellGraphExecutor.compile`.
833
-
834
- Args:
835
- params (dict): The parameters dictionary used for initializing the data graph.
836
- """
837
- if params is None:
838
- params = self.parameters_dict()
839
- if isinstance(params, OrderedDict):
840
- for key in params:
841
- tensor = params[key].data
842
- if key not in self.parameter_layout_dict:
843
- logger.info("The layout dict does not contain the key %s.", key)
844
- continue
845
- if params[key].sliced:
846
- logger.debug("The param %s is already sliced.", key)
847
- continue
848
- layout = self.parameter_layout_dict[key]
849
- new_tensor = _load_tensor_by_layout(tensor, layout)
850
- params[key].set_data(new_tensor, True)
851
- else:
852
- raise TypeError("For 'load_parameter_slice', the argument 'params' must be OrderedDict type, "
853
- "but got {}.".format(type(params)))
854
-
855
- def _load_inputs(self, *inputs):
855
+ Note:
856
+ This interface is deprecated.
856
857
  """
857
- Slice inputs tensors by parallel strategies.
858
-
859
- Args:
860
- inputs (Function or Cell): inputs of construct method.
861
- """
862
- parallel_inputs_run = []
863
- # judge if *args exists in input
864
- if self.argspec[1] is not None:
865
- prefix = self.argspec[1]
866
- for i in range(len(inputs)):
867
- key = prefix + str(i)
868
- self._construct_inputs_names = self._construct_inputs_names + (key,)
869
- self._construct_inputs_num = self._construct_inputs_num + 1
870
- for i, tensor in enumerate(inputs):
871
- key = self._construct_inputs_names[i]
872
- # if input is not used, self.parameter_layout_dict may not contain the key
873
- if key not in self.parameter_layout_dict:
874
- logger.warning("Layout dict does not contain the key %s.", key)
875
- parallel_inputs_run.append(tensor)
876
- else:
877
- layout = self.parameter_layout_dict[key]
878
- new_tensor = _load_tensor_by_layout(tensor, layout)
879
- parallel_inputs_run.append(new_tensor)
880
- return tuple(parallel_inputs_run)
858
+ logger.warning("'load_parameter_slice' function is deprecated.")
881
859
 
882
860
  def set_parallel_input_with_inputs(self, *inputs):
883
861
  """
884
862
  Slice inputs tensors by parallel strategies.
885
863
 
886
- Args:
887
- inputs (tuple): inputs of construct method.
888
- """
889
- self._parallel_inputs_run = self._load_inputs(*inputs)
890
-
891
- def _get_construct_inputs_number_and_name(self):
892
- """Compute self._construct_inputs_names and self._construct_inputs_num"""
893
- from mindspore._extends.parse.parser import get_parse_method_of_class
894
-
895
- fn = get_parse_method_of_class(self)
896
- self.argspec = inspect.getfullargspec(fn)
897
- self._construct_inputs_num = fn.__code__.co_argcount
898
- self._construct_inputs_names = fn.__code__.co_varnames
899
-
900
- if self._construct_inputs_num <= 0:
901
- raise ValueError(f"For 'set_auto_parallel', the number of inputs must be greater than 0,"
902
- f"but got {self._construct_inputs_num}.")
903
- if self._construct_inputs_names[0] != 'self':
904
- raise ValueError(f"First member of fn function must be self, but got {self._construct_inputs_names[0]}")
905
- if self._construct_inputs_num - 1 > len(self._construct_inputs_names):
906
- raise ValueError(f"Num of inputs must be greater than num of fn function members, num of inputs is \
907
- {self._construct_inputs_names - 1}, num of fn function members is {len(self._construct_inputs_names)}")
908
- self._construct_inputs_names = self._construct_inputs_names[1:self._construct_inputs_num]
909
- self._construct_inputs_num = self._construct_inputs_num - 1
864
+ Note:
865
+ This interface is deprecated.
866
+ """
867
+ logger.warning("'set_parallel_input_with_inputs' function is deprecated.")
910
868
 
911
869
  def set_inputs(self, *inputs):
912
870
  """
@@ -917,8 +875,8 @@ class Cell(Cell_):
917
875
  Args:
918
876
  inputs (tuple): Inputs of the Cell object.
919
877
 
920
- Note:
921
- This is an experimental interface that is subject to change or deletion.
878
+ .. warning::
879
+ This is an experimental API that is subject to change or deletion.
922
880
 
923
881
  Examples:
924
882
  >>> import numpy as np
@@ -949,7 +907,7 @@ class Cell(Cell_):
949
907
  if self._dynamic_shape_inputs:
950
908
  ds.config.set_dynamic_shape(True)
951
909
  if context._get_mode() == context.PYNATIVE_MODE:
952
- _pynative_executor.set_dynamic_input(self, *self._dynamic_shape_inputs)
910
+ _pynative_executor.set_dynamic_input(self)
953
911
 
954
912
  def get_inputs(self):
955
913
  """
@@ -958,36 +916,31 @@ class Cell(Cell_):
958
916
  Returns:
959
917
  inputs (tuple), Inputs of the Cell object.
960
918
 
961
- Note:
962
- This is an experimental interface that is subject to change or deletion.
919
+ .. warning::
920
+ This is an experimental API that is subject to change or deletion.
963
921
  """
964
922
 
965
923
  return self._dynamic_shape_inputs
966
924
 
967
- def compile(self, *inputs):
925
+ def compile(self, *args, **kwargs):
968
926
  """
969
927
  Compile Cell as a computation graph, the input must be consistent with the input defined in construct.
970
928
 
971
929
  Args:
972
- inputs (tuple): Inputs of the Cell object.
930
+ args (tuple): Args of the Cell object.
931
+ kwargs (dict): Kwargs of the Cell object.
973
932
  """
974
- if self._dynamic_shape_inputs is None or self._dynamic_shape_inputs[0] is None:
975
- _cell_graph_executor.compile(self, *inputs, phase=self.phase, auto_parallel_mode=self._auto_parallel_mode,
976
- jit_config_dict=self._jit_config_dict)
933
+ if self._dynamic_shape_inputs is None:
934
+ _cell_graph_executor.compile(self, phase=self.phase,
935
+ jit_config_dict=self._jit_config_dict, *args, **kwargs)
977
936
  else:
978
- self._check_compile_dynamic_shape(*inputs)
979
- if self.saved_dynamic_shape:
980
- for i in range(len(self.saved_dynamic_shape)):
981
- if self.saved_dynamic_shape[i].shape != self._dynamic_shape_inputs[i].shape:
982
- return
983
-
937
+ self._check_compile_dynamic_shape(self._dynamic_shape_inputs, args)
984
938
  self.saved_dynamic_shape = self._dynamic_shape_inputs
985
939
  _cell_graph_executor.compile(self, *self._dynamic_shape_inputs, phase=self.phase,
986
- auto_parallel_mode=self._auto_parallel_mode,
987
- jit_config_dict=self._jit_config_dict)
940
+ jit_config_dict=self._jit_config_dict, **kwargs)
988
941
  logger.debug("Compiled Graph with dynamic shape")
989
942
 
990
- def compile_and_run(self, *inputs):
943
+ def compile_and_run(self, *args, **kwargs):
991
944
  """
992
945
  Compile and run Cell, the input must be consistent with the input defined in construct.
993
946
 
@@ -995,25 +948,25 @@ class Cell(Cell_):
995
948
  It is not recommended to call directly.
996
949
 
997
950
  Args:
998
- inputs (tuple): Inputs of the Cell object.
951
+ args (tuple): Args of the Cell object.
952
+ kwargs (dict): Kwargs of the Cell object.
999
953
 
1000
954
  Returns:
1001
955
  Object, the result of executing.
1002
956
  """
1003
- self._auto_parallel_compile_and_run = True
1004
- self.compile(*inputs)
957
+ self.compile(*args, **kwargs)
1005
958
 
1006
- new_inputs = _get_args_for_run(self, inputs)
1007
- return _cell_graph_executor(self, *new_inputs, phase=self.phase)
959
+ new_args = _get_args_for_run(self, args, kwargs)
960
+ return _cell_graph_executor(self, *new_args, phase=self.phase)
1008
961
 
1009
962
  def auto_parallel_compile_and_run(self):
1010
963
  """
1011
964
  Whether or not to execute compile and run in 'AUTO_PARALLEL' or 'SEMI_AUTO_PARALLEL' mode.
1012
965
 
1013
- Returns:
1014
- bool, `_auto_parallel_compile_and_run` value.
966
+ Note:
967
+ This interface is deprecated.
1015
968
  """
1016
- return self._auto_parallel_compile_and_run
969
+ logger.warning("'auto_parallel_compile_and_run' function is deprecated.")
1017
970
 
1018
971
  def exec_checkpoint_graph(self):
1019
972
  """Executes saving checkpoint graph operation."""
@@ -1063,6 +1016,8 @@ class Cell(Cell_):
1063
1016
  Returns:
1064
1017
  Parameter, the input parameter with type automatically cast.
1065
1018
  """
1019
+ msg = f"'cast_param' is deprecated from version 2.0 and will be removed in a future version."
1020
+ logger.warning(msg)
1066
1021
  mixed_type = self.get_mixed_precision_type()
1067
1022
  if mixed_type != MixedPrecisionType.NOTSET:
1068
1023
  if mixed_type == MixedPrecisionType.FP32:
@@ -1084,8 +1039,12 @@ class Cell(Cell_):
1084
1039
 
1085
1040
  Raises:
1086
1041
  KeyError: Child Cell's name is incorrect or duplicated with the other child name.
1042
+ TypeError: If type of `child_name` is not str.
1087
1043
  TypeError: Child Cell's type is incorrect.
1088
1044
  """
1045
+ if not isinstance(child_name, str):
1046
+ raise TypeError(f"For 'insert_child_to_cell', the type of parameter 'child_name' must be str, "
1047
+ f"but got {type(child_name)}.")
1089
1048
  if not child_name or '.' in child_name:
1090
1049
  raise KeyError("For 'insert_child_to_cell', the parameter 'child_name' can not be None and "
1091
1050
  "can not contain '.'")
@@ -1097,7 +1056,7 @@ class Cell(Cell_):
1097
1056
  f"but got type {type(child_cell)}.")
1098
1057
  self._cells[child_name] = child_cell
1099
1058
 
1100
- def construct(self, *inputs, **kwargs):
1059
+ def construct(self, *args, **kwargs):
1101
1060
  """
1102
1061
  Defines the computation to be performed. This method must be overridden by all subclasses.
1103
1062
 
@@ -1105,7 +1064,7 @@ class Cell(Cell_):
1105
1064
  It is not supported currently that inputs contain both tuple and non-tuple types at same time.
1106
1065
 
1107
1066
  Args:
1108
- inputs (tuple): Tuple of variable parameters.
1067
+ args (tuple): Tuple of variable parameters.
1109
1068
  kwargs (dict): Dictionary of variable keyword parameters.
1110
1069
 
1111
1070
  Returns:
@@ -1158,15 +1117,7 @@ class Cell(Cell_):
1158
1117
  def _updata(param):
1159
1118
  if param in replace:
1160
1119
  return replace.get(param)
1161
- layout = None
1162
- set_sliced = False
1163
- if auto_parallel_mode:
1164
- set_sliced = True
1165
- if param.name not in self.parameter_layout_dict:
1166
- logger.debug("Layout dict does not contain the key %s.", param.name)
1167
- else:
1168
- layout = self.parameter_layout_dict[param.name]
1169
- new_p = param.init_data(layout, set_sliced=set_sliced)
1120
+ new_p = param.init_data(None, set_sliced=False)
1170
1121
  replace[param] = new_p
1171
1122
  return new_p
1172
1123
 
@@ -1265,6 +1216,7 @@ class Cell(Cell_):
1265
1216
  param.is_init = False
1266
1217
  param.name = prefix + name
1267
1218
 
1219
+ @jit_forbidden_register
1268
1220
  def trainable_params(self, recurse=True):
1269
1221
  """
1270
1222
  Returns all trainable parameters.
@@ -1279,6 +1231,7 @@ class Cell(Cell_):
1279
1231
  """
1280
1232
  return list(filter(lambda x: x.requires_grad, self.get_parameters(expand=recurse)))
1281
1233
 
1234
+ @jit_forbidden_register
1282
1235
  def untrainable_params(self, recurse=True):
1283
1236
  """
1284
1237
  Returns all untrainable parameters.
@@ -1293,6 +1246,7 @@ class Cell(Cell_):
1293
1246
  """
1294
1247
  return list(filter(lambda x: not x.requires_grad, self.get_parameters(expand=recurse)))
1295
1248
 
1249
+ @jit_forbidden_register
1296
1250
  def get_parameters(self, expand=True):
1297
1251
  """
1298
1252
  Returns an iterator over cell parameters.
@@ -1484,6 +1438,38 @@ class Cell(Cell_):
1484
1438
  if "fp32" in flags and flags.get("fp32", False):
1485
1439
  self._set_mixed_precision_type_recursive(MixedPrecisionType.FP32)
1486
1440
 
1441
+ def apply(self, fn):
1442
+ """
1443
+ Applies fn recursively to every subcell (as returned by .cells()) as well as self.
1444
+ Typical use includes initializing the parameters of a model.
1445
+
1446
+ Args:
1447
+ fn (function): function to be applied to each subcell.
1448
+
1449
+ Returns:
1450
+ Cell, self.
1451
+
1452
+ Examples:
1453
+ >>> import mindspore.nn as nn
1454
+ >>> from mindspore.common.initializer import initializer, One
1455
+ >>> net = nn.SequentialCell(nn.Dense(2, 2), nn.Dense(2, 2))
1456
+ >>> def func(cell):
1457
+ ... if isinstance(cell, nn.Dense):
1458
+ ... cell.weight.set_data(initializer(One(), cell.weight.shape, cell.weight.dtype))
1459
+ >>> net.apply(func)
1460
+ SequentialCell<
1461
+ (0): Dense<input_channels=2, output_channels=2, has_bias=True>
1462
+ (1): Dense<input_channels=2, output_channels=2, has_bias=True>
1463
+ >
1464
+ >>> print(net[0].weight.asnumpy())
1465
+ [[1. 1.]
1466
+ [1. 1.]]
1467
+ """
1468
+ for cell in self.cells():
1469
+ cell.apply(fn)
1470
+ fn(self)
1471
+ return self
1472
+
1487
1473
  def add_flags(self, **flags):
1488
1474
  """
1489
1475
  Add customized attributes for cell.
@@ -1538,7 +1524,7 @@ class Cell(Cell_):
1538
1524
  Add cast on all inputs of cell and child cells to run with certain float type.
1539
1525
 
1540
1526
  If `dst_type` is `mindspore.dtype.float16`, all the inputs of Cell, including input, Parameter and Tensor, will
1541
- be cast to float16. Please refer to the usage in source code of :func:`mindspore.build_train_network`.
1527
+ be cast to float16. Please refer to the usage in source code of :func:`mindspore.amp.build_train_network`.
1542
1528
 
1543
1529
  Note:
1544
1530
  Multiple calls will overwrite.
@@ -1554,7 +1540,7 @@ class Cell(Cell_):
1554
1540
  ValueError: If dst_type is not mstype.float32 or mstype.float16.
1555
1541
 
1556
1542
  Supported Platforms:
1557
- ``Ascend`` ``GPU`` ``CPU``
1543
+ ``Ascend`` ``GPU`` ``CPU``
1558
1544
 
1559
1545
  Examples:
1560
1546
  >>> import mindspore.nn as nn
@@ -1570,8 +1556,10 @@ class Cell(Cell_):
1570
1556
  "but got {}.".format(dst_type))
1571
1557
  if dst_type == mstype.float16:
1572
1558
  self._set_mixed_precision_type_recursive(MixedPrecisionType.FP16)
1559
+ self.to_float_fp16 = True
1573
1560
  else:
1574
1561
  self._set_mixed_precision_type_recursive(MixedPrecisionType.FP32)
1562
+ self.to_float_fp16 = False
1575
1563
  flags = {'fp16': dst_type == mstype.float16, 'fp32': dst_type == mstype.float32}
1576
1564
  self._add_init_args(**flags)
1577
1565
  return self
@@ -1582,7 +1570,7 @@ class Cell(Cell_):
1582
1570
  accelerate the algorithm in the algorithm library.
1583
1571
 
1584
1572
  If `boost_type` is not in the algorithm library, please view the algorithm in the algorithm library through
1585
- `algorithm library <https://gitee.com/mindspore/mindspore/tree/r1.10/mindspore/python/mindspore/boost>`_.
1573
+ `algorithm library <https://gitee.com/mindspore/mindspore/tree/r2.0/mindspore/python/mindspore/boost>`_.
1586
1574
 
1587
1575
  Note:
1588
1576
  Some acceleration algorithms may affect the accuracy of the network, please choose carefully.
@@ -1627,6 +1615,10 @@ class Cell(Cell_):
1627
1615
  for training and predicting, such as `BatchNorm`, will distinguish between the branches by this attribute. If
1628
1616
  set to true, the training branch will be executed, otherwise another branch.
1629
1617
 
1618
+ Note:
1619
+ When execute function Model.train(), framework will call Cell.set_train(True).
1620
+ When execute function Model.eval(), framework will call Cell.set_train(False).
1621
+
1630
1622
  Args:
1631
1623
  mode (bool): Specifies whether the model is training. Default: True.
1632
1624
 
@@ -1655,11 +1647,9 @@ class Cell(Cell_):
1655
1647
  Set the cell to auto parallel mode.
1656
1648
 
1657
1649
  Note:
1658
- If a cell needs to use the auto parallel or semi auto parallel mode for training, evaluation or prediction,
1659
- this interface needs to be called by the cell.
1650
+ This interface is deprecated.
1660
1651
  """
1661
- self._auto_parallel_mode = True
1662
- self._get_construct_inputs_number_and_name()
1652
+ logger.warning("'set_auto_parallel' function is deprecated.")
1663
1653
 
1664
1654
  def set_jit_config(self, jit_config):
1665
1655
  """
@@ -1672,6 +1662,11 @@ class Cell(Cell_):
1672
1662
  logger.warning("For Cell, jit config can only be set once, ignore this setting.")
1673
1663
  else:
1674
1664
  self._jit_config_dict = jit_config.jit_config_dict
1665
+ enable_ge = os.getenv("MS_ENABLE_GE") == '1'
1666
+ enable_jit_level_o3 = self._jit_config_dict.get('jit_level') == "O3"
1667
+ if (not enable_ge and enable_jit_level_o3) or (enable_ge and not enable_jit_level_o3):
1668
+ raise RuntimeError("GE and jit_level=O3 should be used together, but got MS_ENABLE_GE={}, jie_level={}".
1669
+ format(os.getenv("MS_ENABLE_GE"), self.jit_config_dict.get('jit_level')))
1675
1670
 
1676
1671
  def flatten_weights(self, fusion_size=0):
1677
1672
  """
@@ -1695,7 +1690,7 @@ class Cell(Cell_):
1695
1690
  Register forward pre hook function for Cell object.
1696
1691
 
1697
1692
  Note:
1698
- - The `register_forward_pre_hook(hook_fn)` does not work in graph mode or ms_function.
1693
+ - The `register_forward_pre_hook(hook_fn)` does not work in graph mode or functions decorated with 'jit'.
1699
1694
  - 'hook_fn' must be defined as the following code.
1700
1695
  `cell_id` is the information of registered Cell object, including name and ID. `inputs` is the forward
1701
1696
  input objects passed to the Cell. The 'hook_fn' can modify the forward input objects by returning new
@@ -1758,7 +1753,7 @@ class Cell(Cell_):
1758
1753
  raise TypeError(f"When using 'register_forward_pre_hook(hook_fn)', the type of 'hook_fn' must be python "
1759
1754
  f"function, but got {type(hook_fn)}.")
1760
1755
  if hook_fn.__code__.co_name == "staging_specialize":
1761
- raise TypeError(f"Decorating hook function {hook_fn.__name__} with '@ms_function' is not supported.")
1756
+ raise TypeError(f"Decorating hook function {hook_fn.__name__} with '@jit' is not supported.")
1762
1757
 
1763
1758
  self._enable_forward_pre_hook = True
1764
1759
  _pynative_executor.set_hook_changed(self)
@@ -1797,7 +1792,7 @@ class Cell(Cell_):
1797
1792
  Set the Cell forward hook function.
1798
1793
 
1799
1794
  Note:
1800
- - The `register_forward_hook(hook_fn)` does not work in graph mode or ms_function.
1795
+ - The `register_forward_hook(hook_fn)` does not work in graph mode or functions decorated with 'jit'.
1801
1796
  - 'hook_fn' must be defined as the following code.
1802
1797
  `cell_id` is the information of registered Cell object, including name and ID. `inputs` is the forward
1803
1798
  input objects passed to the Cell. `output` is the forward output object of the Cell. The 'hook_fn' can
@@ -1862,7 +1857,7 @@ class Cell(Cell_):
1862
1857
  raise TypeError(f"When using 'register_forward_hook(hook_fn)', the type of 'hook_fn' must be python "
1863
1858
  f"function, but got {type(hook_fn)}.")
1864
1859
  if hook_fn.__code__.co_name == "staging_specialize":
1865
- raise TypeError(f"Decorating hook function {hook_fn.__name__} with '@ms_function' is not supported.")
1860
+ raise TypeError(f"Decorating hook function {hook_fn.__name__} with '@jit' is not supported.")
1866
1861
 
1867
1862
  self._enable_forward_hook = True
1868
1863
  _pynative_executor.set_hook_changed(self)
@@ -1899,7 +1894,7 @@ class Cell(Cell_):
1899
1894
  Register the backward hook function.
1900
1895
 
1901
1896
  Note:
1902
- - The `register_backward_hook(hook_fn)` does not work in graph mode or ms_function.
1897
+ - The `register_backward_hook(hook_fn)` does not work in graph mode or functions decorated with 'jit'.
1903
1898
  - The 'hook_fn' must be defined as the following code.
1904
1899
  `cell_id` is the information of registered Cell object, including name and ID. `grad_input` is the
1905
1900
  gradient passed to the Cell. `grad_output` is the gradient computed and passed to the next Cell or
@@ -2002,6 +1997,7 @@ class Cell(Cell_):
2002
1997
 
2003
1998
  Note:
2004
1999
  It only works when a running task is in the parameter server mode.
2000
+ It is only supported in graph mode.
2005
2001
 
2006
2002
  Args:
2007
2003
  recurse (bool): Whether sets the trainable parameters of subcells. Default: True.
@@ -2083,9 +2079,6 @@ class Cell(Cell_):
2083
2079
  """
2084
2080
  Set the cell recomputed.
2085
2081
  """
2086
- if context._get_mode() == context.PYNATIVE_MODE:
2087
- raise TypeError("Recompute is not supported in pynative mode currently, you can use "
2088
- "'context.set_context(mode=context.GRAPH_MODE)' to set graph mode.")
2089
2082
  Validator.check_bool(mode)
2090
2083
  Validator.check_bool(output_recompute)
2091
2084
  if not self._has_config_recompute:
@@ -2184,35 +2177,86 @@ class Cell(Cell_):
2184
2177
  params.append(param)
2185
2178
  return params
2186
2179
 
2187
- def _check_compile_dynamic_shape(self, *inputs):
2180
+ def place(self, role, rank_id):
2188
2181
  """
2189
- Check if graph has been compiled with dynamic shape.
2182
+ Set the label for all operators in this cell.
2183
+ This label tells MindSpore compiler on which process this cell should be launched.
2184
+ And each process's identical label consists of input `role` and `rank_id`.
2185
+ So by setting different cells with different labels, which will be launched on different processes,
2186
+ users can launch a distributed training or predicting job.
2187
+
2188
+ Note:
2189
+ - This method is effective only after
2190
+ `mindspore.communication.init()` is called for dynamic cluster building.
2190
2191
 
2191
2192
  Args:
2192
- inputs (tuple): Inputs of the Cell object.
2193
+ role (str): The role of the process on which this cell will be launched.
2194
+ Only 'MS_WORKER' is supported for now.
2195
+ rank_id (int): The rank id of the process on which this cell will be launched.
2196
+ The rank is unique in processes with the same role.
2197
+
2198
+ Examples:
2199
+ >>> from mindspore import context
2200
+ >>> import mindspore.nn as nn
2201
+ >>> context.set_context(mode=context.GRAPH_MODE)
2202
+ >>> fc = nn.Dense(2, 3)
2203
+ >>> fc.place('MS_WORKER', 0)
2204
+ """
2205
+ all_ops = self._get_prims_recursively()
2206
+ for op in all_ops:
2207
+ op.place(role, rank_id)
2208
+
2209
+ def _check_dynamic_tensor(self, set_input, net_input, index):
2193
2210
  """
2194
- set_inputs_len = len(self._dynamic_shape_inputs)
2195
- inputs_len = len(inputs)
2196
- if set_inputs_len != inputs_len:
2197
- raise ValueError("The number of 'set_input' Tensor must be equal to network's inputs."
2198
- f"but got 'set_inputs': {set_inputs_len} and network's input: {inputs_len}.")
2199
- for index, (net_input, set_input) in enumerate(zip(inputs, self._dynamic_shape_inputs)):
2211
+ Check if tensor is correctly set for dynamic shape.
2212
+
2213
+ Args:
2214
+ set_input (Tensor): Tensor set for dynamic shape.
2215
+ net_input (Tensor): Input tensor of the Cell object.
2216
+ index (int): Tensor index for set inputs.
2217
+ """
2218
+ if not isinstance(net_input, Tensor):
2219
+ raise TypeError(
2220
+ f"The {index + 1}th input type of 'set_inputs' must be Tensor, but got {type(net_input)}.")
2221
+ if set_input.dtype != net_input.dtype:
2222
+ raise ValueError(
2223
+ f"The {index + 1}th input type of 'set_inputs' must be the same as network's input, "
2224
+ f"but got 'set_inputs': {set_input.dtype} and network's input: {net_input.dtype}.")
2225
+ if net_input.dim() != 0 and set_input.dim() != net_input.dim():
2226
+ raise ValueError(
2227
+ f"The {index + 1}th input dims of 'set_inputs' must be the same as network's input, "
2228
+ f"but got 'set_inputs': {set_input.dim()} and network's input: {net_input.dim()}.")
2229
+ if not all([ele1 in (-1, ele2) for ele1, ele2 in zip(set_input.shape, net_input.shape)]):
2230
+ raise ValueError(
2231
+ f"The {index + 1}th input shape of 'set_inputs' must be the same as network's input, "
2232
+ f"but got 'set_inputs': {set_input.shape} and network's input: {net_input.shape}.")
2233
+
2234
+ def _check_compile_dynamic_shape(self, set_inputs, net_inputs):
2235
+ """
2236
+ Check if graph has been compiled with dynamic shape.
2237
+
2238
+ Args:
2239
+ net_inputs (tuple): Inputs of the Cell object.
2240
+ """
2241
+ set_inputs_len = len(set_inputs)
2242
+ net_inputs_len = len(net_inputs)
2243
+ if set_inputs_len != net_inputs_len:
2244
+ raise ValueError("The length of 'set_inputs' must be equal to network's inputs, "
2245
+ f"but got 'set_inputs': {set_inputs_len} and network's input: {net_inputs_len}.")
2246
+ for index, (set_input, net_input) in enumerate(zip(set_inputs, net_inputs)):
2200
2247
  if isinstance(set_input, Tensor):
2201
- if not isinstance(net_input, Tensor):
2248
+ self._check_dynamic_tensor(set_input, net_input, index)
2249
+ elif isinstance(set_input, (tuple, list)):
2250
+ if not isinstance(net_input, (tuple, list)):
2202
2251
  raise TypeError(
2203
- f"The {index + 1}th input type of 'set_inputs' must be Tensor, but got {type(net_input)}.")
2204
- if set_input.dtype is not net_input.dtype:
2205
- raise ValueError(
2206
- f"The {index + 1}th input type of 'set_inputs' must be the same as network's input, "
2207
- f"but got 'set_inputs': {set_input.dtype} and network's input: {net_input.dtype}.")
2208
- if net_input.dim() != 0 and set_input.dim() != net_input.dim():
2209
- raise ValueError(
2210
- f"The {index + 1}th input dims of 'set_inputs' must be the same as network's input, "
2211
- f"but got 'set_inputs': {set_input.dim()} and network's input: {net_input.dim()}.")
2212
- if not all([ele1 in (-1, ele2) for ele1, ele2 in zip(set_input.shape, net_input.shape)]):
2252
+ f"The {index + 1}th input type of 'set_inputs' must be tuple or list, "
2253
+ f"but got {type(net_input)}.")
2254
+ self._check_compile_dynamic_shape(set_input, net_input)
2255
+ else:
2256
+ if net_input != set_input:
2213
2257
  raise ValueError(
2214
- f"The {index + 1}th input shape of 'set_inputs' must be the same as network's input, "
2215
- f"but got 'set_inputs': {set_input.shape} and network's input: {net_input.shape}.")
2258
+ f"The {index + 1}th input of 'set_inputs' must be the same with network's input, but got "
2259
+ f"set_inputs: {set_input} and network's input: {net_input}.")
2216
2260
 
2217
2261
 
2218
2262
  class GraphCell(Cell):
@@ -2228,6 +2272,11 @@ class GraphCell(Cell):
2228
2272
  The key is the parameter name whose type is str, and the value is a Tensor or Parameter.
2229
2273
  If the parameter exists in the graph according to the name, update it's value.
2230
2274
  If the parameter does not exist, ignore it. Default: None.
2275
+ obf_random_seed (Union[int, None]): The random seed used for dynamic obfuscation. "dynamic obfuscation" is
2276
+ used for model protection, which can refer to :func:`mindspore.obfuscate_model`. If the input `graph` is
2277
+ a func_graph loaded from a mindir file obfuscated with `obf_random_seed` , then `obf_random_seed` should be
2278
+ provided. `obf_random_seed` should be in (0, 9223372036854775807]. default: None.
2279
+
2231
2280
  Raises:
2232
2281
  TypeError: If the `graph` is not a FuncGraph.
2233
2282
  TypeError: If the `params_init` is not a dict.
@@ -2242,7 +2291,8 @@ class GraphCell(Cell):
2242
2291
  >>> import mindspore as ms
2243
2292
  >>> import mindspore.nn as nn
2244
2293
  >>> from mindspore import Tensor
2245
- >>>
2294
+ >>> from mindspore import context
2295
+ >>> context.set_context(mode=context.GRAPH_MODE)
2246
2296
  >>> net = nn.Conv2d(1, 1, kernel_size=3, weight_init="ones")
2247
2297
  >>> input = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
2248
2298
  >>> ms.export(net, input, file_name="net", file_format="MINDIR")
@@ -2254,13 +2304,23 @@ class GraphCell(Cell):
2254
2304
  [6. 9. 6.]
2255
2305
  [4. 6. 4.]]]]
2256
2306
  """
2257
- def __init__(self, graph, params_init=None):
2307
+
2308
+ def __init__(self, graph, params_init=None, obf_random_seed=None):
2258
2309
  super(GraphCell, self).__init__(auto_prefix=True)
2259
2310
  if not isinstance(graph, FuncGraph):
2260
2311
  raise TypeError(f"For 'GraphCell', the argument 'graph' must be a FuncGraph loaded from MindIR, "
2261
2312
  f"but got type {type(graph)}.")
2262
2313
  self.graph = graph
2263
-
2314
+ self.obf_random_seed = obf_random_seed
2315
+ if obf_random_seed is not None:
2316
+ if not isinstance(obf_random_seed, int):
2317
+ raise TypeError("'obf_random_seed' must be int, but got {}.".format(type(obf_random_seed)))
2318
+ int_64_max = 9223372036854775807
2319
+ if obf_random_seed <= 0 or obf_random_seed > int_64_max:
2320
+ raise ValueError(
2321
+ "'obf_random_seed' must be larger than 0, and less or equal than int64 ({}),"
2322
+ "but got {}.".format(int_64_max, obf_random_seed))
2323
+ self._branch_control_input = _generate_branch_control_input(self.obf_random_seed)
2264
2324
  params_init = {} if params_init is None else params_init
2265
2325
  if not isinstance(params_init, dict):
2266
2326
  raise TypeError(f"For 'GraphCell', the argument 'params_init' must be a dict, but got {type(params_init)}.")
@@ -2277,10 +2337,13 @@ class GraphCell(Cell):
2277
2337
  def construct(self, *inputs):
2278
2338
  return self.graph(*inputs)
2279
2339
 
2280
- def __call__(self, *inputs):
2340
+ def __call__(self, *args, **kwargs):
2281
2341
  self.phase = "graph_load_from_mindir"
2282
2342
  self._add_attr("graph_load_from_mindir", self.graph)
2283
- return self.compile_and_run(*inputs)
2343
+ if not self.obf_random_seed:
2344
+ return self.compile_and_run(*args, **kwargs)
2345
+ append_input = Tensor((numpy.ones((1, 1)) * self._branch_control_input).astype(numpy.int32))
2346
+ return self.compile_and_run(*args, append_input, **kwargs)
2284
2347
 
2285
2348
 
2286
2349
  def _check_param_list_tuple(value):