mindspore 2.1.0__cp38-cp38-manylinux1_x86_64.whl → 2.2.11__cp38-cp38-manylinux1_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (589) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +4 -1
  3. mindspore/_akg/akg/build_module.py +5 -6
  4. mindspore/_akg/akg/composite/build_module.py +139 -22
  5. mindspore/_akg/akg/composite/split_stitch.py +10 -11
  6. mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
  7. mindspore/_akg/akg/tvm/api.py +4 -3
  8. mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
  9. mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
  10. mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
  11. mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
  12. mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
  13. mindspore/_akg/akg/tvm/build_module.py +16 -1
  14. mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
  15. mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
  16. mindspore/_akg/akg/tvm/ir_builder.py +1 -1
  17. mindspore/_akg/akg/tvm/module.py +1 -2
  18. mindspore/_akg/akg/tvm/stmt.py +2 -2
  19. mindspore/_akg/akg/utils/ascend_profilier/cann_file_parser.py +76 -0
  20. mindspore/_akg/akg/utils/ascend_profilier/file_manager.py +56 -0
  21. mindspore/_akg/akg/utils/ascend_profilier/op_summary_bean.py +23 -0
  22. mindspore/_akg/akg/utils/ascend_profilier/op_summary_headers.py +8 -0
  23. mindspore/_akg/akg/utils/ascend_profilier/op_summary_parser.py +42 -0
  24. mindspore/_akg/akg/utils/ascend_profilier/path_manager.py +65 -0
  25. mindspore/_akg/akg/utils/composite_op_helper.py +16 -12
  26. mindspore/_akg/akg/utils/dump_ascend_meta.py +22 -3
  27. mindspore/_akg/akg/utils/kernel_exec.py +98 -274
  28. mindspore/_akg/akg/utils/result_analysis.py +4 -24
  29. mindspore/_akg/akg/utils/tbe_codegen_utils.py +219 -0
  30. mindspore/_akg/akg/utils/util.py +56 -1
  31. mindspore/_c_dataengine.cpython-38-x86_64-linux-gnu.so +0 -0
  32. mindspore/_c_expression.cpython-38-x86_64-linux-gnu.so +0 -0
  33. mindspore/_c_mindrecord.cpython-38-x86_64-linux-gnu.so +0 -0
  34. mindspore/_check_jit_forbidden_api.py +3 -1
  35. mindspore/_checkparam.py +23 -29
  36. mindspore/_extends/graph_kernel/__init__.py +0 -1
  37. mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
  38. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  39. mindspore/_extends/graph_kernel/splitter.py +4 -11
  40. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
  41. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +84 -67
  42. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  43. mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
  44. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
  45. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +6 -5
  46. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  47. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  48. mindspore/_extends/parse/__init__.py +13 -15
  49. mindspore/_extends/parse/namespace.py +7 -33
  50. mindspore/_extends/parse/parser.py +67 -72
  51. mindspore/_extends/parse/resources.py +1 -1
  52. mindspore/_extends/parse/standard_method.py +86 -106
  53. mindspore/_extends/parse/trope.py +1 -1
  54. mindspore/_extends/remote/kernel_build_server.py +25 -7
  55. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  56. mindspore/_install_custom.py +43 -0
  57. mindspore/_mindspore_offline_debug.cpython-38-x86_64-linux-gnu.so +0 -0
  58. mindspore/amp.py +47 -11
  59. mindspore/bin/cache_admin +0 -0
  60. mindspore/bin/cache_server +0 -0
  61. mindspore/boost/boost.py +1 -8
  62. mindspore/boost/boost_cell_wrapper.py +3 -2
  63. mindspore/boost/grad_accumulation.py +1 -1
  64. mindspore/boost/group_loss_scale_manager.py +8 -7
  65. mindspore/common/__init__.py +5 -3
  66. mindspore/common/_jit_fallback_utils.py +6 -0
  67. mindspore/common/_register_for_adapter.py +2 -0
  68. mindspore/common/_register_for_tensor.py +2 -2
  69. mindspore/common/_stub_tensor.py +13 -0
  70. mindspore/common/_utils.py +29 -0
  71. mindspore/common/api.py +174 -259
  72. mindspore/common/auto_dynamic_shape.py +494 -0
  73. mindspore/common/dtype.py +18 -11
  74. mindspore/common/dump.py +6 -4
  75. mindspore/common/initializer.py +14 -14
  76. mindspore/common/jit_config.py +33 -15
  77. mindspore/common/lazy_inline.py +126 -7
  78. mindspore/common/mindir_util.py +101 -0
  79. mindspore/common/parameter.py +51 -41
  80. mindspore/common/seed.py +4 -4
  81. mindspore/common/sparse_tensor.py +13 -14
  82. mindspore/common/tensor.py +243 -165
  83. mindspore/communication/__init__.py +7 -4
  84. mindspore/communication/_comm_helper.py +83 -4
  85. mindspore/communication/management.py +152 -84
  86. mindspore/config/op_info.config +14 -3
  87. mindspore/config/super_bar_config.json +4 -2
  88. mindspore/context.py +152 -61
  89. mindspore/dataset/__init__.py +5 -5
  90. mindspore/dataset/audio/__init__.py +2 -2
  91. mindspore/dataset/audio/transforms.py +52 -52
  92. mindspore/dataset/callback/ds_callback.py +16 -2
  93. mindspore/dataset/core/config.py +68 -51
  94. mindspore/dataset/engine/cache_client.py +33 -7
  95. mindspore/dataset/engine/datasets.py +250 -112
  96. mindspore/dataset/engine/datasets_audio.py +43 -211
  97. mindspore/dataset/engine/datasets_standard_format.py +16 -35
  98. mindspore/dataset/engine/datasets_text.py +43 -67
  99. mindspore/dataset/engine/datasets_user_defined.py +86 -100
  100. mindspore/dataset/engine/datasets_vision.py +219 -1029
  101. mindspore/dataset/engine/iterators.py +11 -4
  102. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
  103. mindspore/dataset/engine/obs/util.py +3 -0
  104. mindspore/dataset/engine/samplers.py +1 -1
  105. mindspore/dataset/engine/validators.py +19 -5
  106. mindspore/dataset/text/__init__.py +3 -3
  107. mindspore/dataset/text/transforms.py +101 -127
  108. mindspore/dataset/text/utils.py +205 -138
  109. mindspore/dataset/transforms/__init__.py +1 -1
  110. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  111. mindspore/dataset/transforms/transforms.py +95 -40
  112. mindspore/dataset/utils/browse_dataset.py +8 -2
  113. mindspore/dataset/utils/line_reader.py +17 -19
  114. mindspore/dataset/vision/__init__.py +3 -3
  115. mindspore/dataset/vision/c_transforms.py +6 -3
  116. mindspore/dataset/vision/transforms.py +409 -287
  117. mindspore/dataset/vision/utils.py +13 -14
  118. mindspore/dataset/vision/validators.py +11 -1
  119. mindspore/experimental/map_parameter.py +14 -0
  120. mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
  121. mindspore/{nn/optim_ex → experimental/optim}/adam.py +60 -67
  122. mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
  123. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  124. mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
  125. mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
  126. mindspore/gen_ops.py +273 -0
  127. mindspore/include/OWNERS +0 -1
  128. mindspore/include/api/data_type.h +2 -1
  129. mindspore/include/api/graph.h +0 -15
  130. mindspore/include/api/kernel.h +2 -0
  131. mindspore/include/api/kernel_api.h +37 -12
  132. mindspore/include/api/model.h +17 -14
  133. mindspore/include/api/status.h +8 -3
  134. mindspore/include/api/types.h +37 -4
  135. mindspore/include/c_api/ms/abstract.h +67 -0
  136. mindspore/include/c_api/ms/attribute.h +197 -0
  137. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  138. mindspore/include/c_api/ms/base/macros.h +32 -0
  139. mindspore/include/c_api/ms/base/status.h +33 -0
  140. mindspore/include/c_api/ms/base/types.h +282 -0
  141. mindspore/include/c_api/ms/context.h +102 -0
  142. mindspore/include/c_api/ms/graph.h +160 -0
  143. mindspore/include/c_api/ms/node.h +606 -0
  144. mindspore/include/c_api/ms/tensor.h +161 -0
  145. mindspore/include/c_api/ms/value.h +84 -0
  146. mindspore/include/dataset/constants.h +6 -5
  147. mindspore/include/dataset/execute.h +23 -13
  148. mindspore/include/dataset/text.h +26 -26
  149. mindspore/include/dataset/transforms.h +13 -13
  150. mindspore/include/dataset/vision.h +60 -60
  151. mindspore/include/dataset/vision_ascend.h +5 -6
  152. mindspore/include/dataset/vision_lite.h +17 -17
  153. mindspore/include/mindapi/base/type_id.h +1 -0
  154. mindspore/include/mindapi/base/types.h +1 -0
  155. mindspore/lib/libdnnl.so.2 +0 -0
  156. mindspore/lib/libjemalloc.so.2 +0 -0
  157. mindspore/lib/libmindspore.so +0 -0
  158. mindspore/lib/libmindspore_backend.so +0 -0
  159. mindspore/lib/libmindspore_common.so +0 -0
  160. mindspore/lib/libmindspore_core.so +0 -0
  161. mindspore/lib/libmindspore_glog.so.0 +0 -0
  162. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  163. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  164. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  165. mindspore/lib/libmindspore_shared_lib.so +0 -0
  166. mindspore/lib/libnnacl.so +0 -0
  167. mindspore/lib/libopencv_core.so.4.5 +0 -0
  168. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  169. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  170. mindspore/lib/libps_cache.so +0 -0
  171. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310/aic-ascend310-ops-info.json +123 -0
  172. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310p/aic-ascend310p-ops-info.json +123 -0
  173. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +158 -0
  174. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +37 -0
  175. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
  176. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
  177. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
  178. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
  179. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
  180. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
  181. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
  182. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
  183. mindspore/lib/plugin/ascend/custom_aicore_ops/op_proto/libop_proto.so +0 -0
  184. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  185. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  186. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +8998 -0
  187. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  188. mindspore/lib/plugin/ascend/libakg.so +0 -0
  189. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  190. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  191. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  192. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  193. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  194. mindspore/lib/plugin/cpu/libakg.so +0 -0
  195. mindspore/lib/plugin/gpu/libcuda_ops.so.10 +0 -0
  196. mindspore/lib/plugin/gpu/libcuda_ops.so.11 +0 -0
  197. mindspore/lib/plugin/gpu10.1/libakg.so +0 -0
  198. mindspore/lib/plugin/gpu10.1/libnccl.so.2 +0 -0
  199. mindspore/lib/plugin/gpu11.1/libakg.so +0 -0
  200. mindspore/lib/plugin/gpu11.1/libnccl.so.2 +0 -0
  201. mindspore/lib/plugin/gpu11.6/libakg.so +0 -0
  202. mindspore/lib/plugin/gpu11.6/libnccl.so.2 +0 -0
  203. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  204. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  205. mindspore/lib/plugin/libmindspore_gpu.so.10.1 +0 -0
  206. mindspore/lib/plugin/libmindspore_gpu.so.11.1 +0 -0
  207. mindspore/lib/plugin/libmindspore_gpu.so.11.6 +0 -0
  208. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  209. mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
  210. mindspore/nn/__init__.py +0 -2
  211. mindspore/nn/cell.py +313 -74
  212. mindspore/nn/dynamic_lr.py +21 -21
  213. mindspore/nn/layer/activation.py +22 -30
  214. mindspore/nn/layer/basic.py +15 -13
  215. mindspore/nn/layer/channel_shuffle.py +1 -1
  216. mindspore/nn/layer/container.py +271 -9
  217. mindspore/nn/layer/conv.py +323 -204
  218. mindspore/nn/layer/dense.py +8 -5
  219. mindspore/nn/layer/embedding.py +33 -27
  220. mindspore/nn/layer/flash_attention.py +61 -95
  221. mindspore/nn/layer/image.py +8 -6
  222. mindspore/nn/layer/math.py +16 -25
  223. mindspore/nn/layer/normalization.py +107 -66
  224. mindspore/nn/layer/padding.py +1 -1
  225. mindspore/nn/layer/pooling.py +131 -109
  226. mindspore/nn/layer/rnn_cells.py +27 -22
  227. mindspore/nn/layer/rnns.py +13 -16
  228. mindspore/nn/layer/thor_layer.py +1 -1
  229. mindspore/nn/layer/transformer.py +221 -154
  230. mindspore/nn/learning_rate_schedule.py +9 -1
  231. mindspore/nn/loss/loss.py +235 -174
  232. mindspore/nn/optim/ada_grad.py +2 -1
  233. mindspore/nn/optim/adadelta.py +1 -0
  234. mindspore/nn/optim/adafactor.py +2 -1
  235. mindspore/nn/optim/adam.py +7 -4
  236. mindspore/nn/optim/adamax.py +3 -2
  237. mindspore/nn/optim/adasum.py +2 -2
  238. mindspore/nn/optim/asgd.py +2 -3
  239. mindspore/nn/optim/ftrl.py +6 -5
  240. mindspore/nn/optim/lamb.py +7 -4
  241. mindspore/nn/optim/lars.py +1 -1
  242. mindspore/nn/optim/lazyadam.py +5 -3
  243. mindspore/nn/optim/momentum.py +2 -1
  244. mindspore/nn/optim/optimizer.py +53 -4
  245. mindspore/nn/optim/proximal_ada_grad.py +3 -4
  246. mindspore/nn/optim/rmsprop.py +4 -3
  247. mindspore/nn/optim/rprop.py +23 -12
  248. mindspore/nn/optim/sgd.py +26 -11
  249. mindspore/nn/optim/thor.py +9 -7
  250. mindspore/nn/probability/bijector/bijector.py +5 -5
  251. mindspore/nn/probability/bijector/power_transform.py +27 -27
  252. mindspore/nn/probability/bijector/softplus.py +3 -3
  253. mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
  254. mindspore/nn/probability/distribution/bernoulli.py +5 -5
  255. mindspore/nn/probability/distribution/beta.py +3 -3
  256. mindspore/nn/probability/distribution/categorical.py +7 -7
  257. mindspore/nn/probability/distribution/cauchy.py +0 -1
  258. mindspore/nn/probability/distribution/distribution.py +3 -3
  259. mindspore/nn/probability/distribution/gamma.py +3 -3
  260. mindspore/nn/probability/distribution/geometric.py +4 -4
  261. mindspore/nn/probability/distribution/gumbel.py +4 -4
  262. mindspore/nn/probability/distribution/log_normal.py +2 -2
  263. mindspore/nn/probability/distribution/logistic.py +2 -2
  264. mindspore/nn/probability/distribution/poisson.py +4 -4
  265. mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
  266. mindspore/nn/probability/distribution/uniform.py +6 -6
  267. mindspore/nn/wrap/__init__.py +4 -2
  268. mindspore/nn/wrap/cell_wrapper.py +87 -34
  269. mindspore/nn/wrap/grad_reducer.py +8 -5
  270. mindspore/nn/wrap/loss_scale.py +105 -42
  271. mindspore/numpy/array_creations.py +1 -2
  272. mindspore/numpy/array_ops.py +3 -2
  273. mindspore/numpy/utils_const.py +5 -5
  274. mindspore/offline_debug/convert_async.py +2 -2
  275. mindspore/ops/_grad_experimental/__init__.py +0 -5
  276. mindspore/ops/_grad_experimental/grad_array_ops.py +2 -3
  277. mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
  278. mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
  279. mindspore/ops/_grad_experimental/grad_implementations.py +11 -1
  280. mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
  281. mindspore/ops/_grad_experimental/grad_math_ops.py +19 -199
  282. mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
  283. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
  284. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  285. mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
  286. mindspore/ops/_op_impl/aicpu/add.py +3 -3
  287. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  288. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  289. mindspore/ops/_op_impl/{_custom_op/flash_attention/constants.py → aicpu/eps.py} +18 -27
  290. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  291. mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +21 -2
  292. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  293. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  294. mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
  295. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  296. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  297. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  298. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  299. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  300. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  301. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  302. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
  303. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
  304. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  305. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  306. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  307. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  308. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  309. mindspore/ops/_op_impl/tbe/__init__.py +4 -4
  310. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  311. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  312. mindspore/ops/_primitive_cache.py +1 -1
  313. mindspore/ops/_tracefunc.py +45 -13
  314. mindspore/ops/_utils/utils.py +6 -1
  315. mindspore/ops/_vmap/vmap_array_ops.py +3 -3
  316. mindspore/ops/_vmap/vmap_base.py +3 -3
  317. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  318. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  319. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  320. mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
  321. mindspore/ops/arg_dtype_cast.py +54 -0
  322. mindspore/ops/composite/base.py +37 -10
  323. mindspore/ops/composite/math_ops.py +5 -4
  324. mindspore/ops/composite/multitype_ops/_compile_utils.py +275 -73
  325. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
  326. mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
  327. mindspore/ops/composite/multitype_ops/getitem_impl.py +42 -4
  328. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  329. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  330. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  331. mindspore/ops/deprecated.py +304 -0
  332. mindspore/ops/function/__init__.py +4 -1
  333. mindspore/ops/function/array_func.py +174 -193
  334. mindspore/ops/function/clip_func.py +81 -13
  335. mindspore/ops/function/debug_func.py +1 -1
  336. mindspore/ops/function/grad/grad_func.py +18 -9
  337. mindspore/ops/function/image_func.py +10 -4
  338. mindspore/ops/function/linalg_func.py +5 -5
  339. mindspore/ops/function/math_func.py +575 -386
  340. mindspore/ops/function/nn_func.py +568 -260
  341. mindspore/ops/function/random_func.py +88 -57
  342. mindspore/ops/function/sparse_func.py +1 -1
  343. mindspore/ops/function/sparse_unary_func.py +14 -12
  344. mindspore/ops/function/vmap_func.py +6 -5
  345. mindspore/ops/functional.py +15 -10
  346. mindspore/ops/op_info_register.py +244 -25
  347. mindspore/ops/operations/__init__.py +31 -19
  348. mindspore/ops/operations/_grad_ops.py +71 -7
  349. mindspore/ops/operations/_inner_ops.py +350 -17
  350. mindspore/ops/operations/_quant_ops.py +4 -8
  351. mindspore/ops/operations/_sequence_ops.py +42 -0
  352. mindspore/ops/operations/array_ops.py +68 -282
  353. mindspore/ops/operations/comm_ops.py +107 -59
  354. mindspore/ops/operations/custom_ops.py +94 -70
  355. mindspore/ops/operations/debug_ops.py +8 -4
  356. mindspore/ops/operations/image_ops.py +18 -12
  357. mindspore/ops/operations/inner_ops.py +26 -3
  358. mindspore/ops/operations/math_ops.py +192 -144
  359. mindspore/ops/operations/nn_ops.py +857 -489
  360. mindspore/ops/operations/other_ops.py +0 -22
  361. mindspore/ops/operations/random_ops.py +53 -111
  362. mindspore/ops/operations/sparse_ops.py +3 -1
  363. mindspore/ops/primitive.py +24 -18
  364. mindspore/parallel/_auto_parallel_context.py +68 -8
  365. mindspore/parallel/_cost_model_context.py +2 -2
  366. mindspore/parallel/_offload_context.py +17 -3
  367. mindspore/parallel/_parallel_serialization.py +12 -5
  368. mindspore/parallel/_ps_context.py +12 -0
  369. mindspore/parallel/_tensor.py +18 -13
  370. mindspore/parallel/_transformer/layers.py +5 -3
  371. mindspore/parallel/_transformer/loss.py +1 -0
  372. mindspore/parallel/_transformer/moe.py +2 -2
  373. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  374. mindspore/parallel/_transformer/transformer.py +23 -3
  375. mindspore/parallel/_utils.py +11 -7
  376. mindspore/parallel/algo_parameter_config.py +85 -5
  377. mindspore/parallel/checkpoint_transform.py +19 -12
  378. mindspore/parallel/shard.py +21 -14
  379. mindspore/profiler/common/struct_type.py +3 -3
  380. mindspore/profiler/common/util.py +4 -2
  381. mindspore/profiler/envprofiling.py +1 -1
  382. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  383. mindspore/profiler/parser/ascend_flops_generator.py +2 -2
  384. mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
  385. mindspore/profiler/parser/ascend_hccl_generator.py +249 -12
  386. mindspore/profiler/parser/ascend_msprof_exporter.py +150 -255
  387. mindspore/profiler/parser/ascend_msprof_generator.py +204 -17
  388. mindspore/profiler/parser/ascend_op_generator.py +6 -6
  389. mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
  390. mindspore/profiler/parser/ascend_timeline_generator.py +14 -187
  391. mindspore/profiler/parser/base_timeline_generator.py +10 -8
  392. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +16 -12
  393. mindspore/profiler/parser/flops_parser.py +15 -11
  394. mindspore/profiler/parser/framework_parser.py +38 -22
  395. mindspore/profiler/parser/hccl_parser.py +16 -12
  396. mindspore/profiler/parser/integrator.py +22 -11
  397. mindspore/profiler/parser/memory_usage_parser.py +2 -2
  398. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  399. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  400. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  401. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  402. mindspore/profiler/parser/optime_parser.py +1 -1
  403. mindspore/profiler/parser/profiler_info.py +21 -2
  404. mindspore/profiler/parser/step_trace_parser.py +11 -14
  405. mindspore/profiler/profiling.py +179 -89
  406. mindspore/rewrite/api/node.py +102 -19
  407. mindspore/rewrite/api/node_type.py +5 -1
  408. mindspore/rewrite/api/pattern_engine.py +1 -1
  409. mindspore/rewrite/api/scoped_value.py +9 -17
  410. mindspore/rewrite/api/symbol_tree.py +131 -47
  411. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  412. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  413. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  414. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
  415. mindspore/rewrite/common/rewrite_elog.py +5 -1
  416. mindspore/rewrite/namer.py +33 -24
  417. mindspore/rewrite/namespace.py +14 -5
  418. mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
  419. mindspore/rewrite/node/call_function.py +79 -0
  420. mindspore/rewrite/node/cell_container.py +135 -0
  421. mindspore/rewrite/node/control_flow.py +88 -0
  422. mindspore/rewrite/{node.py → node/node.py} +273 -234
  423. mindspore/rewrite/node/node_manager.py +254 -0
  424. mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
  425. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  426. mindspore/rewrite/parsers/assign_parser.py +216 -221
  427. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  428. mindspore/rewrite/parsers/class_def_parser.py +174 -113
  429. mindspore/rewrite/parsers/constant_parser.py +9 -6
  430. mindspore/rewrite/parsers/container_parser.py +9 -7
  431. mindspore/rewrite/parsers/for_parser.py +42 -21
  432. mindspore/rewrite/parsers/function_def_parser.py +24 -16
  433. mindspore/rewrite/parsers/if_parser.py +28 -24
  434. mindspore/rewrite/parsers/module_parser.py +196 -25
  435. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  436. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  437. mindspore/rewrite/parsers/return_parser.py +6 -6
  438. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  439. mindspore/rewrite/sparsify/utils.py +1 -1
  440. mindspore/rewrite/symbol_tree.py +523 -578
  441. mindspore/rewrite/symbol_tree_builder.py +9 -193
  442. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  443. mindspore/run_check/_check_version.py +6 -4
  444. mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
  445. mindspore/safeguard/rewrite_obfuscation.py +541 -0
  446. mindspore/scipy/linalg.py +1 -1
  447. mindspore/scipy/ops.py +55 -5
  448. mindspore/scipy/optimize/__init__.py +3 -2
  449. mindspore/scipy/optimize/linear_sum_assignment.py +38 -33
  450. mindspore/scipy/optimize/minimize.py +7 -3
  451. mindspore/train/_utils.py +7 -3
  452. mindspore/train/amp.py +323 -123
  453. mindspore/train/anf_ir_pb2.py +14 -2
  454. mindspore/train/callback/_backup_and_restore.py +2 -12
  455. mindspore/train/callback/_callback.py +29 -4
  456. mindspore/train/callback/_checkpoint.py +23 -8
  457. mindspore/train/callback/_early_stop.py +2 -2
  458. mindspore/train/callback/_landscape.py +4 -4
  459. mindspore/train/callback/_loss_monitor.py +2 -2
  460. mindspore/train/callback/_on_request_exit.py +2 -2
  461. mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
  462. mindspore/train/callback/_summary_collector.py +15 -8
  463. mindspore/train/callback/_time_monitor.py +58 -5
  464. mindspore/train/data_sink.py +5 -11
  465. mindspore/train/dataset_helper.py +84 -57
  466. mindspore/train/loss_scale_manager.py +2 -2
  467. mindspore/train/metrics/__init__.py +3 -3
  468. mindspore/train/metrics/cosine_similarity.py +1 -1
  469. mindspore/train/metrics/hausdorff_distance.py +3 -2
  470. mindspore/train/metrics/mean_surface_distance.py +3 -2
  471. mindspore/train/metrics/metric.py +39 -19
  472. mindspore/train/metrics/roc.py +2 -2
  473. mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
  474. mindspore/train/mind_ir_pb2.py +85 -36
  475. mindspore/train/model.py +187 -47
  476. mindspore/train/serialization.py +487 -161
  477. mindspore/train/summary/_summary_adapter.py +1 -1
  478. mindspore/train/summary/_writer_pool.py +3 -2
  479. mindspore/train/summary/summary_record.py +37 -17
  480. mindspore/train/train_thor/convert_utils.py +3 -3
  481. mindspore/train/train_thor/dataset_helper.py +1 -1
  482. mindspore/version.py +1 -1
  483. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/METADATA +8 -8
  484. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/RECORD +488 -539
  485. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/entry_points.txt +0 -1
  486. mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
  487. mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
  488. mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
  489. mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
  490. mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
  491. mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
  492. mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
  493. mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
  494. mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
  495. mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
  496. mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
  497. mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
  498. mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
  499. mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
  500. mindspore/_akg/akg/tvm/rpc/base.py +0 -182
  501. mindspore/_akg/akg/tvm/rpc/client.py +0 -436
  502. mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
  503. mindspore/_akg/akg/tvm/rpc/server.py +0 -413
  504. mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
  505. mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
  506. mindspore/_extends/graph_kernel/expander.py +0 -80
  507. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
  508. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  509. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  510. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  511. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  512. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  513. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  514. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  515. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  516. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  517. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  518. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  519. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  520. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  521. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  522. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  523. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  524. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  525. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  526. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  527. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  528. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  529. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  530. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  531. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  532. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  533. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  534. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  535. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  536. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  537. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  538. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  539. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  540. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  541. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  542. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  543. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  544. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  545. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  546. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  547. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  548. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  549. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  550. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  551. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  552. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  553. mindspore/dataset/datapreprocess/__init__.py +0 -20
  554. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  555. mindspore/include/api/net.h +0 -142
  556. mindspore/nn/lr_scheduler.py +0 -262
  557. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
  558. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
  559. mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
  560. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  561. mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
  562. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +0 -350
  563. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +0 -409
  564. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +0 -578
  565. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +0 -199
  566. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +0 -446
  567. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
  568. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +0 -45
  569. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +0 -67
  570. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +0 -62
  571. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
  572. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
  573. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
  574. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  575. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  576. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
  577. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
  578. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  579. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  580. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
  581. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
  582. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
  583. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
  584. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
  585. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  586. mindspore/rewrite/node_visitor.py +0 -44
  587. /mindspore/{ops/_op_impl/_custom_op/flash_attention → _akg/akg/utils/ascend_profilier}/__init__.py +0 -0
  588. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/WHEEL +0 -0
  589. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/top_level.txt +0 -0
