mindspore 2.0.0rc1__cp38-cp38-manylinux1_x86_64.whl → 2.2.0__cp38-cp38-manylinux1_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (884) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Third_Party_Open_Source_Software_Notice +2 -2
  3. mindspore/__init__.py +5 -2
  4. mindspore/_akg/akg/build_module.py +5 -6
  5. mindspore/_akg/akg/composite/build_module.py +49 -16
  6. mindspore/_akg/akg/composite/split_stitch.py +10 -11
  7. mindspore/_akg/akg/config/repository.json +195 -0
  8. mindspore/_akg/akg/global_configs.py +5 -1
  9. mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
  10. mindspore/_akg/akg/tvm/api.py +4 -3
  11. mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
  12. mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
  13. mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
  14. mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
  15. mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
  16. mindspore/_akg/akg/tvm/build_module.py +16 -1
  17. mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
  18. mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
  19. mindspore/_akg/akg/tvm/ir_builder.py +1 -1
  20. mindspore/_akg/akg/tvm/module.py +1 -2
  21. mindspore/_akg/akg/tvm/stmt.py +2 -2
  22. mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
  23. mindspore/_akg/akg/utils/kernel_exec.py +58 -260
  24. mindspore/_akg/akg/utils/op_dsl.py +17 -1
  25. mindspore/_akg/akg/utils/result_analysis.py +4 -24
  26. mindspore/_akg/akg/utils/tbe_codegen_utils.py +198 -0
  27. mindspore/_c_dataengine.cpython-38-x86_64-linux-gnu.so +0 -0
  28. mindspore/_c_expression.cpython-38-x86_64-linux-gnu.so +0 -0
  29. mindspore/_c_mindrecord.cpython-38-x86_64-linux-gnu.so +0 -0
  30. mindspore/_check_jit_forbidden_api.py +5 -1
  31. mindspore/_checkparam.py +79 -62
  32. mindspore/_extends/graph_kernel/__init__.py +0 -1
  33. mindspore/_extends/graph_kernel/model/graph_split.py +2 -0
  34. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  35. mindspore/_extends/graph_kernel/splitter.py +1 -9
  36. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +128 -21
  37. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +2 -2
  38. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  39. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +18 -13
  40. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +13 -9
  41. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  42. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  43. mindspore/_extends/parse/__init__.py +19 -17
  44. mindspore/_extends/parse/namespace.py +7 -36
  45. mindspore/_extends/parse/parser.py +375 -189
  46. mindspore/_extends/parse/resources.py +36 -41
  47. mindspore/_extends/parse/standard_method.py +350 -245
  48. mindspore/_extends/parse/trope.py +2 -12
  49. mindspore/_extends/remote/kernel_build_server.py +24 -7
  50. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  51. mindspore/_install_custom.py +43 -0
  52. mindspore/_mindspore_offline_debug.cpython-38-x86_64-linux-gnu.so +0 -0
  53. mindspore/amp.py +85 -19
  54. mindspore/bin/cache_admin +0 -0
  55. mindspore/bin/cache_server +0 -0
  56. mindspore/boost/base.py +2 -2
  57. mindspore/boost/boost.py +27 -32
  58. mindspore/boost/boost_cell_wrapper.py +37 -13
  59. mindspore/boost/grad_accumulation.py +1 -1
  60. mindspore/boost/grad_freeze.py +34 -6
  61. mindspore/boost/group_loss_scale_manager.py +15 -14
  62. mindspore/boost/less_batch_normalization.py +28 -3
  63. mindspore/common/__init__.py +15 -11
  64. mindspore/common/_auto_dynamic.py +68 -0
  65. mindspore/common/_jit_fallback_utils.py +111 -0
  66. mindspore/common/_register_for_adapter.py +17 -5
  67. mindspore/common/_register_for_tensor.py +2 -2
  68. mindspore/common/_stub_tensor.py +18 -15
  69. mindspore/common/_utils.py +31 -7
  70. mindspore/common/api.py +269 -101
  71. mindspore/common/auto_dynamic_shape.py +498 -0
  72. mindspore/common/dtype.py +61 -21
  73. mindspore/common/dump.py +9 -7
  74. mindspore/common/initializer.py +106 -76
  75. mindspore/common/jit_config.py +35 -14
  76. mindspore/common/lazy_inline.py +187 -0
  77. mindspore/common/mindir_util.py +101 -0
  78. mindspore/common/mutable.py +10 -13
  79. mindspore/common/parameter.py +246 -55
  80. mindspore/common/seed.py +13 -7
  81. mindspore/common/sparse_tensor.py +29 -33
  82. mindspore/common/tensor.py +907 -251
  83. mindspore/communication/__init__.py +7 -4
  84. mindspore/communication/_comm_helper.py +84 -4
  85. mindspore/communication/management.py +160 -88
  86. mindspore/config/op_info.config +99 -75
  87. mindspore/config/super_bar_config.json +36 -4
  88. mindspore/context.py +526 -219
  89. mindspore/dataset/__init__.py +9 -46
  90. mindspore/dataset/audio/__init__.py +4 -19
  91. mindspore/dataset/audio/transforms.py +545 -233
  92. mindspore/dataset/audio/utils.py +21 -18
  93. mindspore/dataset/callback/ds_callback.py +42 -13
  94. mindspore/dataset/core/config.py +158 -100
  95. mindspore/dataset/core/validator_helpers.py +1 -63
  96. mindspore/dataset/debug/debug_hook.py +45 -13
  97. mindspore/dataset/debug/pre_defined_hook.py +5 -5
  98. mindspore/dataset/engine/__init__.py +0 -5
  99. mindspore/dataset/engine/cache_client.py +38 -15
  100. mindspore/dataset/engine/datasets.py +615 -278
  101. mindspore/dataset/engine/datasets_audio.py +154 -283
  102. mindspore/dataset/engine/datasets_standard_format.py +104 -116
  103. mindspore/dataset/engine/datasets_text.py +443 -326
  104. mindspore/dataset/engine/datasets_user_defined.py +251 -164
  105. mindspore/dataset/engine/datasets_vision.py +839 -1443
  106. mindspore/dataset/engine/iterators.py +11 -4
  107. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +7 -3
  108. mindspore/dataset/engine/obs/util.py +3 -0
  109. mindspore/dataset/engine/offload.py +6 -6
  110. mindspore/dataset/engine/queue.py +15 -14
  111. mindspore/dataset/engine/samplers.py +39 -23
  112. mindspore/dataset/engine/serializer_deserializer.py +22 -6
  113. mindspore/dataset/engine/validators.py +21 -331
  114. mindspore/dataset/text/__init__.py +5 -33
  115. mindspore/dataset/text/transforms.py +334 -165
  116. mindspore/dataset/text/utils.py +215 -145
  117. mindspore/dataset/transforms/__init__.py +1 -1
  118. mindspore/dataset/transforms/c_transforms.py +3 -2
  119. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  120. mindspore/dataset/transforms/transforms.py +174 -71
  121. mindspore/dataset/utils/browse_dataset.py +25 -17
  122. mindspore/dataset/utils/line_reader.py +24 -21
  123. mindspore/dataset/vision/__init__.py +5 -26
  124. mindspore/dataset/vision/c_transforms.py +177 -165
  125. mindspore/dataset/vision/py_transforms.py +114 -119
  126. mindspore/dataset/vision/py_transforms_util.py +54 -51
  127. mindspore/dataset/vision/transforms.py +1127 -381
  128. mindspore/dataset/vision/utils.py +54 -38
  129. mindspore/dataset/vision/validators.py +12 -2
  130. mindspore/experimental/map_parameter.py +38 -4
  131. mindspore/{dataset/datapreprocess → experimental/optim}/__init__.py +14 -4
  132. mindspore/experimental/optim/adam.py +192 -0
  133. mindspore/experimental/optim/adamw.py +181 -0
  134. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  135. mindspore/experimental/optim/optimizer.py +252 -0
  136. mindspore/experimental/optim/sgd.py +147 -0
  137. mindspore/gen_ops.py +273 -0
  138. mindspore/include/OWNERS +1 -2
  139. mindspore/include/api/context.h +21 -1
  140. mindspore/include/api/data_type.h +2 -1
  141. mindspore/include/api/graph.h +0 -15
  142. mindspore/include/api/kernel.h +2 -0
  143. mindspore/include/api/kernel_api.h +37 -12
  144. mindspore/include/api/model.h +29 -42
  145. mindspore/include/api/model_group.h +14 -3
  146. mindspore/include/api/model_parallel_runner.h +18 -2
  147. mindspore/include/api/serialization.h +26 -0
  148. mindspore/include/api/status.h +1 -0
  149. mindspore/include/api/types.h +38 -4
  150. mindspore/include/c_api/ms/abstract.h +67 -0
  151. mindspore/include/c_api/ms/attribute.h +197 -0
  152. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  153. mindspore/include/c_api/ms/base/macros.h +32 -0
  154. mindspore/include/c_api/ms/base/status.h +33 -0
  155. mindspore/include/c_api/ms/base/types.h +282 -0
  156. mindspore/include/c_api/ms/context.h +102 -0
  157. mindspore/include/c_api/ms/graph.h +160 -0
  158. mindspore/include/c_api/ms/node.h +606 -0
  159. mindspore/include/c_api/ms/tensor.h +161 -0
  160. mindspore/include/c_api/ms/value.h +84 -0
  161. mindspore/include/c_api/status_c.h +3 -0
  162. mindspore/include/dataset/constants.h +6 -12
  163. mindspore/include/dataset/execute.h +23 -13
  164. mindspore/include/dataset/text.h +26 -26
  165. mindspore/include/dataset/transforms.h +25 -31
  166. mindspore/include/dataset/vision.h +60 -60
  167. mindspore/include/dataset/vision_ascend.h +5 -6
  168. mindspore/include/dataset/vision_lite.h +17 -17
  169. mindspore/include/mindapi/base/format.h +0 -1
  170. mindspore/include/mindapi/base/type_id.h +2 -1
  171. mindspore/include/mindapi/base/types.h +5 -1
  172. mindspore/lib/libdnnl.so.2 +0 -0
  173. mindspore/lib/libjemalloc.so.2 +0 -0
  174. mindspore/lib/libmindspore.so +0 -0
  175. mindspore/lib/libmindspore_backend.so +0 -0
  176. mindspore/lib/libmindspore_common.so +0 -0
  177. mindspore/lib/libmindspore_core.so +0 -0
  178. mindspore/lib/libmindspore_glog.so.0 +0 -0
  179. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  180. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  181. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  182. mindspore/lib/libmindspore_shared_lib.so +0 -0
  183. mindspore/lib/libmpi_adapter.so +0 -0
  184. mindspore/lib/libnnacl.so +0 -0
  185. mindspore/lib/libopencv_core.so.4.5 +0 -0
  186. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  187. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  188. mindspore/lib/libps_cache.so +0 -0
  189. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  190. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  191. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +9000 -0
  192. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  193. mindspore/lib/plugin/ascend/libakg.so +0 -0
  194. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  195. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  196. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  197. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  198. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  199. mindspore/lib/plugin/cpu/libakg.so +0 -0
  200. mindspore/lib/plugin/gpu/libcuda_ops.so.10 +0 -0
  201. mindspore/lib/plugin/gpu/libcuda_ops.so.11 +0 -0
  202. mindspore/lib/plugin/gpu10.1/libakg.so +0 -0
  203. mindspore/lib/plugin/gpu10.1/libnccl.so.2 +0 -0
  204. mindspore/lib/plugin/gpu10.1/libnvidia_collective.so +0 -0
  205. mindspore/lib/plugin/gpu11.1/libakg.so +0 -0
  206. mindspore/lib/plugin/gpu11.1/libnccl.so.2 +0 -0
  207. mindspore/lib/plugin/gpu11.1/libnvidia_collective.so +0 -0
  208. mindspore/lib/plugin/gpu11.6/libakg.so +0 -0
  209. mindspore/lib/plugin/gpu11.6/libnccl.so.2 +0 -0
  210. mindspore/lib/plugin/gpu11.6/libnvidia_collective.so +0 -0
  211. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  212. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  213. mindspore/lib/plugin/libmindspore_gpu.so.10.1 +0 -0
  214. mindspore/lib/plugin/libmindspore_gpu.so.11.1 +0 -0
  215. mindspore/lib/plugin/libmindspore_gpu.so.11.6 +0 -0
  216. mindspore/log.py +9 -6
  217. mindspore/mindrecord/filereader.py +33 -4
  218. mindspore/mindrecord/filewriter.py +70 -35
  219. mindspore/mindrecord/mindpage.py +40 -34
  220. mindspore/mindrecord/shardreader.py +1 -1
  221. mindspore/mindrecord/shardsegment.py +1 -1
  222. mindspore/mindrecord/tools/cifar100_to_mr.py +25 -18
  223. mindspore/mindrecord/tools/cifar10_to_mr.py +25 -18
  224. mindspore/mindrecord/tools/csv_to_mr.py +29 -13
  225. mindspore/mindrecord/tools/imagenet_to_mr.py +24 -10
  226. mindspore/mindrecord/tools/mnist_to_mr.py +24 -11
  227. mindspore/mindrecord/tools/tfrecord_to_mr.py +31 -26
  228. mindspore/nn/cell.py +463 -169
  229. mindspore/nn/dynamic_lr.py +47 -43
  230. mindspore/nn/layer/activation.py +225 -82
  231. mindspore/nn/layer/basic.py +121 -79
  232. mindspore/nn/layer/channel_shuffle.py +21 -21
  233. mindspore/nn/layer/combined.py +33 -26
  234. mindspore/nn/layer/container.py +277 -22
  235. mindspore/nn/layer/conv.py +441 -304
  236. mindspore/nn/layer/dense.py +19 -13
  237. mindspore/nn/layer/embedding.py +62 -49
  238. mindspore/nn/layer/flash_attention.py +264 -0
  239. mindspore/nn/layer/image.py +50 -39
  240. mindspore/nn/layer/math.py +62 -51
  241. mindspore/nn/layer/normalization.py +219 -167
  242. mindspore/nn/layer/padding.py +58 -70
  243. mindspore/nn/layer/pooling.py +334 -287
  244. mindspore/nn/layer/rnn_cells.py +53 -38
  245. mindspore/nn/layer/rnns.py +59 -56
  246. mindspore/nn/layer/thor_layer.py +52 -44
  247. mindspore/nn/layer/timedistributed.py +6 -4
  248. mindspore/nn/layer/transformer.py +284 -164
  249. mindspore/nn/learning_rate_schedule.py +34 -25
  250. mindspore/nn/loss/__init__.py +3 -2
  251. mindspore/nn/loss/loss.py +554 -311
  252. mindspore/nn/optim/ada_grad.py +12 -9
  253. mindspore/nn/optim/adadelta.py +14 -11
  254. mindspore/nn/optim/adafactor.py +19 -16
  255. mindspore/nn/optim/adam.py +62 -47
  256. mindspore/nn/optim/adamax.py +13 -10
  257. mindspore/nn/optim/adasum.py +12 -8
  258. mindspore/nn/optim/asgd.py +10 -9
  259. mindspore/nn/optim/ftrl.py +20 -17
  260. mindspore/nn/optim/lamb.py +16 -12
  261. mindspore/nn/optim/lars.py +8 -6
  262. mindspore/nn/optim/lazyadam.py +25 -20
  263. mindspore/nn/optim/momentum.py +10 -7
  264. mindspore/nn/optim/optimizer.py +61 -9
  265. mindspore/nn/optim/proximal_ada_grad.py +14 -13
  266. mindspore/nn/optim/rmsprop.py +17 -13
  267. mindspore/nn/optim/rprop.py +30 -17
  268. mindspore/nn/optim/sgd.py +40 -23
  269. mindspore/nn/optim/thor.py +24 -26
  270. mindspore/nn/probability/bijector/bijector.py +11 -11
  271. mindspore/nn/probability/bijector/exp.py +1 -1
  272. mindspore/nn/probability/bijector/gumbel_cdf.py +3 -3
  273. mindspore/nn/probability/bijector/invert.py +1 -1
  274. mindspore/nn/probability/bijector/power_transform.py +29 -29
  275. mindspore/nn/probability/bijector/scalar_affine.py +3 -3
  276. mindspore/nn/probability/bijector/softplus.py +5 -5
  277. mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +4 -2
  278. mindspore/nn/probability/bnn_layers/conv_variational.py +13 -13
  279. mindspore/nn/probability/bnn_layers/dense_variational.py +12 -12
  280. mindspore/nn/probability/bnn_layers/layer_distribution.py +9 -8
  281. mindspore/nn/probability/distribution/_utils/custom_ops.py +19 -3
  282. mindspore/nn/probability/distribution/_utils/utils.py +1 -1
  283. mindspore/nn/probability/distribution/bernoulli.py +9 -9
  284. mindspore/nn/probability/distribution/beta.py +8 -8
  285. mindspore/nn/probability/distribution/categorical.py +23 -15
  286. mindspore/nn/probability/distribution/cauchy.py +5 -6
  287. mindspore/nn/probability/distribution/distribution.py +3 -3
  288. mindspore/nn/probability/distribution/exponential.py +4 -4
  289. mindspore/nn/probability/distribution/gamma.py +10 -10
  290. mindspore/nn/probability/distribution/geometric.py +8 -8
  291. mindspore/nn/probability/distribution/gumbel.py +8 -9
  292. mindspore/nn/probability/distribution/half_normal.py +5 -5
  293. mindspore/nn/probability/distribution/laplace.py +5 -5
  294. mindspore/nn/probability/distribution/log_normal.py +12 -11
  295. mindspore/nn/probability/distribution/logistic.py +8 -8
  296. mindspore/nn/probability/distribution/normal.py +6 -5
  297. mindspore/nn/probability/distribution/poisson.py +10 -11
  298. mindspore/nn/probability/distribution/student_t.py +8 -9
  299. mindspore/nn/probability/distribution/transformed_distribution.py +5 -5
  300. mindspore/nn/probability/distribution/uniform.py +11 -11
  301. mindspore/nn/reinforcement/tensor_array.py +2 -2
  302. mindspore/nn/sparse/sparse.py +9 -9
  303. mindspore/nn/wrap/cell_wrapper.py +188 -63
  304. mindspore/nn/wrap/grad_reducer.py +21 -12
  305. mindspore/nn/wrap/loss_scale.py +136 -49
  306. mindspore/numpy/__init__.py +4 -4
  307. mindspore/numpy/array_creations.py +55 -56
  308. mindspore/numpy/array_ops.py +134 -35
  309. mindspore/numpy/logic_ops.py +66 -20
  310. mindspore/numpy/math_ops.py +142 -139
  311. mindspore/numpy/utils_const.py +2 -2
  312. mindspore/offline_debug/convert_async.py +2 -2
  313. mindspore/ops/_grad_experimental/__init__.py +7 -5
  314. mindspore/ops/_grad_experimental/grad_array_ops.py +231 -348
  315. mindspore/ops/{_grad → _grad_experimental}/grad_base.py +1 -33
  316. mindspore/ops/{_grad → _grad_experimental}/grad_comm_ops.py +25 -13
  317. mindspore/ops/{_grad/__init__.py → _grad_experimental/grad_debug_ops.py} +15 -7
  318. mindspore/ops/{_grad → _grad_experimental}/grad_implementations.py +17 -11
  319. mindspore/ops/_grad_experimental/grad_inner_ops.py +33 -52
  320. mindspore/ops/_grad_experimental/grad_math_ops.py +151 -1224
  321. mindspore/ops/_grad_experimental/grad_nn_ops.py +141 -414
  322. mindspore/ops/{_grad → _grad_experimental}/grad_quant_ops.py +10 -6
  323. mindspore/ops/_grad_experimental/grad_sparse.py +317 -2
  324. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -13
  325. mindspore/ops/{_grad → _grad_experimental}/taylor_rule.py +1 -1
  326. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  327. mindspore/ops/_op_impl/_custom_op/flash_attention/__init__.py +0 -0
  328. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +406 -0
  329. mindspore/{_extends/graph_kernel/expanders/complex/__init__.py → ops/_op_impl/_custom_op/flash_attention/constants.py} +27 -8
  330. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +467 -0
  331. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +563 -0
  332. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +193 -0
  333. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +435 -0
  334. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
  335. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +45 -0
  336. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +67 -0
  337. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +62 -0
  338. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
  339. mindspore/ops/_op_impl/aicpu/__init__.py +41 -1
  340. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
  341. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  342. mindspore/ops/_op_impl/aicpu/cast.py +52 -0
  343. mindspore/ops/_op_impl/aicpu/coalesce.py +2 -0
  344. mindspore/ops/_op_impl/aicpu/col2im.py +3 -1
  345. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  346. mindspore/ops/_op_impl/aicpu/dropout_genmask.py +6 -0
  347. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  348. mindspore/ops/_op_impl/aicpu/eye.py +4 -4
  349. mindspore/ops/_op_impl/aicpu/fft_with_size.py +6 -0
  350. mindspore/ops/_op_impl/aicpu/fill_diagonal.py +5 -0
  351. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  352. mindspore/ops/_op_impl/aicpu/im2col.py +3 -5
  353. mindspore/ops/_op_impl/aicpu/lgamma.py +1 -0
  354. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  355. mindspore/ops/_op_impl/aicpu/lu.py +39 -0
  356. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  357. mindspore/ops/_op_impl/aicpu/masked_scatter.py +1 -0
  358. mindspore/ops/_op_impl/aicpu/masked_select_grad.py +3 -0
  359. mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
  360. mindspore/ops/_op_impl/aicpu/matrix_power.py +6 -1
  361. mindspore/ops/_op_impl/aicpu/median.py +1 -0
  362. mindspore/ops/_op_impl/aicpu/multinomial.py +9 -9
  363. mindspore/ops/_op_impl/aicpu/not_equal.py +0 -5
  364. mindspore/ops/_op_impl/aicpu/pad_v3.py +3 -1
  365. mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +2 -0
  366. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  367. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  368. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  369. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  370. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  371. mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +0 -1
  372. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +0 -6
  373. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +0 -7
  374. mindspore/ops/_op_impl/aicpu/scatter_nd.py +2 -0
  375. mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
  376. mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
  377. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  378. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  379. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -4
  380. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -4
  381. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  382. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  383. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  384. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  385. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  386. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +14 -6
  387. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +22 -8
  388. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +11 -6
  389. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +21 -10
  390. mindspore/ops/_op_impl/tbe/__init__.py +6 -4
  391. mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
  392. mindspore/ops/_op_impl/tbe/avg_pool.py +2 -2
  393. mindspore/ops/_op_impl/tbe/avg_pool_3d.py +3 -3
  394. mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +4 -4
  395. mindspore/ops/_op_impl/tbe/avg_pool_ds.py +2 -2
  396. mindspore/ops/_op_impl/tbe/avg_pool_grad.py +3 -3
  397. mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py +3 -3
  398. mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
  399. mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +2 -2
  400. mindspore/ops/_op_impl/tbe/bn_infer.py +2 -2
  401. mindspore/ops/_op_impl/tbe/bn_infer_ds.py +3 -2
  402. mindspore/ops/_op_impl/tbe/broadcast_to.py +1 -1
  403. mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +3 -3
  404. mindspore/ops/_op_impl/tbe/expand_dims.py +1 -1
  405. mindspore/ops/_op_impl/tbe/gather_v2.py +56 -0
  406. mindspore/ops/_op_impl/tbe/im2col.py +4 -4
  407. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  408. mindspore/ops/_op_impl/tbe/mem_set.py +38 -0
  409. mindspore/ops/_op_impl/tbe/scatter_nd_add.py +3 -0
  410. mindspore/ops/_op_impl/tbe/scatter_nd_d.py +1 -1
  411. mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
  412. mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +2 -2
  413. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  414. mindspore/ops/_primitive_cache.py +1 -1
  415. mindspore/ops/_tracefunc.py +241 -0
  416. mindspore/ops/_utils/utils.py +10 -2
  417. mindspore/ops/_vmap/vmap_array_ops.py +5 -3
  418. mindspore/ops/_vmap/vmap_base.py +5 -4
  419. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  420. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  421. mindspore/ops/_vmap/vmap_grad_nn_ops.py +11 -6
  422. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  423. mindspore/ops/_vmap/vmap_nn_ops.py +135 -11
  424. mindspore/ops/arg_dtype_cast.py +54 -0
  425. mindspore/ops/composite/__init__.py +7 -5
  426. mindspore/ops/composite/base.py +78 -34
  427. mindspore/ops/composite/math_ops.py +5 -695
  428. mindspore/ops/composite/multitype_ops/_compile_utils.py +403 -97
  429. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +28 -22
  430. mindspore/ops/composite/multitype_ops/add_impl.py +69 -7
  431. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
  432. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
  433. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -0
  434. mindspore/ops/composite/multitype_ops/div_impl.py +1 -0
  435. mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -0
  436. mindspore/ops/composite/multitype_ops/getitem_impl.py +48 -10
  437. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +2 -0
  438. mindspore/ops/composite/multitype_ops/greater_impl.py +2 -0
  439. mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -0
  440. mindspore/ops/composite/multitype_ops/less_equal_impl.py +2 -0
  441. mindspore/ops/composite/multitype_ops/less_impl.py +2 -0
  442. mindspore/ops/composite/multitype_ops/logic_not_impl.py +2 -2
  443. mindspore/ops/composite/multitype_ops/mod_impl.py +1 -0
  444. mindspore/ops/composite/multitype_ops/mul_impl.py +1 -0
  445. mindspore/ops/composite/multitype_ops/negative_impl.py +1 -0
  446. mindspore/ops/composite/multitype_ops/not_in_impl.py +1 -0
  447. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  448. mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
  449. mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -0
  450. mindspore/ops/composite/multitype_ops/setitem_impl.py +10 -7
  451. mindspore/ops/composite/multitype_ops/sub_impl.py +1 -0
  452. mindspore/ops/composite/multitype_ops/uadd_impl.py +2 -0
  453. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  454. mindspore/ops/deprecated.py +304 -0
  455. mindspore/ops/function/__init__.py +41 -4
  456. mindspore/ops/function/array_func.py +1108 -467
  457. mindspore/ops/function/clip_func.py +94 -27
  458. mindspore/ops/function/debug_func.py +3 -1
  459. mindspore/ops/function/grad/grad_func.py +82 -73
  460. mindspore/ops/function/image_func.py +28 -12
  461. mindspore/ops/function/linalg_func.py +135 -39
  462. mindspore/ops/function/math_func.py +3779 -894
  463. mindspore/ops/function/nn_func.py +1584 -657
  464. mindspore/ops/function/parameter_func.py +13 -3
  465. mindspore/ops/function/random_func.py +247 -153
  466. mindspore/ops/function/sparse_func.py +14 -11
  467. mindspore/ops/function/sparse_unary_func.py +173 -47
  468. mindspore/ops/function/spectral_func.py +8 -4
  469. mindspore/ops/function/vmap_func.py +8 -7
  470. mindspore/ops/functional.py +47 -16
  471. mindspore/ops/op_info_register.py +346 -86
  472. mindspore/ops/operations/__init__.py +38 -22
  473. mindspore/ops/operations/_grad_ops.py +145 -149
  474. mindspore/ops/operations/_inner_ops.py +298 -56
  475. mindspore/ops/operations/_ms_kernel.py +3 -3
  476. mindspore/ops/operations/_quant_ops.py +24 -28
  477. mindspore/ops/operations/_rl_inner_ops.py +9 -7
  478. mindspore/ops/operations/_scalar_ops.py +115 -0
  479. mindspore/ops/operations/_sequence_ops.py +148 -10
  480. mindspore/ops/operations/_tensor_array.py +1 -1
  481. mindspore/ops/operations/_thor_ops.py +2 -2
  482. mindspore/ops/operations/array_ops.py +1239 -561
  483. mindspore/ops/operations/comm_ops.py +166 -90
  484. mindspore/ops/operations/control_ops.py +3 -3
  485. mindspore/ops/operations/custom_ops.py +124 -102
  486. mindspore/ops/operations/debug_ops.py +24 -11
  487. mindspore/ops/operations/image_ops.py +86 -71
  488. mindspore/ops/operations/inner_ops.py +18 -13
  489. mindspore/ops/operations/linalg_ops.py +30 -11
  490. mindspore/ops/operations/math_ops.py +1730 -435
  491. mindspore/ops/operations/nn_ops.py +1953 -943
  492. mindspore/ops/operations/other_ops.py +65 -43
  493. mindspore/ops/operations/random_ops.py +258 -98
  494. mindspore/ops/operations/rl_ops.py +4 -36
  495. mindspore/ops/operations/sparse_ops.py +38 -33
  496. mindspore/ops/operations/spectral_ops.py +8 -4
  497. mindspore/ops/primitive.py +66 -44
  498. mindspore/ops/signature.py +5 -5
  499. mindspore/parallel/_auto_parallel_context.py +80 -19
  500. mindspore/parallel/_cost_model_context.py +42 -0
  501. mindspore/parallel/_offload_context.py +162 -72
  502. mindspore/parallel/_parallel_serialization.py +2 -2
  503. mindspore/parallel/_ps_context.py +16 -4
  504. mindspore/parallel/_recovery_context.py +2 -1
  505. mindspore/parallel/_tensor.py +15 -13
  506. mindspore/parallel/_transformer/layers.py +8 -6
  507. mindspore/parallel/_transformer/loss.py +1 -0
  508. mindspore/parallel/_transformer/moe.py +7 -7
  509. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  510. mindspore/parallel/_transformer/transformer.py +34 -14
  511. mindspore/parallel/_utils.py +36 -14
  512. mindspore/parallel/algo_parameter_config.py +114 -20
  513. mindspore/parallel/checkpoint_transform.py +16 -18
  514. mindspore/parallel/shard.py +16 -13
  515. mindspore/profiler/__init__.py +1 -1
  516. mindspore/profiler/common/struct_type.py +3 -3
  517. mindspore/profiler/common/util.py +3 -2
  518. mindspore/profiler/envprofiling.py +11 -4
  519. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  520. mindspore/profiler/parser/ascend_flops_generator.py +94 -0
  521. mindspore/profiler/parser/ascend_fpbp_generator.py +76 -0
  522. mindspore/profiler/parser/ascend_hccl_generator.py +288 -0
  523. mindspore/profiler/parser/ascend_msprof_exporter.py +213 -0
  524. mindspore/profiler/parser/ascend_msprof_generator.py +199 -0
  525. mindspore/profiler/parser/ascend_op_generator.py +276 -0
  526. mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
  527. mindspore/profiler/parser/ascend_timeline_generator.py +110 -54
  528. mindspore/profiler/parser/base_timeline_generator.py +11 -7
  529. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +45 -46
  530. mindspore/profiler/parser/flops_parser.py +15 -11
  531. mindspore/profiler/parser/framework_parser.py +92 -73
  532. mindspore/profiler/parser/hccl_parser.py +16 -12
  533. mindspore/profiler/parser/integrator.py +22 -11
  534. mindspore/profiler/parser/memory_usage_parser.py +36 -11
  535. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  536. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  537. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  538. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  539. mindspore/profiler/parser/optime_parser.py +1 -1
  540. mindspore/profiler/parser/profiler_info.py +4 -5
  541. mindspore/profiler/parser/step_trace_parser.py +11 -14
  542. mindspore/profiler/profiling.py +678 -377
  543. mindspore/rewrite/api/node.py +211 -54
  544. mindspore/rewrite/api/node_type.py +5 -0
  545. mindspore/rewrite/api/pattern_engine.py +22 -23
  546. mindspore/rewrite/api/scoped_value.py +20 -17
  547. mindspore/rewrite/api/symbol_tree.py +252 -106
  548. mindspore/rewrite/api/tree_node_helper.py +3 -0
  549. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  550. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  551. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  552. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +97 -46
  553. mindspore/rewrite/common/rewrite_elog.py +5 -1
  554. mindspore/rewrite/namer.py +51 -51
  555. mindspore/rewrite/namespace.py +14 -5
  556. mindspore/{ops/bprop_mindir → rewrite/node}/__init__.py +9 -4
  557. mindspore/rewrite/node/call_function.py +79 -0
  558. mindspore/rewrite/node/cell_container.py +135 -0
  559. mindspore/rewrite/node/control_flow.py +88 -0
  560. mindspore/rewrite/{node.py → node/node.py} +313 -247
  561. mindspore/rewrite/node/node_manager.py +254 -0
  562. mindspore/rewrite/node/node_topological_manager.py +243 -0
  563. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  564. mindspore/rewrite/parsers/assign_parser.py +225 -239
  565. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  566. mindspore/rewrite/parsers/class_def_parser.py +179 -218
  567. mindspore/rewrite/parsers/constant_parser.py +9 -6
  568. mindspore/rewrite/parsers/container_parser.py +9 -7
  569. mindspore/rewrite/parsers/for_parser.py +36 -15
  570. mindspore/rewrite/parsers/function_def_parser.py +23 -20
  571. mindspore/rewrite/parsers/if_parser.py +28 -24
  572. mindspore/rewrite/parsers/module_parser.py +202 -25
  573. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  574. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  575. mindspore/rewrite/parsers/return_parser.py +6 -6
  576. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  577. mindspore/rewrite/sparsify/sparsify.py +4 -1
  578. mindspore/rewrite/sparsify/utils.py +11 -5
  579. mindspore/rewrite/symbol_tree.py +577 -732
  580. mindspore/rewrite/symbol_tree_builder.py +9 -175
  581. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  582. mindspore/run_check/_check_version.py +46 -39
  583. mindspore/run_check/run_check.py +3 -2
  584. mindspore/{scipy/sparse → safeguard}/__init__.py +4 -5
  585. mindspore/safeguard/rewrite_obfuscation.py +517 -0
  586. mindspore/scipy/__init__.py +1 -1
  587. mindspore/scipy/linalg.py +67 -61
  588. mindspore/scipy/ops.py +5 -41
  589. mindspore/scipy/ops_grad.py +3 -2
  590. mindspore/scipy/ops_wrapper.py +5 -5
  591. mindspore/scipy/optimize/line_search.py +8 -8
  592. mindspore/scipy/optimize/linear_sum_assignment.py +4 -4
  593. mindspore/scipy/optimize/minimize.py +16 -12
  594. mindspore/scipy/utils.py +1 -52
  595. mindspore/scipy/utils_const.py +4 -4
  596. mindspore/train/__init__.py +4 -4
  597. mindspore/train/_utils.py +13 -5
  598. mindspore/train/amp.py +410 -148
  599. mindspore/train/anf_ir_pb2.py +16 -4
  600. mindspore/train/callback/_backup_and_restore.py +8 -11
  601. mindspore/train/callback/_callback.py +80 -3
  602. mindspore/train/callback/_checkpoint.py +82 -51
  603. mindspore/train/callback/_early_stop.py +12 -15
  604. mindspore/train/callback/_history.py +1 -1
  605. mindspore/train/callback/_lambda_callback.py +13 -13
  606. mindspore/train/callback/_landscape.py +21 -17
  607. mindspore/train/callback/_loss_monitor.py +9 -10
  608. mindspore/train/callback/_on_request_exit.py +16 -33
  609. mindspore/train/callback/_reduce_lr_on_plateau.py +21 -24
  610. mindspore/train/callback/_summary_collector.py +44 -30
  611. mindspore/train/callback/_time_monitor.py +62 -12
  612. mindspore/train/data_sink.py +10 -16
  613. mindspore/train/dataset_helper.py +154 -86
  614. mindspore/train/loss_scale_manager.py +14 -9
  615. mindspore/train/metrics/__init__.py +10 -2
  616. mindspore/train/metrics/accuracy.py +1 -1
  617. mindspore/train/metrics/auc.py +1 -1
  618. mindspore/train/metrics/bleu_score.py +2 -2
  619. mindspore/train/metrics/confusion_matrix.py +14 -14
  620. mindspore/train/metrics/cosine_similarity.py +3 -3
  621. mindspore/train/metrics/dice.py +1 -1
  622. mindspore/train/metrics/fbeta.py +1 -1
  623. mindspore/train/metrics/hausdorff_distance.py +8 -6
  624. mindspore/train/metrics/mean_surface_distance.py +5 -4
  625. mindspore/train/metrics/metric.py +49 -17
  626. mindspore/train/metrics/occlusion_sensitivity.py +4 -4
  627. mindspore/train/metrics/perplexity.py +1 -1
  628. mindspore/train/metrics/precision.py +2 -2
  629. mindspore/train/metrics/recall.py +2 -3
  630. mindspore/train/metrics/roc.py +7 -7
  631. mindspore/train/metrics/root_mean_square_surface_distance.py +5 -4
  632. mindspore/train/metrics/topk.py +7 -4
  633. mindspore/train/mind_ir_pb2.py +193 -48
  634. mindspore/train/model.py +377 -133
  635. mindspore/train/serialization.py +697 -245
  636. mindspore/train/summary/_summary_adapter.py +5 -2
  637. mindspore/train/summary/_writer_pool.py +4 -3
  638. mindspore/train/summary/summary_record.py +25 -23
  639. mindspore/train/train_thor/convert_utils.py +39 -23
  640. mindspore/train/train_thor/dataset_helper.py +4 -3
  641. mindspore/train/train_thor/model_thor.py +8 -8
  642. mindspore/version.py +1 -1
  643. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/METADATA +7 -8
  644. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/RECORD +647 -818
  645. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/entry_points.txt +0 -1
  646. mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
  647. mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
  648. mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
  649. mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
  650. mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
  651. mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
  652. mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
  653. mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
  654. mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
  655. mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
  656. mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
  657. mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
  658. mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
  659. mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
  660. mindspore/_akg/akg/tvm/rpc/base.py +0 -182
  661. mindspore/_akg/akg/tvm/rpc/client.py +0 -436
  662. mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
  663. mindspore/_akg/akg/tvm/rpc/server.py +0 -413
  664. mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
  665. mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
  666. mindspore/_extends/graph_kernel/expander.py +0 -80
  667. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -57
  668. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  669. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  670. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  671. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  672. mindspore/_extends/graph_kernel/expanders/bias_add_grad.py +0 -49
  673. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  674. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  675. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  676. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  677. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  678. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  679. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  680. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  681. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  682. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  683. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  684. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  685. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  686. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  687. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  688. mindspore/_extends/graph_kernel/expanders/gather.py +0 -43
  689. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  690. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  691. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  692. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  693. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  694. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  695. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  696. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  697. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  698. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  699. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  700. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  701. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  702. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  703. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  704. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  705. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  706. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  707. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  708. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  709. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  710. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  711. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  712. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  713. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  714. mindspore/_extends/graph_kernel/expanders/tile.py +0 -54
  715. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  716. mindspore/_extends/parse/jit_fallback_modules.py +0 -51
  717. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  718. mindspore/dataset/engine/graphdata.py +0 -1586
  719. mindspore/include/api/net.h +0 -142
  720. mindspore/ops/_grad/grad_array_ops.py +0 -1347
  721. mindspore/ops/_grad/grad_clip_ops.py +0 -84
  722. mindspore/ops/_grad/grad_debug_ops.py +0 -68
  723. mindspore/ops/_grad/grad_inner_ops.py +0 -235
  724. mindspore/ops/_grad/grad_math_ops.py +0 -1684
  725. mindspore/ops/_grad/grad_nn_ops.py +0 -1529
  726. mindspore/ops/_grad/grad_other_ops.py +0 -89
  727. mindspore/ops/_grad/grad_sequence_ops.py +0 -296
  728. mindspore/ops/_grad/grad_sparse.py +0 -323
  729. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -249
  730. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -195
  731. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  732. mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
  733. mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
  734. mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +0 -19
  735. mindspore/ops/bprop_mindir/Argmax_bprop.mindir +0 -15
  736. mindspore/ops/bprop_mindir/Argmin_bprop.mindir +0 -15
  737. mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +0 -19
  738. mindspore/ops/bprop_mindir/Assign_bprop.mindir +0 -17
  739. mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +0 -150
  740. mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +0 -66
  741. mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
  742. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -15
  743. mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
  744. mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +0 -28
  745. mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
  746. mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +0 -33
  747. mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +0 -306
  748. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -13
  749. mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
  750. mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
  751. mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +0 -240
  752. mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +0 -247
  753. mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +0 -247
  754. mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +0 -315
  755. mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +0 -278
  756. mindspore/ops/bprop_mindir/DType_bprop.mindir +0 -14
  757. mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +0 -58
  758. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -13
  759. mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +0 -23
  760. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  761. mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +0 -15
  762. mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
  763. mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
  764. mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +0 -25
  765. mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +0 -18
  766. mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +0 -27
  767. mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
  768. mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
  769. mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
  770. mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +0 -14
  771. mindspore/ops/bprop_mindir/Elu_bprop.mindir +0 -16
  772. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  773. mindspore/ops/bprop_mindir/Equal_bprop.mindir +0 -19
  774. mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +0 -58
  775. mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +0 -16
  776. mindspore/ops/bprop_mindir/Flatten_bprop.mindir +0 -54
  777. mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +0 -19
  778. mindspore/ops/bprop_mindir/GatherD_bprop.mindir +0 -26
  779. mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +0 -57
  780. mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
  781. mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +0 -19
  782. mindspore/ops/bprop_mindir/Greater_bprop.mindir +0 -19
  783. mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +0 -16
  784. mindspore/ops/bprop_mindir/HSwish_bprop.mindir +0 -16
  785. mindspore/ops/bprop_mindir/IOU_bprop.mindir +0 -19
  786. mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
  787. mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +0 -15
  788. mindspore/ops/bprop_mindir/IsInf_bprop.mindir +0 -15
  789. mindspore/ops/bprop_mindir/IsNan_bprop.mindir +0 -15
  790. mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +0 -126
  791. mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +0 -15
  792. mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +0 -30
  793. mindspore/ops/bprop_mindir/LRN_bprop.mindir +0 -43
  794. mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
  795. mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +0 -19
  796. mindspore/ops/bprop_mindir/Less_bprop.mindir +0 -19
  797. mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +0 -23
  798. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -13
  799. mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +0 -23
  800. mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +0 -19
  801. mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +0 -15
  802. mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +0 -21
  803. mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +0 -74
  804. mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +0 -74
  805. mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +0 -75
  806. mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +0 -65
  807. mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
  808. mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
  809. mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
  810. mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +0 -27
  811. mindspore/ops/bprop_mindir/Mish_bprop.mindir +0 -35
  812. mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
  813. mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
  814. mindspore/ops/bprop_mindir/NonZero_bprop.mindir +0 -14
  815. mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +0 -19
  816. mindspore/ops/bprop_mindir/OneHot_bprop.mindir +0 -26
  817. mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +0 -14
  818. mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
  819. mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
  820. mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
  821. mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +0 -29
  822. mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +0 -82
  823. mindspore/ops/bprop_mindir/Range_bprop.mindir +0 -22
  824. mindspore/ops/bprop_mindir/Rank_bprop.mindir +0 -14
  825. mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +0 -16
  826. mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
  827. mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +0 -19
  828. mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +0 -19
  829. mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +0 -20
  830. mindspore/ops/bprop_mindir/Reshape_bprop.mindir +0 -60
  831. mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +0 -29
  832. mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +0 -89
  833. mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +0 -52
  834. mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +0 -22
  835. mindspore/ops/bprop_mindir/Round_bprop.mindir +0 -15
  836. mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
  837. mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
  838. mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +0 -22
  839. mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +0 -24
  840. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -22
  841. mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
  842. mindspore/ops/bprop_mindir/SeLU_bprop.mindir +0 -21
  843. mindspore/ops/bprop_mindir/Select_bprop.mindir +0 -31
  844. mindspore/ops/bprop_mindir/Shape_bprop.mindir +0 -14
  845. mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +0 -21
  846. mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
  847. mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +0 -16
  848. mindspore/ops/bprop_mindir/Sign_bprop.mindir +0 -15
  849. mindspore/ops/bprop_mindir/Slice_bprop.mindir +0 -26
  850. mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +0 -36
  851. mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  852. mindspore/ops/bprop_mindir/Softplus_bprop.mindir +0 -16
  853. mindspore/ops/bprop_mindir/Softsign_bprop.mindir +0 -33
  854. mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
  855. mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +0 -28
  856. mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +0 -23
  857. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  858. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  859. mindspore/ops/bprop_mindir/Split_bprop.mindir +0 -22
  860. mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +0 -54
  861. mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +0 -95
  862. mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +0 -98
  863. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -29
  864. mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
  865. mindspore/ops/bprop_mindir/Tanh_bprop.mindir +0 -66
  866. mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +0 -22
  867. mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +0 -29
  868. mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +0 -14
  869. mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
  870. mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
  871. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -23
  872. mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +0 -19
  873. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -20
  874. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -16
  875. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -22
  876. mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +0 -32
  877. mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +0 -38
  878. mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +0 -15
  879. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  880. mindspore/rewrite/node_visitor.py +0 -44
  881. mindspore/rewrite/topological_manager.py +0 -203
  882. mindspore/scipy/sparse/linalg.py +0 -192
  883. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/WHEEL +0 -0
  884. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,67 @@
