mindspore 2.1.0__cp38-cp38-manylinux1_x86_64.whl → 2.2.0__cp38-cp38-manylinux1_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (550) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +4 -1
  3. mindspore/_akg/akg/build_module.py +5 -6
  4. mindspore/_akg/akg/composite/build_module.py +49 -16
  5. mindspore/_akg/akg/composite/split_stitch.py +10 -11
  6. mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
  7. mindspore/_akg/akg/tvm/api.py +4 -3
  8. mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
  9. mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
  10. mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
  11. mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
  12. mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
  13. mindspore/_akg/akg/tvm/build_module.py +16 -1
  14. mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
  15. mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
  16. mindspore/_akg/akg/tvm/ir_builder.py +1 -1
  17. mindspore/_akg/akg/tvm/module.py +1 -2
  18. mindspore/_akg/akg/tvm/stmt.py +2 -2
  19. mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
  20. mindspore/_akg/akg/utils/kernel_exec.py +58 -260
  21. mindspore/_akg/akg/utils/result_analysis.py +4 -24
  22. mindspore/_akg/akg/utils/tbe_codegen_utils.py +198 -0
  23. mindspore/_c_dataengine.cpython-38-x86_64-linux-gnu.so +0 -0
  24. mindspore/_c_expression.cpython-38-x86_64-linux-gnu.so +0 -0
  25. mindspore/_c_mindrecord.cpython-38-x86_64-linux-gnu.so +0 -0
  26. mindspore/_check_jit_forbidden_api.py +3 -1
  27. mindspore/_checkparam.py +26 -32
  28. mindspore/_extends/graph_kernel/__init__.py +0 -1
  29. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  30. mindspore/_extends/graph_kernel/splitter.py +1 -9
  31. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
  32. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +2 -2
  33. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  34. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
  35. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +4 -4
  36. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  37. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  38. mindspore/_extends/parse/__init__.py +12 -15
  39. mindspore/_extends/parse/namespace.py +7 -33
  40. mindspore/_extends/parse/parser.py +61 -71
  41. mindspore/_extends/parse/resources.py +1 -1
  42. mindspore/_extends/parse/standard_method.py +72 -95
  43. mindspore/_extends/parse/trope.py +1 -1
  44. mindspore/_extends/remote/kernel_build_server.py +24 -7
  45. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  46. mindspore/_install_custom.py +43 -0
  47. mindspore/_mindspore_offline_debug.cpython-38-x86_64-linux-gnu.so +0 -0
  48. mindspore/amp.py +47 -11
  49. mindspore/bin/cache_admin +0 -0
  50. mindspore/bin/cache_server +0 -0
  51. mindspore/boost/boost.py +1 -8
  52. mindspore/boost/boost_cell_wrapper.py +3 -2
  53. mindspore/boost/grad_accumulation.py +1 -1
  54. mindspore/boost/group_loss_scale_manager.py +8 -7
  55. mindspore/common/__init__.py +5 -3
  56. mindspore/common/_jit_fallback_utils.py +6 -0
  57. mindspore/common/_register_for_adapter.py +2 -0
  58. mindspore/common/_register_for_tensor.py +2 -2
  59. mindspore/common/_stub_tensor.py +13 -0
  60. mindspore/common/_utils.py +13 -0
  61. mindspore/common/api.py +173 -258
  62. mindspore/common/auto_dynamic_shape.py +498 -0
  63. mindspore/common/dtype.py +18 -11
  64. mindspore/common/dump.py +6 -4
  65. mindspore/common/initializer.py +14 -14
  66. mindspore/common/jit_config.py +33 -15
  67. mindspore/common/lazy_inline.py +126 -7
  68. mindspore/common/mindir_util.py +101 -0
  69. mindspore/common/parameter.py +51 -41
  70. mindspore/common/seed.py +4 -4
  71. mindspore/common/sparse_tensor.py +13 -14
  72. mindspore/common/tensor.py +240 -145
  73. mindspore/communication/__init__.py +7 -4
  74. mindspore/communication/_comm_helper.py +83 -4
  75. mindspore/communication/management.py +152 -84
  76. mindspore/config/op_info.config +13 -2
  77. mindspore/config/super_bar_config.json +4 -2
  78. mindspore/context.py +143 -59
  79. mindspore/dataset/__init__.py +5 -5
  80. mindspore/dataset/audio/__init__.py +2 -2
  81. mindspore/dataset/audio/transforms.py +52 -52
  82. mindspore/dataset/callback/ds_callback.py +16 -2
  83. mindspore/dataset/core/config.py +68 -51
  84. mindspore/dataset/engine/cache_client.py +28 -5
  85. mindspore/dataset/engine/datasets.py +250 -112
  86. mindspore/dataset/engine/datasets_audio.py +43 -211
  87. mindspore/dataset/engine/datasets_standard_format.py +11 -35
  88. mindspore/dataset/engine/datasets_text.py +43 -67
  89. mindspore/dataset/engine/datasets_user_defined.py +86 -100
  90. mindspore/dataset/engine/datasets_vision.py +219 -1029
  91. mindspore/dataset/engine/iterators.py +11 -4
  92. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
  93. mindspore/dataset/engine/obs/util.py +3 -0
  94. mindspore/dataset/engine/samplers.py +1 -1
  95. mindspore/dataset/engine/validators.py +19 -5
  96. mindspore/dataset/text/__init__.py +3 -3
  97. mindspore/dataset/text/transforms.py +101 -127
  98. mindspore/dataset/text/utils.py +205 -138
  99. mindspore/dataset/transforms/__init__.py +1 -1
  100. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  101. mindspore/dataset/transforms/transforms.py +95 -40
  102. mindspore/dataset/utils/browse_dataset.py +8 -2
  103. mindspore/dataset/utils/line_reader.py +17 -19
  104. mindspore/dataset/vision/__init__.py +3 -3
  105. mindspore/dataset/vision/c_transforms.py +6 -3
  106. mindspore/dataset/vision/transforms.py +409 -287
  107. mindspore/dataset/vision/utils.py +13 -14
  108. mindspore/dataset/vision/validators.py +11 -1
  109. mindspore/experimental/map_parameter.py +14 -0
  110. mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
  111. mindspore/{nn/optim_ex → experimental/optim}/adam.py +59 -66
  112. mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
  113. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  114. mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
  115. mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
  116. mindspore/gen_ops.py +273 -0
  117. mindspore/include/OWNERS +0 -1
  118. mindspore/include/api/data_type.h +2 -1
  119. mindspore/include/api/graph.h +0 -15
  120. mindspore/include/api/kernel.h +2 -0
  121. mindspore/include/api/kernel_api.h +37 -12
  122. mindspore/include/api/model.h +0 -14
  123. mindspore/include/api/types.h +37 -4
  124. mindspore/include/c_api/ms/abstract.h +67 -0
  125. mindspore/include/c_api/ms/attribute.h +197 -0
  126. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  127. mindspore/include/c_api/ms/base/macros.h +32 -0
  128. mindspore/include/c_api/ms/base/status.h +33 -0
  129. mindspore/include/c_api/ms/base/types.h +282 -0
  130. mindspore/include/c_api/ms/context.h +102 -0
  131. mindspore/include/c_api/ms/graph.h +160 -0
  132. mindspore/include/c_api/ms/node.h +606 -0
  133. mindspore/include/c_api/ms/tensor.h +161 -0
  134. mindspore/include/c_api/ms/value.h +84 -0
  135. mindspore/include/dataset/constants.h +6 -5
  136. mindspore/include/dataset/execute.h +23 -13
  137. mindspore/include/dataset/text.h +26 -26
  138. mindspore/include/dataset/transforms.h +13 -13
  139. mindspore/include/dataset/vision.h +60 -60
  140. mindspore/include/dataset/vision_ascend.h +5 -6
  141. mindspore/include/dataset/vision_lite.h +17 -17
  142. mindspore/include/mindapi/base/type_id.h +1 -0
  143. mindspore/include/mindapi/base/types.h +1 -0
  144. mindspore/lib/libdnnl.so.2 +0 -0
  145. mindspore/lib/libjemalloc.so.2 +0 -0
  146. mindspore/lib/libmindspore.so +0 -0
  147. mindspore/lib/libmindspore_backend.so +0 -0
  148. mindspore/lib/libmindspore_common.so +0 -0
  149. mindspore/lib/libmindspore_core.so +0 -0
  150. mindspore/lib/libmindspore_glog.so.0 +0 -0
  151. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  152. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  153. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  154. mindspore/lib/libmindspore_shared_lib.so +0 -0
  155. mindspore/lib/libnnacl.so +0 -0
  156. mindspore/lib/libopencv_core.so.4.5 +0 -0
  157. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  158. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  159. mindspore/lib/libps_cache.so +0 -0
  160. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  161. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  162. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +9000 -0
  163. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  164. mindspore/lib/plugin/ascend/libakg.so +0 -0
  165. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  166. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  167. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  168. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  169. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  170. mindspore/lib/plugin/cpu/libakg.so +0 -0
  171. mindspore/lib/plugin/gpu/libcuda_ops.so.10 +0 -0
  172. mindspore/lib/plugin/gpu/libcuda_ops.so.11 +0 -0
  173. mindspore/lib/plugin/gpu10.1/libakg.so +0 -0
  174. mindspore/lib/plugin/gpu10.1/libnccl.so.2 +0 -0
  175. mindspore/lib/plugin/gpu11.1/libakg.so +0 -0
  176. mindspore/lib/plugin/gpu11.1/libnccl.so.2 +0 -0
  177. mindspore/lib/plugin/gpu11.6/libakg.so +0 -0
  178. mindspore/lib/plugin/gpu11.6/libnccl.so.2 +0 -0
  179. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  180. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  181. mindspore/lib/plugin/libmindspore_gpu.so.10.1 +0 -0
  182. mindspore/lib/plugin/libmindspore_gpu.so.11.1 +0 -0
  183. mindspore/lib/plugin/libmindspore_gpu.so.11.6 +0 -0
  184. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  185. mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
  186. mindspore/nn/__init__.py +0 -2
  187. mindspore/nn/cell.py +316 -74
  188. mindspore/nn/dynamic_lr.py +21 -21
  189. mindspore/nn/layer/activation.py +21 -28
  190. mindspore/nn/layer/basic.py +15 -13
  191. mindspore/nn/layer/channel_shuffle.py +1 -1
  192. mindspore/nn/layer/container.py +271 -9
  193. mindspore/nn/layer/conv.py +310 -207
  194. mindspore/nn/layer/dense.py +8 -5
  195. mindspore/nn/layer/embedding.py +33 -27
  196. mindspore/nn/layer/flash_attention.py +82 -41
  197. mindspore/nn/layer/image.py +8 -6
  198. mindspore/nn/layer/math.py +13 -18
  199. mindspore/nn/layer/normalization.py +107 -66
  200. mindspore/nn/layer/padding.py +1 -1
  201. mindspore/nn/layer/pooling.py +131 -109
  202. mindspore/nn/layer/rnn_cells.py +22 -17
  203. mindspore/nn/layer/rnns.py +13 -16
  204. mindspore/nn/layer/thor_layer.py +1 -1
  205. mindspore/nn/layer/transformer.py +221 -154
  206. mindspore/nn/learning_rate_schedule.py +9 -1
  207. mindspore/nn/loss/loss.py +235 -174
  208. mindspore/nn/optim/ada_grad.py +2 -1
  209. mindspore/nn/optim/adadelta.py +1 -0
  210. mindspore/nn/optim/adafactor.py +2 -1
  211. mindspore/nn/optim/adam.py +7 -4
  212. mindspore/nn/optim/adamax.py +3 -2
  213. mindspore/nn/optim/adasum.py +2 -2
  214. mindspore/nn/optim/asgd.py +2 -3
  215. mindspore/nn/optim/ftrl.py +6 -5
  216. mindspore/nn/optim/lamb.py +7 -4
  217. mindspore/nn/optim/lars.py +1 -1
  218. mindspore/nn/optim/lazyadam.py +5 -3
  219. mindspore/nn/optim/momentum.py +2 -1
  220. mindspore/nn/optim/optimizer.py +53 -4
  221. mindspore/nn/optim/proximal_ada_grad.py +3 -4
  222. mindspore/nn/optim/rmsprop.py +4 -3
  223. mindspore/nn/optim/rprop.py +23 -12
  224. mindspore/nn/optim/sgd.py +26 -11
  225. mindspore/nn/optim/thor.py +9 -7
  226. mindspore/nn/probability/bijector/bijector.py +5 -5
  227. mindspore/nn/probability/bijector/power_transform.py +27 -27
  228. mindspore/nn/probability/bijector/softplus.py +3 -3
  229. mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
  230. mindspore/nn/probability/distribution/bernoulli.py +5 -5
  231. mindspore/nn/probability/distribution/beta.py +3 -3
  232. mindspore/nn/probability/distribution/categorical.py +7 -7
  233. mindspore/nn/probability/distribution/cauchy.py +0 -1
  234. mindspore/nn/probability/distribution/distribution.py +3 -3
  235. mindspore/nn/probability/distribution/gamma.py +3 -3
  236. mindspore/nn/probability/distribution/geometric.py +4 -4
  237. mindspore/nn/probability/distribution/gumbel.py +4 -4
  238. mindspore/nn/probability/distribution/log_normal.py +2 -2
  239. mindspore/nn/probability/distribution/logistic.py +2 -2
  240. mindspore/nn/probability/distribution/poisson.py +4 -4
  241. mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
  242. mindspore/nn/probability/distribution/uniform.py +6 -6
  243. mindspore/nn/wrap/cell_wrapper.py +78 -34
  244. mindspore/nn/wrap/grad_reducer.py +8 -5
  245. mindspore/nn/wrap/loss_scale.py +105 -42
  246. mindspore/numpy/array_creations.py +1 -2
  247. mindspore/numpy/array_ops.py +3 -2
  248. mindspore/offline_debug/convert_async.py +2 -2
  249. mindspore/ops/_grad_experimental/__init__.py +0 -5
  250. mindspore/ops/_grad_experimental/grad_array_ops.py +1 -2
  251. mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
  252. mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
  253. mindspore/ops/_grad_experimental/grad_implementations.py +10 -0
  254. mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
  255. mindspore/ops/_grad_experimental/grad_math_ops.py +0 -181
  256. mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
  257. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  258. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +165 -109
  259. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +144 -86
  260. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +172 -187
  261. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +51 -57
  262. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +6 -17
  263. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +1 -1
  264. mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
  265. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  266. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  267. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  268. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  269. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  270. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  271. mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
  272. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  273. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  274. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  275. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  276. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  277. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  278. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  279. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
  280. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
  281. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  282. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  283. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  284. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  285. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  286. mindspore/ops/_op_impl/tbe/__init__.py +4 -4
  287. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  288. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  289. mindspore/ops/_primitive_cache.py +1 -1
  290. mindspore/ops/_tracefunc.py +45 -13
  291. mindspore/ops/_utils/utils.py +4 -1
  292. mindspore/ops/_vmap/vmap_array_ops.py +3 -3
  293. mindspore/ops/_vmap/vmap_base.py +3 -3
  294. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  295. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  296. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  297. mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
  298. mindspore/ops/arg_dtype_cast.py +54 -0
  299. mindspore/ops/composite/base.py +37 -10
  300. mindspore/ops/composite/math_ops.py +5 -4
  301. mindspore/ops/composite/multitype_ops/_compile_utils.py +273 -72
  302. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
  303. mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
  304. mindspore/ops/composite/multitype_ops/getitem_impl.py +40 -2
  305. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  306. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  307. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  308. mindspore/ops/deprecated.py +304 -0
  309. mindspore/ops/function/__init__.py +4 -1
  310. mindspore/ops/function/array_func.py +167 -189
  311. mindspore/ops/function/clip_func.py +81 -13
  312. mindspore/ops/function/debug_func.py +1 -1
  313. mindspore/ops/function/grad/grad_func.py +18 -8
  314. mindspore/ops/function/image_func.py +10 -4
  315. mindspore/ops/function/linalg_func.py +5 -5
  316. mindspore/ops/function/math_func.py +575 -386
  317. mindspore/ops/function/nn_func.py +470 -251
  318. mindspore/ops/function/random_func.py +86 -56
  319. mindspore/ops/function/sparse_func.py +1 -1
  320. mindspore/ops/function/sparse_unary_func.py +14 -12
  321. mindspore/ops/function/vmap_func.py +6 -5
  322. mindspore/ops/functional.py +15 -10
  323. mindspore/ops/op_info_register.py +235 -19
  324. mindspore/ops/operations/__init__.py +25 -17
  325. mindspore/ops/operations/_grad_ops.py +52 -7
  326. mindspore/ops/operations/_inner_ops.py +213 -12
  327. mindspore/ops/operations/_quant_ops.py +4 -8
  328. mindspore/ops/operations/_sequence_ops.py +42 -0
  329. mindspore/ops/operations/array_ops.py +64 -280
  330. mindspore/ops/operations/comm_ops.py +105 -57
  331. mindspore/ops/operations/custom_ops.py +10 -3
  332. mindspore/ops/operations/debug_ops.py +8 -4
  333. mindspore/ops/operations/image_ops.py +18 -12
  334. mindspore/ops/operations/math_ops.py +185 -138
  335. mindspore/ops/operations/nn_ops.py +716 -492
  336. mindspore/ops/operations/other_ops.py +0 -22
  337. mindspore/ops/operations/random_ops.py +53 -111
  338. mindspore/ops/operations/sparse_ops.py +3 -1
  339. mindspore/ops/primitive.py +24 -18
  340. mindspore/parallel/_auto_parallel_context.py +68 -8
  341. mindspore/parallel/_cost_model_context.py +2 -2
  342. mindspore/parallel/_offload_context.py +17 -3
  343. mindspore/parallel/_parallel_serialization.py +2 -2
  344. mindspore/parallel/_ps_context.py +12 -0
  345. mindspore/parallel/_tensor.py +14 -12
  346. mindspore/parallel/_transformer/layers.py +5 -3
  347. mindspore/parallel/_transformer/loss.py +1 -0
  348. mindspore/parallel/_transformer/moe.py +2 -2
  349. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  350. mindspore/parallel/_transformer/transformer.py +23 -3
  351. mindspore/parallel/_utils.py +11 -7
  352. mindspore/parallel/algo_parameter_config.py +85 -5
  353. mindspore/parallel/checkpoint_transform.py +6 -10
  354. mindspore/parallel/shard.py +4 -4
  355. mindspore/profiler/common/struct_type.py +3 -3
  356. mindspore/profiler/common/util.py +3 -2
  357. mindspore/profiler/envprofiling.py +1 -1
  358. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  359. mindspore/profiler/parser/ascend_flops_generator.py +2 -2
  360. mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
  361. mindspore/profiler/parser/ascend_hccl_generator.py +17 -12
  362. mindspore/profiler/parser/ascend_msprof_exporter.py +104 -252
  363. mindspore/profiler/parser/ascend_msprof_generator.py +8 -8
  364. mindspore/profiler/parser/ascend_op_generator.py +5 -5
  365. mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
  366. mindspore/profiler/parser/ascend_timeline_generator.py +9 -6
  367. mindspore/profiler/parser/base_timeline_generator.py +9 -7
  368. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +14 -10
  369. mindspore/profiler/parser/flops_parser.py +15 -11
  370. mindspore/profiler/parser/framework_parser.py +37 -21
  371. mindspore/profiler/parser/hccl_parser.py +16 -12
  372. mindspore/profiler/parser/integrator.py +22 -11
  373. mindspore/profiler/parser/memory_usage_parser.py +2 -2
  374. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  375. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  376. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  377. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  378. mindspore/profiler/parser/optime_parser.py +1 -1
  379. mindspore/profiler/parser/profiler_info.py +2 -2
  380. mindspore/profiler/parser/step_trace_parser.py +11 -14
  381. mindspore/profiler/profiling.py +139 -71
  382. mindspore/rewrite/api/node.py +102 -19
  383. mindspore/rewrite/api/node_type.py +5 -1
  384. mindspore/rewrite/api/scoped_value.py +9 -17
  385. mindspore/rewrite/api/symbol_tree.py +131 -47
  386. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  387. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  388. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  389. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
  390. mindspore/rewrite/common/rewrite_elog.py +5 -1
  391. mindspore/rewrite/namer.py +33 -24
  392. mindspore/rewrite/namespace.py +14 -5
  393. mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
  394. mindspore/rewrite/node/call_function.py +79 -0
  395. mindspore/rewrite/node/cell_container.py +135 -0
  396. mindspore/rewrite/node/control_flow.py +88 -0
  397. mindspore/rewrite/{node.py → node/node.py} +273 -234
  398. mindspore/rewrite/node/node_manager.py +254 -0
  399. mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
  400. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  401. mindspore/rewrite/parsers/assign_parser.py +216 -221
  402. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  403. mindspore/rewrite/parsers/class_def_parser.py +174 -113
  404. mindspore/rewrite/parsers/constant_parser.py +9 -6
  405. mindspore/rewrite/parsers/container_parser.py +9 -7
  406. mindspore/rewrite/parsers/for_parser.py +36 -15
  407. mindspore/rewrite/parsers/function_def_parser.py +24 -16
  408. mindspore/rewrite/parsers/if_parser.py +28 -24
  409. mindspore/rewrite/parsers/module_parser.py +196 -25
  410. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  411. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  412. mindspore/rewrite/parsers/return_parser.py +6 -6
  413. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  414. mindspore/rewrite/sparsify/utils.py +1 -1
  415. mindspore/rewrite/symbol_tree.py +525 -577
  416. mindspore/rewrite/symbol_tree_builder.py +9 -193
  417. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  418. mindspore/run_check/_check_version.py +2 -2
  419. mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
  420. mindspore/safeguard/rewrite_obfuscation.py +517 -0
  421. mindspore/scipy/linalg.py +1 -1
  422. mindspore/scipy/optimize/minimize.py +7 -3
  423. mindspore/train/_utils.py +7 -3
  424. mindspore/train/amp.py +323 -123
  425. mindspore/train/anf_ir_pb2.py +14 -2
  426. mindspore/train/callback/_backup_and_restore.py +2 -12
  427. mindspore/train/callback/_callback.py +29 -4
  428. mindspore/train/callback/_checkpoint.py +23 -8
  429. mindspore/train/callback/_early_stop.py +2 -2
  430. mindspore/train/callback/_landscape.py +4 -4
  431. mindspore/train/callback/_loss_monitor.py +2 -2
  432. mindspore/train/callback/_on_request_exit.py +2 -2
  433. mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
  434. mindspore/train/callback/_summary_collector.py +14 -7
  435. mindspore/train/callback/_time_monitor.py +58 -5
  436. mindspore/train/data_sink.py +5 -11
  437. mindspore/train/dataset_helper.py +83 -57
  438. mindspore/train/loss_scale_manager.py +2 -2
  439. mindspore/train/metrics/__init__.py +3 -3
  440. mindspore/train/metrics/cosine_similarity.py +1 -1
  441. mindspore/train/metrics/hausdorff_distance.py +3 -2
  442. mindspore/train/metrics/mean_surface_distance.py +3 -2
  443. mindspore/train/metrics/metric.py +39 -19
  444. mindspore/train/metrics/roc.py +2 -2
  445. mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
  446. mindspore/train/mind_ir_pb2.py +85 -36
  447. mindspore/train/model.py +185 -45
  448. mindspore/train/serialization.py +390 -150
  449. mindspore/train/summary/_writer_pool.py +3 -2
  450. mindspore/train/summary/summary_record.py +14 -10
  451. mindspore/train/train_thor/convert_utils.py +3 -3
  452. mindspore/train/train_thor/dataset_helper.py +1 -1
  453. mindspore/version.py +1 -1
  454. {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/METADATA +6 -7
  455. {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/RECORD +458 -518
  456. {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/entry_points.txt +0 -1
  457. mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
  458. mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
  459. mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
  460. mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
  461. mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
  462. mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
  463. mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
  464. mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
  465. mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
  466. mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
  467. mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
  468. mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
  469. mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
  470. mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
  471. mindspore/_akg/akg/tvm/rpc/base.py +0 -182
  472. mindspore/_akg/akg/tvm/rpc/client.py +0 -436
  473. mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
  474. mindspore/_akg/akg/tvm/rpc/server.py +0 -413
  475. mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
  476. mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
  477. mindspore/_extends/graph_kernel/expander.py +0 -80
  478. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
  479. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  480. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  481. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  482. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  483. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  484. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  485. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  486. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  487. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  488. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  489. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  490. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  491. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  492. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  493. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  494. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  495. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  496. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  497. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  498. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  499. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  500. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  501. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  502. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  503. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  504. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  505. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  506. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  507. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  508. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  509. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  510. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  511. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  512. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  513. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  514. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  515. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  516. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  517. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  518. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  519. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  520. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  521. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  522. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  523. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  524. mindspore/dataset/datapreprocess/__init__.py +0 -20
  525. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  526. mindspore/include/api/net.h +0 -142
  527. mindspore/nn/lr_scheduler.py +0 -262
  528. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
  529. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
  530. mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
  531. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  532. mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
  533. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
  534. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
  535. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
  536. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  537. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  538. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
  539. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
  540. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  541. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  542. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
  543. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
  544. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
  545. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
  546. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
  547. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  548. mindspore/rewrite/node_visitor.py +0 -44
  549. {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/WHEEL +0 -0
  550. {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2020 Huawei Technologies Co., Ltd
1
+ # Copyright 2020-2023 Huawei Technologies Co., Ltd
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -19,12 +19,15 @@ from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataTyp
19
19
  random_choice_with_mask_op_info = AiCPURegOp("RandomChoiceWithMask") \
20
20
  .fusion_type("OPAQUE") \
21
21
  .input(0, "x", "required") \
22
+ .input(1, "counts", "required") \
23
+ .input(2, "states", "required") \
22
24
  .output(0, "y", "required") \
23
25
  .output(1, "mask", "required") \
24
26
  .attr("count", "int") \
25
27
  .attr("seed", "int") \
26
28
  .attr("seed2", "int") \
27
- .dtype_format(DataType.BOOL_Default, DataType.I32_Default, DataType.BOOL_Default) \
29
+ .dtype_format(DataType.BOOL_Default, DataType.U64_Default, DataType.U64_Default,
30
+ DataType.I32_Default, DataType.BOOL_Default) \
28
31
  .get_op_info()
29
32
 
30
33
  @op_info_register(random_choice_with_mask_op_info)
@@ -1,4 +1,4 @@
1
- # Copyright 2022 Huawei Technologies Co., Ltd
1
+ # Copyright 2022-2023 Huawei Technologies Co., Ltd
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -20,60 +20,111 @@ random_poisson_op_info = AiCPURegOp("RandomPoisson") \
20
20
  .fusion_type("OPAQUE") \
21
21
  .input(0, "shape", "required") \
22
22
  .input(1, "rate", "required") \
23
+ .input(2, "counts", "required") \
24
+ .input(3, "states", "required") \
23
25
  .output(0, "output", "required") \
24
26
  .attr("seed", "int") \
25
27
  .attr("seed2", "int") \
26
- .attr("dtype", "Type") \
27
- .dtype_format(DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \
28
- .dtype_format(DataType.I32_Default, DataType.F16_Default, DataType.F32_Default) \
29
- .dtype_format(DataType.I32_Default, DataType.F16_Default, DataType.F64_Default) \
30
- .dtype_format(DataType.I32_Default, DataType.F16_Default, DataType.I32_Default) \
31
- .dtype_format(DataType.I32_Default, DataType.F16_Default, DataType.I64_Default) \
32
- .dtype_format(DataType.I32_Default, DataType.F32_Default, DataType.F16_Default) \
33
- .dtype_format(DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \
34
- .dtype_format(DataType.I32_Default, DataType.F32_Default, DataType.F64_Default) \
35
- .dtype_format(DataType.I32_Default, DataType.F32_Default, DataType.I32_Default) \
36
- .dtype_format(DataType.I32_Default, DataType.F32_Default, DataType.I64_Default) \
37
- .dtype_format(DataType.I32_Default, DataType.F64_Default, DataType.F16_Default) \
38
- .dtype_format(DataType.I32_Default, DataType.F64_Default, DataType.F32_Default) \
39
- .dtype_format(DataType.I32_Default, DataType.F64_Default, DataType.F64_Default) \
40
- .dtype_format(DataType.I32_Default, DataType.F64_Default, DataType.I32_Default) \
41
- .dtype_format(DataType.I32_Default, DataType.F64_Default, DataType.I64_Default) \
42
- .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.F16_Default) \
43
- .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.F32_Default) \
44
- .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.F64_Default) \
45
- .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
46
- .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I64_Default) \
47
- .dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.F16_Default) \
48
- .dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.F32_Default) \
49
- .dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.F64_Default) \
50
- .dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.I32_Default) \
51
- .dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.I64_Default) \
52
- .dtype_format(DataType.I64_Default, DataType.F16_Default, DataType.F16_Default) \
53
- .dtype_format(DataType.I64_Default, DataType.F16_Default, DataType.F32_Default) \
54
- .dtype_format(DataType.I64_Default, DataType.F16_Default, DataType.F64_Default) \
55
- .dtype_format(DataType.I64_Default, DataType.F16_Default, DataType.I32_Default) \
56
- .dtype_format(DataType.I64_Default, DataType.F16_Default, DataType.I64_Default) \
57
- .dtype_format(DataType.I64_Default, DataType.F32_Default, DataType.F16_Default) \
58
- .dtype_format(DataType.I64_Default, DataType.F32_Default, DataType.F32_Default) \
59
- .dtype_format(DataType.I64_Default, DataType.F32_Default, DataType.F64_Default) \
60
- .dtype_format(DataType.I64_Default, DataType.F32_Default, DataType.I32_Default) \
61
- .dtype_format(DataType.I64_Default, DataType.F32_Default, DataType.I64_Default) \
62
- .dtype_format(DataType.I64_Default, DataType.F64_Default, DataType.F16_Default) \
63
- .dtype_format(DataType.I64_Default, DataType.F64_Default, DataType.F32_Default) \
64
- .dtype_format(DataType.I64_Default, DataType.F64_Default, DataType.F64_Default) \
65
- .dtype_format(DataType.I64_Default, DataType.F64_Default, DataType.I32_Default) \
66
- .dtype_format(DataType.I64_Default, DataType.F64_Default, DataType.I64_Default) \
67
- .dtype_format(DataType.I64_Default, DataType.I32_Default, DataType.F16_Default) \
68
- .dtype_format(DataType.I64_Default, DataType.I32_Default, DataType.F32_Default) \
69
- .dtype_format(DataType.I64_Default, DataType.I32_Default, DataType.F64_Default) \
70
- .dtype_format(DataType.I64_Default, DataType.I32_Default, DataType.I32_Default) \
71
- .dtype_format(DataType.I64_Default, DataType.I32_Default, DataType.I64_Default) \
72
- .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.F16_Default) \
73
- .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.F32_Default) \
74
- .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.F64_Default) \
75
- .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I32_Default) \
76
- .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
28
+ .dtype_format(DataType.I32_Default, DataType.F16_Default, DataType.U64_Default,
29
+ DataType.U64_Default, DataType.F16_Default) \
30
+ .dtype_format(DataType.I32_Default, DataType.F16_Default, DataType.U64_Default,
31
+ DataType.U64_Default, DataType.F32_Default) \
32
+ .dtype_format(DataType.I32_Default, DataType.F16_Default, DataType.U64_Default,
33
+ DataType.U64_Default, DataType.F64_Default) \
34
+ .dtype_format(DataType.I32_Default, DataType.F16_Default, DataType.U64_Default,
35
+ DataType.U64_Default, DataType.I32_Default) \
36
+ .dtype_format(DataType.I32_Default, DataType.F16_Default, DataType.U64_Default,
37
+ DataType.U64_Default, DataType.I64_Default) \
38
+ .dtype_format(DataType.I32_Default, DataType.F32_Default, DataType.U64_Default,
39
+ DataType.U64_Default, DataType.F16_Default) \
40
+ .dtype_format(DataType.I32_Default, DataType.F32_Default, DataType.U64_Default,
41
+ DataType.U64_Default, DataType.F32_Default) \
42
+ .dtype_format(DataType.I32_Default, DataType.F32_Default, DataType.U64_Default,
43
+ DataType.U64_Default, DataType.F64_Default) \
44
+ .dtype_format(DataType.I32_Default, DataType.F32_Default, DataType.U64_Default,
45
+ DataType.U64_Default, DataType.I32_Default) \
46
+ .dtype_format(DataType.I32_Default, DataType.F32_Default, DataType.U64_Default,
47
+ DataType.U64_Default, DataType.I64_Default) \
48
+ .dtype_format(DataType.I32_Default, DataType.F64_Default, DataType.U64_Default,
49
+ DataType.U64_Default, DataType.F16_Default) \
50
+ .dtype_format(DataType.I32_Default, DataType.F64_Default, DataType.U64_Default,
51
+ DataType.U64_Default, DataType.F32_Default) \
52
+ .dtype_format(DataType.I32_Default, DataType.F64_Default, DataType.U64_Default,
53
+ DataType.U64_Default, DataType.F64_Default) \
54
+ .dtype_format(DataType.I32_Default, DataType.F64_Default, DataType.U64_Default,
55
+ DataType.U64_Default, DataType.I32_Default) \
56
+ .dtype_format(DataType.I32_Default, DataType.F64_Default, DataType.U64_Default,
57
+ DataType.U64_Default, DataType.I64_Default) \
58
+ .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.U64_Default,
59
+ DataType.U64_Default, DataType.F16_Default) \
60
+ .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.U64_Default,
61
+ DataType.U64_Default, DataType.F32_Default) \
62
+ .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.U64_Default,
63
+ DataType.U64_Default, DataType.F64_Default) \
64
+ .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.U64_Default,
65
+ DataType.U64_Default, DataType.I32_Default) \
66
+ .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.U64_Default,
67
+ DataType.U64_Default, DataType.I64_Default) \
68
+ .dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.U64_Default,
69
+ DataType.U64_Default, DataType.F16_Default) \
70
+ .dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.U64_Default,
71
+ DataType.U64_Default, DataType.F32_Default) \
72
+ .dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.U64_Default,
73
+ DataType.U64_Default, DataType.F64_Default) \
74
+ .dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.U64_Default,
75
+ DataType.U64_Default, DataType.I32_Default) \
76
+ .dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.U64_Default,
77
+ DataType.U64_Default, DataType.I64_Default) \
78
+ .dtype_format(DataType.I64_Default, DataType.F16_Default, DataType.U64_Default,
79
+ DataType.U64_Default, DataType.F16_Default) \
80
+ .dtype_format(DataType.I64_Default, DataType.F16_Default, DataType.U64_Default,
81
+ DataType.U64_Default, DataType.F32_Default) \
82
+ .dtype_format(DataType.I64_Default, DataType.F16_Default, DataType.U64_Default,
83
+ DataType.U64_Default, DataType.F64_Default) \
84
+ .dtype_format(DataType.I64_Default, DataType.F16_Default, DataType.U64_Default,
85
+ DataType.U64_Default, DataType.I32_Default) \
86
+ .dtype_format(DataType.I64_Default, DataType.F16_Default, DataType.U64_Default,
87
+ DataType.U64_Default, DataType.I64_Default) \
88
+ .dtype_format(DataType.I64_Default, DataType.F32_Default, DataType.U64_Default,
89
+ DataType.U64_Default, DataType.F16_Default) \
90
+ .dtype_format(DataType.I64_Default, DataType.F32_Default, DataType.U64_Default,
91
+ DataType.U64_Default, DataType.F32_Default) \
92
+ .dtype_format(DataType.I64_Default, DataType.F32_Default, DataType.U64_Default,
93
+ DataType.U64_Default, DataType.F64_Default) \
94
+ .dtype_format(DataType.I64_Default, DataType.F32_Default, DataType.U64_Default,
95
+ DataType.U64_Default, DataType.I32_Default) \
96
+ .dtype_format(DataType.I64_Default, DataType.F32_Default, DataType.U64_Default,
97
+ DataType.U64_Default, DataType.I64_Default) \
98
+ .dtype_format(DataType.I64_Default, DataType.F64_Default, DataType.U64_Default,
99
+ DataType.U64_Default, DataType.F16_Default) \
100
+ .dtype_format(DataType.I64_Default, DataType.F64_Default, DataType.U64_Default,
101
+ DataType.U64_Default, DataType.F32_Default) \
102
+ .dtype_format(DataType.I64_Default, DataType.F64_Default, DataType.U64_Default,
103
+ DataType.U64_Default, DataType.F64_Default) \
104
+ .dtype_format(DataType.I64_Default, DataType.F64_Default, DataType.U64_Default,
105
+ DataType.U64_Default, DataType.I32_Default) \
106
+ .dtype_format(DataType.I64_Default, DataType.F64_Default, DataType.U64_Default,
107
+ DataType.U64_Default, DataType.I64_Default) \
108
+ .dtype_format(DataType.I64_Default, DataType.I32_Default, DataType.U64_Default,
109
+ DataType.U64_Default, DataType.F16_Default) \
110
+ .dtype_format(DataType.I64_Default, DataType.I32_Default, DataType.U64_Default,
111
+ DataType.U64_Default, DataType.F32_Default) \
112
+ .dtype_format(DataType.I64_Default, DataType.I32_Default, DataType.U64_Default,
113
+ DataType.U64_Default, DataType.F64_Default) \
114
+ .dtype_format(DataType.I64_Default, DataType.I32_Default, DataType.U64_Default,
115
+ DataType.U64_Default, DataType.I32_Default) \
116
+ .dtype_format(DataType.I64_Default, DataType.I32_Default, DataType.U64_Default,
117
+ DataType.U64_Default, DataType.I64_Default) \
118
+ .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.U64_Default,
119
+ DataType.U64_Default, DataType.F16_Default) \
120
+ .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.U64_Default,
121
+ DataType.U64_Default, DataType.F32_Default) \
122
+ .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.U64_Default,
123
+ DataType.U64_Default, DataType.F64_Default) \
124
+ .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.U64_Default,
125
+ DataType.U64_Default, DataType.I32_Default) \
126
+ .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.U64_Default,
127
+ DataType.U64_Default, DataType.I64_Default) \
77
128
  .get_op_info()