@@ -42,6 +42,24 @@ from ._pyfunc_registry import add_pyfunc
42
42
  if platform.system() != "Windows":
43
43
  import fcntl
44
44
 
45
+ KEY_ATTR = "attr"
46
+ KEY_NAME = "name"
47
+ INPUT_NAMES = "input_names"
48
+ ATTR_NAMES = "attr_names"
49
+ AUTO_DIFF = "autodiff"
50
+ IMPLY_TYPE = "imply_type"
51
+ FUSION_TYPE = "fusion_type"
52
+ MS_KERNEL_FLAG = "ms_kernel_flag"
53
+ AKG = "AKG"
54
+ TBE = "TBE"
55
+ CUDA = "CUDA"
56
+ AICORE = "AiCore"
57
+ CPU = "CPU"
58
+ GPU = "GPU"
59
+ ASCEND = "Ascend"
60
+ HYBRID_TYPE = "hybrid"
61
+ OP_NAME = "op_name"
62
+
45
63
 
46
64
  def _get_cache_path():
47
65
  """
@@ -146,10 +164,10 @@ class Custom(ops.PrimitiveWithInfer):
146
164
  function if needed. Then these `Custom` objects can be directly used in neural networks.
147
165
  Detailed description and introduction of user-defined operators, including correct writing of parameters,
148
166
  please refer to `Custom Operators Tutorial
149
- <https://www.mindspore.cn/tutorials/experts/en/r2.1/operation/op_custom.html>`_ .
167
+ <https://www.mindspore.cn/tutorials/experts/en/r2.2/operation/op_custom.html>`_ .
150
168
 
151
169
  .. warning::
152
- This is an experimental API that is subject to change.
170
+ - This is an experimental API that is subject to change.
153
171
 
154
172
  .. note::
155
173
  The supported platforms are determined by the input `func_type`. The supported platforms are as follows:
@@ -162,6 +180,12 @@ class Custom(ops.PrimitiveWithInfer):
162
180
  - "julia": supports ["CPU"].
163
181
  - "aicpu": supports ["Ascend"].
164
182
 
183
+ If run on ge backend, use `CustomRegOp` to generate the registration information of "aicpu" and "tbe" operator,
184
+ use `custom_info_register` to bind the registration information to the `func` of the "tbe" operator,
185
+ then save the registration information of "aicpu" operator and the `func` implementation of "tbe" operator to
186
+ a file or separate files, keep these files in a separate directory, and set the absolute path of this directory
187
+ to environment variable "MS_DEV_CUSTOM_OPP_PATH" before running the network.
188
+
165
189
  Args:
166
190
  func (Union[function, str]):
167
191
 
@@ -446,10 +470,10 @@ class Custom(ops.PrimitiveWithInfer):
446
470
  op_path_in_cache = [] # Save paths for op functions created in the cached.
447
471
  custom_aot_warning = True # Flag to enable warnings about custom aot path white list
448
472
 
449
- def __init__(self, func, out_shape=None, out_dtype=None, func_type="hybrid", bprop=None, reg_info=None):
450
- ops.PrimitiveWithInfer.__init__(self, "Custom")
473
+ def __init__(self, func, out_shape=None, out_dtype=None, func_type=HYBRID_TYPE, bprop=None, reg_info=None):
474
+ super().__init__("Custom")
451
475
 
452
- self.supported_targets = ["Ascend", "GPU", "CPU"]
476
+ self.supported_targets = [ASCEND, GPU, CPU]
453
477
  self.supported_func_type = ["hybrid", "akg", "tbe", "aicpu", "aot", "pyfunc", "julia"]
454
478
  self.log_prefix = "For '{}', 'func_type': {}, 'func': {}".format(self.name, func_type, func)
455
479
  self.func = func
@@ -466,7 +490,7 @@ class Custom(ops.PrimitiveWithInfer):
466
490
  self._update_func_info(reg_info)
467
491
  self.add_prim_attr("func_name", self.func_name)
468
492
  self.add_prim_attr("uniq_name", self.uniq_name)
469
- if self.func_type == "hybrid":
493
+ if self.func_type == HYBRID_TYPE:
470
494
  self.add_prim_attr("func_compile_attrs", self._func_compile_attrs)
471
495
 
472
496
  self.add_prim_attr("imply_path", self.imply_path)
@@ -495,7 +519,7 @@ class Custom(ops.PrimitiveWithInfer):
495
519
  if func_type == "akg":
496
520
  self._set_akg_kernel_type()
497
521
 
498
- if not self.bprop and self.func_type == "hybrid":
522
+ if not self.bprop and self.func_type == HYBRID_TYPE:
499
523
  self._hybrid_autodiff(func_type)
500
524
 
501
525
  self.add_prim_attr("func_type", self.func_type)
@@ -570,7 +594,7 @@ class Custom(ops.PrimitiveWithInfer):
570
594
  elif "compute" in self.func_source_str:
571
595
  self.func_type = "tvm_compute"
572
596
  else:
573
- self.func_type = "hybrid"
597
+ self.func_type = HYBRID_TYPE
574
598
  self._hybrid_func_analyser()
575
599
 
576
600
  def _check_julia_func(self):
@@ -620,24 +644,24 @@ class Custom(ops.PrimitiveWithInfer):
620
644
  raise TypeError(
621
645
  "{}, the legal path for the file is {}, but the file is {}".format(
622
646
  self.log_prefix, legal_path, file_path))
623
- if not file_path.endswith("so"):
647
+ if file_path.endswith(("cu", "cpp", "cc")):
624
648
  file_path = _compile_aot(file_path)
625
649
  self.func = file_path + ":" + file_name_list[1]
626
650
 
627
651
  elif self.func_type == "julia":
628
652
  self._check_julia_func()
629
- elif self.func_type == "hybrid":
630
- if not hasattr(self.func, "ms_kernel_flag"):
653
+ elif self.func_type == HYBRID_TYPE:
654
+ if not hasattr(self.func, MS_KERNEL_FLAG):
631
655
  raise TypeError("{}, 'func' must be a function decorated by kernel".format(self.log_prefix))
632
656
  self._is_ms_kernel = True
633
657
  self._func_compile_attrs = getattr(self.func, "compile_attrs", {})
634
658
  elif self.func_type == "akg":
635
- if hasattr(self.func, "ms_kernel_flag"):
659
+ if hasattr(self.func, MS_KERNEL_FLAG):
636
660
  logger.warning("{}. To have a better user experience, the mode hybrid is suggested "
637
661
  "for the input function with decorator @kernel. "
638
662
  "To enable this mode, set the 'func_type' to be \"hybrid\"".format(self.log_prefix))
639
663
  elif self.func_type == "pyfunc":
640
- if hasattr(self.func, "ms_kernel_flag"):
664
+ if hasattr(self.func, MS_KERNEL_FLAG):
641
665
  logger.warning("{}. Now you are using the function with decorator @kernel in the mode pyfunc. "
642
666
  "The kernel will be executed as a native python function, which might lead to "
643
667
  "low efficiency. To accelerate the kernel, set the 'func_type' to be \"hybrid\""
@@ -751,7 +775,7 @@ class Custom(ops.PrimitiveWithInfer):
751
775
  continue
752
776
  if isinstance(reg_info_item, str):
753
777
  reg_info_item = json.loads(reg_info_item)
754
- prefix = "_".join([prefix, reg_info_item.get("op_name", "")])
778
+ prefix = "_".join([prefix, reg_info_item.get(OP_NAME, "")])
755
779
  self.uniq_name = prefix + "_" + self.func_name
756
780
  else:
757
781
  raise TypeError("For '{}', 'func' must be of type function or str, but got {}"
@@ -761,23 +785,23 @@ class Custom(ops.PrimitiveWithInfer):
761
785
  """Update op attrs in reg_info."""
762
786
  output_name_list = []
763
787
  for _, item in enumerate(reg_info.get("outputs", [])):
764
- if isinstance(item, dict) and item.get("name"):
765
- output_name_list.append(item.get("name"))
788
+ if isinstance(item, dict) and item.get(KEY_NAME):
789
+ output_name_list.append(item.get(KEY_NAME))
766
790
  if output_name_list:
767
791
  self.add_prim_attr("output_names", output_name_list)
768
792
 
769
- if isinstance(reg_info.get("op_name"), str):
770
- self.add_prim_attr("reg_op_name", reg_info.get("op_name"))
793
+ if isinstance(reg_info.get(OP_NAME), str):
794
+ self.add_prim_attr("reg_op_name", reg_info.get(OP_NAME))
771
795
 
772
796
  if self.func_type == "aicpu":
773
- self.uniq_name = reg_info["op_name"]
797
+ self.uniq_name = reg_info[OP_NAME]
774
798
  self.add_prim_attr("uniq_name", self.uniq_name)
775
799
 
776
800
  if self.func_type in ["aot", "aicpu"]:
777
- if reg_info.get("attr") is not None and isinstance(reg_info["attr"], list):
778
- for item in reg_info["attr"]:
801
+ if reg_info.get(KEY_ATTR) is not None and isinstance(reg_info[KEY_ATTR], list):
802
+ for item in reg_info[KEY_ATTR]:
779
803
  if isinstance(item, dict) and item.get("value") is not None:
780
- self.add_prim_attr(item["name"], item["value"])
804
+ self.add_prim_attr(item[KEY_NAME], item["value"])
781
805
 
782
806
  def _register_info(self, info):
783
807
  """Register reg_info."""
@@ -795,7 +819,7 @@ class Custom(ops.PrimitiveWithInfer):
795
819
  if isinstance(reg_info, str):
796
820
  reg_info = json.loads(reg_info)
797
821
  if self.fake_output:
798
- reg_info["outputs"].append(dict({"index": 0, "name": "y", "param_type": "required"}))
822
+ reg_info["outputs"].append(dict({"index": 0, KEY_NAME: "y", "param_type": "required"}))
799
823
  new_dtype_format = []
800
824
  for i in reg_info["dtype_format"]:
801
825
  new_dtype_format.append(i + (DataType.I32_Default,))
@@ -867,16 +891,16 @@ class Custom(ops.PrimitiveWithInfer):
867
891
  "'CustomRegOp' to generate the registration information, then pass it to 'reg_info' or "
868
892
  "use 'custom_info_register' to bind it to 'func' if 'func' is a function."
869
893
  .format(self.log_prefix, reg_info, type(reg_info)))
870
- reg_info["op_name"] = self.uniq_name
871
- reg_info["imply_type"] = self._get_imply_type(reg_info, target)
872
- if not isinstance(reg_info.get("fusion_type"), str) or not reg_info["fusion_type"].strip():
873
- reg_info["fusion_type"] = "OPAQUE"
894
+ reg_info[OP_NAME] = self.uniq_name
895
+ reg_info[IMPLY_TYPE] = self._get_imply_type(reg_info, target)
896
+ if not isinstance(reg_info.get(FUSION_TYPE), str) or not reg_info[FUSION_TYPE].strip():
897
+ reg_info[FUSION_TYPE] = "OPAQUE"
874
898
  # Supplement necessary info for TBE if these information is missing in reg_info
875
- if reg_info["imply_type"] == "TBE":
876
- if reg_info.get("attr") is not None and isinstance(reg_info["attr"], list):
877
- for i, item in enumerate(reg_info["attr"]):
899
+ if reg_info[IMPLY_TYPE] == TBE:
900
+ if reg_info.get(KEY_ATTR) is not None and isinstance(reg_info[KEY_ATTR], list):
901
+ for i, item in enumerate(reg_info[KEY_ATTR]):
878
902
  if isinstance(item, dict) and item.get("value") is None:
879
- reg_info["attr"][i]["value"] = "all"
903
+ reg_info[KEY_ATTR][i]["value"] = "all"
880
904
  reg_info["async_flag"] = reg_info.get("async_flag", False)
881
905
  reg_info["binfile"] = "%s.so" % self.func_name
882
906
  reg_info["compute_cost"] = reg_info.get("compute_cost", 10)
@@ -884,8 +908,8 @@ class Custom(ops.PrimitiveWithInfer):
884
908
  reg_info["partial_flag"] = reg_info.get("partial_flag", True)
885
909
  reg_info["needCheckSupport"] = reg_info.get("need_check_supported", False)
886
910
  # Supplement necessary info for AKG if these information is missing in reg_info
887
- if reg_info["imply_type"] == "AKG":
888
- target_to_processor = {"Ascend": "AiCore", "GPU": "CUDA", "CPU": "CPU"}
911
+ if reg_info[IMPLY_TYPE] == AKG:
912
+ target_to_processor = {ASCEND: AICORE, GPU: CUDA, CPU: CPU}
889
913
  reg_info["processor"] = reg_info.get("processor", target_to_processor.get(target))
890
914
  return reg_info
891
915
 
@@ -898,15 +922,15 @@ class Custom(ops.PrimitiveWithInfer):
898
922
  # Infer target from reg_info["processor"], reg_info generated from AkgGpuRegOp or AkgAscendRegOp
899
923
  # will have the processor information.
900
924
  if target not in self.supported_targets:
901
- processor_to_target = {"AiCore": "Ascend", "CUDA": "GPU", "CPU": "CPU"}
925
+ processor_to_target = {AICORE: ASCEND, CUDA: GPU, CPU: CPU}
902
926
  target = processor_to_target.get(reg_info.get("processor"))
903
- # Infer target from reg_info["imply_type"]
927
+ # Infer target from reg_info[IMPLY_TYPE]
904
928
  if target not in self.supported_targets:
905
- imply_type_to_target = {"TBE": "Ascend", "GPU": "GPU", "CPU": "CPU"}
906
- target = imply_type_to_target.get(reg_info.get("imply_type"))
929
+ imply_type_to_target = {TBE: ASCEND, GPU: GPU, CPU: CPU}
930
+ target = imply_type_to_target.get(reg_info.get(IMPLY_TYPE))
907
931
  # Infer target from func_type
908
932
  if target not in self.supported_targets:
909
- func_type_to_target = {"tbe": "Ascend", "pyfunc": "CPU"}
933
+ func_type_to_target = {"tbe": ASCEND, "pyfunc": CPU}
910
934
  target = func_type_to_target.get(self.func_type)
911
935
  if target not in self.supported_targets:
912
936
  raise ValueError("{}, target set in registration information must be one of {}, but got {}"
@@ -915,14 +939,14 @@ class Custom(ops.PrimitiveWithInfer):
915
939
 
916
940
  def _get_imply_type(self, reg_info, target):
917
941
  """Get imply_typ information."""
918
- # Get imply_type from reg_info["imply_type"]
919
- if isinstance(reg_info, dict) and isinstance(reg_info.get("imply_type"), str) and \
920
- reg_info["imply_type"].strip():
921
- return reg_info["imply_type"]
942
+ # Get imply_type from reg_info[IMPLY_TYPE]
943
+ if isinstance(reg_info, dict) and isinstance(reg_info.get(IMPLY_TYPE), str) and \
944
+ reg_info[IMPLY_TYPE].strip():
945
+ return reg_info[IMPLY_TYPE]
922
946
  # Infer imply_type from func_type
923
- func_type_to_imply_type = {"hybrid": "AKG", "akg": "AKG", "tbe": "TBE", "aicpu": "AiCPU", "pyfunc": target,
924
- "julia": target, "aot": "BiSheng" if target == "Ascend" else target}
925
- return func_type_to_imply_type.get(self.func_type, "AKG")
947
+ func_type_to_imply_type = {"hybrid": AKG, "akg": AKG, "tbe": TBE, "aicpu": "AiCPU", "pyfunc": target,
948
+ "julia": target, "aot": "BiSheng" if target == ASCEND else target}
949
+ return func_type_to_imply_type.get(self.func_type, AKG)
926
950
 
927
951
  def _save_attr(self, reg_info):
928
952
  """Save input_names and attr_names of current func."""
@@ -936,18 +960,18 @@ class Custom(ops.PrimitiveWithInfer):
936
960
  return value
937
961
 
938
962
  tensor_inputs = _get_value_list("inputs")
939
- attr = _get_value_list("attr")
963
+ attr = _get_value_list(KEY_ATTR)
940
964
  input_names = [] # include tensor input names and attr input names
941
965
  attr_names = []
942
966
  pure_input_names = []
943
967
  for item in tensor_inputs:
944
- if isinstance(item, dict) and item.get("name") is not None:
945
- input_names.append(item["name"])
946
- pure_input_names.append(item["name"])
968
+ if isinstance(item, dict) and item.get(KEY_NAME) is not None:
969
+ input_names.append(item[KEY_NAME])
970
+ pure_input_names.append(item[KEY_NAME])
947
971
  # attr is converted from inputs only when graph mode or when inputs name is also in reg info
948
972
  attr_to_input_safe = bool(input_names) or context.get_context("mode") == ms.GRAPH_MODE
949
973
  for item in attr:
950
- if isinstance(item, dict) and item.get("name") is not None:
974
+ if isinstance(item, dict) and item.get(KEY_NAME) is not None:
951
975
  # for custom op with function tbe, we always add attrs to inputs as we don't
952
976
  # deal with attr value here and leave them to the backend process to fit the
953
977
  # usual process of tbe op compiling in mindspore
@@ -956,9 +980,9 @@ class Custom(ops.PrimitiveWithInfer):
956
980
  # add attr name to input name only when the value of attr is None in reg info
957
981
  # as we need to get values of attrs from inputs
958
982
  if attr_to_input_safe and (self.func_type == "tbe" or item.get("value", None) is None):
959
- input_names.append(item["name"])
960
- attr_names.append(item["name"])
961
- cur_attr = {"input_names": input_names, "attr_names": attr_names, "pure_input_names": pure_input_names}
983
+ input_names.append(item[KEY_NAME])
984
+ attr_names.append(item[KEY_NAME])
985
+ cur_attr = {INPUT_NAMES: input_names, ATTR_NAMES: attr_names, "pure_input_names": pure_input_names}
962
986
  # If func does not have attr, save current attr.
963
987
  # Else, check if current attr is same as previous saved one.
964
988
  prev_attr_names = attr_names
@@ -967,13 +991,13 @@ class Custom(ops.PrimitiveWithInfer):
967
991
  if not isinstance(func_attr, dict):
968
992
  setattr(self.func, "func_attr", cur_attr)
969
993
  else:
970
- prev_attr_names = func_attr.get("attr_names")
994
+ prev_attr_names = func_attr.get(ATTR_NAMES)
971
995
  elif isinstance(self.func, str):
972
996
  func_attr = Custom.attr_dict.get(self.func)
973
997
  if not isinstance(func_attr, dict):
974
998
  Custom.attr_dict[self.func] = cur_attr
975
999
  else:
976
- prev_attr_names = func_attr.get("attr_names")
1000
+ prev_attr_names = func_attr.get(ATTR_NAMES)
977
1001
  if attr_names != prev_attr_names:
978
1002
  raise ValueError("{}, attr names set in registration information must be the same as previous saved one, "
979
1003
  "but got {} vs {}".format(self.log_prefix, attr_names, prev_attr_names))
@@ -982,23 +1006,23 @@ class Custom(ops.PrimitiveWithInfer):
982
1006
  """Add primitive_target to primitive's attr."""