1
+ # Copyright 2023 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+ """The base class of tiling strategy"""
16
+ from abc import ABCMeta
17
+ from abc import abstractmethod
18
+ from collections import namedtuple
19
+
20
+ TilingPara = namedtuple("TilingPara", "Br last_Br Bc last_Bc Tr Tc")
21
+
22
+
23
+ class TilingStrategy(metaclass=ABCMeta):
24
+ """Tiling strategy interface. All implementations should be defined in this module,
25
+ otherwise, the UT will fail.
26
+ """
27
+
28
+ _strategies = {}
29
+
30
+ def __init__(self, Nq, N, head_dim) -> None:
31
+ super().__init__()
32
+ self.Nq = Nq
33
+ self.N = N
34
+ self.Br = None
35
+ self.last_Br = None
36
+ self.Bc = None
37
+ self.last_Bc = None
38
+ self.Tr = None
39
+ self.Tc = None
40
+ self.d = head_dim
41
+
42
+ def __init_subclass__(cls, **kwargs):
43
+ TilingStrategy._strategies[cls.strategy_name()] = cls
44
+
45
+ @classmethod
46
+ @abstractmethod
47
+ def strategy_name(cls):
48
+ """strategy name"""
49
+ raise NotImplementedError
50
+
51
+ @classmethod
52
+ def from_strategy_name(cls, stgy_name: str):
53
+ """from strategy name"""
54
+ stgy_clz = TilingStrategy._strategies.get(stgy_name)
55
+ if stgy_clz is None:
56
+ raise Exception(f"Strategy:{stgy_name} not supported")
57
+
58
+ return stgy_clz
59
+
60
+ @abstractmethod
61
+ def tiling(self) -> TilingPara:
62
+ """tiling"""
63
+ raise NotImplementedError
64
+
65
+ def gen_tiling_para(self) -> TilingPara:
66
+ """gen tiling para"""
67
+ return TilingPara(self.Br, self.last_Br, self.Bc, self.last_Bc, self.Tr, self.Tc)
@@ -0,0 +1,62 @@
1
+ # Copyright 2023 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+ """wukong tiling"""
16
+ from mindspore.ops._op_impl._custom_op.flash_attention.tiling_strategy.strategy import TilingPara
17
+ from mindspore.ops._op_impl._custom_op.flash_attention.tiling_strategy.strategy import TilingStrategy
18
+
19
+
20
+ class WukongTiling(TilingStrategy):
21
+ """A tiling strategy implementation for wukonghuahua model shape"""
22
+
23
+ @classmethod
24
+ def strategy_name(cls):
25
+ return "wukong"
26
+
27
+ def tiling(self) -> TilingPara:
28
+ """
29
+ 反向的空间分布待详细分析
30
+ N = (4096, 1024, 256, 64) 或 77
31
+ Nq = (4096, 1024, 256, 64)
32
+ d = dv = (40, 80, 160, 160)
33
+ """
34
+ if self.N <= 77: # [77, 64]
35
+ # cross-attention or self-attention of (64, 64, 160)
36
+ self.Bc = self.N
37
+ self.Tc = self.N // self.Bc
38
+ if self.d <= 80: # [40, 80]
39
+ # 内存瓶颈为在ub中对P*V结果[Br, dv]进行cast
40
+ self.Br = min(self.Nq, 64)
41
+ self.Tr = self.Nq // self.Br
42
+ else:
43
+ self.Br = min(self.Nq, 64)
44
+ self.Tr = self.Nq // self.Br
45
+ else:
46
+ # self-attention
47
+ if self.N == 256:
48
+ self.Bc = 64
49
+ self.Tc = 1
50
+ # 内存瓶颈为在ub中对Q*K的结果[Br, Bc]进行cast
51
+ self.Br = 64
52
+ self.Tr = self.Nq // self.Br
53
+ else:
54
+ self.Bc = 64
55
+ self.Tc = self.N // self.Bc
56
+ self.Br = 64
57
+ self.Tr = self.Nq // self.Br
58
+
59
+ self.last_Br = self.Br
60
+ self.last_Bc = self.Bc
61
+
62
+ return self.gen_tiling_para()
@@ -187,8 +187,8 @@ def core(shape_a_temp, shape_b_temp, shape_output, kernel_name):
187
187
  tik_instance.mmad(resmatmul_local_ub_local_l0c, input_1_local_l1_local_l0a,
188
188
  input_2_local_l1_local_l0b, 128, 128, 256, 0)