78
129
 
79
130
 
@@ -1,4 +1,4 @@
1
- # Copyright 2020 Huawei Technologies Co., Ltd
1
+ # Copyright 2020-2023 Huawei Technologies Co., Ltd
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -19,23 +19,25 @@ from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataTyp
19
19
  randomshuffle_op_info = AiCPURegOp("RandomShuffle") \
20
20
  .fusion_type("OPAQUE") \
21
21
  .input(0, "x", "required") \
22
+ .input(1, "counts", "required") \
23
+ .input(2, "states", "required") \
22
24
  .output(0, "y", "required") \
23
25
  .attr("seed", "int") \
24
26
  .attr("seed2", "int") \
25
- .dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
26
- .dtype_format(DataType.I8_Default, DataType.I8_Default) \
27
- .dtype_format(DataType.I16_Default, DataType.I16_Default) \
28
- .dtype_format(DataType.I32_Default, DataType.I32_Default) \
29
- .dtype_format(DataType.I64_Default, DataType.I64_Default) \
30
- .dtype_format(DataType.U8_Default, DataType.U8_Default) \
31
- .dtype_format(DataType.U16_Default, DataType.U16_Default) \
32
- .dtype_format(DataType.U32_Default, DataType.U32_Default) \
33
- .dtype_format(DataType.U64_Default, DataType.U64_Default) \
34
- .dtype_format(DataType.F16_Default, DataType.F16_Default) \
35
- .dtype_format(DataType.F32_Default, DataType.F32_Default) \
36
- .dtype_format(DataType.F64_Default, DataType.F64_Default) \
37
- .dtype_format(DataType.C64_Default, DataType.C64_Default) \
38
- .dtype_format(DataType.C128_Default, DataType.C128_Default) \
27
+ .dtype_format(DataType.BOOL_Default, DataType.U64_Default, DataType.U64_Default, DataType.BOOL_Default) \
28
+ .dtype_format(DataType.I8_Default, DataType.U64_Default, DataType.U64_Default, DataType.I8_Default) \
29
+ .dtype_format(DataType.I16_Default, DataType.U64_Default, DataType.U64_Default, DataType.I16_Default) \
30
+ .dtype_format(DataType.I32_Default, DataType.U64_Default, DataType.U64_Default, DataType.I32_Default) \
31
+ .dtype_format(DataType.I64_Default, DataType.U64_Default, DataType.U64_Default, DataType.I64_Default) \
32
+ .dtype_format(DataType.U8_Default, DataType.U64_Default, DataType.U64_Default, DataType.U8_Default) \
33
+ .dtype_format(DataType.U16_Default, DataType.U64_Default, DataType.U64_Default, DataType.U16_Default) \
34
+ .dtype_format(DataType.U32_Default, DataType.U64_Default, DataType.U64_Default, DataType.U32_Default) \
35
+ .dtype_format(DataType.U64_Default, DataType.U64_Default, DataType.U64_Default, DataType.U64_Default) \
36
+ .dtype_format(DataType.F16_Default, DataType.U64_Default, DataType.U64_Default, DataType.F16_Default) \
37
+ .dtype_format(DataType.F32_Default, DataType.U64_Default, DataType.U64_Default, DataType.F32_Default) \
38
+ .dtype_format(DataType.F64_Default, DataType.U64_Default, DataType.U64_Default, DataType.F64_Default) \
39
+ .dtype_format(DataType.C64_Default, DataType.U64_Default, DataType.U64_Default, DataType.C64_Default) \
40
+ .dtype_format(DataType.C128_Default, DataType.U64_Default, DataType.U64_Default, DataType.C128_Default) \
39
41
  .get_op_info()