983
1007
  registered_targets = self._get_registered_targets()
984
1008
  if self.func_type == "pyfunc":
985
- self.set_device("CPU")
986
- if registered_targets and registered_targets != ["CPU"]:
1009
+ self.set_device(CPU)
1010
+ if registered_targets and registered_targets != [CPU]:
987
1011
  logger.warning("{}, only supports CPU platform, but got registered target {}. "
988
1012
  "We will run it on CPU".format(self.log_prefix, registered_targets))
989
1013
  elif self.func_type == "aot":
990
1014
  if len(registered_targets) != 1:
991
1015
  logger.info("{}, target will be set according to context.".format(self.log_prefix))
992
- elif registered_targets == ["GPU"]:
993
- self.set_device("GPU")
994
- elif registered_targets == ["CPU"]:
995
- self.set_device("CPU")
1016
+ elif registered_targets == [GPU]:
1017
+ self.set_device(GPU)
1018
+ elif registered_targets == [CPU]:
1019
+ self.set_device(CPU)
996
1020
  elif self.func_type == "julia":
997
- self.set_device("CPU")
1021
+ self.set_device(CPU)
998
1022
  device_target = context.get_context('device_target')
999
- if device_target == "CPU":
1023
+ if device_target == CPU:
1000
1024
  pass