189
189
  tik_instance.data_move(resmatmul_local_ub, resmatmul_local_ub_local_l0c, 0, 1, 128, 0, 0, 1)
190
- tik_instance.data_move(resmatmul[cc6 * 256 * 1008 + core_m_idx * 8 * 256 + core_n_idx * 512 * 1008]
191
- , resmatmul_local_ub, 0, 16, 256 // 2, 0, 55 * 16 * 2 // 2)
190
+ tik_instance.data_move(resmatmul[cc6 * 256 * 1008 + core_m_idx * 8 * 256 + core_n_idx * 512 * 1008],
191
+ resmatmul_local_ub, 0, 16, 256 // 2, 0, 55 * 16 * 2 // 2)
192
192
  with tik_instance.else_scope():
193
193
  tik_instance.data_move(input_1_local_l1, input_x1[core_m_idx * (8 * 256 + 128 * 1008)], 0, 7, 112,
194
194
  56 * 16, 0)
@@ -14,12 +14,16 @@
14
14
 
15
15
  """aicpu ops"""
16
16
  from .adaptive_max_pool_3d_grad import _adaptive_max_pool_3d_grad_aicpu
17
+ from .adaptive_max_pool_2d import _adaptive_max_pool_2d_aicpu
17
18
  from .adaptive_max_pool_2d_grad import _adaptive_max_pool_2d_grad_aicpu
18
19
  from .adaptive_avg_pool_3d_grad import _adaptiveavgpool3d_grad_aicpu
19
20
  from .adaptive_avg_pool_3d import _adaptiveavgpool3d_aicpu
21
+ from .adaptive_max_pool_3d import _adaptive_max_pool_3d_aicpu
20
22
  from .tile import _tile_aicpu
21
23
  from .tanh import _tanh_aicpu
22
24
  from .less import _less_aicpu
25
+ from .lstsq import _lstsq_aicpu
26
+ from .left_shift import _left_shift_aicpu
23
27
  from .add import _add_aicpu
24
28
  from .sparse_matrix_transpose import _sparse_matrix_transpose_aicpu
25
29
  from .sparse_matrix_nnz import _sparse_matrix_nnz_aicpu
@@ -46,6 +50,7 @@ from .arg_min import _arg_min_aicpu
46
50
  from .argmin_with_value import _argmin_with_value_aicpu
47
51
  from .avgpool_v1 import _avgpool_v1_aicpu
48
52
  from .avgpool_grad_v1 import _avgpool_grad_v1_aicpu
53
+ from .matrix_logarithm import _matrix_logarithm_aicpu
49
54
  from .matrix_solve import _matrix_solve_aicpu
50
55
  from .betainc import _betainc_aicpu
51
56
  from .bartlett_window import _bartlett_window_aicpu
@@ -91,6 +96,8 @@ from .acos_grad import _acos_grad_aicpu
91
96
  from .expand import _expand_aicpu
92
97
  from .expand_dims import _expand_dims_aicpu
93
98
  from .randperm import _randperm_aicpu
99
+ from .randperm_v2 import _randperm_v2_aicpu
100
+ from .random_poisson import _random_poisson_aicpu
94
101
  from .random_choice_with_mask import _random_choice_with_mask_aicpu
95
102
  from .rsqrt import _rsqrt_aicpu
96
103
  from .sqrt import _sqrt_aicpu
@@ -101,11 +108,14 @@ from .search_sorted import _search_sorted_aicpu
101
108
  from .stack import _stack_aicpu
102
109
  from .unstack import _unstack_aicpu
103
110
  from .unsorted_segment_sum import _unsorted_segment_sum_aicpu
111
+ from .unsorted_segment_prod import _unsorted_segment_prod_aicpu
104
112
  from .addcmul import _addcmul_aicpu
105
113
  from .uniform_candidate_sampler import _uniform_candidate_sampler_aicpu
106
114
  from .log_uniform_candidate_sampler import _log_uniform_candidate_sampler_aicpu
107
115
  from .compute_accidental_hits import _compute_accidental_hits_aicpu
108
116
  from .ctcloss import _ctcloss_aicpu
117
+ from .ctc_loss_v2 import _ctc_loss_v2_aicpu
118
+ from .ctc_loss_v2_grad import _ctc_loss_v2_grad_aicpu
109
119
  from .reverse_sequence import _reverse_sequence_aicpu
110
120
  from .log_matrix_determinant import _log_matrix_determinant_aicpu
111
121
  from .crop_and_resize import _crop_and_resize_aicpu
@@ -119,6 +129,7 @@ from .tanh_grad import _tanh_grad_aicpu
119
129
  from .cast import _cast_aicpu
120
130
  from .mirror_pad import _mirror_pad_aicpu
121
131
  from .mirror_pad_grad import _mirror_pad_grad_aicpu
132
+ from .masked_scatter import _masked_scatter_aicpu
122
133
  from .masked_select import _masked_select_aicpu
123
134
  from .masked_select_grad import _masked_select_grad_aicpu
124
135
  from .mul import _mul_aicpu
@@ -129,8 +140,13 @@ from .sub import _sub_aicpu
129
140
  from .not_equal import _not_equal_aicpu
130
141
  from .poisson import _poisson_aicpu
131
142
  from .update_cache import _update_cache_aicpu
143
+ from .upsample_nearest_3d import _upsample_nearest_3d_aicpu
144
+ from .upsample_nearest_3d_grad import _upsample_nearest_3d_grad_aicpu
145
+ from .upsample_trilinear_3d import _upsample_trilinear_3d_aicpu
146
+ from .upsample_trilinear_3d_grad import _upsample_trilinear_3d_grad_aicpu
132
147
  from .upper_bound import _upper_bound_aicpu
133
148
  from .cache_swap_table import _cache_swap_table_aicpu
149
+ from .uniform import _uniform_aicpu
134
150
  from .uniform_int import _uniform_int_aicpu
135
151
  from .uniform_real import _uniform_real_aicpu
136
152
  from .standard_laplace import _standard_laplace_aicpu
@@ -142,12 +158,13 @@ from .fused_sparse_adam import _fused_sparse_adam_aicpu
142
158
  from .fused_sparse_lazy_adam import _fused_sparse_lazy_adam_aicpu
143
159
  from .fused_sparse_ftrl import _fused_sparse_ftrl_aicpu
144
160
  from .sparse_fill_empty_rows_grad import _sparse_fill_empty_rows_grad_aicpu
161
+ from .sparse_reorder import _sparse_reorder_aicpu
145
162
  from .sparse_reshape import _sparse_reshape_aicpu
146
163
  from .sparse_segment_sqrt_n_grad import _sparse_segment_sqrt_n_grad_aicpu
147
164
  from .sparse_segment_sum import _sparse_segment_sum_aicpu
148
165
  from .sparse_segment_sum_with_num_segments import _sparse_segment_sum_with_num_segments_aicpu
149
166
  from .sparse_softmax_cross_entropy_with_logits_v2 import _sparse_softmax_cross_entropy_with_logits_v2_aicpu
150
- from .sparsesparsemaximum import _sparsesparsemaximum_aicpu
167
+ from .sparse_sparse_maximum import _sparse_sparse_maximum_aicpu
151
168
  from .split import _split_aicpu
152
169
  from .transpose import _transpose_aicpu
153
170
  from .tril_indices import _tril_indices_aicpu
@@ -164,6 +181,8 @@ from .stack_push_pop import _stack_push_aicpu
164
181
  from .stack_push_pop import _stack_pop_aicpu
165
182
  from .asinh import _asinh_aicpu
166
183
  from .stack_push_pop import _stack_destroy_aicpu
184
+ from .matrix_band_part import _matrix_band_part_aicpu
185
+ from .matrix_exp import _matrix_exp_aicpu
167
186
  from .matrix_diag_v3 import _matrix_diag_v3_aicpu
168
187
  from .matrix_diag_part_v3 import _matrix_diag_part_v3_aicpu
169
188
  from .tan import _tan_aicpu
@@ -189,6 +208,7 @@ from .environ_get import _environ_get_aicpu
189
208
  from .environ_destroy_all import _environ_destroy_all_aicpu
190
209
  from .cross import _cross_aicpu
191
210
  from .check_numerics import _check_numerics_aicpu
211
+ from .cummax import _cummax_aicpu
192
212
  from .cumsum import _cumsum_aicpu
193
213
  from .round import _round_aicpu
194
214
  from .stft import _stft_aicpu
@@ -213,6 +233,7 @@ from .scatter_nd_update import _scatter_nd_update_aicpu
213
233
  from .scatter_nd_max import _scatter_nd_max_aicpu
214
234
  from .conj import _conj_aicpu
215
235
  from .scatter_nd_min import _scatter_nd_min_aicpu
236
+ from .scatter_add_with_axis import _scatter_add_with_axis_aicpu
216
237
  from .compare_and_bitpack import _compare_and_bitpack_aicpu
217
238
  from .addcdiv import _addcdiv_aicpu
218
239
  from .unique_consecutive import _unique_consecutive_aicpu
@@ -226,6 +247,7 @@ from .reservoir_replay_buffer import _rrb_sample_op_cpu
226
247
  from .reservoir_replay_buffer import _rrb_destroy_op_cpu
227
248
  from .concat_offset import _concat_offset_aicpu
228
249
  from .range import _range_aicpu
250
+ from .range_v2 import _range_v2_aicpu
229
251
  from .slice_grad import _slice_grad_aicpu
230
252
  from .median import _median_aicpu
231
253
  from .median_grad import _median_grad_aicpu
@@ -233,6 +255,7 @@ from .reduce_sum import _reduce_sum_aicpu
233
255
  from .adaptive_avg_pool_2d import _adaptive_avg_pool_2d_aicpu
234
256
  from .adaptive_avg_pool_2d_grad import _adaptive_avg_pool_2d_grad_aicpu
235
257
  from .fill_v2 import _fill_v2_aicpu
258
+ from .fill_diagonal import _fill_diagonal_aicpu
236
259
  from .data_format_vec_permute import _data_format_vec_permute_aicpu
237
260
  from .multinomial import _multinomial_aicpu
238
261
  from .fft_with_size import _fft_with_size_aicpu
@@ -254,6 +277,7 @@ from .complex import _complex_aicpu
254
277
  from .complex_abs import _complex_abs_aicpu
255
278
  from .concat import _concat_aicpu
256
279
  from .cos import _cos_aicpu
280
+ from .count_nonzero import _count_nonzero_aicpu
257
281
  from .csr_sparse_matrix_to_dense import _csr_sparse_matrix_to_dense_aicpu
258
282
  from .cumprod import _cumprod_aicpu
259
283
  from .exp import _exp_aicpu
@@ -269,6 +293,7 @@ from .one_hot import _one_hot_aicpu
269
293
  from .orgqr import _orgqr_aicpu
270
294
  from .parameterized_truncated_normal import _parameterized_truncated_normal_aicpu
271
295
  from .polar import _polar_aicpu
296
+ from .polygamma import _polygamma_aicpu
272
297
  from .pdist_grad import _pdist_grad_aicpu
273
298
  from .ragged_range import _raggedrange_aicpu
274
299
  from .ragged_tensor_to_sparse import _ragged_tensor_to_sparse_aicpu
@@ -294,12 +319,14 @@ from .cumulative_logsumexp import _cumulative_logsumexp_aicpu
294
319
  from .sparse_segment_sqrt_n import _sparse_segment_sqrt_n_aicpu
295
320
  from .scale_and_translate import _scale_and_translate_aicpu
296
321
  from .quant_dtype_cast import _quant_dtype_cast_aicpu
322
+ from .quantile import _quantile_aicpu
297
323
  from .fse_decode import _fse_decode_aicpu
298
324
  from .dense_to_csr_sparse_matrix import _dense_to_csr_sparse_matrix_aicpu
299
325
  from .dense_to_sparse_set_operation import _dense_to_sparse_set_operation_aicpu
300
326
  from .diag import _diag_aicpu
301
327
  from .diagonal import _diagonal_aicpu
302
328
  from .diag_part import _diag_part_aicpu
329
+ from .digamma import _digamma_aicpu
303
330
  from .bias_add import _bias_add_aicpu
304
331
  from .bias_add_grad import _bias_add_grad_aicpu
305
332
  from .eig import _eig_aicpu
@@ -318,6 +345,8 @@ from .heaviside import _heaviside_aicpu
318
345
  from .hypot import _hypot_aicpu
319
346
  from .identity_n import _identity_n_aicpu
320
347
  from .index_fill import _index_fill_aicpu
348
+ from .index_put import _index_put_aicpu
349
+ from .inplace_index_add import _inplace_index_add_aicpu
321
350
  from .kldivloss import _kldiv_loss_aicpu
322
351
  from .kldivlossgrad import _kldiv_loss_grad_aicpu
323
352
  from .lcm import _lcm_aicpu
@@ -333,7 +362,9 @@ from .pad_v3 import _pad_v3_aicpu
333
362
  from .cholesky import _cholesky_aicpu
334
363
  from .hsv_to_rgb import _hsv_to_rgb_aicpu
335
364
  from .im2col import _im2col_aicpu
365
+ from .bessel_i0 import _bessel_i0_aicpu
336
366
  from .lu_solve import _lu_solve_aicpu
367
+ from .lu import _lu_aicpu
337
368
  from .relu_grad_v3 import _relu_grad_v3_aicpu
338
369
  from .resize_bicubic import _resize_bicubic_aicpu
339
370
  from .extract_glimpse import _extract_glimpse_aicpu
@@ -359,6 +390,7 @@ from .layer_norm_grad_grad import _layernorm_grad_grad_aicpu
359
390
  from .list_diff import _list_diff_aicpu
360
391
  from .log import _log_aicpu
361
392
  from .logspace import _logspace_aicpu
393
+ from .lgamma import _lgamma_aicpu
362
394
  from .matrix_inverse import _matrix_inverse_aicpu
363
395
  from .matrix_power import _matrix_power_aicpu
364
396
  from .max_pool3d_grad_with_argmax import _max_pool3d_grad_with_argmax_aicpu
@@ -375,6 +407,9 @@ from .non_deterministic_ints import _non_deterministic_ints_aicpu
375
407
  from .pow import _pow_aicpu
376
408
  from .real import _real_aicpu
377
409
  from .resize_area import _resize_area_aicpu
410
+ from .segment_mean import _segment_mean_aicpu
411
+ from .segment_min import _segment_min_aicpu
412
+ from .segment_prod import _segment_prod_aicpu
378
413
  from .segment_sum import _segment_sum_aicpu
379
414
  from .set_size import _set_size_aicpu
380
415
  from .slice import _slice_aicpu
@@ -386,6 +421,7 @@ from .sparse_tensor_dense_mat_mul import _sparse_tensor_dense_mat_mul_aicpu
386
421
  from .trace import _trace_aicpu
387
422
  from .tracegrad import _tracegrad_aicpu
388
423
  from .tridiagonal_solve import _tridiagonal_solve_aicpu
424
+ from .tridiagonal_matmul import _tridiagonal_matmul_aicpu
389
425
  from .truncated_normal import _truncated_normal_aicpu
390
426
  from .glu import _glu_aicpu
391
427
  from .deformable_offsets import _deformable_offsets_aicpu
@@ -397,4 +433,8 @@ from .bernoulli import _bernoulli_aicpu
397
433
  from .glu_grad import _glu_grad_aicpu
398
434
  from .sspaddmm import _sspaddmm_aicpu
399
435
  from .sequence_addn import _sequence_addn_aicpu
436
+ from .sequence_concat import _sequence_concat_aicpu
437
+ from .sequence_stack import _sequence_stack_aicpu
400
438
  from .affine_grid import _affine_grid_aicpu
439
+ from .depth_to_space import _depth_to_space_aicpu
440
+ from .eps import _eps_aicpu
@@ -0,0 +1,37 @@
1
+ # Copyright 2023 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+
16
+ """AdaptiveMaxPool2D op"""
17
+ from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
18
+
19
+ adaptive_max_pool_2d_op_info = AiCPURegOp("AdaptiveMaxPool2D") \
20
+ .fusion_type("OPAQUE") \
21
+ .attr("output_size", "listInt") \
22
+ .input(0, "x", "required") \
23
+ .output(0, "y", "required") \
24
+ .output(1, "argmax", "required") \
25
+ .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.I32_Default) \
26
+ .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.I32_Default) \
27
+ .dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.I32_Default) \
28
+ .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.I64_Default) \
29
+ .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.I64_Default) \
30
+ .dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.I64_Default) \
31
+ .get_op_info()
32
+
33
+
34
+ @op_info_register(adaptive_max_pool_2d_op_info)
35
+ def _adaptive_max_pool_2d_aicpu():
36
+ """AdaptiveMaxPool2D aicpu register"""
37
+ return
@@ -31,7 +31,6 @@ bias_add_grad_op_info = AiCPURegOp("BiasAddGrad") \
31
31
  .dtype_format(DataType.I64_Default, DataType.I64_Default) \