40
42
 
41
43
 
@@ -16,7 +16,7 @@
16
16
  """SparseAddmm op"""
17
17
  from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
18
18
 
19
- sparseaddmm_op_info = AiCPURegOp("SparseAddmm") \
19
+ sparse_addmm_op_info = AiCPURegOp("SparseAddmm") \
20
20
  .fusion_type("OPAQUE") \
21
21
  .input(0, "x1_indices", "required") \
22
22
  .input(1, "x1_values", "required") \
@@ -81,7 +81,7 @@ sparseaddmm_op_info = AiCPURegOp("SparseAddmm") \
81
81
  .get_op_info()
82
82
 
83
83
 
84
- @op_info_register(sparseaddmm_op_info)
84
+ @op_info_register(sparse_addmm_op_info)
85
85
  def _sparse_addmm_aicpu():
86
86
  """SparseAddmm AiCPU register"""
87
87
  return
@@ -13,10 +13,10 @@
13
13
  # limitations under the License.
14
14
  # ============================================================================
15
15
 
16
- """sparsesparsemaximum op"""
16
+ """sparse_sparse_maximum op"""
17
17
  from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
18
18
 
19
- sparsesparsemaximum_op_info = AiCPURegOp("SparseSparseMaximum") \
19
+ sparse_sparse_maximum_op_info = AiCPURegOp("SparseSparseMaximum") \
20
20
  .fusion_type("OPAQUE") \