1001
- elif device_target == "GPU" and registered_targets and registered_targets == ["CPU"]:
1025
+ elif device_target == GPU and registered_targets and registered_targets == [CPU]:
1002
1026
  logger.warning("{}, only supports CPU platform, but got registered target {}. "
1003
1027
  "We will run it on CPU".format(self.log_prefix, registered_targets))
1004
1028
  else:
@@ -1021,15 +1045,15 @@ class Custom(ops.PrimitiveWithInfer):
1021
1045
  elif isinstance(self.func, str):
1022
1046
  func_attr = Custom.attr_dict.get(self.func)
1023
1047
  if isinstance(func_attr, dict):
1024
- _add_prim_attr("input_names")
1025
- _add_prim_attr("attr_names")
1048
+ _add_prim_attr(INPUT_NAMES)
1049
+ _add_prim_attr(ATTR_NAMES)
1026
1050
  _add_prim_attr("pure_input_names")
1027
1051
  self._add_prim_target()
1028
1052
  if callable(self.func) and callable(self.out_shape):
1029
- if hasattr(self.out_shape, "type") and getattr(self.out_shape, "type") == "autodiff":
1030
- self.add_prim_attr("autodiff", True)
1053
+ if hasattr(self.out_shape, "type") and getattr(self.out_shape, "type") == AUTO_DIFF:
1054
+ self.add_prim_attr(AUTO_DIFF, True)
1031
1055
  else:
1032
- self.add_prim_attr("autodiff", False)
1056
+ self.add_prim_attr(AUTO_DIFF, False)
1033
1057
 
1034
1058
  def _hybrid_autodiff(self, input_func_type):
1035
1059
  """generate backward op for a custom hybrid op"""
@@ -1045,7 +1069,7 @@ class Custom(ops.PrimitiveWithInfer):
1045
1069
  def infer_func(*args):
1046
1070
  return args[:inputs_num]
1047
1071
 
1048
- setattr(infer_func, "type", "autodiff")
1072
+ setattr(infer_func, "type", AUTO_DIFF)
1049
1073
  op = Custom(func=self.func, out_shape=infer_func, out_dtype=infer_func,
1050
1074
  func_type=input_func_type, bprop=True)
1051
1075
  self.bprop = grad_func(op)
@@ -58,7 +58,7 @@ class ScalarSummary(Primitive):
58
58
  This operator will put a scalar to a summary file with protocol buffer format. It must be used with SummaryRecord
59
59
  or SummaryCollector, which specify the directory of the summary file. The summary file can
60
60
  be loaded and shown by MindInsight, see `MindInsight documents <https://www.mindspore.cn/
61
- mindinsight/docs/en/r2.1/index.html>`_ for details.
61
+ mindinsight/docs/en/r2.2/index.html>`_ for details.
62
62
 
63
63
  Inputs:
64
64
  - **name** (str) - The name of the input variable, it must not be an empty string.
@@ -104,6 +104,7 @@ class ScalarSummary(Primitive):
104
104
  raise ValueError('The Summary is not supported, please without `-s on` and recompile source.')
105
105
 
106
106
  self.add_prim_attr("side_effect_io", True)
107
+ self.add_prim_attr("channel_name", "ms_scalar_summary")
107
108
 
108
109
  def __call__(self, *args):
109
110
  _cache_summary_data(self.name, args[0], args[1])
@@ -114,7 +115,7 @@ class ImageSummary(PrimitiveWithInfer):
114
115
  This operator will put an image tensor to a summary file with protocol buffer format. It must be used with
115
116
  SummaryRecord or SummaryCollector, which specify the directory of the summary file. The summary file can
116
117
  be loaded and shown by MindInsight, see `MindInsight documents <https://www.mindspore.cn/
