mindspore 2.0.0rc1__cp38-none-any.whl → 2.2.0__cp38-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (870) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Third_Party_Open_Source_Software_Notice +2 -2
  3. mindspore/__init__.py +5 -2
  4. mindspore/_akg/akg/build_module.py +5 -6
  5. mindspore/_akg/akg/composite/build_module.py +49 -16
  6. mindspore/_akg/akg/composite/split_stitch.py +10 -11
  7. mindspore/_akg/akg/config/repository.json +195 -0
  8. mindspore/_akg/akg/global_configs.py +5 -1
  9. mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
  10. mindspore/_akg/akg/tvm/api.py +4 -3
  11. mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
  12. mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
  13. mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
  14. mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
  15. mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
  16. mindspore/_akg/akg/tvm/build_module.py +16 -1
  17. mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
  18. mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
  19. mindspore/_akg/akg/tvm/ir_builder.py +1 -1
  20. mindspore/_akg/akg/tvm/module.py +1 -2
  21. mindspore/_akg/akg/tvm/stmt.py +2 -2
  22. mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
  23. mindspore/_akg/akg/utils/kernel_exec.py +58 -260
  24. mindspore/_akg/akg/utils/op_dsl.py +17 -1
  25. mindspore/_akg/akg/utils/result_analysis.py +4 -24
  26. mindspore/_akg/akg/utils/tbe_codegen_utils.py +198 -0
  27. mindspore/_c_dataengine.cpython-38-aarch64-linux-gnu.so +0 -0
  28. mindspore/_c_expression.cpython-38-aarch64-linux-gnu.so +0 -0
  29. mindspore/_c_mindrecord.cpython-38-aarch64-linux-gnu.so +0 -0
  30. mindspore/_check_jit_forbidden_api.py +5 -1
  31. mindspore/_checkparam.py +79 -62
  32. mindspore/_extends/graph_kernel/__init__.py +0 -1
  33. mindspore/_extends/graph_kernel/model/graph_split.py +2 -0
  34. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  35. mindspore/_extends/graph_kernel/splitter.py +1 -9
  36. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +128 -21
  37. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +2 -2
  38. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  39. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +18 -13
  40. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +13 -9
  41. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  42. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  43. mindspore/_extends/parse/__init__.py +19 -17
  44. mindspore/_extends/parse/namespace.py +7 -36
  45. mindspore/_extends/parse/parser.py +375 -189
  46. mindspore/_extends/parse/resources.py +36 -41
  47. mindspore/_extends/parse/standard_method.py +350 -245
  48. mindspore/_extends/parse/trope.py +2 -12
  49. mindspore/_extends/remote/kernel_build_server.py +24 -7
  50. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  51. mindspore/_install_custom.py +43 -0
  52. mindspore/_mindspore_offline_debug.cpython-38-aarch64-linux-gnu.so +0 -0
  53. mindspore/amp.py +85 -19
  54. mindspore/bin/cache_admin +0 -0
  55. mindspore/bin/cache_server +0 -0
  56. mindspore/boost/base.py +2 -2
  57. mindspore/boost/boost.py +27 -32
  58. mindspore/boost/boost_cell_wrapper.py +37 -13
  59. mindspore/boost/grad_accumulation.py +1 -1
  60. mindspore/boost/grad_freeze.py +34 -6
  61. mindspore/boost/group_loss_scale_manager.py +15 -14
  62. mindspore/boost/less_batch_normalization.py +28 -3
  63. mindspore/common/__init__.py +15 -11
  64. mindspore/common/_auto_dynamic.py +68 -0
  65. mindspore/common/_jit_fallback_utils.py +111 -0
  66. mindspore/common/_register_for_adapter.py +17 -5
  67. mindspore/common/_register_for_tensor.py +2 -2
  68. mindspore/common/_stub_tensor.py +18 -15
  69. mindspore/common/_utils.py +31 -7
  70. mindspore/common/api.py +269 -101
  71. mindspore/common/auto_dynamic_shape.py +498 -0
  72. mindspore/common/dtype.py +61 -21
  73. mindspore/common/dump.py +9 -7
  74. mindspore/common/initializer.py +106 -76
  75. mindspore/common/jit_config.py +35 -14
  76. mindspore/common/lazy_inline.py +187 -0
  77. mindspore/common/mindir_util.py +101 -0
  78. mindspore/common/mutable.py +10 -13
  79. mindspore/common/parameter.py +246 -55
  80. mindspore/common/seed.py +13 -7
  81. mindspore/common/sparse_tensor.py +29 -33
  82. mindspore/common/tensor.py +907 -251
  83. mindspore/communication/__init__.py +7 -4
  84. mindspore/communication/_comm_helper.py +84 -4
  85. mindspore/communication/management.py +160 -88
  86. mindspore/config/op_info.config +99 -75
  87. mindspore/config/super_bar_config.json +36 -4
  88. mindspore/context.py +526 -219
  89. mindspore/dataset/__init__.py +9 -46
  90. mindspore/dataset/audio/__init__.py +4 -19
  91. mindspore/dataset/audio/transforms.py +545 -233
  92. mindspore/dataset/audio/utils.py +21 -18
  93. mindspore/dataset/callback/ds_callback.py +42 -13
  94. mindspore/dataset/core/config.py +158 -100
  95. mindspore/dataset/core/validator_helpers.py +1 -63
  96. mindspore/dataset/debug/debug_hook.py +45 -13
  97. mindspore/dataset/debug/pre_defined_hook.py +5 -5
  98. mindspore/dataset/engine/__init__.py +0 -5
  99. mindspore/dataset/engine/cache_client.py +38 -15
  100. mindspore/dataset/engine/datasets.py +615 -278
  101. mindspore/dataset/engine/datasets_audio.py +154 -283
  102. mindspore/dataset/engine/datasets_standard_format.py +104 -116
  103. mindspore/dataset/engine/datasets_text.py +443 -326
  104. mindspore/dataset/engine/datasets_user_defined.py +251 -164
  105. mindspore/dataset/engine/datasets_vision.py +839 -1443
  106. mindspore/dataset/engine/iterators.py +11 -4
  107. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +7 -3
  108. mindspore/dataset/engine/obs/util.py +3 -0
  109. mindspore/dataset/engine/offload.py +6 -6
  110. mindspore/dataset/engine/queue.py +15 -14
  111. mindspore/dataset/engine/samplers.py +39 -23
  112. mindspore/dataset/engine/serializer_deserializer.py +22 -6
  113. mindspore/dataset/engine/validators.py +21 -331
  114. mindspore/dataset/text/__init__.py +5 -33
  115. mindspore/dataset/text/transforms.py +334 -165
  116. mindspore/dataset/text/utils.py +215 -145
  117. mindspore/dataset/transforms/__init__.py +1 -1
  118. mindspore/dataset/transforms/c_transforms.py +3 -2
  119. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  120. mindspore/dataset/transforms/transforms.py +174 -71
  121. mindspore/dataset/utils/browse_dataset.py +25 -17
  122. mindspore/dataset/utils/line_reader.py +24 -21
  123. mindspore/dataset/vision/__init__.py +5 -26
  124. mindspore/dataset/vision/c_transforms.py +177 -165
  125. mindspore/dataset/vision/py_transforms.py +114 -119
  126. mindspore/dataset/vision/py_transforms_util.py +54 -51
  127. mindspore/dataset/vision/transforms.py +1127 -381
  128. mindspore/dataset/vision/utils.py +54 -38
  129. mindspore/dataset/vision/validators.py +12 -2
  130. mindspore/experimental/map_parameter.py +38 -4
  131. mindspore/{dataset/datapreprocess → experimental/optim}/__init__.py +14 -4
  132. mindspore/experimental/optim/adam.py +192 -0
  133. mindspore/experimental/optim/adamw.py +181 -0
  134. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  135. mindspore/experimental/optim/optimizer.py +252 -0
  136. mindspore/experimental/optim/sgd.py +147 -0
  137. mindspore/gen_ops.py +273 -0
  138. mindspore/include/OWNERS +1 -2
  139. mindspore/include/api/context.h +21 -1
  140. mindspore/include/api/data_type.h +2 -1
  141. mindspore/include/api/graph.h +0 -15
  142. mindspore/include/api/kernel.h +2 -0
  143. mindspore/include/api/kernel_api.h +37 -12
  144. mindspore/include/api/model.h +29 -42
  145. mindspore/include/api/model_group.h +14 -3
  146. mindspore/include/api/model_parallel_runner.h +18 -2
  147. mindspore/include/api/serialization.h +26 -0
  148. mindspore/include/api/status.h +1 -0
  149. mindspore/include/api/types.h +38 -4
  150. mindspore/include/c_api/ms/abstract.h +67 -0
  151. mindspore/include/c_api/ms/attribute.h +197 -0
  152. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  153. mindspore/include/c_api/ms/base/macros.h +32 -0
  154. mindspore/include/c_api/ms/base/status.h +33 -0
  155. mindspore/include/c_api/ms/base/types.h +282 -0
  156. mindspore/include/c_api/ms/context.h +102 -0
  157. mindspore/include/c_api/ms/graph.h +160 -0
  158. mindspore/include/c_api/ms/node.h +606 -0
  159. mindspore/include/c_api/ms/tensor.h +161 -0
  160. mindspore/include/c_api/ms/value.h +84 -0
  161. mindspore/include/c_api/status_c.h +3 -0
  162. mindspore/include/dataset/constants.h +6 -12
  163. mindspore/include/dataset/execute.h +23 -13
  164. mindspore/include/dataset/text.h +26 -26
  165. mindspore/include/dataset/transforms.h +25 -31
  166. mindspore/include/dataset/vision.h +60 -60
  167. mindspore/include/dataset/vision_ascend.h +5 -6
  168. mindspore/include/dataset/vision_lite.h +17 -17
  169. mindspore/include/mindapi/base/format.h +0 -1
  170. mindspore/include/mindapi/base/type_id.h +2 -1
  171. mindspore/include/mindapi/base/types.h +5 -1
  172. mindspore/lib/libdnnl.so.2 +0 -0
  173. mindspore/lib/libjemalloc.so.2 +0 -0
  174. mindspore/lib/libmindspore.so +0 -0
  175. mindspore/lib/libmindspore_backend.so +0 -0
  176. mindspore/lib/libmindspore_common.so +0 -0
  177. mindspore/lib/libmindspore_core.so +0 -0
  178. mindspore/lib/libmindspore_glog.so.0 +0 -0
  179. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  180. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  181. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  182. mindspore/lib/libmindspore_shared_lib.so +0 -0
  183. mindspore/lib/libmpi_adapter.so +0 -0
  184. mindspore/lib/libnnacl.so +0 -0
  185. mindspore/lib/libopencv_core.so.4.5 +0 -0
  186. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  187. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  188. mindspore/lib/libps_cache.so +0 -0
  189. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  190. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  191. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +9000 -0
  192. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  193. mindspore/lib/plugin/ascend/libakg.so +0 -0
  194. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  195. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  196. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  197. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  198. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  199. mindspore/lib/plugin/cpu/libakg.so +0 -0
  200. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  201. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  202. mindspore/log.py +9 -6
  203. mindspore/mindrecord/filereader.py +33 -4
  204. mindspore/mindrecord/filewriter.py +70 -35
  205. mindspore/mindrecord/mindpage.py +40 -34
  206. mindspore/mindrecord/shardreader.py +1 -1
  207. mindspore/mindrecord/shardsegment.py +1 -1
  208. mindspore/mindrecord/tools/cifar100_to_mr.py +25 -18
  209. mindspore/mindrecord/tools/cifar10_to_mr.py +25 -18
  210. mindspore/mindrecord/tools/csv_to_mr.py +29 -13
  211. mindspore/mindrecord/tools/imagenet_to_mr.py +24 -10
  212. mindspore/mindrecord/tools/mnist_to_mr.py +24 -11
  213. mindspore/mindrecord/tools/tfrecord_to_mr.py +31 -26
  214. mindspore/nn/cell.py +463 -169
  215. mindspore/nn/dynamic_lr.py +47 -43
  216. mindspore/nn/layer/activation.py +225 -82
  217. mindspore/nn/layer/basic.py +121 -79
  218. mindspore/nn/layer/channel_shuffle.py +21 -21
  219. mindspore/nn/layer/combined.py +33 -26
  220. mindspore/nn/layer/container.py +277 -22
  221. mindspore/nn/layer/conv.py +441 -304
  222. mindspore/nn/layer/dense.py +19 -13
  223. mindspore/nn/layer/embedding.py +62 -49
  224. mindspore/nn/layer/flash_attention.py +264 -0
  225. mindspore/nn/layer/image.py +50 -39
  226. mindspore/nn/layer/math.py +62 -51
  227. mindspore/nn/layer/normalization.py +219 -167
  228. mindspore/nn/layer/padding.py +58 -70
  229. mindspore/nn/layer/pooling.py +334 -287
  230. mindspore/nn/layer/rnn_cells.py +53 -38
  231. mindspore/nn/layer/rnns.py +59 -56
  232. mindspore/nn/layer/thor_layer.py +52 -44
  233. mindspore/nn/layer/timedistributed.py +6 -4
  234. mindspore/nn/layer/transformer.py +284 -164
  235. mindspore/nn/learning_rate_schedule.py +34 -25
  236. mindspore/nn/loss/__init__.py +3 -2
  237. mindspore/nn/loss/loss.py +554 -311
  238. mindspore/nn/optim/ada_grad.py +12 -9
  239. mindspore/nn/optim/adadelta.py +14 -11
  240. mindspore/nn/optim/adafactor.py +19 -16
  241. mindspore/nn/optim/adam.py +62 -47
  242. mindspore/nn/optim/adamax.py +13 -10
  243. mindspore/nn/optim/adasum.py +12 -8
  244. mindspore/nn/optim/asgd.py +10 -9
  245. mindspore/nn/optim/ftrl.py +20 -17
  246. mindspore/nn/optim/lamb.py +16 -12
  247. mindspore/nn/optim/lars.py +8 -6
  248. mindspore/nn/optim/lazyadam.py +25 -20
  249. mindspore/nn/optim/momentum.py +10 -7
  250. mindspore/nn/optim/optimizer.py +61 -9
  251. mindspore/nn/optim/proximal_ada_grad.py +14 -13
  252. mindspore/nn/optim/rmsprop.py +17 -13
  253. mindspore/nn/optim/rprop.py +30 -17
  254. mindspore/nn/optim/sgd.py +40 -23
  255. mindspore/nn/optim/thor.py +24 -26
  256. mindspore/nn/probability/bijector/bijector.py +11 -11
  257. mindspore/nn/probability/bijector/exp.py +1 -1
  258. mindspore/nn/probability/bijector/gumbel_cdf.py +3 -3
  259. mindspore/nn/probability/bijector/invert.py +1 -1
  260. mindspore/nn/probability/bijector/power_transform.py +29 -29
  261. mindspore/nn/probability/bijector/scalar_affine.py +3 -3
  262. mindspore/nn/probability/bijector/softplus.py +5 -5
  263. mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +4 -2
  264. mindspore/nn/probability/bnn_layers/conv_variational.py +13 -13
  265. mindspore/nn/probability/bnn_layers/dense_variational.py +12 -12
  266. mindspore/nn/probability/bnn_layers/layer_distribution.py +9 -8
  267. mindspore/nn/probability/distribution/_utils/custom_ops.py +19 -3
  268. mindspore/nn/probability/distribution/_utils/utils.py +1 -1
  269. mindspore/nn/probability/distribution/bernoulli.py +9 -9
  270. mindspore/nn/probability/distribution/beta.py +8 -8
  271. mindspore/nn/probability/distribution/categorical.py +23 -15
  272. mindspore/nn/probability/distribution/cauchy.py +5 -6
  273. mindspore/nn/probability/distribution/distribution.py +3 -3
  274. mindspore/nn/probability/distribution/exponential.py +4 -4
  275. mindspore/nn/probability/distribution/gamma.py +10 -10
  276. mindspore/nn/probability/distribution/geometric.py +8 -8
  277. mindspore/nn/probability/distribution/gumbel.py +8 -9
  278. mindspore/nn/probability/distribution/half_normal.py +5 -5
  279. mindspore/nn/probability/distribution/laplace.py +5 -5
  280. mindspore/nn/probability/distribution/log_normal.py +12 -11
  281. mindspore/nn/probability/distribution/logistic.py +8 -8
  282. mindspore/nn/probability/distribution/normal.py +6 -5
  283. mindspore/nn/probability/distribution/poisson.py +10 -11
  284. mindspore/nn/probability/distribution/student_t.py +8 -9
  285. mindspore/nn/probability/distribution/transformed_distribution.py +5 -5
  286. mindspore/nn/probability/distribution/uniform.py +11 -11
  287. mindspore/nn/reinforcement/tensor_array.py +2 -2
  288. mindspore/nn/sparse/sparse.py +9 -9
  289. mindspore/nn/wrap/cell_wrapper.py +188 -63
  290. mindspore/nn/wrap/grad_reducer.py +21 -12
  291. mindspore/nn/wrap/loss_scale.py +136 -49
  292. mindspore/numpy/__init__.py +4 -4
  293. mindspore/numpy/array_creations.py +55 -56
  294. mindspore/numpy/array_ops.py +134 -35
  295. mindspore/numpy/logic_ops.py +66 -20
  296. mindspore/numpy/math_ops.py +142 -139
  297. mindspore/numpy/utils_const.py +2 -2
  298. mindspore/offline_debug/convert_async.py +2 -2
  299. mindspore/ops/_grad_experimental/__init__.py +7 -5
  300. mindspore/ops/_grad_experimental/grad_array_ops.py +231 -348
  301. mindspore/ops/{_grad → _grad_experimental}/grad_base.py +1 -33
  302. mindspore/ops/{_grad → _grad_experimental}/grad_comm_ops.py +25 -13
  303. mindspore/ops/{_grad/__init__.py → _grad_experimental/grad_debug_ops.py} +15 -7
  304. mindspore/ops/{_grad → _grad_experimental}/grad_implementations.py +17 -11
  305. mindspore/ops/_grad_experimental/grad_inner_ops.py +33 -52
  306. mindspore/ops/_grad_experimental/grad_math_ops.py +151 -1224
  307. mindspore/ops/_grad_experimental/grad_nn_ops.py +141 -414
  308. mindspore/ops/{_grad → _grad_experimental}/grad_quant_ops.py +10 -6
  309. mindspore/ops/_grad_experimental/grad_sparse.py +317 -2
  310. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -13
  311. mindspore/ops/{_grad → _grad_experimental}/taylor_rule.py +1 -1
  312. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  313. mindspore/ops/_op_impl/_custom_op/flash_attention/__init__.py +0 -0
  314. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +406 -0
  315. mindspore/{_extends/graph_kernel/expanders/complex/__init__.py → ops/_op_impl/_custom_op/flash_attention/constants.py} +27 -8
  316. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +467 -0
  317. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +563 -0
  318. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +193 -0
  319. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +435 -0
  320. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
  321. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +45 -0
  322. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +67 -0
  323. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +62 -0
  324. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
  325. mindspore/ops/_op_impl/aicpu/__init__.py +41 -1
  326. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
  327. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  328. mindspore/ops/_op_impl/aicpu/cast.py +52 -0
  329. mindspore/ops/_op_impl/aicpu/coalesce.py +2 -0
  330. mindspore/ops/_op_impl/aicpu/col2im.py +3 -1
  331. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  332. mindspore/ops/_op_impl/aicpu/dropout_genmask.py +6 -0
  333. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  334. mindspore/ops/_op_impl/aicpu/eye.py +4 -4
  335. mindspore/ops/_op_impl/aicpu/fft_with_size.py +6 -0
  336. mindspore/ops/_op_impl/aicpu/fill_diagonal.py +5 -0
  337. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  338. mindspore/ops/_op_impl/aicpu/im2col.py +3 -5
  339. mindspore/ops/_op_impl/aicpu/lgamma.py +1 -0
  340. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  341. mindspore/ops/_op_impl/aicpu/lu.py +39 -0
  342. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  343. mindspore/ops/_op_impl/aicpu/masked_scatter.py +1 -0
  344. mindspore/ops/_op_impl/aicpu/masked_select_grad.py +3 -0
  345. mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
  346. mindspore/ops/_op_impl/aicpu/matrix_power.py +6 -1
  347. mindspore/ops/_op_impl/aicpu/median.py +1 -0
  348. mindspore/ops/_op_impl/aicpu/multinomial.py +9 -9
  349. mindspore/ops/_op_impl/aicpu/not_equal.py +0 -5
  350. mindspore/ops/_op_impl/aicpu/pad_v3.py +3 -1
  351. mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +2 -0
  352. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  353. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  354. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  355. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  356. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  357. mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +0 -1
  358. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +0 -6
  359. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +0 -7
  360. mindspore/ops/_op_impl/aicpu/scatter_nd.py +2 -0
  361. mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
  362. mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
  363. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  364. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  365. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -4
  366. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -4
  367. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  368. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  369. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  370. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  371. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  372. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +14 -6
  373. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +22 -8
  374. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +11 -6
  375. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +21 -10
  376. mindspore/ops/_op_impl/tbe/__init__.py +6 -4
  377. mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
  378. mindspore/ops/_op_impl/tbe/avg_pool.py +2 -2
  379. mindspore/ops/_op_impl/tbe/avg_pool_3d.py +3 -3
  380. mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +4 -4
  381. mindspore/ops/_op_impl/tbe/avg_pool_ds.py +2 -2
  382. mindspore/ops/_op_impl/tbe/avg_pool_grad.py +3 -3
  383. mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py +3 -3
  384. mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
  385. mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +2 -2
  386. mindspore/ops/_op_impl/tbe/bn_infer.py +2 -2
  387. mindspore/ops/_op_impl/tbe/bn_infer_ds.py +3 -2
  388. mindspore/ops/_op_impl/tbe/broadcast_to.py +1 -1
  389. mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +3 -3
  390. mindspore/ops/_op_impl/tbe/expand_dims.py +1 -1
  391. mindspore/ops/_op_impl/tbe/gather_v2.py +56 -0
  392. mindspore/ops/_op_impl/tbe/im2col.py +4 -4
  393. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  394. mindspore/ops/_op_impl/tbe/mem_set.py +38 -0
  395. mindspore/ops/_op_impl/tbe/scatter_nd_add.py +3 -0
  396. mindspore/ops/_op_impl/tbe/scatter_nd_d.py +1 -1
  397. mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
  398. mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +2 -2
  399. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  400. mindspore/ops/_primitive_cache.py +1 -1
  401. mindspore/ops/_tracefunc.py +241 -0
  402. mindspore/ops/_utils/utils.py +10 -2
  403. mindspore/ops/_vmap/vmap_array_ops.py +5 -3
  404. mindspore/ops/_vmap/vmap_base.py +5 -4
  405. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  406. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  407. mindspore/ops/_vmap/vmap_grad_nn_ops.py +11 -6
  408. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  409. mindspore/ops/_vmap/vmap_nn_ops.py +135 -11
  410. mindspore/ops/arg_dtype_cast.py +54 -0
  411. mindspore/ops/composite/__init__.py +7 -5
  412. mindspore/ops/composite/base.py +78 -34
  413. mindspore/ops/composite/math_ops.py +5 -695
  414. mindspore/ops/composite/multitype_ops/_compile_utils.py +403 -97
  415. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +28 -22
  416. mindspore/ops/composite/multitype_ops/add_impl.py +69 -7
  417. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
  418. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
  419. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -0
  420. mindspore/ops/composite/multitype_ops/div_impl.py +1 -0
  421. mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -0
  422. mindspore/ops/composite/multitype_ops/getitem_impl.py +48 -10
  423. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +2 -0
  424. mindspore/ops/composite/multitype_ops/greater_impl.py +2 -0
  425. mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -0
  426. mindspore/ops/composite/multitype_ops/less_equal_impl.py +2 -0
  427. mindspore/ops/composite/multitype_ops/less_impl.py +2 -0
  428. mindspore/ops/composite/multitype_ops/logic_not_impl.py +2 -2
  429. mindspore/ops/composite/multitype_ops/mod_impl.py +1 -0
  430. mindspore/ops/composite/multitype_ops/mul_impl.py +1 -0
  431. mindspore/ops/composite/multitype_ops/negative_impl.py +1 -0
  432. mindspore/ops/composite/multitype_ops/not_in_impl.py +1 -0
  433. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  434. mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
  435. mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -0
  436. mindspore/ops/composite/multitype_ops/setitem_impl.py +10 -7
  437. mindspore/ops/composite/multitype_ops/sub_impl.py +1 -0
  438. mindspore/ops/composite/multitype_ops/uadd_impl.py +2 -0
  439. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  440. mindspore/ops/deprecated.py +304 -0
  441. mindspore/ops/function/__init__.py +41 -4
  442. mindspore/ops/function/array_func.py +1108 -467
  443. mindspore/ops/function/clip_func.py +94 -27
  444. mindspore/ops/function/debug_func.py +3 -1
  445. mindspore/ops/function/grad/grad_func.py +82 -73
  446. mindspore/ops/function/image_func.py +28 -12
  447. mindspore/ops/function/linalg_func.py +135 -39
  448. mindspore/ops/function/math_func.py +3779 -894
  449. mindspore/ops/function/nn_func.py +1584 -657
  450. mindspore/ops/function/parameter_func.py +13 -3
  451. mindspore/ops/function/random_func.py +247 -153
  452. mindspore/ops/function/sparse_func.py +14 -11
  453. mindspore/ops/function/sparse_unary_func.py +173 -47
  454. mindspore/ops/function/spectral_func.py +8 -4
  455. mindspore/ops/function/vmap_func.py +8 -7
  456. mindspore/ops/functional.py +47 -16
  457. mindspore/ops/op_info_register.py +346 -86
  458. mindspore/ops/operations/__init__.py +38 -22
  459. mindspore/ops/operations/_grad_ops.py +145 -149
  460. mindspore/ops/operations/_inner_ops.py +298 -56
  461. mindspore/ops/operations/_ms_kernel.py +3 -3
  462. mindspore/ops/operations/_quant_ops.py +24 -28
  463. mindspore/ops/operations/_rl_inner_ops.py +9 -7
  464. mindspore/ops/operations/_scalar_ops.py +115 -0
  465. mindspore/ops/operations/_sequence_ops.py +148 -10
  466. mindspore/ops/operations/_tensor_array.py +1 -1
  467. mindspore/ops/operations/_thor_ops.py +2 -2
  468. mindspore/ops/operations/array_ops.py +1239 -561
  469. mindspore/ops/operations/comm_ops.py +166 -90
  470. mindspore/ops/operations/control_ops.py +3 -3
  471. mindspore/ops/operations/custom_ops.py +124 -102
  472. mindspore/ops/operations/debug_ops.py +24 -11
  473. mindspore/ops/operations/image_ops.py +86 -71
  474. mindspore/ops/operations/inner_ops.py +18 -13
  475. mindspore/ops/operations/linalg_ops.py +30 -11
  476. mindspore/ops/operations/math_ops.py +1730 -435
  477. mindspore/ops/operations/nn_ops.py +1953 -943
  478. mindspore/ops/operations/other_ops.py +65 -43
  479. mindspore/ops/operations/random_ops.py +258 -98
  480. mindspore/ops/operations/rl_ops.py +4 -36
  481. mindspore/ops/operations/sparse_ops.py +38 -33
  482. mindspore/ops/operations/spectral_ops.py +8 -4
  483. mindspore/ops/primitive.py +66 -44
  484. mindspore/ops/signature.py +5 -5
  485. mindspore/parallel/_auto_parallel_context.py +80 -19
  486. mindspore/parallel/_cost_model_context.py +42 -0
  487. mindspore/parallel/_offload_context.py +162 -72
  488. mindspore/parallel/_parallel_serialization.py +2 -2
  489. mindspore/parallel/_ps_context.py +16 -4
  490. mindspore/parallel/_recovery_context.py +2 -1
  491. mindspore/parallel/_tensor.py +15 -13
  492. mindspore/parallel/_transformer/layers.py +8 -6
  493. mindspore/parallel/_transformer/loss.py +1 -0
  494. mindspore/parallel/_transformer/moe.py +7 -7
  495. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  496. mindspore/parallel/_transformer/transformer.py +34 -14
  497. mindspore/parallel/_utils.py +36 -14
  498. mindspore/parallel/algo_parameter_config.py +114 -20
  499. mindspore/parallel/checkpoint_transform.py +16 -18
  500. mindspore/parallel/shard.py +16 -13
  501. mindspore/profiler/__init__.py +1 -1
  502. mindspore/profiler/common/struct_type.py +3 -3
  503. mindspore/profiler/common/util.py +3 -2
  504. mindspore/profiler/envprofiling.py +11 -4
  505. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  506. mindspore/profiler/parser/ascend_flops_generator.py +94 -0
  507. mindspore/profiler/parser/ascend_fpbp_generator.py +76 -0
  508. mindspore/profiler/parser/ascend_hccl_generator.py +288 -0
  509. mindspore/profiler/parser/ascend_msprof_exporter.py +213 -0
  510. mindspore/profiler/parser/ascend_msprof_generator.py +199 -0
  511. mindspore/profiler/parser/ascend_op_generator.py +276 -0
  512. mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
  513. mindspore/profiler/parser/ascend_timeline_generator.py +110 -54
  514. mindspore/profiler/parser/base_timeline_generator.py +11 -7
  515. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +45 -46
  516. mindspore/profiler/parser/flops_parser.py +15 -11
  517. mindspore/profiler/parser/framework_parser.py +92 -73
  518. mindspore/profiler/parser/hccl_parser.py +16 -12
  519. mindspore/profiler/parser/integrator.py +22 -11
  520. mindspore/profiler/parser/memory_usage_parser.py +36 -11
  521. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  522. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  523. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  524. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  525. mindspore/profiler/parser/optime_parser.py +1 -1
  526. mindspore/profiler/parser/profiler_info.py +4 -5
  527. mindspore/profiler/parser/step_trace_parser.py +11 -14
  528. mindspore/profiler/profiling.py +678 -377
  529. mindspore/rewrite/api/node.py +211 -54
  530. mindspore/rewrite/api/node_type.py +5 -0
  531. mindspore/rewrite/api/pattern_engine.py +22 -23
  532. mindspore/rewrite/api/scoped_value.py +20 -17
  533. mindspore/rewrite/api/symbol_tree.py +252 -106
  534. mindspore/rewrite/api/tree_node_helper.py +3 -0
  535. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  536. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  537. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  538. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +97 -46
  539. mindspore/rewrite/common/rewrite_elog.py +5 -1
  540. mindspore/rewrite/namer.py +51 -51
  541. mindspore/rewrite/namespace.py +14 -5
  542. mindspore/{ops/bprop_mindir → rewrite/node}/__init__.py +9 -4
  543. mindspore/rewrite/node/call_function.py +79 -0
  544. mindspore/rewrite/node/cell_container.py +135 -0
  545. mindspore/rewrite/node/control_flow.py +88 -0
  546. mindspore/rewrite/{node.py → node/node.py} +313 -247
  547. mindspore/rewrite/node/node_manager.py +254 -0
  548. mindspore/rewrite/node/node_topological_manager.py +243 -0
  549. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  550. mindspore/rewrite/parsers/assign_parser.py +225 -239
  551. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  552. mindspore/rewrite/parsers/class_def_parser.py +179 -218
  553. mindspore/rewrite/parsers/constant_parser.py +9 -6
  554. mindspore/rewrite/parsers/container_parser.py +9 -7
  555. mindspore/rewrite/parsers/for_parser.py +36 -15
  556. mindspore/rewrite/parsers/function_def_parser.py +23 -20
  557. mindspore/rewrite/parsers/if_parser.py +28 -24
  558. mindspore/rewrite/parsers/module_parser.py +202 -25
  559. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  560. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  561. mindspore/rewrite/parsers/return_parser.py +6 -6
  562. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  563. mindspore/rewrite/sparsify/sparsify.py +4 -1
  564. mindspore/rewrite/sparsify/utils.py +11 -5
  565. mindspore/rewrite/symbol_tree.py +577 -732
  566. mindspore/rewrite/symbol_tree_builder.py +9 -175
  567. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  568. mindspore/run_check/_check_version.py +46 -39
  569. mindspore/run_check/run_check.py +3 -2
  570. mindspore/{scipy/sparse → safeguard}/__init__.py +4 -5
  571. mindspore/safeguard/rewrite_obfuscation.py +517 -0
  572. mindspore/scipy/__init__.py +1 -1
  573. mindspore/scipy/linalg.py +67 -61
  574. mindspore/scipy/ops.py +5 -41
  575. mindspore/scipy/ops_grad.py +3 -2
  576. mindspore/scipy/ops_wrapper.py +5 -5
  577. mindspore/scipy/optimize/line_search.py +8 -8
  578. mindspore/scipy/optimize/linear_sum_assignment.py +4 -4
  579. mindspore/scipy/optimize/minimize.py +16 -12
  580. mindspore/scipy/utils.py +1 -52
  581. mindspore/scipy/utils_const.py +4 -4
  582. mindspore/train/__init__.py +4 -4
  583. mindspore/train/_utils.py +13 -5
  584. mindspore/train/amp.py +410 -148
  585. mindspore/train/anf_ir_pb2.py +16 -4
  586. mindspore/train/callback/_backup_and_restore.py +8 -11
  587. mindspore/train/callback/_callback.py +80 -3
  588. mindspore/train/callback/_checkpoint.py +82 -51
  589. mindspore/train/callback/_early_stop.py +12 -15
  590. mindspore/train/callback/_history.py +1 -1
  591. mindspore/train/callback/_lambda_callback.py +13 -13
  592. mindspore/train/callback/_landscape.py +21 -17
  593. mindspore/train/callback/_loss_monitor.py +9 -10
  594. mindspore/train/callback/_on_request_exit.py +16 -33
  595. mindspore/train/callback/_reduce_lr_on_plateau.py +21 -24
  596. mindspore/train/callback/_summary_collector.py +44 -30
  597. mindspore/train/callback/_time_monitor.py +62 -12
  598. mindspore/train/data_sink.py +10 -16
  599. mindspore/train/dataset_helper.py +154 -86
  600. mindspore/train/loss_scale_manager.py +14 -9
  601. mindspore/train/metrics/__init__.py +10 -2
  602. mindspore/train/metrics/accuracy.py +1 -1
  603. mindspore/train/metrics/auc.py +1 -1
  604. mindspore/train/metrics/bleu_score.py +2 -2
  605. mindspore/train/metrics/confusion_matrix.py +14 -14
  606. mindspore/train/metrics/cosine_similarity.py +3 -3
  607. mindspore/train/metrics/dice.py +1 -1
  608. mindspore/train/metrics/fbeta.py +1 -1
  609. mindspore/train/metrics/hausdorff_distance.py +8 -6
  610. mindspore/train/metrics/mean_surface_distance.py +5 -4
  611. mindspore/train/metrics/metric.py +49 -17
  612. mindspore/train/metrics/occlusion_sensitivity.py +4 -4
  613. mindspore/train/metrics/perplexity.py +1 -1
  614. mindspore/train/metrics/precision.py +2 -2
  615. mindspore/train/metrics/recall.py +2 -3
  616. mindspore/train/metrics/roc.py +7 -7
  617. mindspore/train/metrics/root_mean_square_surface_distance.py +5 -4
  618. mindspore/train/metrics/topk.py +7 -4
  619. mindspore/train/mind_ir_pb2.py +193 -48
  620. mindspore/train/model.py +377 -133
  621. mindspore/train/serialization.py +697 -245
  622. mindspore/train/summary/_summary_adapter.py +5 -2
  623. mindspore/train/summary/_writer_pool.py +4 -3
  624. mindspore/train/summary/summary_record.py +25 -23
  625. mindspore/train/train_thor/convert_utils.py +39 -23
  626. mindspore/train/train_thor/dataset_helper.py +4 -3
  627. mindspore/train/train_thor/model_thor.py +8 -8
  628. mindspore/version.py +1 -1
  629. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/METADATA +7 -8
  630. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/RECORD +633 -804
  631. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/entry_points.txt +0 -1
  632. mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
  633. mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
  634. mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
  635. mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
  636. mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
  637. mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
  638. mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
  639. mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
  640. mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
  641. mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
  642. mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
  643. mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
  644. mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
  645. mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
  646. mindspore/_akg/akg/tvm/rpc/base.py +0 -182
  647. mindspore/_akg/akg/tvm/rpc/client.py +0 -436
  648. mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
  649. mindspore/_akg/akg/tvm/rpc/server.py +0 -413
  650. mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
  651. mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
  652. mindspore/_extends/graph_kernel/expander.py +0 -80
  653. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -57
  654. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  655. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  656. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  657. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  658. mindspore/_extends/graph_kernel/expanders/bias_add_grad.py +0 -49
  659. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  660. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  661. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  662. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  663. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  664. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  665. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  666. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  667. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  668. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  669. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  670. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  671. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  672. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  673. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  674. mindspore/_extends/graph_kernel/expanders/gather.py +0 -43
  675. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  676. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  677. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  678. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  679. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  680. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  681. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  682. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  683. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  684. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  685. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  686. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  687. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  688. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  689. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  690. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  691. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  692. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  693. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  694. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  695. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  696. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  697. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  698. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  699. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  700. mindspore/_extends/graph_kernel/expanders/tile.py +0 -54
  701. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  702. mindspore/_extends/parse/jit_fallback_modules.py +0 -51
  703. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  704. mindspore/dataset/engine/graphdata.py +0 -1586
  705. mindspore/include/api/net.h +0 -142
  706. mindspore/ops/_grad/grad_array_ops.py +0 -1347
  707. mindspore/ops/_grad/grad_clip_ops.py +0 -84
  708. mindspore/ops/_grad/grad_debug_ops.py +0 -68
  709. mindspore/ops/_grad/grad_inner_ops.py +0 -235
  710. mindspore/ops/_grad/grad_math_ops.py +0 -1684
  711. mindspore/ops/_grad/grad_nn_ops.py +0 -1529
  712. mindspore/ops/_grad/grad_other_ops.py +0 -89
  713. mindspore/ops/_grad/grad_sequence_ops.py +0 -296
  714. mindspore/ops/_grad/grad_sparse.py +0 -323
  715. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -249
  716. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -195
  717. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  718. mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
  719. mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
  720. mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +0 -19
  721. mindspore/ops/bprop_mindir/Argmax_bprop.mindir +0 -15
  722. mindspore/ops/bprop_mindir/Argmin_bprop.mindir +0 -15
  723. mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +0 -19
  724. mindspore/ops/bprop_mindir/Assign_bprop.mindir +0 -17
  725. mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +0 -150
  726. mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +0 -66
  727. mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
  728. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -15
  729. mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
  730. mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +0 -28
  731. mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
  732. mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +0 -33
  733. mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +0 -306
  734. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -13
  735. mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
  736. mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
  737. mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +0 -240
  738. mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +0 -247
  739. mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +0 -247
  740. mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +0 -315
  741. mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +0 -278
  742. mindspore/ops/bprop_mindir/DType_bprop.mindir +0 -14
  743. mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +0 -58
  744. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -13
  745. mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +0 -23
  746. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  747. mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +0 -15
  748. mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
  749. mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
  750. mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +0 -25
  751. mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +0 -18
  752. mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +0 -27
  753. mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
  754. mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
  755. mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
  756. mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +0 -14
  757. mindspore/ops/bprop_mindir/Elu_bprop.mindir +0 -16
  758. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  759. mindspore/ops/bprop_mindir/Equal_bprop.mindir +0 -19
  760. mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +0 -58
  761. mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +0 -16
  762. mindspore/ops/bprop_mindir/Flatten_bprop.mindir +0 -54
  763. mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +0 -19
  764. mindspore/ops/bprop_mindir/GatherD_bprop.mindir +0 -26
  765. mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +0 -57
  766. mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
  767. mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +0 -19
  768. mindspore/ops/bprop_mindir/Greater_bprop.mindir +0 -19
  769. mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +0 -16
  770. mindspore/ops/bprop_mindir/HSwish_bprop.mindir +0 -16
  771. mindspore/ops/bprop_mindir/IOU_bprop.mindir +0 -19
  772. mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
  773. mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +0 -15
  774. mindspore/ops/bprop_mindir/IsInf_bprop.mindir +0 -15
  775. mindspore/ops/bprop_mindir/IsNan_bprop.mindir +0 -15
  776. mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +0 -126
  777. mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +0 -15
  778. mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +0 -30
  779. mindspore/ops/bprop_mindir/LRN_bprop.mindir +0 -43
  780. mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
  781. mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +0 -19
  782. mindspore/ops/bprop_mindir/Less_bprop.mindir +0 -19
  783. mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +0 -23
  784. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -13
  785. mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +0 -23
  786. mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +0 -19
  787. mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +0 -15
  788. mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +0 -21
  789. mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +0 -74
  790. mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +0 -74
  791. mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +0 -75
  792. mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +0 -65
  793. mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
  794. mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
  795. mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
  796. mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +0 -27
  797. mindspore/ops/bprop_mindir/Mish_bprop.mindir +0 -35
  798. mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
  799. mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
  800. mindspore/ops/bprop_mindir/NonZero_bprop.mindir +0 -14
  801. mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +0 -19
  802. mindspore/ops/bprop_mindir/OneHot_bprop.mindir +0 -26
  803. mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +0 -14
  804. mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
  805. mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
  806. mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
  807. mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +0 -29
  808. mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +0 -82
  809. mindspore/ops/bprop_mindir/Range_bprop.mindir +0 -22
  810. mindspore/ops/bprop_mindir/Rank_bprop.mindir +0 -14
  811. mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +0 -16
  812. mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
  813. mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +0 -19
  814. mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +0 -19
  815. mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +0 -20
  816. mindspore/ops/bprop_mindir/Reshape_bprop.mindir +0 -60
  817. mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +0 -29
  818. mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +0 -89
  819. mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +0 -52
  820. mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +0 -22
  821. mindspore/ops/bprop_mindir/Round_bprop.mindir +0 -15
  822. mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
  823. mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
  824. mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +0 -22
  825. mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +0 -24
  826. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -22
  827. mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
  828. mindspore/ops/bprop_mindir/SeLU_bprop.mindir +0 -21
  829. mindspore/ops/bprop_mindir/Select_bprop.mindir +0 -31
  830. mindspore/ops/bprop_mindir/Shape_bprop.mindir +0 -14
  831. mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +0 -21
  832. mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
  833. mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +0 -16
  834. mindspore/ops/bprop_mindir/Sign_bprop.mindir +0 -15
  835. mindspore/ops/bprop_mindir/Slice_bprop.mindir +0 -26
  836. mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +0 -36
  837. mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  838. mindspore/ops/bprop_mindir/Softplus_bprop.mindir +0 -16
  839. mindspore/ops/bprop_mindir/Softsign_bprop.mindir +0 -33
  840. mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
  841. mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +0 -28
  842. mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +0 -23
  843. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  844. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  845. mindspore/ops/bprop_mindir/Split_bprop.mindir +0 -22
  846. mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +0 -54
  847. mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +0 -95
  848. mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +0 -98
  849. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -29
  850. mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
  851. mindspore/ops/bprop_mindir/Tanh_bprop.mindir +0 -66
  852. mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +0 -22
  853. mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +0 -29
  854. mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +0 -14
  855. mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
  856. mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
  857. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -23
  858. mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +0 -19
  859. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -20
  860. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -16
  861. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -22
  862. mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +0 -32
  863. mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +0 -38
  864. mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +0 -15
  865. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  866. mindspore/rewrite/node_visitor.py +0 -44
  867. mindspore/rewrite/topological_manager.py +0 -203
  868. mindspore/scipy/sparse/linalg.py +0 -192
  869. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/WHEEL +0 -0
  870. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/top_level.txt +0 -0