21
21
  .input(0, "x1_indices", "required") \
22
22
  .input(1, "x1_values", "required") \
@@ -47,7 +47,7 @@ sparsesparsemaximum_op_info = AiCPURegOp("SparseSparseMaximum") \
47
47
  .get_op_info()
48
48
 
49
49
 
50
- @op_info_register(sparsesparsemaximum_op_info)
51
- def _sparsesparsemaximum_aicpu():
50
+ @op_info_register(sparse_sparse_maximum_op_info)
51
+ def _sparse_sparse_maximum_aicpu():
52
52
  """SparseSparseMaximum AiCPU register"""
53
53
  return
@@ -1,4 +1,4 @@
1
- # Copyright 2020 Huawei Technologies Co., Ltd
1
+ # Copyright 2020-2023 Huawei Technologies Co., Ltd
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -19,13 +19,13 @@ from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataTyp
19
19
  laplace_op_info = AiCPURegOp("StandardLaplace") \
20
20
  .fusion_type("OPAQUE") \
21
21
  .input(0, "shape", "required") \
22
- .input(1, "seed", "required") \
23
- .input(2, "seed2", "required") \
22
+ .input(1, "counts", "required") \
23
+ .input(2, "states", "required") \
24
24
  .output(0, "output", "required") \
25
25
  .attr("seed", "int") \