32
32
  .dtype_format(DataType.F16_Default, DataType.F16_Default) \
33
33
  .dtype_format(DataType.F32_Default, DataType.F32_Default) \
34
- .dtype_format(DataType.F64_Default, DataType.F64_Default) \
35
34
  .dtype_format(DataType.C64_Default, DataType.C64_Default) \
36
35
  .dtype_format(DataType.C128_Default, DataType.C128_Default) \
37
36
  .get_op_info()
@@ -32,6 +32,8 @@ cast_op_info = AiCPURegOp("Cast") \
32
32
  .dtype_format(DataType.U8_Default, DataType.F32_Default) \
33
33
  .dtype_format(DataType.U8_Default, DataType.F64_Default) \
34
34
  .dtype_format(DataType.U8_Default, DataType.BOOL_Default) \
35
+ .dtype_format(DataType.U8_Default, DataType.C64_Default) \
36
+ .dtype_format(DataType.U8_Default, DataType.C128_Default) \
35
37
  .dtype_format(DataType.U16_Default, DataType.U8_Default) \
36
38
  .dtype_format(DataType.U16_Default, DataType.U16_Default) \
37
39
  .dtype_format(DataType.U16_Default, DataType.U32_Default) \
@@ -44,6 +46,8 @@ cast_op_info = AiCPURegOp("Cast") \
44
46
  .dtype_format(DataType.U16_Default, DataType.F32_Default) \