@@ -62,27 +62,6 @@ def _get_cache_path():
62
62
  return cache_path
63
63
 
64
64
 
65
- def _get_cuda_bare_metal_version():
66
- """
67
- Automatically get the cuda version.
68
-
69
- Returns:
70
- tuple(str), the version of cuda of the platform.ss
71
- """
72
- raw_output = subprocess.check_output(["nvcc", "-V"],
73
- universal_newlines=True)
74
- output = raw_output.split()
75
- release_idx = output.index("release") + 1
76
- release = output[release_idx].split(".")
77
- version_major = release[0]
78
- version_idx = release_idx + 1
79
- version = output[version_idx].split(".")
80
- version_middle = version[1] if len(version) > 1 else 0
81
- version_minor = version[2] if len(version) > 2 else 0
82
-
83
- return int(version_major), int(version_middle), int(version_minor)
84
-
85
-
86
65
  def _compile_aot(file):
87
66
  """
88
67
  Automatically compile the source file for custom aot
@@ -99,11 +78,7 @@ def _compile_aot(file):
99
78
  cache_path = os.path.join(cache_path, "rank_" + str(get_rank()), "")
100
79
  os.makedirs(cache_path, exist_ok=True)
101
80
 
102
- search_res = importlib.util.find_spec("mindspore")
103
- if search_res is None:
104
- raise RuntimeError("Cannot find mindspore module!")
105
-
106
- res_path = search_res.origin
81
+ res_path = importlib.util.find_spec("mindspore").origin
107
82
  find_pos = res_path.find("__init__.py")
108
83
  if find_pos == -1:
109
84
  raise RuntimeError(
@@ -111,9 +86,8 @@ def _compile_aot(file):
111
86
  include_file = "-I{}include/api/".format(res_path[:find_pos])
112
87
 
113
88
  file_name = file.split('/')[-1]
114
- file_folder = file[:file.rindex('/')]
115
89
  func_path = cache_path + file_name + ".so"
116
- include_file = "{} -I{}".format(include_file, file_folder)
90
+ include_file = "{} -I{}".format(include_file, file[:file.rindex('/')])
117
91
 
118
92
  if func_path not in Custom.compiled_bin:
119
93
  Custom.compiled_bin.append(func_path)
@@ -127,10 +101,23 @@ def _compile_aot(file):
127
101
  cmd += ["--use_fast_math", "--expt-relaxed-constexpr"]
128
102
  cmd += ["-D_GLIBCXX_USE_CXX11_ABI=0"]
129
103
 
104
+ def _get_cuda_bare_metal_version():
105
+ raw_output = subprocess.check_output(["nvcc", "-V"],
106
+ universal_newlines=True)
107
+ output = raw_output.split()
108
+ release_idx = output.index("release") + 1
109
+ release = output[release_idx].split(".")
110
+ version_idx = release_idx + 1
111
+ version = output[version_idx].split(".")
112
+ version_middle = version[1] if len(version) > 1 else 0
113
+ version_minor = version[2] if len(version) > 2 else 0
114
+
115
+ return int(release[0]), int(version_middle), int(version_minor)
116
+
130
117
  v_major, v_mid, v_minor = _get_cuda_bare_metal_version()
131
118
  if v_major >= 11:
132
119
  cmd += ["-gencode", "arch=compute_80,code=sm_80", "--expt-extended-lambda"]
133
- elif v_major == 10 and not(v_mid >= 1 and v_minor >= 168):
120
+ elif v_major == 10 and not (v_mid >= 1 and v_minor >= 168):
134
121
  logger.warning("The current version of nvcc, V{}.{}.{}, might have unfixed issues with std string, "
135
122
  "which will lead to errors in aot custom op with attrs."
136
123
  "The version higher than V10.1.168 is recommended".format(v_major, v_mid, v_minor))
@@ -159,10 +146,11 @@ class Custom(ops.PrimitiveWithInfer):
159
146
  function if needed. Then these `Custom` objects can be directly used in neural networks.
160
147
  Detailed description and introduction of user-defined operators, including correct writing of parameters,
161
148
  please refer to `Custom Operators Tutorial
162
- <https://www.mindspore.cn/tutorials/experts/en/r2.0/operation/op_custom.html>`_ .
149
+ <https://www.mindspore.cn/tutorials/experts/en/r2.2/operation/op_custom.html>`_ .
163
150
 
164
151
  .. warning::
165
- This is an experimental API that is subject to change.
152
+ - This is an experimental API that is subject to change.
153
+ - Currently, the functionality of Custom does not support Ascend 910B.
166
154
 
167
155
  .. note::
168
156
  The supported platforms are determined by the input `func_type`. The supported platforms are as follows:
@@ -175,6 +163,12 @@ class Custom(ops.PrimitiveWithInfer):
175
163
  - "julia": supports ["CPU"].
176
164
  - "aicpu": supports ["Ascend"].
177
165
 
166
+ If run on ge backend, use `CustomRegOp` to generate the registration information of "aicpu" and "tbe" operator,
167
+ use `custom_info_register` to bind the registration information to the `func` of the "tbe" operator,
168
+ then save the registration information of "aicpu" operator and the `func` implementation of "tbe" operator to
169
+ a file or separate files, keep these files in a separate directory, and set the absolute path of this directory
170
+ to environment variable "MS_DEV_CUSTOM_OPP_PATH" before running the network.
171
+
178
172
  Args:
179
173
  func (Union[function, str]):
180
174
 
@@ -265,7 +259,7 @@ class Custom(ops.PrimitiveWithInfer):
265
259
  (ex. Custom(func="./add.jl:Add:add", out_shape=[1], out_dtype=mstype.float32, "julia"))
266
260
 
267
261
  out_shape (Union[function, list, tuple]): The output shape infer function or the value of output shape of
268
- `func`. Default: None.
262
+ `func`. Default: ``None`` .
269
263
 
270
264
  If func has single output, then the value of output shape is a list or tuple of int.
271
265
 
@@ -276,7 +270,7 @@ class Custom(ops.PrimitiveWithInfer):
276
270
  shape mechanic will be enabled.
277
271
 
278
272
  out_dtype (Union[function, :class:`mindspore.dtype`, tuple[:class:`mindspore.dtype`]]): The output data type
279
- infer function or the value of output data type of `func`. Default: None.
273
+ infer function or the value of output data type of `func`. Default: ``None`` .
280
274
 
281
275
  If func has single output, then the value of output shape is a `mindspore.dtype`.
282
276
 
@@ -288,23 +282,23 @@ class Custom(ops.PrimitiveWithInfer):
288
282
 
289
283
  func_type (str): The implementation type of `func`, should be one of
290
284
 
291
- ["hybrid", "akg", "tbe", "aot", "pyfunc", "julia", "aicpu"].
285
+ [ ``"hybrid"`` , ``"akg"`` , ``"tbe"`` , ``"aot"`` , ``"pyfunc"`` , ``"julia"`` , ``"aicpu"`` ].
292
286
 
293
- Each `func_type` only supports specific platforms(targets). Default: "hybrid".
287
+ Each `func_type` only supports specific platforms(targets). Default: ``"hybrid"`` .
294
288
  The supported platforms of `func_type`:
295
289
 
296
- - "hybrid": supports ["Ascend", "GPU", "CPU"].
297
- - "akg": supports ["Ascend", "GPU", "CPU"].
298
- - "tbe": supports ["Ascend"].
299
- - "aot": supports ["GPU", "CPU"].
300
- - "pyfunc": supports ["CPU"].
301
- - "julia": supports ["CPU"].
302
- - "aicpu": supports ["Ascend"].
290
+ - ``"hybrid"``: supports ["Ascend", "GPU", "CPU"].
291
+ - ``"akg"``: supports ["Ascend", "GPU", "CPU"].
292
+ - ``"tbe"``: supports ["Ascend"].
293
+ - ``"aot"``: supports ["GPU", "CPU"].
294
+ - ``"pyfunc"``: supports ["CPU"].
295
+ - ``"julia"``: supports ["CPU"].
296
+ - ``"aicpu"``: supports ["Ascend"].
303
297
 
304
- bprop (function): The back propagation function of `func`. Default: None.
298
+ bprop (function): The back propagation function of `func`. Default: ``None`` .
305
299
  reg_info (Union[str, dict, list, tuple]): Represents the registration information(reg info) of `func` with
306
300
  json format of type str or dict. The reg info specifies supported data types and formats of inputs and
307
- outputs, attributes and target of `func`. Default: None.
301
+ outputs, attributes and target of `func`. Default: ``None`` .
308
302
 
309
303
  If reg info is a list or tuple, then each item should be with json format of type str or dict, which
310
304
  represents the registration information of `func` in a specific target. You need to invoke `CustomRegOp`
@@ -457,6 +451,7 @@ class Custom(ops.PrimitiveWithInfer):
457
451
  tbe_path_checked = [] # Save paths for tbe functions which is safe to be imported as module.
458
452
  tbe_path_failed = [] # Save paths for tbe functions which fail to be imported as module.
459
453
  op_path_in_cache = [] # Save paths for op functions created in the cached.
454
+ custom_aot_warning = True # Flag to enable warnings about custom aot path white list
460
455
 
461
456
  def __init__(self, func, out_shape=None, out_dtype=None, func_type="hybrid", bprop=None, reg_info=None):
462
457
  ops.PrimitiveWithInfer.__init__(self, "Custom")
@@ -473,6 +468,7 @@ class Custom(ops.PrimitiveWithInfer):
473
468
  self._func_compile_attrs = {}
474
469
  self._is_ms_kernel = False
475
470
 
471
+ self._check_platform()
476
472
  self._check_func()
477
473
  self._update_func_info(reg_info)
478
474
  self.add_prim_attr("func_name", self.func_name)
@@ -487,21 +483,24 @@ class Custom(ops.PrimitiveWithInfer):
487
483
  self.add_prim_attr("fn_id", func_id)
488
484
 
489
485
  self.out_shape = out_shape
486
+ if self.out_shape is None and self.func_type == "aot":
487
+ self.add_prim_attr("cpp_infer_shape", True)
490
488
  self.out_dtype = out_dtype
491
489
  self.bprop = bprop
492
- self._update_op_attr()
490
+ self.fake_output = False
491
+ self.single_scalar_output = False
492
+ if not self.out_dtype:
493
+ self.fake_output = True
494
+ elif not self.out_shape:
495
+ self.single_scalar_output = True
496
+ self.add_prim_attr("fake_output", self.fake_output)
497
+ self.add_prim_attr("single_scalar_output", self.single_scalar_output)
498
+
493
499
  # Register info
494
500
  self._register_info(reg_info)
495
501
 
496
502
  if func_type == "akg":
497
- self.add_prim_attr('func_source_str', self.func_source_str)
498
- if "ir_builder" in self.func_source_str:
499
- self.func_type = "ir_builder"
500
- elif "compute" in self.func_source_str:
501
- self.func_type = "tvm_compute"
502
- else:
503
- self.func_type = "hybrid"
504
- self._hybrid_func_analyser()
503
+ self._set_akg_kernel_type()
505
504
 
506
505
  if not self.bprop and self.func_type == "hybrid":
507
506
  self._hybrid_autodiff(func_type)
@@ -510,7 +509,6 @@ class Custom(ops.PrimitiveWithInfer):
510
509
  self._update_attr()
511
510
 
512
511
  def __infer__(self, *args):
513
- """Infer function of the custom op"""
514
512
  if callable(self.out_shape):
515
513
  infer_shape = self.out_shape(*(x["shape"] for x in args))
516
514
  else:
@@ -570,21 +568,17 @@ class Custom(ops.PrimitiveWithInfer):
570
568
  return out
571
569
 
572
570
  def get_bprop(self):
573
- """Get the bprop of the custom op"""
574
571
  return self.bprop
575
572
 
576
- def _update_op_attr(self):
577
- """Update the attrs of the custom op"""
578
- if self.out_shape is None and self.func_type == "aot":
579
- self.add_prim_attr("cpp_infer_shape", True)
580
- self.fake_output = False
581
- self.single_scalar_output = False
582
- if not self.out_dtype:
583
- self.fake_output = True
584
- elif not self.out_shape:
585
- self.single_scalar_output = True
586
- self.add_prim_attr("fake_output", self.fake_output)
587
- self.add_prim_attr("single_scalar_output", self.single_scalar_output)
573
+ def _set_akg_kernel_type(self):
574
+ self.add_prim_attr('func_source_str', self.func_source_str)
575
+ if "ir_builder" in self.func_source_str:
576
+ self.func_type = "ir_builder"
577
+ elif "compute" in self.func_source_str:
578
+ self.func_type = "tvm_compute"
579
+ else:
580
+ self.func_type = "hybrid"
581
+ self._hybrid_func_analyser()
588
582
 
589
583
  def _check_julia_func(self):
590
584
  """Check the validity of julia func"""
@@ -602,6 +596,10 @@ class Custom(ops.PrimitiveWithInfer):
602
596
  raise Exception("{}, function {} is not found in source file {}!"
603
597
  .format(self.log_prefix, func, source_file))
604
598
 
599
+ def _check_platform(self):
600
+ if platform.system() != 'Linux':
601
+ raise Exception("Custom op only supported on Linux platform currently.")
602
+
605
603
  def _check_func(self):
606
604
  """Check the validity of func_type and type of func"""
607
605
  if self.func_type not in self.supported_func_type:
@@ -617,7 +615,19 @@ class Custom(ops.PrimitiveWithInfer):
617
615
  "{}, 'func' should be like 'file_name:func_name', but got {}".format(
618
616
  self.log_prefix, self.func))
619
617
  file_path = os.path.abspath(file_name_list[0])
620
- if not file_path.endswith("so"):
618
+ if os.environ.get('MS_CUSTOM_AOT_WHITE_LIST') is None:
619
+ if Custom.custom_aot_warning:
620
+ logger.warning("{}, no white list is set and it might cause problems. "
621
+ "Set the legal path of the file in MS_CUSTOM_AOT_WHITE_LIST"
622
+ .format(self.log_prefix))
623
+ Custom.custom_aot_warning = False
624
+ else:
625
+ legal_path = os.path.abspath(os.environ.get('MS_CUSTOM_AOT_WHITE_LIST'))
626
+ if legal_path not in file_path:
627
+ raise TypeError(
628
+ "{}, the legal path for the file is {}, but the file is {}".format(
629
+ self.log_prefix, legal_path, file_path))
630
+ if file_path.endswith(("cu", "cpp", "cc")):
621
631
  file_path = _compile_aot(file_path)
622
632
  self.func = file_path + ":" + file_name_list[1]
623
633
 
@@ -639,7 +649,7 @@ class Custom(ops.PrimitiveWithInfer):
639
649
  "The kernel will be executed as a native python function, which might lead to "
640
650
  "low efficiency. To accelerate the kernel, set the 'func_type' to be \"hybrid\""
641
651
  .format(self.log_prefix))
642
- else:
652
+ elif self.func_type == "tbe":
643
653
  if not callable(self.func):
644
654
  raise TypeError("{}, 'func' must be of type function, but got {}"
645
655
  .format(self.log_prefix, type(self.func)))
@@ -661,10 +671,10 @@ class Custom(ops.PrimitiveWithInfer):
661
671
  if file_path not in Custom.tbe_path_failed:
662
672
  # As a single file might include multiply functions
663
673
  # we will not try the file path which already failed in previous trials
664
- mod_spec = importlib.util.spec_from_file_location(
665
- self.func_name, file_path)
666
- custom_mod = importlib.util.module_from_spec(mod_spec)
667
674
  try:
675
+ mod_spec = importlib.util.spec_from_file_location(
676
+ self.func_name, file_path)
677
+ custom_mod = importlib.util.module_from_spec(mod_spec)
668
678
  mod_spec.loader.exec_module(custom_mod)
669
679
  except (ImportError, RecursionError):
670
680
  Custom.tbe_path_failed.append(file_path)
@@ -756,16 +766,21 @@ class Custom(ops.PrimitiveWithInfer):
756
766
 
757
767
  def _update_reg_attrs(self, reg_info):
758
768
  """Update op attrs in reg_info."""
769
+ output_name_list = []
759
770
  for _, item in enumerate(reg_info.get("outputs", [])):
760
- output_name_list = []
761
771
  if isinstance(item, dict) and item.get("name"):
762
772
  output_name_list.append(item.get("name"))
773
+ if output_name_list:
763
774
  self.add_prim_attr("output_names", output_name_list)
764
775
 
765
776
  if isinstance(reg_info.get("op_name"), str):
766
777
  self.add_prim_attr("reg_op_name", reg_info.get("op_name"))
767
778
 
768
- if self.func_type == "aot":
779
+ if self.func_type == "aicpu":
780
+ self.uniq_name = reg_info["op_name"]
781
+ self.add_prim_attr("uniq_name", self.uniq_name)
782
+
783
+ if self.func_type in ["aot", "aicpu"]:
769
784
  if reg_info.get("attr") is not None and isinstance(reg_info["attr"], list):
770
785
  for item in reg_info["attr"]:
771
786
  if isinstance(item, dict) and item.get("value") is not None:
@@ -852,12 +867,6 @@ class Custom(ops.PrimitiveWithInfer):
852
867
  else:
853
868
  Custom.registered_func[func_name] = [target]
854
869
 
855
- def _get_op_name(self, reg_info):
856
- if self.func_type == "aicpu":
857
- self.uniq_name = reg_info["op_name"]
858
- self.add_prim_attr("uniq_name", self.uniq_name)
859
- return self.uniq_name
860
-
861
870
  def _reformat_reg_info(self, reg_info, target):
862
871
  """Reformat registration information."""
863
872
  if not isinstance(reg_info, dict):
@@ -865,7 +874,7 @@ class Custom(ops.PrimitiveWithInfer):
865
874
  "'CustomRegOp' to generate the registration information, then pass it to 'reg_info' or "
866
875
  "use 'custom_info_register' to bind it to 'func' if 'func' is a function."
867
876
  .format(self.log_prefix, reg_info, type(reg_info)))
868
- reg_info["op_name"] = self._get_op_name(reg_info)
877
+ reg_info["op_name"] = self.uniq_name
869
878
  reg_info["imply_type"] = self._get_imply_type(reg_info, target)
870
879
  if not isinstance(reg_info.get("fusion_type"), str) or not reg_info["fusion_type"].strip():
871
880
  reg_info["fusion_type"] = "OPAQUE"
@@ -926,26 +935,37 @@ class Custom(ops.PrimitiveWithInfer):
926
935
  """Save input_names and attr_names of current func."""
927
936
  if not isinstance(reg_info, dict):
928
937
  return
929
- tensor_inputs = reg_info.get("inputs", [])
930
- attr = reg_info.get("attr", [])
931
- if not isinstance(tensor_inputs, (list, tuple)):
932
- tensor_inputs = [tensor_inputs]
933
- if not isinstance(attr, (list, tuple)):
934
- attr = [attr]
935
- # input_names include tensor input names and attr input names
936
- input_names = []
937
- # attr_names only includes attr input names
938
+
939
+ def _get_value_list(key):
940
+ value = reg_info.get(key, [])
941
+ if not isinstance(value, (list, tuple)):
942
+ value = [value]
943
+ return value
944
+
945
+ tensor_inputs = _get_value_list("inputs")
946
+ attr = _get_value_list("attr")
947
+ input_names = [] # include tensor input names and attr input names
938
948
  attr_names = []
949
+ pure_input_names = []
939
950
  for item in tensor_inputs:
940
951
  if isinstance(item, dict) and item.get("name") is not None:
941
952
  input_names.append(item["name"])
942
- has_input_name = bool(input_names)
953
+ pure_input_names.append(item["name"])
954
+ # attr is converted from inputs only when graph mode or when inputs name is also in reg info
955
+ attr_to_input_safe = bool(input_names) or context.get_context("mode") == ms.GRAPH_MODE
943
956
  for item in attr:
944
957
  if isinstance(item, dict) and item.get("name") is not None:
945
- if has_input_name or context.get_context("mode") != ms.PYNATIVE_MODE:
958
+ # for custom op with function tbe, we always add attrs to inputs as we don't
959
+ # deal with attr value here and leave them to the backend process to fit the
960
+ # usual process of tbe op compiling in mindspore
961
+ # for the rest cases, namely aot and aicpu, if we find values for attrs, we
962
+ # have already add them as prim attr of the op in the fun _update_reg_attrs
963
+ # add attr name to input name only when the value of attr is None in reg info
964
+ # as we need to get values of attrs from inputs
965
+ if attr_to_input_safe and (self.func_type == "tbe" or item.get("value", None) is None):
946
966
  input_names.append(item["name"])
947
967
  attr_names.append(item["name"])
948
- cur_attr = {"input_names": input_names, "attr_names": attr_names}
968
+ cur_attr = {"input_names": input_names, "attr_names": attr_names, "pure_input_names": pure_input_names}
949
969
  # If func does not have attr, save current attr.
950
970
  # Else, check if current attr is same as previous saved one.
951
971
  prev_attr_names = attr_names
@@ -994,7 +1014,12 @@ class Custom(ops.PrimitiveWithInfer):
994
1014
 
995
1015
  def _update_attr(self):
996
1016
  """Add input_names, attr_names, primitive_target to primitive's attr."""