26
26
  .attr("seed2", "int") \
27
- .dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.I64_Default, DataType.F32_Default) \
28
- .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.F32_Default) \
27
+ .dtype_format(DataType.I32_Default, DataType.U64_Default, DataType.U64_Default, DataType.F32_Default) \
28
+ .dtype_format(DataType.I64_Default, DataType.U64_Default, DataType.U64_Default, DataType.F32_Default) \
29
29
  .get_op_info()
30
30
 
31
31
  @op_info_register(laplace_op_info)
@@ -1,4 +1,4 @@
1
- # Copyright 2020 Huawei Technologies Co., Ltd
1
+ # Copyright 2020-2023 Huawei Technologies Co., Ltd
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -19,13 +19,13 @@ from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataTyp
19
19
  normal_op_info = AiCPURegOp("StandardNormal") \
20
20
  .fusion_type("OPAQUE") \
21
21
  .input(0, "shape", "required") \
22
- .input(1, "seed", "required") \
23
- .input(2, "seed2", "required") \
22
+ .input(1, "counts", "required") \
23
+ .input(2, "states", "required") \
24
24
  .output(0, "output", "required") \
25
25
  .attr("seed", "int") \
26
26
  .attr("seed2", "int") \
27
- .dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.I64_Default, DataType.F32_Default) \
28
- .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.F32_Default) \
27
+ .dtype_format(DataType.I32_Default, DataType.U64_Default, DataType.U64_Default, DataType.F32_Default) \
28
+ .dtype_format(DataType.I64_Default, DataType.U64_Default, DataType.U64_Default, DataType.F32_Default) \
29
29
  .get_op_info()