117
- mindinsight/docs/en/r2.1/index.html>`_ for details.
118
+ mindinsight/docs/en/r2.2/index.html>`_ for details.
118
119
 
119
120
  Inputs:
120
121
  - **name** (str) - The name of the input variable, it must not be an empty string.
@@ -153,6 +154,7 @@ class ImageSummary(PrimitiveWithInfer):
153
154
  raise ValueError('The Summary is not supported, please without `-s on` and recompile source.')
154
155
 
155
156
  self.add_prim_attr("side_effect_io", True)
157
+ self.add_prim_attr("channel_name", "ms_image_summary")
156
158
 
157
159
  def __infer__(self, name, value):
158
160
  _check_summary_param(name, value, self.__class__.__name__)
@@ -175,7 +177,7 @@ class TensorSummary(Primitive):
175
177
  This operator will put a tensor to a summary file with protocol buffer format. It must be used with SummaryRecord
176
178
  or SummaryCollector, which specify the directory of the summary file. The summary file can
177
179
  be loaded and shown by MindInsight, see `MindInsight documents <https://www.mindspore.cn/
178
- mindinsight/docs/en/r2.1/index.html>`_ for details.
180
+ mindinsight/docs/en/r2.2/index.html>`_ for details.
179
181
 
180
182
  Inputs:
181
183
  - **name** (str) - The name of the input variable.
@@ -221,6 +223,7 @@ class TensorSummary(Primitive):
221
223
  raise ValueError('The Summary is not supported, please without `-s on` and recompile source.')
222
224
 
223
225
  self.add_prim_attr("side_effect_io", True)
226
+ self.add_prim_attr("channel_name", "ms_tensor_summary")
224
227
 
225
228
  def __call__(self, *args):
226
229
  _cache_summary_data(self.name, args[0], args[1])
@@ -231,7 +234,7 @@ class HistogramSummary(PrimitiveWithInfer):
231
234
  This operator will calculate the histogram of a tensor and put it to a summary file with protocol buffer format.
232
235
  It must be used with SummaryRecord or SummaryCollector, which specify the directory of the summary file.
233
236
  The summary file can be loaded and shown by MindInsight, see `MindInsight documents <https://www.mindspore.cn/