997
- # add input_names, attr_names
1017
+
1018
+ def _add_prim_attr(key):
1019
+ value = func_attr.get(key)
1020
+ if value:
1021
+ self.add_prim_attr(key, value)
1022
+
998
1023
  func_attr = {}
999
1024
  if callable(self.func):
1000
1025
  inputs_num = len(inspect.signature(self.func).parameters)
@@ -1003,12 +1028,9 @@ class Custom(ops.PrimitiveWithInfer):
1003
1028
  elif isinstance(self.func, str):
1004
1029
  func_attr = Custom.attr_dict.get(self.func)
1005
1030
  if isinstance(func_attr, dict):
1006
- input_names = func_attr.get("input_names")
1007
- attr_names = func_attr.get("attr_names")
1008
- if input_names:
1009
- self.add_prim_attr("input_names", input_names)
1010
- if attr_names:
1011
- self.add_prim_attr("attr_names", attr_names)
1031
+ _add_prim_attr("input_names")
1032
+ _add_prim_attr("attr_names")
1033
+ _add_prim_attr("pure_input_names")
1012
1034
  self._add_prim_target()
1013
1035
  if callable(self.func) and callable(self.out_shape):
1014
1036
  if hasattr(self.out_shape, "type") and getattr(self.out_shape, "type") == "autodiff":