45
47
  .dtype_format(DataType.U16_Default, DataType.F64_Default) \
46
48
  .dtype_format(DataType.U16_Default, DataType.BOOL_Default) \
49
+ .dtype_format(DataType.U16_Default, DataType.C64_Default) \
50
+ .dtype_format(DataType.U16_Default, DataType.C128_Default) \
47
51
  .dtype_format(DataType.U32_Default, DataType.U8_Default) \
48
52
  .dtype_format(DataType.U32_Default, DataType.U16_Default) \
49
53
  .dtype_format(DataType.U32_Default, DataType.U32_Default) \
@@ -56,6 +60,8 @@ cast_op_info = AiCPURegOp("Cast") \
56
60
  .dtype_format(DataType.U32_Default, DataType.F32_Default) \
57
61
  .dtype_format(DataType.U32_Default, DataType.F64_Default) \
58
62
  .dtype_format(DataType.U32_Default, DataType.BOOL_Default) \
63
+ .dtype_format(DataType.U32_Default, DataType.C64_Default) \
64
+ .dtype_format(DataType.U32_Default, DataType.C128_Default) \
59
65
  .dtype_format(DataType.U64_Default, DataType.U8_Default) \
60
66
  .dtype_format(DataType.U64_Default, DataType.U16_Default) \