234
- mindinsight/docs/en/r2.1/index.html>`_ for details.
237
+ mindinsight/docs/en/r2.2/index.html>`_ for details.
235
238
 
236
239
  Inputs:
237
240
  - **name** (str) - The name of the input variable.
@@ -276,6 +279,7 @@ class HistogramSummary(PrimitiveWithInfer):
276
279
  raise ValueError('The Summary is not supported, please without `-s on` and recompile source.')
277
280
 
278
281
  self.add_prim_attr("side_effect_io", True)
282
+ self.add_prim_attr("channel_name", "ms_histogram_summary")
279
283
 
280
284
  def __infer__(self, name, value):
281
285
  _check_summary_param(name, value, self.__class__.__name__)
@@ -180,7 +180,7 @@ class AdjustHue(Primitive):
180
180
 
181
181
 
182
182
  class ExtractGlimpse(Primitive):
183
- """
183
+ r"""
184
184
  Extracts glimpses(usually subarea of rectangle) from the input image Tensor and return as windows.
185
185
 
186
186
  Note:
@@ -205,20 +205,20 @@ class ExtractGlimpse(Primitive):
205
205
  - When `noise` is ``'zero'`` , the value of `uniform_noise` must be ``'False'`` and the
206
206
  filling noise will be zero so that the result is fixed.
207
207
  - When `uniform_noise` is ``'True'`` , the value of `noise` only can be ``'uniform'`` .