@@ -1065,7 +1087,7 @@ class Custom(ops.PrimitiveWithInfer):
1065
1087
  arg_dtype = arg["dtype"]
1066
1088
  # if any value is missing from input, disable infer value
1067
1089
  enable_infer_value = False
1068
- if isinstance(arg_dtype, mstype.tensor_type):
1090
+ if isinstance(arg_dtype, mstype.TensorType):
1069
1091
  arg_dtype = arg_dtype.element_type()
1070
1092
  fake_arg = np.zeros(arg["shape"]).astype(
1071
1093
  mstype.dtype_to_nptype(arg_dtype))
@@ -1075,7 +1097,7 @@ class Custom(ops.PrimitiveWithInfer):
1075
1097
 
1076
1098
  if hasattr(fake_output, 'shape'):
1077
1099
  infer_shape = fake_output.shape
1078
- infer_dtype = mstype.tensor_type(mstype.pytype_to_dtype(fake_output.dtype))
1100
+ infer_dtype = mstype.TensorType(mstype.pytype_to_dtype(fake_output.dtype))
1079
1101
  else:
1080
1102
  infer_shape = (1,)
1081
1103
  infer_dtype = mstype.pytype_to_dtype(fake_output.dtype)
@@ -43,7 +43,7 @@ def _check_summary_param(name, value, class_name):
43
43
  raise ValueError(f"For '{class_name}', the name must be valid string, but got '{n_value}'.")