30
30
 
31
31
  @op_info_register(normal_op_info)
@@ -1,4 +1,4 @@
1
- # Copyright 2021 Huawei Technologies Co., Ltd
1
+ # Copyright 2021-2023 Huawei Technologies Co., Ltd
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -19,15 +19,17 @@ from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataTyp
19
19
  truncated_normal_op_info = AiCPURegOp("TruncatedNormal")\
20
20
  .fusion_type("OPAQUE")\
21
21
  .input(0, "shape", "required")\
22
+ .input(1, "counts", "required") \
23
+ .input(2, "states", "required") \
22
24
  .output(0, "output", "required")\
23
25
  .attr("seed", "int")\
24
26
  .attr("seed2", "int")\
25
- .dtype_format(DataType.I32_Default, DataType.F16_Default)\
26
- .dtype_format(DataType.I32_Default, DataType.F32_Default)\
27
- .dtype_format(DataType.I32_Default, DataType.F64_Default)\
28
- .dtype_format(DataType.I64_Default, DataType.F16_Default)\
29
- .dtype_format(DataType.I64_Default, DataType.F32_Default)\
30
- .dtype_format(DataType.I64_Default, DataType.F64_Default)\
27
+ .dtype_format(DataType.I32_Default, DataType.U64_Default, DataType.U64_Default, DataType.F16_Default)\
28
+ .dtype_format(DataType.I32_Default, DataType.U64_Default, DataType.U64_Default, DataType.F32_Default)\
29
+ .dtype_format(DataType.I32_Default, DataType.U64_Default, DataType.U64_Default, DataType.F64_Default)\
30
+ .dtype_format(DataType.I64_Default, DataType.U64_Default, DataType.U64_Default, DataType.F16_Default)\
31
+ .dtype_format(DataType.I64_Default, DataType.U64_Default, DataType.U64_Default, DataType.F32_Default)\
32
+ .dtype_format(DataType.I64_Default, DataType.U64_Default, DataType.U64_Default, DataType.F64_Default)\
31
33
  .get_op_info()