208
- When `uniform_noise` is ``'False'`` , the value of `noise` can be ``'uniform'`` , ``'gaussian'`` and
208
+ When `uniform_noise` is ``'False'`` , the value of `noise` can be ``'uniform'`` , ``'gaussian'`` or
209
209
  ``'zero'`` .
210
210
 
211
211
  Inputs:
212
- - **x** (Tensor) - A 4-D float tensor of shape :math:`(batch_size, height, width, channels)`.
212
+ - **x** (Tensor) - A 4-D float tensor of shape :math:`(batch\_size, height, width, channels)`.
213
213
  Types allowed: float32.
214
214
  - **size** (Tensor) - A 1-D tensor of 2 elements containing the size of the glimpses to extract.
215
215
  The glimpse height must be specified first, following by the glimpse width. Types allowed: int32.
216
216
  The value of size must be greater than zero.
217
- - **offsets** (Tensor) - A 2-D integer tensor of shape :math:`(batch_size, 2)` containing the y, x locations
217
+ - **offsets** (Tensor) - A 2-D integer tensor of shape :math:`(batch\_size, 2)` containing the y, x locations
218
218
  of the center of each window. Types allowed: float32.
219
219
 
220
220
  Outputs:
221
- A 4-D tensor of shape :math:`(batch_size, glimpse_height, glimpse_width, channels)` with type: float32.
221
+ A 4-D tensor of shape :math:`(batch\_size, glimpse\_height, glimpse\_width, channels)` with type: float32.
222
222
 
223
223
  Raises:
224
224
  TypeError: If `centered` is not a bool.
@@ -277,9 +277,16 @@ class CropAndResize(Primitive):
277
277
 
278
278
  Args:
279
279
  method (str, optional): An optional string that specifies the sampling method for resizing.
280
- It can be ``"bilinear"`` , ``"nearest"`` or ``"bilinear_v2"`` . The option "bilinear" stands for standard
281
- ``"bilinear"`` interpolation algorithm, while ``"bilinear_v2"`` may result in better result in some cases.
282
- Default: ``"bilinear"`` .
280
+ It can be ``"bilinear"`` , ``"nearest"`` or ``"bilinear_v2"`` . Default: ``"bilinear"`` .
281
+
282
+ - ``"nearest"``: Nearest neighbor interpolation. Each output pixel is assigned the value of the
283
+ nearest input pixel. This method is simple and fast but can result in blocky or pixelated outputs.
284
+ - ``"bilinear"``: Bilinear interpolation. Each output pixel is a weighted average of the four nearest input
285
+ pixels, computed using bilinear interpolation. This method produces smoother results compared
286
+ to nearest neighbor interpolation.
287
+ - ``"bilinear_v2"``: The optimized variant of
288
+ ``"bilinear"``, it may achieve better result(higher precision and speed) in some cases.
289
+
283
290
  extrapolation_value (float, optional): An optional float value used extrapolation, if applicable.
284
291
  Default: ``0.0`` .
285
292
 
@@ -358,7 +365,6 @@ class CropAndResize(Primitive):
358
365
  self.method = method
359
366
  validator.check_value_type("extrapolation_value", extrapolation_value, [float], self.name)
360
367
  self.extrapolation_value = extrapolation_value
361
- self.is_ge = context.get_context("enable_ge")
362
368
 
363
369
 
364
370
  class NonMaxSuppressionV3(Primitive):
@@ -470,7 +476,7 @@ class NonMaxSuppressionWithOverlaps(Primitive):
470
476
  Raises:
471
477
  TypeError: If the dtype of `overlaps` , `scores` `overlap_threshold` and `score_threshold`
472
478
  is not float16, float32 or float64.
473
- TypeError: If `overlaps` or `scores` is not Tensor
479
+ TypeError: If `overlaps` or `scores` is not Tensor.
474
480
  TypeError: If `max_output_size` is not Tensor or Scalar.If `max_output_size` is not int32.
475
481
  TypeError: If `overlap_threshold` is not Tensor or scalar. If its type is not float16, float32 or float64.
476
482
  TypeError: If `score_threshold` is not Tensor or scalar. If its type is not float16, float32 or float64.
@@ -1115,13 +1121,13 @@ class CombinedNonMaxSuppression(Primitive):
1115
1121
  it exceeds `max_total_size`.
1116
1122
 
1117
1123
  Inputs:
1118
- - **boxes** (Tensor) - A float32 Tensor with shape :math:`(batch_size, num_boxes, q, 4)`
1124
+ - **boxes** (Tensor) - A float32 Tensor with shape :math:`(batch\_size, num\_boxes, q, 4)`
1119
1125
  representing the bounding box coordinates.
1120
1126
  `q` indicates mapping relationship between boxes and classes.
1121
1127
  If `q` is 1, all classes use the same bounding box. If `q` is equal to the number of classes,
1122
1128
  class-specific boxes are applied.
1123
1129
  - **scores** (Tensor) - A 3-D Tensor of float32 type with the shape
1124
- :math:`(batch_size, num_boxes, num_classes)`. It contains a score value for each box,
1130
+ :math:`(batch\_size, num\_boxes, num\_classes)`. It contains a score value for each box,
1125
1131
  with each row of `boxes` represented by a single score.
1126
1132
  - **max_output_size_per_class** (Tensor) - The maximum number of boxes that can be selected for each class
1127
1133
  by the non-maximum suppression algorithm, represented by a scalar Tensor of type int32.
@@ -238,13 +238,14 @@ class LambApplyOptimizerAssign(PrimitiveWithInfer):
238
238
  @prim_attr_register
239
239
  def __init__(self):
240
240
  """Initialize LambApplyOptimizerAssign"""
241
+ self.var_shape = "var_shape"
241
242
  self.add_prim_attr('side_effect_mem', True)
242
243
 
243
244
  def infer_shape(self, grad_shape, v_shape, m_shape, var_shape, beta1_shape, sub1_shape,
244
245
  beta2_shape, sub2_shape, eps_shape, steps_shape, use_weight_shape, weight_decay_shape):
245
- validator.check("var_shape", var_shape, "m_shape", m_shape, validator.EQ, self.name)
246
- validator.check("var_shape", var_shape, "v_shape", v_shape, validator.EQ, self.name)
247
- validator.check("var_shape", var_shape, "grad_shape", grad_shape, validator.EQ, self.name)
246
+ validator.check(self.var_shape, var_shape, "m_shape", m_shape, validator.EQ, self.name)
247
+ validator.check(self.var_shape, var_shape, "v_shape", v_shape, validator.EQ, self.name)
248
+ validator.check(self.var_shape, var_shape, "grad_shape", grad_shape, validator.EQ, self.name)
248
249
  return m_shape, v_shape, m_shape
249
250
 
250
251
  def infer_dtype(self, grad_dtype, v_dtype, m_dtype, var_dtype, beta1_dtype, sub1_dtype,
@@ -658,3 +659,25 @@ class ScaleGrad(PrimitiveWithInfer):
658
659
  @prim_attr_register
659
660
  def __init__(self):
660
661
  """Initialize ScaleGrad"""
662
+
663
+
664
+ class KVCacheMgr(Primitive):
665
+ """
666
+ Update past with cur and index along sequence axis.
667
+
668
+ Inputs:
669
+ - **past** (Parameter) - 4-D tensor with shape: :math:`(batch_size, num_head, seq_len, hidden_size)`.
670
+ - **cur** (Tensor) - 4-D tensor with shape: :math:`(batch_size, num_head, 1, hidden_size)`.
671
+ - **index** (Tensor) - 1-D tensor with shape: :math:`(batch_size,)`.
672
+
673
+ Outputs:
674
+ Tensor, has the same data type and shape as original `past`.
675
+
676
+ Supported Platforms:
677
+ ``Ascend``
678
+ """
679
+
680
+ @prim_attr_register
681
+ def __init__(self):
682
+ self.init_prim_io_names(inputs=['past', 'cur', 'index'], outputs=['past'])
683
+ self.add_prim_attr('side_effect_mem', True)