44
44
 
45
45
  v_type = value['dtype']
46
- validator.check_value_type('value', v_type, [type(mstype.tensor)], class_name)
46
+ validator.check_value_type('value', v_type, [type(mstype.tensor_type)], class_name)
47
47
 
48
48
 
49
49
  # Note: The return value of the summary operator is not used,
@@ -58,7 +58,7 @@ class ScalarSummary(Primitive):
58
58
  This operator will put a scalar to a summary file with protocol buffer format. It must be used with SummaryRecord
59
59
  or SummaryCollector, which specify the directory of the summary file. The summary file can
60
60
  be loaded and shown by MindInsight, see `MindInsight documents <https://www.mindspore.cn/
61
- mindinsight/docs/en/r2.0/index.html>`_ for details.
61
+ mindinsight/docs/en/r2.2/index.html>`_ for details.
62
62
 
63
63
  Inputs:
64
64
  - **name** (str) - The name of the input variable, it must not be an empty string.
@@ -104,6 +104,7 @@ class ScalarSummary(Primitive):
104
104
  raise ValueError('The Summary is not supported, please without `-s on` and recompile source.')
105
105
 
106
106
  self.add_prim_attr("side_effect_io", True)
107
+ self.add_prim_attr("channel_name", "ms_scalar_summary")
107
108
 
108
109
  def __call__(self, *args):
109
110
  _cache_summary_data(self.name, args[0], args[1])
@@ -114,7 +115,7 @@ class ImageSummary(PrimitiveWithInfer):
114
115
  This operator will put an image tensor to a summary file with protocol buffer format. It must be used with
115
116
  SummaryRecord or SummaryCollector, which specify the directory of the summary file. The summary file can
116
117
  be loaded and shown by MindInsight, see `MindInsight documents <https://www.mindspore.cn/
