mindspore 2.1.0__cp38-cp38-manylinux1_x86_64.whl → 2.2.10__cp38-cp38-manylinux1_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (580) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +4 -1
  3. mindspore/_akg/akg/build_module.py +5 -6
  4. mindspore/_akg/akg/composite/build_module.py +46 -19
  5. mindspore/_akg/akg/composite/split_stitch.py +10 -11
  6. mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
  7. mindspore/_akg/akg/tvm/api.py +4 -3
  8. mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
  9. mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
  10. mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
  11. mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
  12. mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
  13. mindspore/_akg/akg/tvm/build_module.py +16 -1
  14. mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
  15. mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
  16. mindspore/_akg/akg/tvm/ir_builder.py +1 -1
  17. mindspore/_akg/akg/tvm/module.py +1 -2
  18. mindspore/_akg/akg/tvm/stmt.py +2 -2
  19. mindspore/_akg/akg/utils/ascend_profilier/__init__.py +0 -0
  20. mindspore/_akg/akg/utils/ascend_profilier/cann_file_parser.py +76 -0
  21. mindspore/_akg/akg/utils/ascend_profilier/file_manager.py +56 -0
  22. mindspore/_akg/akg/utils/ascend_profilier/op_summary_bean.py +23 -0
  23. mindspore/_akg/akg/utils/ascend_profilier/op_summary_headers.py +8 -0
  24. mindspore/_akg/akg/utils/ascend_profilier/op_summary_parser.py +42 -0
  25. mindspore/_akg/akg/utils/ascend_profilier/path_manager.py +65 -0
  26. mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
  27. mindspore/_akg/akg/utils/kernel_exec.py +98 -274
  28. mindspore/_akg/akg/utils/result_analysis.py +4 -24
  29. mindspore/_akg/akg/utils/tbe_codegen_utils.py +219 -0
  30. mindspore/_akg/akg/utils/util.py +38 -0
  31. mindspore/_c_dataengine.cpython-38-x86_64-linux-gnu.so +0 -0
  32. mindspore/_c_expression.cpython-38-x86_64-linux-gnu.so +0 -0
  33. mindspore/_c_mindrecord.cpython-38-x86_64-linux-gnu.so +0 -0
  34. mindspore/_check_jit_forbidden_api.py +3 -1
  35. mindspore/_checkparam.py +23 -29
  36. mindspore/_extends/graph_kernel/__init__.py +0 -1
  37. mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
  38. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  39. mindspore/_extends/graph_kernel/splitter.py +4 -11
  40. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
  41. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +84 -67
  42. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  43. mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
  44. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
  45. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +6 -5
  46. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  47. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  48. mindspore/_extends/parse/__init__.py +12 -15
  49. mindspore/_extends/parse/namespace.py +7 -33
  50. mindspore/_extends/parse/parser.py +61 -71
  51. mindspore/_extends/parse/resources.py +1 -1
  52. mindspore/_extends/parse/standard_method.py +74 -104
  53. mindspore/_extends/parse/trope.py +1 -1
  54. mindspore/_extends/remote/kernel_build_server.py +25 -7
  55. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  56. mindspore/_install_custom.py +43 -0
  57. mindspore/_mindspore_offline_debug.cpython-38-x86_64-linux-gnu.so +0 -0
  58. mindspore/amp.py +47 -11
  59. mindspore/bin/cache_admin +0 -0
  60. mindspore/bin/cache_server +0 -0
  61. mindspore/boost/boost.py +1 -8
  62. mindspore/boost/boost_cell_wrapper.py +3 -2
  63. mindspore/boost/grad_accumulation.py +1 -1
  64. mindspore/boost/group_loss_scale_manager.py +8 -7
  65. mindspore/common/__init__.py +5 -3
  66. mindspore/common/_jit_fallback_utils.py +6 -0
  67. mindspore/common/_register_for_adapter.py +2 -0
  68. mindspore/common/_register_for_tensor.py +2 -2
  69. mindspore/common/_stub_tensor.py +13 -0
  70. mindspore/common/_utils.py +13 -0
  71. mindspore/common/api.py +174 -259
  72. mindspore/common/auto_dynamic_shape.py +494 -0
  73. mindspore/common/dtype.py +18 -11
  74. mindspore/common/dump.py +6 -4
  75. mindspore/common/initializer.py +14 -14
  76. mindspore/common/jit_config.py +33 -15
  77. mindspore/common/lazy_inline.py +126 -7
  78. mindspore/common/mindir_util.py +101 -0
  79. mindspore/common/parameter.py +51 -41
  80. mindspore/common/seed.py +4 -4
  81. mindspore/common/sparse_tensor.py +13 -14
  82. mindspore/common/tensor.py +243 -165
  83. mindspore/communication/__init__.py +7 -4
  84. mindspore/communication/_comm_helper.py +83 -4
  85. mindspore/communication/management.py +152 -84
  86. mindspore/config/op_info.config +14 -3
  87. mindspore/config/super_bar_config.json +4 -2
  88. mindspore/context.py +152 -61
  89. mindspore/dataset/__init__.py +5 -5
  90. mindspore/dataset/audio/__init__.py +2 -2
  91. mindspore/dataset/audio/transforms.py +52 -52
  92. mindspore/dataset/callback/ds_callback.py +16 -2
  93. mindspore/dataset/core/config.py +68 -51
  94. mindspore/dataset/engine/cache_client.py +28 -5
  95. mindspore/dataset/engine/datasets.py +250 -112
  96. mindspore/dataset/engine/datasets_audio.py +43 -211
  97. mindspore/dataset/engine/datasets_standard_format.py +16 -35
  98. mindspore/dataset/engine/datasets_text.py +43 -67
  99. mindspore/dataset/engine/datasets_user_defined.py +86 -100
  100. mindspore/dataset/engine/datasets_vision.py +219 -1029
  101. mindspore/dataset/engine/iterators.py +11 -4
  102. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
  103. mindspore/dataset/engine/obs/util.py +3 -0
  104. mindspore/dataset/engine/samplers.py +1 -1
  105. mindspore/dataset/engine/validators.py +19 -5
  106. mindspore/dataset/text/__init__.py +3 -3
  107. mindspore/dataset/text/transforms.py +101 -127
  108. mindspore/dataset/text/utils.py +205 -138
  109. mindspore/dataset/transforms/__init__.py +1 -1
  110. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  111. mindspore/dataset/transforms/transforms.py +95 -40
  112. mindspore/dataset/utils/browse_dataset.py +8 -2
  113. mindspore/dataset/utils/line_reader.py +17 -19
  114. mindspore/dataset/vision/__init__.py +3 -3
  115. mindspore/dataset/vision/c_transforms.py +6 -3
  116. mindspore/dataset/vision/transforms.py +409 -287
  117. mindspore/dataset/vision/utils.py +13 -14
  118. mindspore/dataset/vision/validators.py +11 -1
  119. mindspore/experimental/map_parameter.py +14 -0
  120. mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
  121. mindspore/{nn/optim_ex → experimental/optim}/adam.py +60 -67
  122. mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
  123. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  124. mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
  125. mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
  126. mindspore/gen_ops.py +273 -0
  127. mindspore/include/OWNERS +0 -1
  128. mindspore/include/api/data_type.h +2 -1
  129. mindspore/include/api/graph.h +0 -15
  130. mindspore/include/api/kernel.h +2 -0
  131. mindspore/include/api/kernel_api.h +37 -12
  132. mindspore/include/api/model.h +17 -14
  133. mindspore/include/api/status.h +8 -3
  134. mindspore/include/api/types.h +37 -4
  135. mindspore/include/c_api/ms/abstract.h +67 -0
  136. mindspore/include/c_api/ms/attribute.h +197 -0
  137. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  138. mindspore/include/c_api/ms/base/macros.h +32 -0
  139. mindspore/include/c_api/ms/base/status.h +33 -0
  140. mindspore/include/c_api/ms/base/types.h +282 -0
  141. mindspore/include/c_api/ms/context.h +102 -0
  142. mindspore/include/c_api/ms/graph.h +160 -0
  143. mindspore/include/c_api/ms/node.h +606 -0
  144. mindspore/include/c_api/ms/tensor.h +161 -0
  145. mindspore/include/c_api/ms/value.h +84 -0
  146. mindspore/include/dataset/constants.h +6 -5
  147. mindspore/include/dataset/execute.h +23 -13
  148. mindspore/include/dataset/text.h +26 -26
  149. mindspore/include/dataset/transforms.h +13 -13
  150. mindspore/include/dataset/vision.h +60 -60
  151. mindspore/include/dataset/vision_ascend.h +5 -6
  152. mindspore/include/dataset/vision_lite.h +17 -17
  153. mindspore/include/mindapi/base/type_id.h +1 -0
  154. mindspore/include/mindapi/base/types.h +1 -0
  155. mindspore/lib/libdnnl.so.2 +0 -0
  156. mindspore/lib/libjemalloc.so.2 +0 -0
  157. mindspore/lib/libmindspore.so +0 -0
  158. mindspore/lib/libmindspore_backend.so +0 -0
  159. mindspore/lib/libmindspore_common.so +0 -0
  160. mindspore/lib/libmindspore_core.so +0 -0
  161. mindspore/lib/libmindspore_glog.so.0 +0 -0
  162. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  163. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  164. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  165. mindspore/lib/libmindspore_shared_lib.so +0 -0
  166. mindspore/lib/libnnacl.so +0 -0
  167. mindspore/lib/libopencv_core.so.4.5 +0 -0
  168. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  169. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  170. mindspore/lib/libps_cache.so +0 -0
  171. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310/aic-ascend310-ops-info.json +123 -0
  172. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310p/aic-ascend310p-ops-info.json +123 -0
  173. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +158 -0
  174. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +37 -0
  175. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
  176. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
  177. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
  178. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
  179. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
  180. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
  181. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
  182. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
  183. mindspore/lib/plugin/ascend/custom_aicore_ops/op_proto/libop_proto.so +0 -0
  184. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  185. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  186. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +8928 -0
  187. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  188. mindspore/lib/plugin/ascend/libakg.so +0 -0
  189. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  190. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  191. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  192. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  193. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  194. mindspore/lib/plugin/cpu/libakg.so +0 -0
  195. mindspore/lib/plugin/gpu/libcuda_ops.so.10 +0 -0
  196. mindspore/lib/plugin/gpu/libcuda_ops.so.11 +0 -0
  197. mindspore/lib/plugin/gpu10.1/libakg.so +0 -0
  198. mindspore/lib/plugin/gpu10.1/libnccl.so.2 +0 -0
  199. mindspore/lib/plugin/gpu11.1/libakg.so +0 -0
  200. mindspore/lib/plugin/gpu11.1/libnccl.so.2 +0 -0
  201. mindspore/lib/plugin/gpu11.6/libakg.so +0 -0
  202. mindspore/lib/plugin/gpu11.6/libnccl.so.2 +0 -0
  203. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  204. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  205. mindspore/lib/plugin/libmindspore_gpu.so.10.1 +0 -0
  206. mindspore/lib/plugin/libmindspore_gpu.so.11.1 +0 -0
  207. mindspore/lib/plugin/libmindspore_gpu.so.11.6 +0 -0
  208. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  209. mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
  210. mindspore/nn/__init__.py +0 -2
  211. mindspore/nn/cell.py +313 -74
  212. mindspore/nn/dynamic_lr.py +21 -21
  213. mindspore/nn/layer/activation.py +22 -30
  214. mindspore/nn/layer/basic.py +15 -13
  215. mindspore/nn/layer/channel_shuffle.py +1 -1
  216. mindspore/nn/layer/container.py +271 -9
  217. mindspore/nn/layer/conv.py +323 -204
  218. mindspore/nn/layer/dense.py +8 -5
  219. mindspore/nn/layer/embedding.py +33 -27
  220. mindspore/nn/layer/flash_attention.py +141 -88
  221. mindspore/nn/layer/image.py +8 -6
  222. mindspore/nn/layer/math.py +16 -25
  223. mindspore/nn/layer/normalization.py +107 -66
  224. mindspore/nn/layer/padding.py +1 -1
  225. mindspore/nn/layer/pooling.py +131 -109
  226. mindspore/nn/layer/rnn_cells.py +27 -22
  227. mindspore/nn/layer/rnns.py +13 -16
  228. mindspore/nn/layer/thor_layer.py +1 -1
  229. mindspore/nn/layer/transformer.py +221 -154
  230. mindspore/nn/learning_rate_schedule.py +9 -1
  231. mindspore/nn/loss/loss.py +235 -174
  232. mindspore/nn/optim/ada_grad.py +2 -1
  233. mindspore/nn/optim/adadelta.py +1 -0
  234. mindspore/nn/optim/adafactor.py +2 -1
  235. mindspore/nn/optim/adam.py +7 -4
  236. mindspore/nn/optim/adamax.py +3 -2
  237. mindspore/nn/optim/adasum.py +2 -2
  238. mindspore/nn/optim/asgd.py +2 -3
  239. mindspore/nn/optim/ftrl.py +6 -5
  240. mindspore/nn/optim/lamb.py +7 -4
  241. mindspore/nn/optim/lars.py +1 -1
  242. mindspore/nn/optim/lazyadam.py +5 -3
  243. mindspore/nn/optim/momentum.py +2 -1
  244. mindspore/nn/optim/optimizer.py +53 -4
  245. mindspore/nn/optim/proximal_ada_grad.py +3 -4
  246. mindspore/nn/optim/rmsprop.py +4 -3
  247. mindspore/nn/optim/rprop.py +23 -12
  248. mindspore/nn/optim/sgd.py +26 -11
  249. mindspore/nn/optim/thor.py +9 -7
  250. mindspore/nn/probability/bijector/bijector.py +5 -5
  251. mindspore/nn/probability/bijector/power_transform.py +27 -27
  252. mindspore/nn/probability/bijector/softplus.py +3 -3
  253. mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
  254. mindspore/nn/probability/distribution/bernoulli.py +5 -5
  255. mindspore/nn/probability/distribution/beta.py +3 -3
  256. mindspore/nn/probability/distribution/categorical.py +7 -7
  257. mindspore/nn/probability/distribution/cauchy.py +0 -1
  258. mindspore/nn/probability/distribution/distribution.py +3 -3
  259. mindspore/nn/probability/distribution/gamma.py +3 -3
  260. mindspore/nn/probability/distribution/geometric.py +4 -4
  261. mindspore/nn/probability/distribution/gumbel.py +4 -4
  262. mindspore/nn/probability/distribution/log_normal.py +2 -2
  263. mindspore/nn/probability/distribution/logistic.py +2 -2
  264. mindspore/nn/probability/distribution/poisson.py +4 -4
  265. mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
  266. mindspore/nn/probability/distribution/uniform.py +6 -6
  267. mindspore/nn/wrap/cell_wrapper.py +84 -34
  268. mindspore/nn/wrap/grad_reducer.py +8 -5
  269. mindspore/nn/wrap/loss_scale.py +105 -42
  270. mindspore/numpy/array_creations.py +1 -2
  271. mindspore/numpy/array_ops.py +3 -2
  272. mindspore/numpy/utils_const.py +5 -5
  273. mindspore/offline_debug/convert_async.py +2 -2
  274. mindspore/ops/_grad_experimental/__init__.py +0 -5
  275. mindspore/ops/_grad_experimental/grad_array_ops.py +2 -3
  276. mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
  277. mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
  278. mindspore/ops/_grad_experimental/grad_implementations.py +11 -1
  279. mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
  280. mindspore/ops/_grad_experimental/grad_math_ops.py +19 -199
  281. mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
  282. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
  283. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  284. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +165 -109
  285. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +144 -86
  286. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +172 -187
  287. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +51 -57
  288. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +6 -17
  289. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +1 -1
  290. mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
  291. mindspore/ops/_op_impl/aicpu/add.py +3 -3
  292. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  293. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  294. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  295. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  296. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  297. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  298. mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
  299. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  300. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  301. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  302. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  303. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  304. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  305. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  306. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
  307. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
  308. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  309. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  310. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  311. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  312. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  313. mindspore/ops/_op_impl/tbe/__init__.py +4 -4
  314. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  315. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  316. mindspore/ops/_primitive_cache.py +1 -1
  317. mindspore/ops/_tracefunc.py +45 -13
  318. mindspore/ops/_utils/utils.py +6 -1
  319. mindspore/ops/_vmap/vmap_array_ops.py +3 -3
  320. mindspore/ops/_vmap/vmap_base.py +3 -3
  321. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  322. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  323. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  324. mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
  325. mindspore/ops/arg_dtype_cast.py +54 -0
  326. mindspore/ops/composite/base.py +37 -10
  327. mindspore/ops/composite/math_ops.py +5 -4
  328. mindspore/ops/composite/multitype_ops/_compile_utils.py +275 -73
  329. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
  330. mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
  331. mindspore/ops/composite/multitype_ops/getitem_impl.py +42 -4
  332. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  333. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  334. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  335. mindspore/ops/deprecated.py +304 -0
  336. mindspore/ops/function/__init__.py +4 -1
  337. mindspore/ops/function/array_func.py +174 -193
  338. mindspore/ops/function/clip_func.py +81 -13
  339. mindspore/ops/function/debug_func.py +1 -1
  340. mindspore/ops/function/grad/grad_func.py +18 -9
  341. mindspore/ops/function/image_func.py +10 -4
  342. mindspore/ops/function/linalg_func.py +5 -5
  343. mindspore/ops/function/math_func.py +575 -386
  344. mindspore/ops/function/nn_func.py +568 -260
  345. mindspore/ops/function/random_func.py +88 -57
  346. mindspore/ops/function/sparse_func.py +1 -1
  347. mindspore/ops/function/sparse_unary_func.py +14 -12
  348. mindspore/ops/function/vmap_func.py +6 -5
  349. mindspore/ops/functional.py +15 -10
  350. mindspore/ops/op_info_register.py +244 -25
  351. mindspore/ops/operations/__init__.py +28 -19
  352. mindspore/ops/operations/_grad_ops.py +72 -7
  353. mindspore/ops/operations/_inner_ops.py +350 -17
  354. mindspore/ops/operations/_quant_ops.py +4 -8
  355. mindspore/ops/operations/_sequence_ops.py +42 -0
  356. mindspore/ops/operations/array_ops.py +68 -282
  357. mindspore/ops/operations/comm_ops.py +107 -59
  358. mindspore/ops/operations/custom_ops.py +94 -70
  359. mindspore/ops/operations/debug_ops.py +8 -4
  360. mindspore/ops/operations/image_ops.py +18 -12
  361. mindspore/ops/operations/inner_ops.py +26 -3
  362. mindspore/ops/operations/math_ops.py +189 -141
  363. mindspore/ops/operations/nn_ops.py +794 -489
  364. mindspore/ops/operations/other_ops.py +0 -22
  365. mindspore/ops/operations/random_ops.py +53 -111
  366. mindspore/ops/operations/sparse_ops.py +3 -1
  367. mindspore/ops/primitive.py +24 -18
  368. mindspore/parallel/_auto_parallel_context.py +68 -8
  369. mindspore/parallel/_cost_model_context.py +2 -2
  370. mindspore/parallel/_offload_context.py +17 -3
  371. mindspore/parallel/_parallel_serialization.py +12 -5
  372. mindspore/parallel/_ps_context.py +12 -0
  373. mindspore/parallel/_tensor.py +18 -13
  374. mindspore/parallel/_transformer/layers.py +5 -3
  375. mindspore/parallel/_transformer/loss.py +1 -0
  376. mindspore/parallel/_transformer/moe.py +2 -2
  377. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  378. mindspore/parallel/_transformer/transformer.py +23 -3
  379. mindspore/parallel/_utils.py +11 -7
  380. mindspore/parallel/algo_parameter_config.py +85 -5
  381. mindspore/parallel/checkpoint_transform.py +19 -12
  382. mindspore/parallel/shard.py +21 -14
  383. mindspore/profiler/common/struct_type.py +3 -3
  384. mindspore/profiler/common/util.py +4 -2
  385. mindspore/profiler/envprofiling.py +1 -1
  386. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  387. mindspore/profiler/parser/ascend_flops_generator.py +2 -2
  388. mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
  389. mindspore/profiler/parser/ascend_hccl_generator.py +249 -12
  390. mindspore/profiler/parser/ascend_msprof_exporter.py +150 -255
  391. mindspore/profiler/parser/ascend_msprof_generator.py +204 -17
  392. mindspore/profiler/parser/ascend_op_generator.py +6 -6
  393. mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
  394. mindspore/profiler/parser/ascend_timeline_generator.py +14 -187
  395. mindspore/profiler/parser/base_timeline_generator.py +10 -8
  396. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +16 -12
  397. mindspore/profiler/parser/flops_parser.py +15 -11
  398. mindspore/profiler/parser/framework_parser.py +38 -22
  399. mindspore/profiler/parser/hccl_parser.py +16 -12
  400. mindspore/profiler/parser/integrator.py +22 -11
  401. mindspore/profiler/parser/memory_usage_parser.py +2 -2
  402. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  403. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  404. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  405. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  406. mindspore/profiler/parser/optime_parser.py +1 -1
  407. mindspore/profiler/parser/profiler_info.py +21 -2
  408. mindspore/profiler/parser/step_trace_parser.py +11 -14
  409. mindspore/profiler/profiling.py +179 -89
  410. mindspore/rewrite/api/node.py +102 -19
  411. mindspore/rewrite/api/node_type.py +5 -1
  412. mindspore/rewrite/api/pattern_engine.py +1 -1
  413. mindspore/rewrite/api/scoped_value.py +9 -17
  414. mindspore/rewrite/api/symbol_tree.py +131 -47
  415. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  416. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  417. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  418. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
  419. mindspore/rewrite/common/rewrite_elog.py +5 -1
  420. mindspore/rewrite/namer.py +33 -24
  421. mindspore/rewrite/namespace.py +14 -5
  422. mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
  423. mindspore/rewrite/node/call_function.py +79 -0
  424. mindspore/rewrite/node/cell_container.py +135 -0
  425. mindspore/rewrite/node/control_flow.py +88 -0
  426. mindspore/rewrite/{node.py → node/node.py} +273 -234
  427. mindspore/rewrite/node/node_manager.py +254 -0
  428. mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
  429. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  430. mindspore/rewrite/parsers/assign_parser.py +216 -221
  431. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  432. mindspore/rewrite/parsers/class_def_parser.py +174 -113
  433. mindspore/rewrite/parsers/constant_parser.py +9 -6
  434. mindspore/rewrite/parsers/container_parser.py +9 -7
  435. mindspore/rewrite/parsers/for_parser.py +36 -15
  436. mindspore/rewrite/parsers/function_def_parser.py +24 -16
  437. mindspore/rewrite/parsers/if_parser.py +28 -24
  438. mindspore/rewrite/parsers/module_parser.py +196 -25
  439. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  440. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  441. mindspore/rewrite/parsers/return_parser.py +6 -6
  442. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  443. mindspore/rewrite/sparsify/utils.py +1 -1
  444. mindspore/rewrite/symbol_tree.py +523 -578
  445. mindspore/rewrite/symbol_tree_builder.py +9 -193
  446. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  447. mindspore/run_check/_check_version.py +6 -4
  448. mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
  449. mindspore/safeguard/rewrite_obfuscation.py +541 -0
  450. mindspore/scipy/linalg.py +1 -1
  451. mindspore/scipy/optimize/minimize.py +7 -3
  452. mindspore/train/_utils.py +7 -3
  453. mindspore/train/amp.py +323 -123
  454. mindspore/train/anf_ir_pb2.py +14 -2
  455. mindspore/train/callback/_backup_and_restore.py +2 -12
  456. mindspore/train/callback/_callback.py +29 -4
  457. mindspore/train/callback/_checkpoint.py +23 -8
  458. mindspore/train/callback/_early_stop.py +2 -2
  459. mindspore/train/callback/_landscape.py +4 -4
  460. mindspore/train/callback/_loss_monitor.py +2 -2
  461. mindspore/train/callback/_on_request_exit.py +2 -2
  462. mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
  463. mindspore/train/callback/_summary_collector.py +15 -8
  464. mindspore/train/callback/_time_monitor.py +58 -5
  465. mindspore/train/data_sink.py +5 -11
  466. mindspore/train/dataset_helper.py +84 -57
  467. mindspore/train/loss_scale_manager.py +2 -2
  468. mindspore/train/metrics/__init__.py +3 -3
  469. mindspore/train/metrics/cosine_similarity.py +1 -1
  470. mindspore/train/metrics/hausdorff_distance.py +3 -2
  471. mindspore/train/metrics/mean_surface_distance.py +3 -2
  472. mindspore/train/metrics/metric.py +39 -19
  473. mindspore/train/metrics/roc.py +2 -2
  474. mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
  475. mindspore/train/mind_ir_pb2.py +85 -36
  476. mindspore/train/model.py +187 -47
  477. mindspore/train/serialization.py +487 -161
  478. mindspore/train/summary/_summary_adapter.py +1 -1
  479. mindspore/train/summary/_writer_pool.py +3 -2
  480. mindspore/train/summary/summary_record.py +37 -17
  481. mindspore/train/train_thor/convert_utils.py +3 -3
  482. mindspore/train/train_thor/dataset_helper.py +1 -1
  483. mindspore/version.py +1 -1
  484. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/METADATA +6 -7
  485. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/RECORD +488 -528
  486. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/entry_points.txt +0 -1
  487. mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
  488. mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
  489. mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
  490. mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
  491. mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
  492. mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
  493. mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
  494. mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
  495. mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
  496. mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
  497. mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
  498. mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
  499. mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
  500. mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
  501. mindspore/_akg/akg/tvm/rpc/base.py +0 -182
  502. mindspore/_akg/akg/tvm/rpc/client.py +0 -436
  503. mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
  504. mindspore/_akg/akg/tvm/rpc/server.py +0 -413
  505. mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
  506. mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
  507. mindspore/_extends/graph_kernel/expander.py +0 -80
  508. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
  509. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  510. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  511. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  512. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  513. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  514. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  515. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  516. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  517. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  518. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  519. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  520. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  521. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  522. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  523. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  524. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  525. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  526. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  527. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  528. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  529. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  530. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  531. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  532. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  533. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  534. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  535. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  536. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  537. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  538. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  539. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  540. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  541. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  542. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  543. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  544. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  545. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  546. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  547. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  548. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  549. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  550. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  551. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  552. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  553. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  554. mindspore/dataset/datapreprocess/__init__.py +0 -20
  555. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  556. mindspore/include/api/net.h +0 -142
  557. mindspore/nn/lr_scheduler.py +0 -262
  558. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
  559. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
  560. mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
  561. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  562. mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
  563. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
  564. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
  565. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
  566. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  567. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  568. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
  569. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
  570. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  571. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  572. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
  573. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
  574. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
  575. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
  576. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
  577. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  578. mindspore/rewrite/node_visitor.py +0 -44
  579. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/WHEEL +0 -0
  580. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/top_level.txt +0 -0