32
34
 
33
35
 
@@ -1,4 +1,4 @@
1
- # Copyright 2022 Huawei Technologies Co., Ltd
1
+ # Copyright 2022-2023 Huawei Technologies Co., Ltd
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -18,13 +18,15 @@ from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataTyp
18
18
  uniform_op_info = AiCPURegOp("Uniform") \
19
19
  .fusion_type("OPAQUE") \
20
20
  .input(0, "x", "required") \
21
+ .input(1, "counts", "required") \
22
+ .input(2, "states", "required") \
21
23
  .output(0, "y", "required") \
22
24
  .attr("from", "float") \
23
25
  .attr("to", "float") \
24
26
  .attr("seed", "int") \
25
27
  .attr("offset", "int") \
26
- .dtype_format(DataType.F32_Default, DataType.F32_Default) \
27
- .dtype_format(DataType.F64_Default, DataType.F64_Default) \
28
+ .dtype_format(DataType.F32_Default, DataType.U64_Default, DataType.U64_Default, DataType.F32_Default) \
29
+ .dtype_format(DataType.F64_Default, DataType.U64_Default, DataType.U64_Default, DataType.F64_Default) \
28
30
  .get_op_info()
29
31
 
30
32
 
@@ -1,4 +1,4 @@
1
- # Copyright 2020 Huawei Technologies Co., Ltd
1
+ # Copyright 2020-2023 Huawei Technologies Co., Ltd
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -18,16 +18,20 @@ from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataTyp
18
18
  uniform_candidate_sampler_op_info = AiCPURegOp("UniformCandidateSampler") \
19
19
  .fusion_type("OPAQUE") \
20
20
  .input(0, "true_classes", "required") \
21
+ .input(1, "counts", "required") \
22
+ .input(2, "states", "required") \
21
23
  .output(0, "sampled_candidates", "required") \
22
24
  .output(1, "true_expected_count", "required") \
23
- .output(2, "true_expected_count", "required") \
25
+ .output(2, "sampled_expected_count", "required") \
24
26
  .attr("num_true", "int") \
25
27
  .attr("num_sampled", "int") \
26
28
  .attr("unique", "bool") \
27
29
  .attr("range_max", "int") \
28
30
  .attr("seed", "int") \
29
- .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.F32_Default, DataType.F32_Default) \
30
- .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \
31
+ .dtype_format(DataType.I64_Default, DataType.U64_Default, DataType.U64_Default, DataType.I64_Default,
32
+ DataType.F32_Default, DataType.F32_Default) \
33
+ .dtype_format(DataType.I32_Default, DataType.U64_Default, DataType.U64_Default, DataType.I32_Default,
34
+ DataType.F32_Default, DataType.F32_Default) \
31
35
  .get_op_info()
32
36
 
33
37
 
@@ -1,4 +1,4 @@
1
- # Copyright 2020 Huawei Technologies Co., Ltd
1
+ # Copyright 2020-2023 Huawei Technologies Co., Ltd
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -21,13 +21,13 @@ uniform_int_op_info = AiCPURegOp("UniformInt") \
21
21
  .input(0, "shape", "required") \
22
22
  .input(1, "a", "required") \
23
23
  .input(2, "b", "required") \
24
- .input(3, "seed", "required") \
25
- .input(4, "seed2", "required") \
24
+ .input(3, "counts", "required") \
25
+ .input(4, "states", "required") \
26
26
  .output(0, "output", "required") \
27
27
  .attr("seed", "int") \
28
28
  .attr("seed2", "int") \
29
- .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I64_Default,
30
- DataType.I64_Default, DataType.I32_Default) \
29
+ .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.U64_Default,
30
+ DataType.U64_Default, DataType.I32_Default) \
31
31
  .get_op_info()
32
32
 
33
33
  @op_info_register(uniform_int_op_info)
@@ -1,4 +1,4 @@
1
- # Copyright 2020 Huawei Technologies Co., Ltd
1
+ # Copyright 2020-2023 Huawei Technologies Co., Ltd
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -19,12 +19,12 @@ from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataTyp
19
19
  uniform_real_op_info = AiCPURegOp("UniformReal") \
20
20
  .fusion_type("OPAQUE") \
21
21
  .input(0, "shape", "required") \
22
- .input(1, "seed", "required") \
23
- .input(2, "seed2", "required") \
22
+ .input(1, "counts", "required") \
23
+ .input(2, "states", "required") \
24
24
  .output(0, "output", "required") \
25
25
  .attr("seed", "int") \
26
26
  .attr("seed2", "int") \
27
- .dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.I64_Default, DataType.F32_Default) \
27
+ .dtype_format(DataType.I32_Default, DataType.U64_Default, DataType.U64_Default, DataType.F32_Default) \
28
28
  .get_op_info()
29
29
 
30
30
  @op_info_register(uniform_real_op_info)
@@ -16,11 +16,11 @@
16
16
  """tbe ops"""
17
17
  from .broadcast_to import _broadcast_to_tbe # The name is occupied
18
18
  from .broadcast_to_ds import _broadcast_to_ds_tbe # The name is occupied
19
- from .batch_to_space import _batch_to_space_tbe # attr type is listIntnot listListInt
20
- from .batch_to_space_nd import _batch_to_space_nd_tbe # attr type is listIntnot listListInt
19
+ from .batch_to_space import _batch_to_space_tbe # attr type is listInt, not listListInt
20
+ from .batch_to_space_nd import _batch_to_space_nd_tbe # attr type is listInt, not listListInt
21
21
  from .batch_to_space_nd_v2 import _batch_to_space_nd_v2_tbe # The name is occupied