117
- mindinsight/docs/en/r2.0/index.html>`_ for details.
118
+ mindinsight/docs/en/r2.2/index.html>`_ for details.
118
119
 
119
120
  Inputs:
120
121
  - **name** (str) - The name of the input variable, it must not be an empty string.
@@ -153,6 +154,7 @@ class ImageSummary(PrimitiveWithInfer):
153
154
  raise ValueError('The Summary is not supported, please without `-s on` and recompile source.')
154
155
 
155
156
  self.add_prim_attr("side_effect_io", True)
157
+ self.add_prim_attr("channel_name", "ms_image_summary")
156
158
 
157
159
  def __infer__(self, name, value):
158
160
  _check_summary_param(name, value, self.__class__.__name__)
@@ -175,7 +177,7 @@ class TensorSummary(Primitive):
175
177
  This operator will put a tensor to a summary file with protocol buffer format. It must be used with SummaryRecord
176
178
  or SummaryCollector, which specify the directory of the summary file. The summary file can
177
179
  be loaded and shown by MindInsight, see `MindInsight documents <https://www.mindspore.cn/
178
- mindinsight/docs/en/r2.0/index.html>`_ for details.
180
+ mindinsight/docs/en/r2.2/index.html>`_ for details.
179
181
 
180
182
  Inputs:
181
183
  - **name** (str) - The name of the input variable.
@@ -221,6 +223,7 @@ class TensorSummary(Primitive):
221
223
  raise ValueError('The Summary is not supported, please without `-s on` and recompile source.')
222
224
 
223
225
  self.add_prim_attr("side_effect_io", True)
226
+ self.add_prim_attr("channel_name", "ms_tensor_summary")
224
227
 
225
228
  def __call__(self, *args):
226
229
  _cache_summary_data(self.name, args[0], args[1])
@@ -231,7 +234,7 @@ class HistogramSummary(PrimitiveWithInfer):
231
234
  This operator will calculate the histogram of a tensor and put it to a summary file with protocol buffer format.
232
235
  It must be used with SummaryRecord or SummaryCollector, which specify the directory of the summary file.
233
236
  The summary file can be loaded and shown by MindInsight, see `MindInsight documents <https://www.mindspore.cn/