@@ -15,7 +15,7 @@
15
15
  """The impl of flash attention"""
16
16
  from __future__ import absolute_import
17
17
  import mindspore.ops as ops
18
- from mindspore import dtype as mstype
18
+ import mindspore.common.dtype as mstype
19
19
  from mindspore.ops import Custom
20
20
  from mindspore.ops import DataType
21
21
  from mindspore.ops import TBERegOp
@@ -39,31 +39,28 @@ cus_flash_atten_op_info = TBERegOp("FlashAttentionPrimitive") \
39
39
  .input(0, "query", False, "required", "all") \
40
40
  .input(1, "key", False, "required", "all") \
41
41
  .input(2, "value", False, "required", "all") \
42
- .input(3, "dim_mask", False, "required", "all") \
43
- .input(4, "attn_mask", False, "optional", "all") \
44
- .input(5, "dropout_mask", False, "optional", "all") \
45
- .input(6, "alibi_mask", False, "optional", "all") \
42
+ .input(3, "attn_mask", False, "optional", "all") \
43
+ .input(4, "dropout_mask", False, "optional", "all") \
44
+ .input(5, "alibi_mask", False, "optional", "all") \
46
45
  .output(0, "output", False, "required", "all") \
47
46
  .output(1, "rowsum", False, "required", "all") \