61
67
  .dtype_format(DataType.U64_Default, DataType.U32_Default) \
@@ -68,6 +74,8 @@ cast_op_info = AiCPURegOp("Cast") \
68
74
  .dtype_format(DataType.U64_Default, DataType.F32_Default) \
69
75
  .dtype_format(DataType.U64_Default, DataType.F64_Default) \
70
76
  .dtype_format(DataType.U64_Default, DataType.BOOL_Default) \
77
+ .dtype_format(DataType.U64_Default, DataType.C64_Default) \
78
+ .dtype_format(DataType.U64_Default, DataType.C128_Default) \
71
79
  .dtype_format(DataType.I8_Default, DataType.U8_Default) \
72
80
  .dtype_format(DataType.I8_Default, DataType.U16_Default) \
73
81
  .dtype_format(DataType.I8_Default, DataType.U32_Default) \
@@ -80,6 +88,8 @@ cast_op_info = AiCPURegOp("Cast") \
80
88
  .dtype_format(DataType.I8_Default, DataType.F32_Default) \
81
89
  .dtype_format(DataType.I8_Default, DataType.F64_Default) \
82
90
  .dtype_format(DataType.I8_Default, DataType.BOOL_Default) \
91
+ .dtype_format(DataType.I8_Default, DataType.C64_Default) \
92
+ .dtype_format(DataType.I8_Default, DataType.C128_Default) \
83
93
  .dtype_format(DataType.I16_Default, DataType.U8_Default) \