234
- mindinsight/docs/en/r2.0/index.html>`_ for details.
237
+ mindinsight/docs/en/r2.2/index.html>`_ for details.
235
238
 
236
239
  Inputs:
237
240
  - **name** (str) - The name of the input variable.
@@ -276,6 +279,7 @@ class HistogramSummary(PrimitiveWithInfer):
276
279
  raise ValueError('The Summary is not supported, please without `-s on` and recompile source.')
277
280
 
278
281
  self.add_prim_attr("side_effect_io", True)
282
+ self.add_prim_attr("channel_name", "ms_histogram_summary")
279
283
 
280
284
  def __infer__(self, name, value):
281
285
  _check_summary_param(name, value, self.__class__.__name__)
@@ -380,7 +384,7 @@ class HookBackward(PrimitiveWithInfer):
380
384
  hook_fn (Function): Python function. hook function.
381
385
  cell_id (str, optional): Used to identify whether the function registered by the hook is actually registered on
382
386
  the specified cell object. For example, 'nn.Conv2d' is a cell object.
383
- The default value of `cell_id` is empty string(""), in this case, the system will automatically
387
+ Default: ``""``, in this case, the system will automatically
384
388
  register a value of `cell_id`.
385
389
  The value of `cell_id` currently does not support custom values.