48
47
  .output(2, "rowmax", False, "required", "all") \
49
- .dtype_format(DataType.F16_Default,
50
- DataType.F16_Default,
51
- DataType.F16_Default,
52
- DataType.I8_Default,
53
- DataType.F16_Default,
54
- DataType.F16_Default,
55
- DataType.F16_Default,
48
+ .dtype_format(DataType.F16_FracNZ,
49
+ DataType.F16_FracNZ,
50
+ DataType.F16_FracNZ,
51
+ DataType.F16_FracNZ,
56
52
  DataType.F16_Default,
53
+ DataType.F16_FracNZ,
54
+ DataType.F16_FracNZ,
57
55
  DataType.F16_Default,
58
56
  DataType.F16_Default) \
59
- .dtype_format(DataType.F16_Default,
60
- DataType.F16_Default,
61
- DataType.F16_Default,
62
- DataType.I8_Default,
63
- DataType.F16_Default,
64
- DataType.F16_Default,
65
- DataType.F16_Default,
57
+ .dtype_format(DataType.F16_FracNZ,
58
+ DataType.F16_FracNZ,
59
+ DataType.F16_FracNZ,
60
+ DataType.F16_FracNZ,
66
61
  DataType.F16_Default,
62
+ DataType.F16_FracNZ,
63
+ DataType.F16_FracNZ,
67
64
  DataType.F32_Default,
68
65
  DataType.F16_Default) \
69
66
  .get_op_info()
@@ -88,41 +85,38 @@ cus_flash_atten_grad_op_info = TBERegOp("FlashAttentionGradPrimitive") \
88
85
  .input(4, "do", False, "required", "all") \
89
86
  .input(5, "rowsum", False, "required", "all") \
90
87
  .input(6, "rowmax", False, "required", "all") \
91
- .input(7, "dim_mask", False, "required", "all") \
92
- .input(8, "attn_mask", False, "optional", "all") \
93
- .input(9, "dropout_mask", False, "optional", "all") \
94
- .input(10, "alibi_mask", False, "optional", "all") \
88
+ .input(7, "attn_mask", False, "optional", "all") \
89
+ .input(8, "dropout_mask", False, "optional", "all") \
90
+ .input(9, "alibi_mask", False, "optional", "all") \
95
91
  .output(0, "dq", False, "required", "all") \
96
92
  .output(1, "dk", False, "required", "all") \
97
93
  .output(2, "dv", False, "required", "all") \
98
- .dtype_format(DataType.F16_Default,
99
- DataType.F16_Default,
100
- DataType.F16_Default,
101
- DataType.F16_Default,
102
- DataType.F16_Default,
103
- DataType.F16_Default,
104
- DataType.F16_Default,
105
- DataType.I8_Default,
106
- DataType.F16_Default,
107
- DataType.F16_Default,
108
- DataType.F16_Default,
109
- DataType.F32_Default,
110
- DataType.F32_Default,
111
- DataType.F32_Default) \
112
- .dtype_format(DataType.F16_Default,
113
- DataType.F16_Default,
114
- DataType.F16_Default,
115
- DataType.F16_Default,
116
- DataType.F16_Default,
94
+ .dtype_format(DataType.F16_FracNZ,
95
+ DataType.F16_FracNZ,
96
+ DataType.F16_FracNZ,
97
+ DataType.F16_FracNZ,
98
+ DataType.F16_FracNZ,
99
+ DataType.F16_Default,
100
+ DataType.F16_Default,
101
+ DataType.F16_FracNZ,
102
+ DataType.F16_Default,
103
+ DataType.F16_FracNZ,
104
+ DataType.F32_FracNZ,
105
+ DataType.F32_FracNZ,
106
+ DataType.F32_FracNZ) \
107
+ .dtype_format(DataType.F16_FracNZ,
108
+ DataType.F16_FracNZ,
109
+ DataType.F16_FracNZ,
110
+ DataType.F16_FracNZ,
111
+ DataType.F16_FracNZ,
117
112
  DataType.F32_Default,
118
113
  DataType.F16_Default,
119
- DataType.I8_Default,
120
- DataType.F16_Default,
114
+ DataType.F16_FracNZ,
121
115
  DataType.F16_Default,
122
- DataType.F16_Default,
123
- DataType.F32_Default,
124
- DataType.F32_Default,
125
- DataType.F32_Default) \
116
+ DataType.F16_FracNZ,
117
+ DataType.F32_FracNZ,
118
+ DataType.F32_FracNZ,
119
+ DataType.F32_FracNZ) \
126
120
  .get_op_info()
127
121
 
128
122
 
@@ -131,11 +125,11 @@ def get_flash_attention_grad(prev_block_num=65536, next_block_num=65536,
131
125
  """get flash attention grad"""
132
126
 
133
127
  def infer_shape(q_shape, k_shape, v_shape, o_shape, do_shape, l_shape, m_shape,
134
- dim_mask_shape, att_mask_shape, dropout_mask_shape, alibi_mask_shape):
128
+ att_mask_shape, dropout_mask_shape, alibi_mask_shape):
135
129
  return q_shape, k_shape, v_shape
136
130
 
137
131
  def infer_dtype(q_dtype, k_dtype, v_dtype, o_dytpe, do_dtype, l_dtype, m_dtype,
138
- dim_mask_dtype, attn_mask_dtype, dropout_mask_dtype, alibi_mask_type):
132
+ attn_mask_dtype, dropout_mask_dtype, alibi_mask_type):
139
133
  return mstype.float32, mstype.float32, mstype.float32
140
134
 
141
135
  fa_grad = Custom(flash_attention_grad, out_shape=infer_shape,
@@ -145,20 +139,20 @@ def get_flash_attention_grad(prev_block_num=65536, next_block_num=65536,
145
139
  fa_grad.add_prim_attr("high_precision", high_precision)
146
140
  fa_grad.add_prim_attr("tiling_stgy_name", tiling_stgy_name)
147
141
  fa_grad.init_prim_io_names(
148
- inputs=["query", "key", "value", "output", "do", "rowsum", "rowmax", "dim_mask", "attn_mask", "dropout_mask",
142
+ inputs=["query", "key", "value", "output", "do", "rowsum", "rowmax", "attn_mask", "dropout_mask",
149
143
  "alibi_mask"],
150
144
  outputs=["dq", "dk", "dv"]
151
145
  )
152
146
 
153
- def bprop(query, key, value, dim_mask, attn_mask, dropout_mask, alibi_mask, out, douts):
147
+ def bprop(query, key, value, attn_mask, dropout_mask, alibi_mask, out, douts):
154
148
  output, rowsum, rowmax = out
155
149
  dout, _, _ = douts
156
- dq, dk, dv = fa_grad(query, key, value, output, dout, rowsum, rowmax, dim_mask, attn_mask, dropout_mask,
150
+ dq, dk, dv = fa_grad(query, key, value, output, dout, rowsum, rowmax, attn_mask, dropout_mask,
157
151
  alibi_mask)
158
152
  dq = ops.cast(dq, mstype.float16)
159
153
  dk = ops.cast(dk, mstype.float16)
160
154
  dv = ops.cast(dv, mstype.float16)
161
- return dq, dk, dv, zeros_like(dim_mask), zeros_like(attn_mask), \
155
+ return dq, dk, dv, zeros_like(attn_mask), \
162
156
  zeros_like(dropout_mask), zeros_like(alibi_mask)
163
157
 
164
158
  return bprop
@@ -167,7 +161,7 @@ def get_flash_attention_grad(prev_block_num=65536, next_block_num=65536,
167
161
  def get_flash_attention(prev_block_num=65536, next_block_num=65536, tiling_stgy_name='sparse', high_precision=False):
168
162
  """get_flash_attention"""
169
163
 
170
- def infer_shape(q_shape, k_shape, v_shape, dim_mask_shape, attn_mask_shape=None,
164
+ def infer_shape(q_shape, k_shape, v_shape, attn_mask_shape=None,
171
165
  dropout_mask_shape=None, alibi_mask_shape=None):
172
166
  """infer shape"""
173
167
  batch, hidden_size, seq_len, _ = q_shape
@@ -175,7 +169,7 @@ def get_flash_attention(prev_block_num=65536, next_block_num=65536, tiling_stgy_
175
169
  m_shape = (batch, hidden_size, seq_len)
176
170
  return q_shape, l_shape, m_shape
177
171
 
178
- def infer_dtype(q_dtype, k_dtype, v_dtype, dim_mask_dtype, attn_mask_dtype=None,
172
+ def infer_dtype(q_dtype, k_dtype, v_dtype, attn_mask_dtype=None,
179
173
  dropout_mask_dtype=None, alibi_mask_type=None):
180
174
  """infer type"""
181
175
  l_dtype = mstype.float16
@@ -192,7 +186,7 @@ def get_flash_attention(prev_block_num=65536, next_block_num=65536, tiling_stgy_
192
186
  fa_forward.add_prim_attr("high_precision", high_precision)
193
187
  fa_forward.add_prim_attr("tiling_stgy_name", tiling_stgy_name)
194
188
  fa_forward.init_prim_io_names(
195
- inputs=["query", "key", "value", "dim_mask", "attn_mask", "dropout_mask", "alibi_mask"],
189
+ inputs=["query", "key", "value", "attn_mask", "dropout_mask", "alibi_mask"],
196
190
  outputs=["output", "rowsum", "rowmax"]
197
191
  )
198
192
 
@@ -19,7 +19,6 @@ from mindspore.ops._op_impl._custom_op.flash_attention.constants import DTYPE_SI
19
19
  from mindspore.ops._op_impl._custom_op.flash_attention.constants import FP16
20
20
  from mindspore.ops._op_impl._custom_op.flash_attention.constants import FP32
21
21
  from mindspore.ops._op_impl._custom_op.flash_attention.constants import L0C
22
- from mindspore.ops._op_impl._custom_op.flash_attention.constants import L1
23
22
  from mindspore.ops._op_impl._custom_op.flash_attention.constants import UB
24
23
 
25
24
 
@@ -179,7 +178,7 @@ class TikOpsUtils:
179
178
  def broadcast(self, vec_ub, shape):
180
179
  """ broadcast a vector to a matrix
181
180
  :param vec_ub: a tensor in UB with shape of (M,), and dtype is float16
182
- :param shape: the target shape, a tuple with value (M, N)M and N are integer multiples of 16
181
+ :param shape: the target shape, a tuple with value (M, N), M and N are integer multiples of 16
183
182
  :return: a tensor in UB with shape of (M, N)
184
183
  """
185
184
  M, N = shape
@@ -321,27 +320,16 @@ class TikOpsUtils:
321
320
  )
322
321
  return vec_rec_ub
323
322
 
324
- def row_sum_cube_impl(self, matrix_l1_K1MK0_ed, rowsum_ub, m, k, precision_type):
323
+ def row_sum_cube_impl(self, matrix_l1_K1MK0_ed, right_all_one_matrix_l1, rowsum_ub, m, k, precision_type):
325
324
  """用cube实现矩阵行和:右乘一个shape=(n,1)全一矩阵
326
325
  :param matrix_l1_K1MK0_ed: input tensor with shape (K1, M, K0)
327
- :param rowsum_ub: output tensor stores the row sum of input tensor.
326
+ :param right_all_one_matrix_l1: input tensor with shape (K, 16)
327
+ :param rowsum_ub: output tensor stores the row sum of input tensor
328
328
  :param m: actual tensor height
329
329
  :param k: actual tensor width
330
330
  :return: row sum of the output tensor
331
331
  """
332
332
  K1, M, K0 = matrix_l1_K1MK0_ed.shape
333
- K = K1 * K0
334
-
335
- # 构造全一右矩阵,由于cube无法处理shape=(n, 1),所以shape=(n, 16),全一矩阵不需分形
336
- right_all_one_matrix_ub = self.tik_instance.Tensor(
337
- FP16, (K, 16), name="right_all_one_matrix_ub", scope=UB
338
- )
339
- self.tik_instance.h_duplicate(right_all_one_matrix_ub, 1.0)
340
- right_all_one_matrix_l1 = self.tik_instance.Tensor(
341
- FP16, (K1 * K0, 16), name="right_all_one_matrix_l1", scope=L1
342
- )
343
- self.cont_data_mv_1_bust(dst=right_all_one_matrix_l1, src=right_all_one_matrix_ub, burst=K)
344
-
345
333
  # 调用matmul实现rowsum,结果shape=(m, 16),取每行的第一个数
346
334
  with self.tik_instance.new_stmt_scope(disable_sync=False):
347
335
  row_sum_ub_N1MN0 = self.matmul_compute(matrix_l1_K1MK0_ed, right_all_one_matrix_l1, m, k, 16,
@@ -352,6 +340,7 @@ class TikOpsUtils:
352
340
  cur_row_sum = self.tik_instance.Scalar(FP32, init_value=row_sum_ub_MN_ed[idx, 0])
353
341
  rowsum_ub[idx].set_as(cur_row_sum)
354
342
  else:
343
+ # row_sum_ub_MN_ed 先转置,然后取一行, 替换原来按行操作: lij_ub[i].set_as(row_sum_ub_MN_ed[i, 0])
355
344
  row_sum_ub_trans = self.tik_instance.Tensor(FP16, (16, M), name="row_sum_ub_trans", scope=UB)
356
345
  row_sum_ub_trans = self.transpose_matrix(row_sum_ub_MN_ed, row_sum_ub_trans, M, True)
357
346
  self.cont_data_mv_1_bust(dst=rowsum_ub, src=row_sum_ub_trans, burst=M // 16)
@@ -409,7 +398,7 @@ class TikOpsUtils:
409
398
  offset = vec_len - a_burst_num
410
399
  last_blk_ub = self.tik_instance.Tensor(FP16, (a_burst_num,), name="last_blk_ub", scope=UB)
411
400
  self.cont_data_mv_1_bust(dst=last_blk_ub, src=src_tensor[gm_offset + offset], burst=1)
412
- with self.tik_instance.for_range(0, a_burst_num) as idx: # offset非32bytes对齐,无法用datamove
401
+ with self.tik_instance.for_range(0, a_burst_num) as idx: # offset非32bytes对齐, 无法用datamove
413
402
  dst_tensor[offset + idx].set_as(last_blk_ub[idx])
414
403
 
415
404
  def move_vector_from_ub_to_gm(self, dst_tensor, src_tensor, gm_offset, block_h):
@@ -29,7 +29,7 @@ class WukongTiling(TilingStrategy):
29
29
  反向的空间分布待详细分析
30
30
  N = (4096, 1024, 256, 64) 或 77
31
31
  Nq = (4096, 1024, 256, 64)
32
- d = dv = (40, 80, 160160)
32
+ d = dv = (40, 80, 160, 160)
33
33
  """
34
34
  if self.N <= 77: # [77, 64]
35
35
  # cross-attention or self-attention of (64, 64, 160)
@@ -108,6 +108,7 @@ from .search_sorted import _search_sorted_aicpu
108
108
  from .stack import _stack_aicpu
109
109
  from .unstack import _unstack_aicpu
110
110
  from .unsorted_segment_sum import _unsorted_segment_sum_aicpu
111
+ from .unsorted_segment_prod import _unsorted_segment_prod_aicpu
111
112
  from .addcmul import _addcmul_aicpu
112
113
  from .uniform_candidate_sampler import _uniform_candidate_sampler_aicpu
113
114
  from .log_uniform_candidate_sampler import _log_uniform_candidate_sampler_aicpu
@@ -145,6 +146,7 @@ from .upsample_trilinear_3d import _upsample_trilinear_3d_aicpu
145
146
  from .upsample_trilinear_3d_grad import _upsample_trilinear_3d_grad_aicpu
146
147
  from .upper_bound import _upper_bound_aicpu
147
148
  from .cache_swap_table import _cache_swap_table_aicpu
149
+ from .uniform import _uniform_aicpu
148
150
  from .uniform_int import _uniform_int_aicpu
149
151
  from .uniform_real import _uniform_real_aicpu
150
152
  from .standard_laplace import _standard_laplace_aicpu
@@ -156,12 +158,13 @@ from .fused_sparse_adam import _fused_sparse_adam_aicpu
156
158
  from .fused_sparse_lazy_adam import _fused_sparse_lazy_adam_aicpu
157
159
  from .fused_sparse_ftrl import _fused_sparse_ftrl_aicpu
158
160
  from .sparse_fill_empty_rows_grad import _sparse_fill_empty_rows_grad_aicpu
161
+ from .sparse_reorder import _sparse_reorder_aicpu
159
162
  from .sparse_reshape import _sparse_reshape_aicpu
160
163
  from .sparse_segment_sqrt_n_grad import _sparse_segment_sqrt_n_grad_aicpu
161
164
  from .sparse_segment_sum import _sparse_segment_sum_aicpu
162
165
  from .sparse_segment_sum_with_num_segments import _sparse_segment_sum_with_num_segments_aicpu
163
166
  from .sparse_softmax_cross_entropy_with_logits_v2 import _sparse_softmax_cross_entropy_with_logits_v2_aicpu
164
- from .sparsesparsemaximum import _sparsesparsemaximum_aicpu
167
+ from .sparse_sparse_maximum import _sparse_sparse_maximum_aicpu
165
168
  from .split import _split_aicpu
166
169
  from .transpose import _transpose_aicpu
167
170
  from .tril_indices import _tril_indices_aicpu
@@ -205,6 +208,7 @@ from .environ_get import _environ_get_aicpu
205
208
  from .environ_destroy_all import _environ_destroy_all_aicpu
206
209
  from .cross import _cross_aicpu
207
210
  from .check_numerics import _check_numerics_aicpu
211
+ from .cummax import _cummax_aicpu
208
212
  from .cumsum import _cumsum_aicpu
209
213
  from .round import _round_aicpu
210
214
  from .stft import _stft_aicpu
@@ -229,6 +233,7 @@ from .scatter_nd_update import _scatter_nd_update_aicpu
229
233
  from .scatter_nd_max import _scatter_nd_max_aicpu
230
234
  from .conj import _conj_aicpu
231
235
  from .scatter_nd_min import _scatter_nd_min_aicpu
236
+ from .scatter_add_with_axis import _scatter_add_with_axis_aicpu
232
237
  from .compare_and_bitpack import _compare_and_bitpack_aicpu
233
238
  from .addcdiv import _addcdiv_aicpu
234
239
  from .unique_consecutive import _unique_consecutive_aicpu
@@ -241,8 +246,8 @@ from .reservoir_replay_buffer import _rrb_push_op_cpu
241
246
  from .reservoir_replay_buffer import _rrb_sample_op_cpu
242
247
  from .reservoir_replay_buffer import _rrb_destroy_op_cpu
243
248
  from .concat_offset import _concat_offset_aicpu
244
- from .concat_offset_v1 import _concat_offset_v1_aicpu
245
249
  from .range import _range_aicpu
250
+ from .range_v2 import _range_v2_aicpu
246
251
  from .slice_grad import _slice_grad_aicpu
247
252
  from .median import _median_aicpu
248
253
  from .median_grad import _median_grad_aicpu
@@ -272,6 +277,7 @@ from .complex import _complex_aicpu
272
277
  from .complex_abs import _complex_abs_aicpu
273
278
  from .concat import _concat_aicpu
274
279
  from .cos import _cos_aicpu
280
+ from .count_nonzero import _count_nonzero_aicpu
275
281
  from .csr_sparse_matrix_to_dense import _csr_sparse_matrix_to_dense_aicpu
276
282
  from .cumprod import _cumprod_aicpu
277
283
  from .exp import _exp_aicpu
@@ -340,6 +346,7 @@ from .hypot import _hypot_aicpu
340
346
  from .identity_n import _identity_n_aicpu
341
347
  from .index_fill import _index_fill_aicpu
342
348
  from .index_put import _index_put_aicpu
349
+ from .inplace_index_add import _inplace_index_add_aicpu
343
350
  from .kldivloss import _kldiv_loss_aicpu
344
351
  from .kldivlossgrad import _kldiv_loss_grad_aicpu
345
352
  from .lcm import _lcm_aicpu
@@ -400,6 +407,9 @@ from .non_deterministic_ints import _non_deterministic_ints_aicpu
400
407
  from .pow import _pow_aicpu
401
408
  from .real import _real_aicpu
402
409
  from .resize_area import _resize_area_aicpu
410
+ from .segment_mean import _segment_mean_aicpu
411
+ from .segment_min import _segment_min_aicpu
412
+ from .segment_prod import _segment_prod_aicpu
403
413
  from .segment_sum import _segment_sum_aicpu
404
414
  from .set_size import _set_size_aicpu
405
415
  from .slice import _slice_aicpu
@@ -411,6 +421,7 @@ from .sparse_tensor_dense_mat_mul import _sparse_tensor_dense_mat_mul_aicpu
411
421
  from .trace import _trace_aicpu
412
422
  from .tracegrad import _tracegrad_aicpu
413
423
  from .tridiagonal_solve import _tridiagonal_solve_aicpu
424
+ from .tridiagonal_matmul import _tridiagonal_matmul_aicpu
414
425
  from .truncated_normal import _truncated_normal_aicpu
415
426
  from .glu import _glu_aicpu
416
427
  from .deformable_offsets import _deformable_offsets_aicpu
@@ -426,3 +437,4 @@ from .sequence_concat import _sequence_concat_aicpu
426
437
  from .sequence_stack import _sequence_stack_aicpu
427
438
  from .affine_grid import _affine_grid_aicpu
428
439
  from .depth_to_space import _depth_to_space_aicpu
440
+ from .eps import _eps_aicpu
@@ -29,9 +29,9 @@ add_op_info = AiCPURegOp("Add") \
29
29
  .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
30
30
  .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
31
31
  .dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \
32
- .dtype_format(DataType.U16_Default, DataType.I16_Default, DataType.I16_Default) \
33
- .dtype_format(DataType.U32_Default, DataType.I32_Default, DataType.I32_Default) \
34
- .dtype_format(DataType.U64_Default, DataType.I64_Default, DataType.I64_Default) \
32
+ .dtype_format(DataType.U16_Default, DataType.U16_Default, DataType.U16_Default) \
33
+ .dtype_format(DataType.U32_Default, DataType.U32_Default, DataType.U32_Default) \
34
+ .dtype_format(DataType.U64_Default, DataType.U64_Default, DataType.U64_Default) \
35
35
  .dtype_format(DataType.C64_Default, DataType.C64_Default, DataType.C64_Default) \
36
36
  .dtype_format(DataType.C128_Default, DataType.C128_Default, DataType.C128_Default) \
37
37
  .get_op_info()
@@ -31,7 +31,6 @@ bias_add_grad_op_info = AiCPURegOp("BiasAddGrad") \
31
31
  .dtype_format(DataType.I64_Default, DataType.I64_Default) \
32
32
  .dtype_format(DataType.F16_Default, DataType.F16_Default) \
33
33
  .dtype_format(DataType.F32_Default, DataType.F32_Default) \
34
- .dtype_format(DataType.F64_Default, DataType.F64_Default) \
35
34
  .dtype_format(DataType.C64_Default, DataType.C64_Default) \
36
35
  .dtype_format(DataType.C128_Default, DataType.C128_Default) \
37
36
  .get_op_info()
@@ -0,0 +1,43 @@
1
+ # Copyright 2022 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+
16
+ """CountNonZero op"""
17
+ from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
18
+
19
+ count_nonzero_op_info = AiCPURegOp("CountNonZero") \
20
+ .fusion_type("OPAQUE") \
21
+ .input(0, "x", "required") \
22
+ .output(0, "y", "required") \
23
+ .attr("dims", "listInt")\
24
+ .dtype_format(DataType.I8_Default, DataType.I64_Default) \
25
+ .dtype_format(DataType.I16_Default, DataType.I64_Default) \
26
+ .dtype_format(DataType.I32_Default, DataType.I64_Default) \
27
+ .dtype_format(DataType.I64_Default, DataType.I64_Default) \
28
+ .dtype_format(DataType.U8_Default, DataType.I64_Default) \
29
+ .dtype_format(DataType.U16_Default, DataType.I64_Default) \
30
+ .dtype_format(DataType.U32_Default, DataType.I64_Default) \
31
+ .dtype_format(DataType.U64_Default, DataType.I64_Default) \
32
+ .dtype_format(DataType.F16_Default, DataType.I64_Default) \
33
+ .dtype_format(DataType.F32_Default, DataType.I64_Default) \
34
+ .dtype_format(DataType.F64_Default, DataType.I64_Default) \
35
+ .dtype_format(DataType.C64_Default, DataType.I64_Default) \
36
+ .dtype_format(DataType.C128_Default, DataType.I64_Default) \
37
+ .get_op_info()
38
+
39
+
40
+ @op_info_register(count_nonzero_op_info)
41
+ def _count_nonzero_aicpu():
42
+ """CountNonZero AiCPU register"""
43
+ return
@@ -0,0 +1,32 @@
1
+ # Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+
16
+ """Eps op"""
17
+ from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
18
+
19
+ eps_op_info = AiCPURegOp("Eps") \
20
+ .fusion_type("OPAQUE") \
21
+ .input(0, "x", "required") \
22
+ .output(0, "y", "required") \
23
+ .dtype_format(DataType.F16_Default, DataType.F16_Default) \
24
+ .dtype_format(DataType.F32_Default, DataType.F32_Default) \
25
+ .dtype_format(DataType.F64_Default, DataType.F64_Default) \
26
+ .get_op_info()
27
+
28
+
29
+ @op_info_register(eps_op_info)
30
+ def _eps_aicpu():
31
+ """Eps AiCPU register"""
32
+ return
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
  # ============================================================================
15
15
 
16
- """RandomGamma op"""
16
+ """Gamma op"""
17
17
  from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
18
18
 
19
19
  gamma_op_info = AiCPURegOp("Gamma") \
@@ -32,5 +32,5 @@ gamma_op_info = AiCPURegOp("Gamma") \
32
32
 
33
33
  @op_info_register(gamma_op_info)
34
34
  def _gamma_aicpu():
35
- """RandomGamma AiCPU register"""
35
+ """Gamma AiCPU register"""
36
36
  return
@@ -1,4 +1,4 @@
1
- # Copyright 2020 Huawei Technologies Co., Ltd
1
+ # Copyright 2020-2023 Huawei Technologies Co., Ltd
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -18,15 +18,18 @@ from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataTyp
18
18
  log_uniform_candidate_sampler_op_info = AiCPURegOp("LogUniformCandidateSampler") \
19
19
  .fusion_type("OPAQUE") \
20
20
  .input(0, "true_classes", "required") \
21
+ .input(1, "counts", "required") \
22
+ .input(2, "states", "required") \
21
23
  .output(0, "sampled_candidates", "required") \
22
24
  .output(1, "true_expected_count", "required") \
23
- .output(2, "true_expected_count", "required") \
25
+ .output(2, "sampled_expected_count", "required") \
24
26
  .attr("num_true", "int") \
25
27
  .attr("num_sampled", "int") \
26
28
  .attr("unique", "bool") \
27
29
  .attr("range_max", "int") \
28
30
  .attr("seed", "int") \
29
- .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.F32_Default, DataType.F32_Default) \
31
+ .dtype_format(DataType.I64_Default, DataType.U64_Default, DataType.U64_Default, DataType.I64_Default,
32
+ DataType.F32_Default, DataType.F32_Default) \
30
33
  .get_op_info()
31
34
 
32
35
 
@@ -19,7 +19,6 @@ from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataTyp
19
19
  lu_unpack_grad_op_info = AiCPURegOp("LuUnpackGrad") \
20
20
  .fusion_type("OPAQUE") \
21
21
  .attr("L_grad_flag", "bool") \
22
- .attr("L_grad_flag", "bool") \
23
22
  .input(0, "L_grad", "required") \
24
23
  .input(1, "U_grad", "required") \
25
24
  .input(2, "LU_data", "required") \
@@ -1,4 +1,4 @@
1
- # Copyright 2022 Huawei Technologies Co., Ltd
1
+ # Copyright 2022-2023 Huawei Technologies Co., Ltd
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -20,8 +20,8 @@ multinomial_op_info = AiCPURegOp("Multinomial") \
20
20
  .fusion_type("OPAQUE") \
21
21
  .input(0, "input", "required") \
22
22
  .input(1, "num_sample", "required") \
23
- .input(2, "count", "required") \
24
- .input(3, "state", "required") \
23
+ .input(2, "counts", "required") \
24
+ .input(3, "states", "required") \
25
25
  .output(0, "output", "required") \
26
26
  .attr("dtype", "Type") \
27
27
  .attr("seed", "int") \
@@ -1,4 +1,4 @@
1
- # Copyright 2022 Huawei Technologies Co., Ltd
1
+ # Copyright 2022-2023 Huawei Technologies Co., Ltd
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -22,21 +22,29 @@ parameterized_truncated_normal_op_info = AiCPURegOp("ParameterizedTruncatedNorma
22
22
  .input(2, "stdevs", "required") \
23
23
  .input(3, "min", "required") \
24
24
  .input(4, "max", "required") \
25
+ .input(5, "counts", "required") \
26
+ .input(6, "states", "required") \
25
27
  .output(0, "y", "required") \
26
28
  .attr("seed", "int")\
27
29
  .attr("seed2", "int")\
28
30
  .dtype_format(DataType.I32_Default, DataType.F16_Default, DataType.F16_Default,
29
- DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
31
+ DataType.F16_Default, DataType.F16_Default, DataType.U64_Default,
32
+ DataType.U64_Default, DataType.F16_Default) \
30
33
  .dtype_format(DataType.I32_Default, DataType.F32_Default, DataType.F32_Default,
31
- DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
34
+ DataType.F32_Default, DataType.F32_Default, DataType.U64_Default,
35
+ DataType.U64_Default, DataType.F32_Default) \
32
36
  .dtype_format(DataType.I32_Default, DataType.F64_Default, DataType.F64_Default,
33
- DataType.F64_Default, DataType.F64_Default, DataType.F64_Default) \
37
+ DataType.F64_Default, DataType.F64_Default, DataType.U64_Default,
38
+ DataType.U64_Default, DataType.F64_Default) \
34
39
  .dtype_format(DataType.I64_Default, DataType.F16_Default, DataType.F16_Default,
35
- DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
40
+ DataType.F16_Default, DataType.F16_Default, DataType.U64_Default,
41
+ DataType.U64_Default, DataType.F16_Default) \
36
42
  .dtype_format(DataType.I64_Default, DataType.F32_Default, DataType.F32_Default,
37
- DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
43
+ DataType.F32_Default, DataType.F32_Default, DataType.U64_Default,
44
+ DataType.U64_Default, DataType.F32_Default) \
38
45
  .dtype_format(DataType.I64_Default, DataType.F64_Default, DataType.F64_Default,
39
- DataType.F64_Default, DataType.F64_Default, DataType.F64_Default) \
46
+ DataType.F64_Default, DataType.F64_Default, DataType.U64_Default,
47
+ DataType.U64_Default, DataType.F64_Default) \
40
48
  .get_op_info()
41
49
 
42
50
 
@@ -1,4 +1,4 @@
1
- # Copyright 2020 Huawei Technologies Co., Ltd
1
+ # Copyright 2020-2023 Huawei Technologies Co., Ltd
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -21,25 +21,45 @@ random_categorical_op_info = AiCPURegOp("RandomCategorical") \
21
21
  .input(0, "logits", "required") \
22
22
  .input(1, "num_sample", "required") \
23
23
  .input(2, "seed", "required") \
24
+ .input(3, "counts", "required") \
25
+ .input(4, "states", "required") \
24
26
  .output(0, "output", "required") \
25
- .dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.I32_Default, DataType.I16_Default) \
26
- .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I16_Default) \
27
- .dtype_format(DataType.F64_Default, DataType.I32_Default, DataType.I32_Default, DataType.I16_Default) \
28
- .dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
29
- .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
30
- .dtype_format(DataType.F64_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
31
- .dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.I32_Default, DataType.I64_Default) \
32
- .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I64_Default) \
33
- .dtype_format(DataType.F64_Default, DataType.I32_Default, DataType.I32_Default, DataType.I64_Default) \
34
- .dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.I64_Default, DataType.I16_Default) \
35
- .dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.I64_Default, DataType.I16_Default) \
36
- .dtype_format(DataType.F64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I16_Default) \
37
- .dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.I64_Default, DataType.I32_Default) \
38
- .dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.I64_Default, DataType.I32_Default) \
39
- .dtype_format(DataType.F64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I32_Default) \
40
- .dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
41
- .dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
42
- .dtype_format(DataType.F64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
27
+ .dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.I32_Default, DataType.U64_Default,
28
+ DataType.U64_Default, DataType.I16_Default) \
29
+ .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.I32_Default, DataType.U64_Default,
30
+ DataType.U64_Default, DataType.I16_Default) \
31
+ .dtype_format(DataType.F64_Default, DataType.I32_Default, DataType.I32_Default, DataType.U64_Default,
32
+ DataType.U64_Default, DataType.I16_Default) \
33
+ .dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.I32_Default, DataType.U64_Default,
34
+ DataType.U64_Default, DataType.I32_Default) \
35
+ .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.I32_Default, DataType.U64_Default,
36
+ DataType.U64_Default, DataType.I32_Default) \
37
+ .dtype_format(DataType.F64_Default, DataType.I32_Default, DataType.I32_Default, DataType.U64_Default,
38
+ DataType.U64_Default, DataType.I32_Default) \
39
+ .dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.I32_Default, DataType.U64_Default,
40
+ DataType.U64_Default, DataType.I64_Default) \
41
+ .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.I32_Default, DataType.U64_Default,
42
+ DataType.U64_Default, DataType.I64_Default) \
43
+ .dtype_format(DataType.F64_Default, DataType.I32_Default, DataType.I32_Default, DataType.U64_Default,
44
+ DataType.U64_Default, DataType.I64_Default) \
45
+ .dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.I64_Default, DataType.U64_Default,
46
+ DataType.U64_Default, DataType.I16_Default) \
47
+ .dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.I64_Default, DataType.U64_Default,
48
+ DataType.U64_Default, DataType.I16_Default) \
49
+ .dtype_format(DataType.F64_Default, DataType.I64_Default, DataType.I64_Default, DataType.U64_Default,
50
+ DataType.U64_Default, DataType.I16_Default) \
51
+ .dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.I64_Default, DataType.U64_Default,
52
+ DataType.U64_Default, DataType.I32_Default) \
53
+ .dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.I64_Default, DataType.U64_Default,
54
+ DataType.U64_Default, DataType.I32_Default) \
55
+ .dtype_format(DataType.F64_Default, DataType.I64_Default, DataType.I64_Default, DataType.U64_Default,
56
+ DataType.U64_Default, DataType.I32_Default) \
57
+ .dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.I64_Default, DataType.U64_Default,
58
+ DataType.U64_Default, DataType.I64_Default) \
59
+ .dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.I64_Default, DataType.U64_Default,
60
+ DataType.U64_Default, DataType.I64_Default) \
61
+ .dtype_format(DataType.F64_Default, DataType.I64_Default, DataType.I64_Default, DataType.U64_Default,
62
+ DataType.U64_Default, DataType.I64_Default) \
43
63
  .get_op_info()
44
64
 
45
65
  @op_info_register(random_categorical_op_info)