84
94
  .dtype_format(DataType.I16_Default, DataType.U16_Default) \
85
95
  .dtype_format(DataType.I16_Default, DataType.U32_Default) \
@@ -92,6 +102,8 @@ cast_op_info = AiCPURegOp("Cast") \
92
102
  .dtype_format(DataType.I16_Default, DataType.F32_Default) \
93
103
  .dtype_format(DataType.I16_Default, DataType.F64_Default) \
94
104
  .dtype_format(DataType.I16_Default, DataType.BOOL_Default) \
105
+ .dtype_format(DataType.I16_Default, DataType.C64_Default) \
106
+ .dtype_format(DataType.I16_Default, DataType.C128_Default) \
95
107
  .dtype_format(DataType.I32_Default, DataType.U8_Default) \
96
108
  .dtype_format(DataType.I32_Default, DataType.U16_Default) \
97
109
  .dtype_format(DataType.I32_Default, DataType.U32_Default) \
@@ -104,6 +116,8 @@ cast_op_info = AiCPURegOp("Cast") \
104
116
  .dtype_format(DataType.I32_Default, DataType.F32_Default) \
105
117
  .dtype_format(DataType.I32_Default, DataType.F64_Default) \
106
118
  .dtype_format(DataType.I32_Default, DataType.BOOL_Default) \
119
+ .dtype_format(DataType.I32_Default, DataType.C64_Default) \
120
+ .dtype_format(DataType.I32_Default, DataType.C128_Default) \
107
121
  .dtype_format(DataType.I32_5HD, DataType.I64_5HD) \
108
122
  .dtype_format(DataType.I64_Default, DataType.U8_Default) \
109
123
  .dtype_format(DataType.I64_Default, DataType.U16_Default) \
@@ -117,6 +131,8 @@ cast_op_info = AiCPURegOp("Cast") \
117
131
  .dtype_format(DataType.I64_Default, DataType.F32_Default) \
118
132
  .dtype_format(DataType.I64_Default, DataType.F64_Default) \
119
133
  .dtype_format(DataType.I64_Default, DataType.BOOL_Default) \