386
390
 
@@ -444,7 +448,7 @@ class HookBackward(PrimitiveWithInfer):
444
448
 
445
449
  def infer_dtype(self, *inputs_type):
446
450
  for dtype in inputs_type:
447
- validator.check_subclass("input", dtype, [mstype.tensor], self.name)
451
+ validator.check_subclass("input", dtype, [mstype.tensor_type], self.name)
448
452
  if len(inputs_type) == 1:
449
453
  return inputs_type[0]
450
454
  return inputs_type
@@ -456,10 +460,19 @@ class Print(Primitive):
456
460
 
457
461
  Refer to :func:`mindspore.ops.print_` for more detail.
458
462
 
463
+ Inputs:
464
+ - **input_x** (Union[Tensor, bool, int, float, str]) - The graph node to attach to.
465
+ Supports multiple inputs which are separated by ','.
466
+
467
+ Outputs:
468
+ Tensor, has the same data type and shape as original `input_x`.
469
+
459
470
  Supported Platforms:
460
471
  ``Ascend`` ``GPU`` ``CPU``
461
472
 
462
473
  Examples:
474
+ >>> import numpy as np
475
+ >>> from mindspore import Tensor, nn
463
476
  >>> class PrintDemo(nn.Cell):
464
477
  ... def __init__(self):
465
478
  ... super(PrintDemo, self).__init__()
@@ -503,16 +516,16 @@ class Print(Primitive):
503
516
  class Assert(PrimitiveWithInfer):
504
517
  """
505
518
  Asserts whether the given condition is True.
506
- If input condition is identified to be false, print a list of the tensor in data.
519
+ If input condition is identified to be ``False``, print a list of the tensor in data.
507
520
 
508
521
  Args:
509
522
  summarize (int, optional): The number of entries to be printed in each tensor while the given condition is
510
- identified to be False. Default: 3.
523
+ identified to be ``False`` . Default: ``3`` .
511
524
 
512
525
  Inputs:
513
526
  - **condition** (Union[Tensor[bool], bool]) - The condition to be identified.
514
527
  - **input_data** (Union[tuple[Tensor], list[Tensor]]) - The tensors to be printed out when the condition
515
- is false.
528
+ is ``False``.
516
529
 
517
530
  Raises:
518
531
  TypeError: If `summarize` is not an int.
@@ -560,5 +573,5 @@ class Assert(PrimitiveWithInfer):
560
573
  def infer_dtype(self, condition, inputs):
561
574
  validator.check_scalar_or_tensor_types_same({"condition": condition}, [mstype.bool_], self.name)
562
575
  for dtype in inputs:
563
- validator.check_subclass("input", dtype, [mstype.tensor], self.name)
576
+ validator.check_subclass("input", dtype, [mstype.tensor_type], self.name)
564
577
  return mstype.int32