22
- from .space_to_batch import _space_to_batch_tbe # attr type is listIntnot listListInt
23
- from .space_to_batch_nd import _space_to_batch_nd_tbe # attr type is listIntnot listListInt
22
+ from .space_to_batch import _space_to_batch_tbe # attr type is listInt, not listListInt
23
+ from .space_to_batch_nd import _space_to_batch_nd_tbe # attr type is listInt, not listListInt
24
24
  from .dynamic_gru_v2 import _dynamic_gru_v2_tbe # input4 is None, GE will change to hidden op by pass
25
25
  from .dynamic_rnn import _dynamic_rnn_tbe # input4 is None, GE will change to hidden op by pass
26
26
  from .kl_div_loss_grad import _kl_div_loss_grad_tbe # Accuracy issues
@@ -29,10 +29,14 @@ inplace_index_add_op_info = TBERegOp("InplaceIndexAdd") \
29
29
  .input(0, "input_x", False, "required", "all") \
30
30
  .input(1, "indices", False, "required", "all") \
31
31
  .input(2, "input_y", False, "required", "all") \
32
+ .input(3, "alpha", False, "optional", "all") \
32
33
  .output(0, "input_x", False, "required", "all") \
33
- .dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \
34
- .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \
35
- .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
34
+ .dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default,
35
+ DataType.F16_Default) \
36
+ .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default,
37
+ DataType.F32_Default) \
38
+ .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
39
+ DataType.I32_Default) \
36
40
  .get_op_info()
37
41
 
38
42
 
@@ -28,6 +28,8 @@ trans_data_op_info = TBERegOp("TransData") \
28
28
  "DefaultFormat, NC1HWC0, FRACTAL_Z, FRACTAL_NZ, HWCN, C1HWNCoC0, NDHWC, NHWC") \
29
29
  .attr("dst_format", "required", "str",
30
30
  "DefaultFormat, NC1HWC0, FRACTAL_Z, FRACTAL_NZ, HWCN, C1HWNCoC0, NDHWC, NHWC") \
31
+ .attr("src_subformat", "optional", "int", "all", "1") \
32
+ .attr("dst_subformat", "optional", "int", "all", "1") \
31
33
  .attr("groups", "optional", "int", "all", "1") \
32
34
  .input(0, "src", False, "required", "all") \
33
35
  .output(0, "dst", False, "required", "all") \
@@ -85,6 +85,6 @@ def _get_cache_prim(cls: Primitive) -> Primitive:
85
85
  _PRIM_CACHE[key] = prim
86
86
  return _PRIM_CACHE.get(key)
87
87
 
88
- if _is_need_compile(_temp_func):
88
+ if _is_need_compile(_temp_func): # @jit.cond: True
89
89
  return _new_prim_for_graph
90
90
  return _get_cache_prim_for_pynative
@@ -17,12 +17,13 @@ import functools
17
17
  import types
18
18
  import textwrap
19
19
  import inspect
20
+ import os
20
21
  from mindspore.common.tensor import Tensor
21
22
  from mindspore.ops.primitive import _RunOpHook, Primitive
22
23
  from mindspore._c_expression import PackExpander, PackNode
23
24
  from mindspore.common._stub_tensor import StubTensor
24
25
  from mindspore.common._register_for_tensor import tensor_operator_registry
25
- from mindspore.common.api import _handle_func_args
26
+ from mindspore.common.api import _handle_func_args, _pynative_executor
26
27
 
27
28
 
28
29
  class _PackTensor(StubTensor):
@@ -64,6 +65,7 @@ class PackFunc(Primitive):
64
65
  """pack function with lazy expander"""
65
66
 
66
67
  expander = PackExpander.get_instance()
68
+ current = None
67
69
 
68
70
  def __init__(self, fun, unique_key, cell_obj, is_pynative_mode=False):
69
71
  super(PackFunc, self).__init__(self.__class__.__name__)
@@ -79,19 +81,29 @@ class PackFunc(Primitive):
79
81
  args = (self.cell_obj, *args)
80
82
  return self.func(*args, **kwargs)
81
83
  self.kwargs = kwargs
82
- return super().__call__(*args)
84
+ output = super().__call__(*args)
85
+ if self.is_pynative_mode and self.grad_attach_num > 0:
86
+ output_num = len(output) - self.grad_attach_num
87
+ if output_num == 1:
88
+ return output[0]
89
+ return output[:output_num]
90
+ return output
83
91
 
84
92
  def __expand__(self, args):
93
+ old = PackFunc.current
94
+ PackFunc.current = self
85
95
  if self.cell_obj:
86
96
  args = (self.cell_obj, *args)
87
97
  with _SetMixedPrecision(self.cell_obj):
88
98
  ret = self._run_op(args)
89
- return ret
90
- return self._run_op(args)
99
+ else:
100
+ ret = self._run_op(args)
101
+ PackFunc.current = old
102
+ return ret
91
103
 
92
104
  @staticmethod
93
105
  def is_tracing():
94
- return _RunOpHook.current and _RunOpHook.current.hook is PackFunc._trace_run_op
106
+ return PackFunc.current is not None
95
107
 
96
108
  @staticmethod
97
109
  def _trace_run_op(obj, args):
@@ -197,13 +209,33 @@ def trace(fn):
197
209
 
198
210
  @functools.wraps(fn)
199
211
  def _trace_wrap(*args, **kwargs):
200
- args, kwargs = _handle_func_args(fn, *args, **kwargs)
201
- obj = None
202
-
203
- if args and not isinstance(args[0], Tensor) and hasattr(args[0], fn.__name__):
204
- obj, args = args[0], args[1:]
205
- key = f"{id(obj)}_{id(fn)}"
206
-
207
- return PackFunc(fn, key, obj, True)(*args, **kwargs)
212
+ pynative_grad_flag = _pynative_executor.grad_flag()
213
+ grad_flag_expr = "1" if pynative_grad_flag else "0"
214
+ if _trace_wrap.is_method is None:
215
+ if args and not isinstance(args[0], Tensor) and hasattr(args[0], fn.__name__):
216
+ _trace_wrap.is_method = False
217
+ else:
218
+ _trace_wrap.is_method = True
219
+ if _trace_wrap.is_method:
220
+ # Similar processing has been done in the __call__ of Cell,
221
+ # so only when obj is None, there is need to do `_handle_func_args`.
222
+ args, kwargs = _handle_func_args(fn, *args, **kwargs)
223
+ pack_func_name = "pack" + grad_flag_expr
224
+ pack_func = getattr(fn, pack_func_name, None)
225
+ if pack_func is None:
226
+ pack_func = PackFunc(fn, f"{id(fn)}_{grad_flag_expr}", None, True)
227
+ setattr(fn, pack_func_name, pack_func)
228
+ return pack_func(*args, **kwargs)
229
+ obj, args = args[0], args[1:]
230
+ pack_func_name = "".join((fn.__name__, "pack", grad_flag_expr))
231
+ pack_func = getattr(obj, pack_func_name, None)
232
+ if pack_func is None:
233
+ pack_func = PackFunc(fn, f"{id(obj)}_{id(fn)}_{grad_flag_expr}", obj, True)
234
+ setattr(obj, pack_func_name, pack_func)
235
+ return pack_func(*args, **kwargs)
236
+
237
+ if "MS_DEV_DISABLE_TRACE" in os.environ and os.environ["MS_DEV_DISABLE_TRACE"] == "on":
238
+ return fn
208
239
  _trace_wrap.pack_fn = fn
240
+ _trace_wrap.is_method = None
209
241
  return _trace_wrap