134
+ .dtype_format(DataType.I64_Default, DataType.C64_Default) \
135
+ .dtype_format(DataType.I64_Default, DataType.C128_Default) \
120
136
  .dtype_format(DataType.F16_Default, DataType.U8_Default) \
121
137
  .dtype_format(DataType.F16_Default, DataType.U16_Default) \
122
138
  .dtype_format(DataType.F16_Default, DataType.U32_Default) \
@@ -129,6 +145,8 @@ cast_op_info = AiCPURegOp("Cast") \
129
145
  .dtype_format(DataType.F16_Default, DataType.F32_Default) \
130
146
  .dtype_format(DataType.F16_Default, DataType.F64_Default) \
131
147
  .dtype_format(DataType.F16_Default, DataType.BOOL_Default) \
148
+ .dtype_format(DataType.F16_Default, DataType.C64_Default) \
149
+ .dtype_format(DataType.F16_Default, DataType.C128_Default) \
132
150
  .dtype_format(DataType.F32_Default, DataType.U8_Default) \
133
151
  .dtype_format(DataType.F32_Default, DataType.U16_Default) \
134
152
  .dtype_format(DataType.F32_Default, DataType.U32_Default) \
@@ -141,6 +159,8 @@ cast_op_info = AiCPURegOp("Cast") \
141
159
  .dtype_format(DataType.F32_Default, DataType.F32_Default) \
142
160
  .dtype_format(DataType.F32_Default, DataType.F64_Default) \
143
161
  .dtype_format(DataType.F32_Default, DataType.BOOL_Default) \
162
+ .dtype_format(DataType.F32_Default, DataType.C64_Default) \
163
+ .dtype_format(DataType.F32_Default, DataType.C128_Default) \
144
164
  .dtype_format(DataType.F64_Default, DataType.U8_Default) \
145
165
  .dtype_format(DataType.F64_Default, DataType.U16_Default) \
146
166
  .dtype_format(DataType.F64_Default, DataType.U32_Default) \
@@ -153,6 +173,8 @@ cast_op_info = AiCPURegOp("Cast") \
153
173
  .dtype_format(DataType.F64_Default, DataType.F32_Default) \
154
174
  .dtype_format(DataType.F64_Default, DataType.F64_Default) \
155
175
  .dtype_format(DataType.F64_Default, DataType.BOOL_Default) \
176
+ .dtype_format(DataType.F64_Default, DataType.C64_Default) \
177
+ .dtype_format(DataType.F64_Default, DataType.C128_Default) \
156
178
  .dtype_format(DataType.BOOL_Default, DataType.U8_Default) \
157
179
  .dtype_format(DataType.BOOL_Default, DataType.U16_Default) \
158
180
  .dtype_format(DataType.BOOL_Default, DataType.U32_Default) \
@@ -165,6 +187,36 @@ cast_op_info = AiCPURegOp("Cast") \
165
187
  .dtype_format(DataType.BOOL_Default, DataType.F32_Default) \
166
188
  .dtype_format(DataType.BOOL_Default, DataType.F64_Default) \
167
189
  .dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
190
+ .dtype_format(DataType.BOOL_Default, DataType.C64_Default) \
191
+ .dtype_format(DataType.BOOL_Default, DataType.C128_Default) \
192
+ .dtype_format(DataType.C64_Default, DataType.U8_Default) \
193
+ .dtype_format(DataType.C64_Default, DataType.U16_Default) \
194
+ .dtype_format(DataType.C64_Default, DataType.U32_Default) \
195
+ .dtype_format(DataType.C64_Default, DataType.U64_Default) \
196
+ .dtype_format(DataType.C64_Default, DataType.I8_Default) \
197
+ .dtype_format(DataType.C64_Default, DataType.I16_Default) \
198
+ .dtype_format(DataType.C64_Default, DataType.I32_Default) \
199
+ .dtype_format(DataType.C64_Default, DataType.I64_Default) \
200
+ .dtype_format(DataType.C64_Default, DataType.F16_Default) \
201
+ .dtype_format(DataType.C64_Default, DataType.F32_Default) \
202
+ .dtype_format(DataType.C64_Default, DataType.F64_Default) \
203
+ .dtype_format(DataType.C64_Default, DataType.BOOL_Default) \
204
+ .dtype_format(DataType.C64_Default, DataType.C64_Default) \
205
+ .dtype_format(DataType.C64_Default, DataType.C128_Default) \
206
+ .dtype_format(DataType.C128_Default, DataType.U8_Default) \
207
+ .dtype_format(DataType.C128_Default, DataType.U16_Default) \
208
+ .dtype_format(DataType.C128_Default, DataType.U32_Default) \
209
+ .dtype_format(DataType.C128_Default, DataType.U64_Default) \
210
+ .dtype_format(DataType.C128_Default, DataType.I8_Default) \
211
+ .dtype_format(DataType.C128_Default, DataType.I16_Default) \
212
+ .dtype_format(DataType.C128_Default, DataType.I32_Default) \
213
+ .dtype_format(DataType.C128_Default, DataType.I64_Default) \
214
+ .dtype_format(DataType.C128_Default, DataType.F16_Default) \
215
+ .dtype_format(DataType.C128_Default, DataType.F32_Default) \
216
+ .dtype_format(DataType.C128_Default, DataType.F64_Default) \
217
+ .dtype_format(DataType.C128_Default, DataType.BOOL_Default) \
218
+ .dtype_format(DataType.C128_Default, DataType.C64_Default) \
219
+ .dtype_format(DataType.C128_Default, DataType.C128_Default) \
168
220
  .get_op_info()
169
221
 
170
222
  @op_info_register(cast_op_info)
@@ -23,6 +23,8 @@ coalesce_op_info = AiCPURegOp("Coalesce") \
23
23
  .output(0, "y_indices", "required") \
24
24
  .output(1, "y_values", "required") \
25
25
  .output(2, "y_shape", "required") \
26
+ .dtype_format(DataType.I64_Default, DataType.F64_Default, DataType.I64_Default, DataType.I64_Default,
27
+ DataType.F64_Default, DataType.I64_Default) \
26
28
  .dtype_format(DataType.I64_Default, DataType.F32_Default, DataType.I64_Default, DataType.I64_Default,
27
29
  DataType.F32_Default, DataType.I64_Default) \
28
30
  .dtype_format(DataType.I64_Default, DataType.F16_Default, DataType.I64_Default, DataType.I64_Default,
@@ -12,7 +12,6 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ============================================================================
15
-
16
15
  """Col2Im op"""
17
16
  from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
18
17
 
@@ -27,6 +26,9 @@ col2im_op_info = AiCPURegOp("Col2Im") \
27
26
  .attr("padding", "listInt") \
28
27
  .dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default) \
29
28
  .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default) \
29
+ .dtype_format(DataType.F64_Default, DataType.I32_Default, DataType.F64_Default) \
30
+ .dtype_format(DataType.C64_Default, DataType.I32_Default, DataType.C64_Default) \
31
+ .dtype_format(DataType.C128_Default, DataType.I32_Default, DataType.C128_Default) \
30
32
  .get_op_info()
31
33
 
32
34
 
@@ -0,0 +1,43 @@
1
+ # Copyright 2022 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+
16
+ """CountNonZero op"""
17
+ from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
18
+
19
+ count_nonzero_op_info = AiCPURegOp("CountNonZero") \
20
+ .fusion_type("OPAQUE") \
21
+ .input(0, "x", "required") \
22
+ .output(0, "y", "required") \
23
+ .attr("dims", "listInt")\
24
+ .dtype_format(DataType.I8_Default, DataType.I64_Default) \
25
+ .dtype_format(DataType.I16_Default, DataType.I64_Default) \
26
+ .dtype_format(DataType.I32_Default, DataType.I64_Default) \
27
+ .dtype_format(DataType.I64_Default, DataType.I64_Default) \
28
+ .dtype_format(DataType.U8_Default, DataType.I64_Default) \
29
+ .dtype_format(DataType.U16_Default, DataType.I64_Default) \
30
+ .dtype_format(DataType.U32_Default, DataType.I64_Default) \
31
+ .dtype_format(DataType.U64_Default, DataType.I64_Default) \
32
+ .dtype_format(DataType.F16_Default, DataType.I64_Default) \
33
+ .dtype_format(DataType.F32_Default, DataType.I64_Default) \
34
+ .dtype_format(DataType.F64_Default, DataType.I64_Default) \
35
+ .dtype_format(DataType.C64_Default, DataType.I64_Default) \
36
+ .dtype_format(DataType.C128_Default, DataType.I64_Default) \
37
+ .get_op_info()
38
+
39
+
40
+ @op_info_register(count_nonzero_op_info)
41
+ def _count_nonzero_aicpu():
42
+ """CountNonZero AiCPU register"""
43
+ return
@@ -25,8 +25,14 @@ dropout_genmask_op_info = AiCPURegOp("DropoutGenMask") \
25
25
  .output(0, "y", "required") \
26
26
  .attr("Seed0", "int") \
27
27
  .attr("Seed1", "int") \
28
+ .dtype_format(DataType.I64_Default, DataType.F16_Default, DataType.U64_Default, DataType.U64_Default,
29
+ DataType.U8_Default) \
28
30
  .dtype_format(DataType.I32_Default, DataType.F16_Default, DataType.U64_Default, DataType.U64_Default,
29
31
  DataType.U8_Default) \
32
+ .dtype_format(DataType.I64_Default, DataType.F32_Default, DataType.U64_Default, DataType.U64_Default,
33
+ DataType.U8_Default) \
34
+ .dtype_format(DataType.I32_Default, DataType.F32_Default, DataType.U64_Default, DataType.U64_Default,
35
+ DataType.U8_Default) \
30
36
  .get_op_info()
31
37
 
32
38
  @op_info_register(dropout_genmask_op_info)
@@ -0,0 +1,32 @@
1
+ # Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+
16
+ """Eps op"""
17
+ from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
18
+
19
+ eps_op_info = AiCPURegOp("Eps") \
20
+ .fusion_type("OPAQUE") \
21
+ .input(0, "x", "required") \
22
+ .output(0, "y", "required") \
23
+ .dtype_format(DataType.F16_Default, DataType.F16_Default) \
24
+ .dtype_format(DataType.F32_Default, DataType.F32_Default) \
25
+ .dtype_format(DataType.F64_Default, DataType.F64_Default) \
26
+ .get_op_info()
27
+
28
+
29
+ @op_info_register(eps_op_info)
30
+ def _eps_aicpu():
31
+ """Eps AiCPU register"""
32
+ return