mindspore 2.4.0__cp311-cp311-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (1387) hide show
  1. mindspore/.commit_id +1 -0
  2. mindspore/__init__.py +53 -0
  3. mindspore/_c_dataengine.cpython-311-darwin.so +0 -0
  4. mindspore/_c_expression.cpython-311-darwin.so +0 -0
  5. mindspore/_c_mindrecord.cpython-311-darwin.so +0 -0
  6. mindspore/_check_jit_forbidden_api.py +106 -0
  7. mindspore/_checkparam.py +1419 -0
  8. mindspore/_extends/__init__.py +23 -0
  9. mindspore/_extends/builtin_operations.py +224 -0
  10. mindspore/_extends/graph_kernel/__init__.py +17 -0
  11. mindspore/_extends/graph_kernel/model/__init__.py +19 -0
  12. mindspore/_extends/graph_kernel/model/graph_parallel.py +311 -0
  13. mindspore/_extends/graph_kernel/model/graph_split.py +1348 -0
  14. mindspore/_extends/graph_kernel/model/model.py +553 -0
  15. mindspore/_extends/graph_kernel/model/model_builder.py +216 -0
  16. mindspore/_extends/graph_kernel/parallel_estimate.py +60 -0
  17. mindspore/_extends/graph_kernel/splitter.py +140 -0
  18. mindspore/_extends/graph_kernel/utils.py +28 -0
  19. mindspore/_extends/parallel_compile/__init__.py +19 -0
  20. mindspore/_extends/parallel_compile/akg_compiler/__init__.py +19 -0
  21. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +269 -0
  22. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +529 -0
  23. mindspore/_extends/parallel_compile/akg_compiler/compiler.py +56 -0
  24. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +96 -0
  25. mindspore/_extends/parallel_compile/akg_compiler/get_file_path.py +36 -0
  26. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +556 -0
  27. mindspore/_extends/parallel_compile/akg_compiler/util.py +159 -0
  28. mindspore/_extends/parse/__init__.py +49 -0
  29. mindspore/_extends/parse/compile_config.py +299 -0
  30. mindspore/_extends/parse/namespace.py +136 -0
  31. mindspore/_extends/parse/parser.py +1448 -0
  32. mindspore/_extends/parse/resources.py +213 -0
  33. mindspore/_extends/parse/standard_method.py +4475 -0
  34. mindspore/_extends/parse/trope.py +97 -0
  35. mindspore/_extends/pijit/__init__.py +23 -0
  36. mindspore/_extends/pijit/pijit_func_white_list.py +669 -0
  37. mindspore/_extends/remote/__init__.py +19 -0
  38. mindspore/_extends/remote/kernel_build_server.py +199 -0
  39. mindspore/_extends/remote/kernel_build_server_akg.py +55 -0
  40. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  41. mindspore/_extends/remote/kernel_build_server_ascend.py +75 -0
  42. mindspore/_extends/utils.py +68 -0
  43. mindspore/_install_custom.py +43 -0
  44. mindspore/_profiler.py +30 -0
  45. mindspore/amp.py +433 -0
  46. mindspore/boost/__init__.py +42 -0
  47. mindspore/boost/adasum.py +319 -0
  48. mindspore/boost/base.py +535 -0
  49. mindspore/boost/boost.py +400 -0
  50. mindspore/boost/boost_cell_wrapper.py +790 -0
  51. mindspore/boost/dim_reduce.py +323 -0
  52. mindspore/boost/grad_accumulation.py +79 -0
  53. mindspore/boost/grad_freeze.py +382 -0
  54. mindspore/boost/group_loss_scale_manager.py +166 -0
  55. mindspore/boost/less_batch_normalization.py +174 -0
  56. mindspore/common/__init__.py +86 -0
  57. mindspore/common/_auto_dynamic.py +68 -0
  58. mindspore/common/_decorator.py +50 -0
  59. mindspore/common/_jit_fallback_utils.py +110 -0
  60. mindspore/common/_monad.py +25 -0
  61. mindspore/common/_pijit_context.py +190 -0
  62. mindspore/common/_register_for_adapter.py +74 -0
  63. mindspore/common/_register_for_recompute.py +48 -0
  64. mindspore/common/_register_for_tensor.py +46 -0
  65. mindspore/common/_stub_tensor.py +210 -0
  66. mindspore/common/_tensor_overload.py +139 -0
  67. mindspore/common/_utils.py +122 -0
  68. mindspore/common/api.py +2064 -0
  69. mindspore/common/auto_dynamic_shape.py +507 -0
  70. mindspore/common/dtype.py +422 -0
  71. mindspore/common/dump.py +130 -0
  72. mindspore/common/file_system.py +48 -0
  73. mindspore/common/generator.py +254 -0
  74. mindspore/common/hook_handle.py +143 -0
  75. mindspore/common/initializer.py +880 -0
  76. mindspore/common/jit_config.py +98 -0
  77. mindspore/common/lazy_inline.py +240 -0
  78. mindspore/common/mindir_util.py +111 -0
  79. mindspore/common/mutable.py +234 -0
  80. mindspore/common/no_inline.py +54 -0
  81. mindspore/common/np_dtype.py +25 -0
  82. mindspore/common/parameter.py +1081 -0
  83. mindspore/common/recompute.py +292 -0
  84. mindspore/common/seed.py +260 -0
  85. mindspore/common/sparse_tensor.py +1175 -0
  86. mindspore/common/symbol.py +122 -0
  87. mindspore/common/tensor.py +5039 -0
  88. mindspore/communication/__init__.py +37 -0
  89. mindspore/communication/_comm_helper.py +501 -0
  90. mindspore/communication/_hccl_management.py +297 -0
  91. mindspore/communication/comm_func.py +1395 -0
  92. mindspore/communication/management.py +673 -0
  93. mindspore/config/op_info.config +533 -0
  94. mindspore/context.py +2077 -0
  95. mindspore/dataset/__init__.py +90 -0
  96. mindspore/dataset/audio/__init__.py +61 -0
  97. mindspore/dataset/audio/transforms.py +3690 -0
  98. mindspore/dataset/audio/utils.py +386 -0
  99. mindspore/dataset/audio/validators.py +1172 -0
  100. mindspore/dataset/callback/__init__.py +20 -0
  101. mindspore/dataset/callback/ds_callback.py +368 -0
  102. mindspore/dataset/callback/validators.py +32 -0
  103. mindspore/dataset/core/__init__.py +13 -0
  104. mindspore/dataset/core/config.py +1095 -0
  105. mindspore/dataset/core/datatypes.py +101 -0
  106. mindspore/dataset/core/py_util_helpers.py +65 -0
  107. mindspore/dataset/core/validator_helpers.py +781 -0
  108. mindspore/dataset/debug/__init__.py +21 -0
  109. mindspore/dataset/debug/debug_hook.py +97 -0
  110. mindspore/dataset/debug/pre_defined_hook.py +67 -0
  111. mindspore/dataset/engine/__init__.py +124 -0
  112. mindspore/dataset/engine/cache_admin.py +47 -0
  113. mindspore/dataset/engine/cache_client.py +129 -0
  114. mindspore/dataset/engine/datasets.py +4582 -0
  115. mindspore/dataset/engine/datasets_audio.py +911 -0
  116. mindspore/dataset/engine/datasets_standard_format.py +543 -0
  117. mindspore/dataset/engine/datasets_text.py +2161 -0
  118. mindspore/dataset/engine/datasets_user_defined.py +1184 -0
  119. mindspore/dataset/engine/datasets_vision.py +4816 -0
  120. mindspore/dataset/engine/iterators.py +371 -0
  121. mindspore/dataset/engine/obs/__init__.py +23 -0
  122. mindspore/dataset/engine/obs/config_loader.py +68 -0
  123. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +508 -0
  124. mindspore/dataset/engine/obs/util.py +482 -0
  125. mindspore/dataset/engine/offload.py +596 -0
  126. mindspore/dataset/engine/queue.py +304 -0
  127. mindspore/dataset/engine/samplers.py +895 -0
  128. mindspore/dataset/engine/serializer_deserializer.py +159 -0
  129. mindspore/dataset/engine/validators.py +2895 -0
  130. mindspore/dataset/text/__init__.py +51 -0
  131. mindspore/dataset/text/transforms.py +1703 -0
  132. mindspore/dataset/text/utils.py +715 -0
  133. mindspore/dataset/text/validators.py +642 -0
  134. mindspore/dataset/transforms/__init__.py +45 -0
  135. mindspore/dataset/transforms/c_transforms.py +638 -0
  136. mindspore/dataset/transforms/py_transforms.py +393 -0
  137. mindspore/dataset/transforms/py_transforms_util.py +255 -0
  138. mindspore/dataset/transforms/transforms.py +1260 -0
  139. mindspore/dataset/transforms/validators.py +410 -0
  140. mindspore/dataset/utils/__init__.py +19 -0
  141. mindspore/dataset/utils/browse_dataset.py +190 -0
  142. mindspore/dataset/utils/line_reader.py +126 -0
  143. mindspore/dataset/vision/__init__.py +65 -0
  144. mindspore/dataset/vision/c_transforms.py +2641 -0
  145. mindspore/dataset/vision/py_transforms.py +2120 -0
  146. mindspore/dataset/vision/py_transforms_util.py +1660 -0
  147. mindspore/dataset/vision/transforms.py +7295 -0
  148. mindspore/dataset/vision/utils.py +863 -0
  149. mindspore/dataset/vision/validators.py +1483 -0
  150. mindspore/default_config.py +2 -0
  151. mindspore/experimental/__init__.py +20 -0
  152. mindspore/experimental/es/__init__.py +22 -0
  153. mindspore/experimental/es/embedding_service.py +883 -0
  154. mindspore/experimental/es/embedding_service_layer.py +581 -0
  155. mindspore/experimental/llm_boost/__init__.py +21 -0
  156. mindspore/experimental/llm_boost/atb/__init__.py +23 -0
  157. mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
  158. mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
  159. mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
  160. mindspore/experimental/llm_boost/register.py +129 -0
  161. mindspore/experimental/llm_boost/utils.py +31 -0
  162. mindspore/experimental/map_parameter.py +309 -0
  163. mindspore/experimental/optim/__init__.py +40 -0
  164. mindspore/experimental/optim/adadelta.py +161 -0
  165. mindspore/experimental/optim/adagrad.py +168 -0
  166. mindspore/experimental/optim/adam.py +193 -0
  167. mindspore/experimental/optim/adamax.py +170 -0
  168. mindspore/experimental/optim/adamw.py +290 -0
  169. mindspore/experimental/optim/asgd.py +153 -0
  170. mindspore/experimental/optim/lr_scheduler.py +1371 -0
  171. mindspore/experimental/optim/nadam.py +157 -0
  172. mindspore/experimental/optim/optimizer.py +262 -0
  173. mindspore/experimental/optim/radam.py +194 -0
  174. mindspore/experimental/optim/rmsprop.py +154 -0
  175. mindspore/experimental/optim/rprop.py +164 -0
  176. mindspore/experimental/optim/sgd.py +156 -0
  177. mindspore/hal/__init__.py +40 -0
  178. mindspore/hal/_ascend.py +57 -0
  179. mindspore/hal/_base.py +57 -0
  180. mindspore/hal/_cpu.py +56 -0
  181. mindspore/hal/_gpu.py +57 -0
  182. mindspore/hal/contiguous_tensors_handle.py +175 -0
  183. mindspore/hal/device.py +356 -0
  184. mindspore/hal/event.py +179 -0
  185. mindspore/hal/memory.py +326 -0
  186. mindspore/hal/stream.py +357 -0
  187. mindspore/include/OWNERS +7 -0
  188. mindspore/include/api/allocator.h +97 -0
  189. mindspore/include/api/callback/callback.h +93 -0
  190. mindspore/include/api/callback/ckpt_saver.h +41 -0
  191. mindspore/include/api/callback/loss_monitor.h +33 -0
  192. mindspore/include/api/callback/lr_scheduler.h +51 -0
  193. mindspore/include/api/callback/time_monitor.h +34 -0
  194. mindspore/include/api/callback/train_accuracy.h +37 -0
  195. mindspore/include/api/cell.h +90 -0
  196. mindspore/include/api/cfg.h +82 -0
  197. mindspore/include/api/context.h +602 -0
  198. mindspore/include/api/data_type.h +47 -0
  199. mindspore/include/api/delegate.h +178 -0
  200. mindspore/include/api/delegate_api.h +75 -0
  201. mindspore/include/api/dual_abi_helper.h +208 -0
  202. mindspore/include/api/format.h +28 -0
  203. mindspore/include/api/graph.h +46 -0
  204. mindspore/include/api/kernel.h +58 -0
  205. mindspore/include/api/kernel_api.h +168 -0
  206. mindspore/include/api/metrics/accuracy.h +36 -0
  207. mindspore/include/api/metrics/metrics.h +41 -0
  208. mindspore/include/api/model.h +438 -0
  209. mindspore/include/api/model_group.h +91 -0
  210. mindspore/include/api/model_parallel_runner.h +168 -0
  211. mindspore/include/api/serialization.h +185 -0
  212. mindspore/include/api/status.h +192 -0
  213. mindspore/include/api/types.h +431 -0
  214. mindspore/include/api/visible.h +41 -0
  215. mindspore/include/c_api/context_c.h +179 -0
  216. mindspore/include/c_api/data_type_c.h +52 -0
  217. mindspore/include/c_api/format_c.h +46 -0
  218. mindspore/include/c_api/model_c.h +347 -0
  219. mindspore/include/c_api/status_c.h +79 -0
  220. mindspore/include/c_api/tensor_c.h +146 -0
  221. mindspore/include/c_api/types_c.h +67 -0
  222. mindspore/include/dataset/config.h +163 -0
  223. mindspore/include/dataset/constants.h +363 -0
  224. mindspore/include/dataset/execute.h +196 -0
  225. mindspore/include/dataset/text.h +1092 -0
  226. mindspore/include/dataset/transforms.h +638 -0
  227. mindspore/include/dataset/vision.h +2129 -0
  228. mindspore/include/dataset/vision_ascend.h +206 -0
  229. mindspore/include/dataset/vision_lite.h +625 -0
  230. mindspore/lib/libavcodec.59.dylib +0 -0
  231. mindspore/lib/libavdevice.59.dylib +0 -0
  232. mindspore/lib/libavfilter.8.dylib +0 -0
  233. mindspore/lib/libavformat.59.dylib +0 -0
  234. mindspore/lib/libavutil.57.dylib +0 -0
  235. mindspore/lib/libdnnl.2.dylib +0 -0
  236. mindspore/lib/libicudata.69.dylib +0 -0
  237. mindspore/lib/libicui18n.69.dylib +0 -0
  238. mindspore/lib/libicuuc.69.dylib +0 -0
  239. mindspore/lib/libmindspore_address_sorting.15.dylib +0 -0
  240. mindspore/lib/libmindspore_backend.dylib +0 -0
  241. mindspore/lib/libmindspore_common.dylib +0 -0
  242. mindspore/lib/libmindspore_core.dylib +0 -0
  243. mindspore/lib/libmindspore_glog.0.dylib +0 -0
  244. mindspore/lib/libmindspore_gpr.15.dylib +0 -0
  245. mindspore/lib/libmindspore_grpc++.1.dylib +0 -0
  246. mindspore/lib/libmindspore_grpc.15.dylib +0 -0
  247. mindspore/lib/libmindspore_np_dtype.dylib +0 -0
  248. mindspore/lib/libmindspore_ops.dylib +0 -0
  249. mindspore/lib/libmindspore_upb.15.dylib +0 -0
  250. mindspore/lib/libnnacl.dylib +0 -0
  251. mindspore/lib/libopencv_core.4.5.dylib +0 -0
  252. mindspore/lib/libopencv_imgcodecs.4.5.dylib +0 -0
  253. mindspore/lib/libopencv_imgproc.4.5.dylib +0 -0
  254. mindspore/lib/libps_cache.dylib +0 -0
  255. mindspore/lib/libswresample.4.dylib +0 -0
  256. mindspore/lib/libswscale.6.dylib +0 -0
  257. mindspore/lib/libtinyxml2.8.dylib +0 -0
  258. mindspore/log.py +633 -0
  259. mindspore/mindrecord/__init__.py +43 -0
  260. mindspore/mindrecord/common/__init__.py +17 -0
  261. mindspore/mindrecord/common/constant.py +20 -0
  262. mindspore/mindrecord/common/enums.py +44 -0
  263. mindspore/mindrecord/common/exceptions.py +311 -0
  264. mindspore/mindrecord/config.py +809 -0
  265. mindspore/mindrecord/filereader.py +174 -0
  266. mindspore/mindrecord/filewriter.py +722 -0
  267. mindspore/mindrecord/mindpage.py +210 -0
  268. mindspore/mindrecord/shardheader.py +141 -0
  269. mindspore/mindrecord/shardindexgenerator.py +74 -0
  270. mindspore/mindrecord/shardreader.py +117 -0
  271. mindspore/mindrecord/shardsegment.py +128 -0
  272. mindspore/mindrecord/shardutils.py +185 -0
  273. mindspore/mindrecord/shardwriter.py +237 -0
  274. mindspore/mindrecord/tools/__init__.py +17 -0
  275. mindspore/mindrecord/tools/cifar10.py +140 -0
  276. mindspore/mindrecord/tools/cifar100.py +153 -0
  277. mindspore/mindrecord/tools/cifar100_to_mr.py +185 -0
  278. mindspore/mindrecord/tools/cifar10_to_mr.py +177 -0
  279. mindspore/mindrecord/tools/csv_to_mr.py +200 -0
  280. mindspore/mindrecord/tools/imagenet_to_mr.py +206 -0
  281. mindspore/mindrecord/tools/mnist_to_mr.py +259 -0
  282. mindspore/mindrecord/tools/tfrecord_to_mr.py +360 -0
  283. mindspore/mint/__init__.py +1586 -0
  284. mindspore/mint/distributed/__init__.py +31 -0
  285. mindspore/mint/distributed/distributed.py +254 -0
  286. mindspore/mint/linalg/__init__.py +22 -0
  287. mindspore/mint/nn/__init__.py +757 -0
  288. mindspore/mint/nn/functional.py +679 -0
  289. mindspore/mint/nn/layer/__init__.py +39 -0
  290. mindspore/mint/nn/layer/activation.py +133 -0
  291. mindspore/mint/nn/layer/normalization.py +477 -0
  292. mindspore/mint/nn/layer/pooling.py +110 -0
  293. mindspore/mint/optim/__init__.py +24 -0
  294. mindspore/mint/optim/adamw.py +206 -0
  295. mindspore/mint/special/__init__.py +63 -0
  296. mindspore/multiprocessing/__init__.py +73 -0
  297. mindspore/nn/__init__.py +47 -0
  298. mindspore/nn/cell.py +2787 -0
  299. mindspore/nn/dynamic_lr.py +482 -0
  300. mindspore/nn/grad/__init__.py +21 -0
  301. mindspore/nn/grad/cell_grad.py +196 -0
  302. mindspore/nn/layer/__init__.py +63 -0
  303. mindspore/nn/layer/activation.py +1822 -0
  304. mindspore/nn/layer/basic.py +1629 -0
  305. mindspore/nn/layer/channel_shuffle.py +90 -0
  306. mindspore/nn/layer/combined.py +248 -0
  307. mindspore/nn/layer/container.py +734 -0
  308. mindspore/nn/layer/conv.py +1505 -0
  309. mindspore/nn/layer/dense.py +204 -0
  310. mindspore/nn/layer/embedding.py +869 -0
  311. mindspore/nn/layer/image.py +661 -0
  312. mindspore/nn/layer/math.py +1069 -0
  313. mindspore/nn/layer/normalization.py +1273 -0
  314. mindspore/nn/layer/padding.py +880 -0
  315. mindspore/nn/layer/pooling.py +2302 -0
  316. mindspore/nn/layer/rnn_cells.py +388 -0
  317. mindspore/nn/layer/rnns.py +849 -0
  318. mindspore/nn/layer/thor_layer.py +963 -0
  319. mindspore/nn/layer/timedistributed.py +155 -0
  320. mindspore/nn/layer/transformer.py +823 -0
  321. mindspore/nn/learning_rate_schedule.py +512 -0
  322. mindspore/nn/loss/__init__.py +36 -0
  323. mindspore/nn/loss/loss.py +2924 -0
  324. mindspore/nn/metrics.py +53 -0
  325. mindspore/nn/optim/__init__.py +45 -0
  326. mindspore/nn/optim/_dist_optimizer_registry.py +111 -0
  327. mindspore/nn/optim/ada_grad.py +217 -0
  328. mindspore/nn/optim/adadelta.py +206 -0
  329. mindspore/nn/optim/adafactor.py +448 -0
  330. mindspore/nn/optim/adam.py +1297 -0
  331. mindspore/nn/optim/adamax.py +220 -0
  332. mindspore/nn/optim/adasum.py +548 -0
  333. mindspore/nn/optim/asgd.py +216 -0
  334. mindspore/nn/optim/ftrl.py +401 -0
  335. mindspore/nn/optim/lamb.py +296 -0
  336. mindspore/nn/optim/lars.py +202 -0
  337. mindspore/nn/optim/lazyadam.py +533 -0
  338. mindspore/nn/optim/momentum.py +239 -0
  339. mindspore/nn/optim/optimizer.py +1034 -0
  340. mindspore/nn/optim/proximal_ada_grad.py +242 -0
  341. mindspore/nn/optim/rmsprop.py +264 -0
  342. mindspore/nn/optim/rprop.py +251 -0
  343. mindspore/nn/optim/sgd.py +237 -0
  344. mindspore/nn/optim/tft_wrapper.py +127 -0
  345. mindspore/nn/optim/thor.py +1310 -0
  346. mindspore/nn/probability/__init__.py +22 -0
  347. mindspore/nn/probability/bijector/__init__.py +35 -0
  348. mindspore/nn/probability/bijector/bijector.py +337 -0
  349. mindspore/nn/probability/bijector/exp.py +65 -0
  350. mindspore/nn/probability/bijector/gumbel_cdf.py +144 -0
  351. mindspore/nn/probability/bijector/invert.py +126 -0
  352. mindspore/nn/probability/bijector/power_transform.py +196 -0
  353. mindspore/nn/probability/bijector/scalar_affine.py +167 -0
  354. mindspore/nn/probability/bijector/softplus.py +189 -0
  355. mindspore/nn/probability/bnn_layers/__init__.py +29 -0
  356. mindspore/nn/probability/bnn_layers/_util.py +46 -0
  357. mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +112 -0
  358. mindspore/nn/probability/bnn_layers/conv_variational.py +267 -0
  359. mindspore/nn/probability/bnn_layers/dense_variational.py +302 -0
  360. mindspore/nn/probability/bnn_layers/layer_distribution.py +123 -0
  361. mindspore/nn/probability/distribution/__init__.py +56 -0
  362. mindspore/nn/probability/distribution/_utils/__init__.py +34 -0
  363. mindspore/nn/probability/distribution/_utils/custom_ops.py +96 -0
  364. mindspore/nn/probability/distribution/_utils/utils.py +362 -0
  365. mindspore/nn/probability/distribution/bernoulli.py +334 -0
  366. mindspore/nn/probability/distribution/beta.py +391 -0
  367. mindspore/nn/probability/distribution/categorical.py +435 -0
  368. mindspore/nn/probability/distribution/cauchy.py +383 -0
  369. mindspore/nn/probability/distribution/distribution.py +827 -0
  370. mindspore/nn/probability/distribution/exponential.py +350 -0
  371. mindspore/nn/probability/distribution/gamma.py +391 -0
  372. mindspore/nn/probability/distribution/geometric.py +335 -0
  373. mindspore/nn/probability/distribution/gumbel.py +257 -0
  374. mindspore/nn/probability/distribution/half_normal.py +133 -0
  375. mindspore/nn/probability/distribution/laplace.py +128 -0
  376. mindspore/nn/probability/distribution/log_normal.py +272 -0
  377. mindspore/nn/probability/distribution/logistic.py +379 -0
  378. mindspore/nn/probability/distribution/normal.py +336 -0
  379. mindspore/nn/probability/distribution/poisson.py +288 -0
  380. mindspore/nn/probability/distribution/student_t.py +149 -0
  381. mindspore/nn/probability/distribution/transformed_distribution.py +235 -0
  382. mindspore/nn/probability/distribution/uniform.py +375 -0
  383. mindspore/nn/reinforcement/__init__.py +24 -0
  384. mindspore/nn/reinforcement/_batch_read_write.py +142 -0
  385. mindspore/nn/reinforcement/_tensors_queue.py +152 -0
  386. mindspore/nn/reinforcement/tensor_array.py +145 -0
  387. mindspore/nn/sparse/__init__.py +23 -0
  388. mindspore/nn/sparse/sparse.py +147 -0
  389. mindspore/nn/wrap/__init__.py +49 -0
  390. mindspore/nn/wrap/cell_wrapper.py +968 -0
  391. mindspore/nn/wrap/grad_reducer.py +608 -0
  392. mindspore/nn/wrap/loss_scale.py +694 -0
  393. mindspore/numpy/__init__.py +121 -0
  394. mindspore/numpy/array_creations.py +2731 -0
  395. mindspore/numpy/array_ops.py +2629 -0
  396. mindspore/numpy/dtypes.py +185 -0
  397. mindspore/numpy/fft.py +966 -0
  398. mindspore/numpy/logic_ops.py +936 -0
  399. mindspore/numpy/math_ops.py +5911 -0
  400. mindspore/numpy/utils.py +214 -0
  401. mindspore/numpy/utils_const.py +565 -0
  402. mindspore/ops/__init__.py +56 -0
  403. mindspore/ops/_constants.py +30 -0
  404. mindspore/ops/_grad_experimental/__init__.py +31 -0
  405. mindspore/ops/_grad_experimental/grad_array_ops.py +830 -0
  406. mindspore/ops/_grad_experimental/grad_base.py +143 -0
  407. mindspore/ops/_grad_experimental/grad_comm_ops.py +714 -0
  408. mindspore/ops/_grad_experimental/grad_debug_ops.py +31 -0
  409. mindspore/ops/_grad_experimental/grad_implementations.py +203 -0
  410. mindspore/ops/_grad_experimental/grad_inner_ops.py +79 -0
  411. mindspore/ops/_grad_experimental/grad_math_ops.py +802 -0
  412. mindspore/ops/_grad_experimental/grad_nn_ops.py +231 -0
  413. mindspore/ops/_grad_experimental/grad_quant_ops.py +238 -0
  414. mindspore/ops/_grad_experimental/grad_sparse.py +342 -0
  415. mindspore/ops/_grad_experimental/grad_sparse_ops.py +399 -0
  416. mindspore/ops/_grad_experimental/taylor_rule.py +220 -0
  417. mindspore/ops/_op_impl/__init__.py +23 -0
  418. mindspore/ops/_op_impl/_custom_op/__init__.py +39 -0
  419. mindspore/ops/_op_impl/_custom_op/_basic.py +158 -0
  420. mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +279 -0
  421. mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +156 -0
  422. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +109 -0
  423. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +125 -0
  424. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +105 -0
  425. mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +124 -0
  426. mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +116 -0
  427. mindspore/ops/_op_impl/_custom_op/correction_mul.py +89 -0
  428. mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +196 -0
  429. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +366 -0
  430. mindspore/ops/_op_impl/_custom_op/dsd_impl.py +162 -0
  431. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +136 -0
  432. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +206 -0
  433. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +88 -0
  434. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +128 -0
  435. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +199 -0
  436. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +88 -0
  437. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +156 -0
  438. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +184 -0
  439. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +143 -0
  440. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +169 -0
  441. mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +548 -0
  442. mindspore/ops/_op_impl/_custom_op/img2col_impl.py +881 -0
  443. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +278 -0
  444. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +200 -0
  445. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +334 -0
  446. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +255 -0
  447. mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +222 -0
  448. mindspore/ops/_op_impl/_custom_op/matmul_dds_grad_impl.py +644 -0
  449. mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +488 -0
  450. mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +87 -0
  451. mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +129 -0
  452. mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +121 -0
  453. mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +352 -0
  454. mindspore/ops/_op_impl/aicpu/__init__.py +441 -0
  455. mindspore/ops/_op_impl/aicpu/abs.py +36 -0
  456. mindspore/ops/_op_impl/aicpu/acos.py +32 -0
  457. mindspore/ops/_op_impl/aicpu/acos_grad.py +33 -0
  458. mindspore/ops/_op_impl/aicpu/acosh.py +34 -0
  459. mindspore/ops/_op_impl/aicpu/acosh_grad.py +35 -0
  460. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d.py +34 -0
  461. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
  462. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d.py +39 -0
  463. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d_grad.py +39 -0
  464. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
  465. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d_grad.py +37 -0
  466. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d.py +42 -0
  467. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d_grad.py +152 -0
  468. mindspore/ops/_op_impl/aicpu/add.py +43 -0
  469. mindspore/ops/_op_impl/aicpu/add_n.py +41 -0
  470. mindspore/ops/_op_impl/aicpu/add_v2.py +40 -0
  471. mindspore/ops/_op_impl/aicpu/addcdiv.py +41 -0
  472. mindspore/ops/_op_impl/aicpu/addcmul.py +47 -0
  473. mindspore/ops/_op_impl/aicpu/adjust_contrastv2.py +32 -0
  474. mindspore/ops/_op_impl/aicpu/adjust_hue.py +31 -0
  475. mindspore/ops/_op_impl/aicpu/adjust_saturation.py +32 -0
  476. mindspore/ops/_op_impl/aicpu/affine_grid.py +33 -0
  477. mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
  478. mindspore/ops/_op_impl/aicpu/angle.py +31 -0
  479. mindspore/ops/_op_impl/aicpu/arg_max.py +75 -0
  480. mindspore/ops/_op_impl/aicpu/arg_min.py +75 -0
  481. mindspore/ops/_op_impl/aicpu/argmax_with_value.py +43 -0
  482. mindspore/ops/_op_impl/aicpu/argmin_with_value.py +43 -0
  483. mindspore/ops/_op_impl/aicpu/asin.py +32 -0
  484. mindspore/ops/_op_impl/aicpu/asin_grad.py +33 -0
  485. mindspore/ops/_op_impl/aicpu/asinh.py +34 -0
  486. mindspore/ops/_op_impl/aicpu/asinh_grad.py +35 -0
  487. mindspore/ops/_op_impl/aicpu/atanh.py +34 -0
  488. mindspore/ops/_op_impl/aicpu/avgpool_grad_v1.py +37 -0
  489. mindspore/ops/_op_impl/aicpu/avgpool_v1.py +36 -0
  490. mindspore/ops/_op_impl/aicpu/bartlett_window.py +36 -0
  491. mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -0
  492. mindspore/ops/_op_impl/aicpu/batch_norm_grad_grad.py +49 -0
  493. mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
  494. mindspore/ops/_op_impl/aicpu/bessel_i0.py +31 -0
  495. mindspore/ops/_op_impl/aicpu/betainc.py +31 -0
  496. mindspore/ops/_op_impl/aicpu/bias_add.py +44 -0
  497. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +42 -0
  498. mindspore/ops/_op_impl/aicpu/bincount.py +33 -0
  499. mindspore/ops/_op_impl/aicpu/blackman_window.py +36 -0
  500. mindspore/ops/_op_impl/aicpu/broadcast_to.py +58 -0
  501. mindspore/ops/_op_impl/aicpu/bucketize.py +34 -0
  502. mindspore/ops/_op_impl/aicpu/cache_swap_table.py +102 -0
  503. mindspore/ops/_op_impl/aicpu/cast.py +225 -0
  504. mindspore/ops/_op_impl/aicpu/cauchy.py +33 -0
  505. mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
  506. mindspore/ops/_op_impl/aicpu/check_numerics.py +33 -0
  507. mindspore/ops/_op_impl/aicpu/cholesky.py +32 -0
  508. mindspore/ops/_op_impl/aicpu/cholesky_inverse.py +31 -0
  509. mindspore/ops/_op_impl/aicpu/cholesky_solve.py +33 -0
  510. mindspore/ops/_op_impl/aicpu/choleskygrad.py +32 -0
  511. mindspore/ops/_op_impl/aicpu/coalesce.py +37 -0
  512. mindspore/ops/_op_impl/aicpu/col2im.py +38 -0
  513. mindspore/ops/_op_impl/aicpu/combined_non_max_suppression.py +42 -0
  514. mindspore/ops/_op_impl/aicpu/compare_and_bitpack.py +37 -0
  515. mindspore/ops/_op_impl/aicpu/complex.py +32 -0
  516. mindspore/ops/_op_impl/aicpu/complex_abs.py +31 -0
  517. mindspore/ops/_op_impl/aicpu/compute_accidental_hits.py +44 -0
  518. mindspore/ops/_op_impl/aicpu/concat.py +57 -0
  519. mindspore/ops/_op_impl/aicpu/concat_offset.py +42 -0
  520. mindspore/ops/_op_impl/aicpu/concat_offset_v1.py +31 -0
  521. mindspore/ops/_op_impl/aicpu/conj.py +42 -0
  522. mindspore/ops/_op_impl/aicpu/conjugate_transpose.py +58 -0
  523. mindspore/ops/_op_impl/aicpu/cos.py +34 -0
  524. mindspore/ops/_op_impl/aicpu/cosh.py +34 -0
  525. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  526. mindspore/ops/_op_impl/aicpu/crop_and_resize.py +69 -0
  527. mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_boxes.py +68 -0
  528. mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_image.py +38 -0
  529. mindspore/ops/_op_impl/aicpu/cross.py +42 -0
  530. mindspore/ops/_op_impl/aicpu/csr_sparse_matrix_to_dense.py +48 -0
  531. mindspore/ops/_op_impl/aicpu/csr_sparse_matrix_to_sparse_tensor.py +51 -0
  532. mindspore/ops/_op_impl/aicpu/ctc_greedy_decoder.py +35 -0
  533. mindspore/ops/_op_impl/aicpu/ctc_loss_v2.py +43 -0
  534. mindspore/ops/_op_impl/aicpu/ctc_loss_v2_grad.py +45 -0
  535. mindspore/ops/_op_impl/aicpu/ctcloss.py +38 -0
  536. mindspore/ops/_op_impl/aicpu/cummax.py +41 -0
  537. mindspore/ops/_op_impl/aicpu/cumprod.py +58 -0
  538. mindspore/ops/_op_impl/aicpu/cumsum.py +58 -0
  539. mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +36 -0
  540. mindspore/ops/_op_impl/aicpu/data_format_vec_permute.py +32 -0
  541. mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
  542. mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +43 -0
  543. mindspore/ops/_op_impl/aicpu/dense_to_csr_sparse_matrix.py +49 -0
  544. mindspore/ops/_op_impl/aicpu/dense_to_dense_set_operation.py +45 -0
  545. mindspore/ops/_op_impl/aicpu/dense_to_sparse_set_operation.py +48 -0
  546. mindspore/ops/_op_impl/aicpu/depth_to_space.py +44 -0
  547. mindspore/ops/_op_impl/aicpu/diag.py +36 -0
  548. mindspore/ops/_op_impl/aicpu/diag_part.py +36 -0
  549. mindspore/ops/_op_impl/aicpu/diagonal.py +35 -0
  550. mindspore/ops/_op_impl/aicpu/digamma.py +31 -0
  551. mindspore/ops/_op_impl/aicpu/div.py +41 -0
  552. mindspore/ops/_op_impl/aicpu/div_no_nan.py +35 -0
  553. mindspore/ops/_op_impl/aicpu/dropout2d.py +42 -0
  554. mindspore/ops/_op_impl/aicpu/dropout3d.py +42 -0
  555. mindspore/ops/_op_impl/aicpu/dropout_genmask.py +41 -0
  556. mindspore/ops/_op_impl/aicpu/dropout_genmask_v3.py +32 -0
  557. mindspore/ops/_op_impl/aicpu/dynamic_stitch.py +42 -0
  558. mindspore/ops/_op_impl/aicpu/edit_distance.py +56 -0
  559. mindspore/ops/_op_impl/aicpu/eig.py +35 -0
  560. mindspore/ops/_op_impl/aicpu/embedding_lookup.py +102 -0
  561. mindspore/ops/_op_impl/aicpu/end_of_sequence.py +30 -0
  562. mindspore/ops/_op_impl/aicpu/environ_create.py +28 -0
  563. mindspore/ops/_op_impl/aicpu/environ_destroy_all.py +28 -0
  564. mindspore/ops/_op_impl/aicpu/environ_get.py +41 -0
  565. mindspore/ops/_op_impl/aicpu/environ_set.py +40 -0
  566. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  567. mindspore/ops/_op_impl/aicpu/equal.py +41 -0
  568. mindspore/ops/_op_impl/aicpu/exp.py +37 -0
  569. mindspore/ops/_op_impl/aicpu/expand.py +45 -0
  570. mindspore/ops/_op_impl/aicpu/expand_dims.py +42 -0
  571. mindspore/ops/_op_impl/aicpu/expm1.py +34 -0
  572. mindspore/ops/_op_impl/aicpu/extract_glimpse.py +35 -0
  573. mindspore/ops/_op_impl/aicpu/eye.py +44 -0
  574. mindspore/ops/_op_impl/aicpu/fft_with_size.py +47 -0
  575. mindspore/ops/_op_impl/aicpu/fill_diagonal.py +39 -0
  576. mindspore/ops/_op_impl/aicpu/fill_v2.py +58 -0
  577. mindspore/ops/_op_impl/aicpu/flatten.py +43 -0
  578. mindspore/ops/_op_impl/aicpu/floor_div.py +38 -0
  579. mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
  580. mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
  581. mindspore/ops/_op_impl/aicpu/fractional_avg_pool.py +41 -0
  582. mindspore/ops/_op_impl/aicpu/fractional_avg_pool_grad.py +41 -0
  583. mindspore/ops/_op_impl/aicpu/fractional_max_pool.py +41 -0
  584. mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_grad_with_fixed_ksize.py +43 -0
  585. mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +65 -0
  586. mindspore/ops/_op_impl/aicpu/fractional_max_pool_grad.py +42 -0
  587. mindspore/ops/_op_impl/aicpu/fractional_max_pool_grad_with_fixed_ksize.py +42 -0
  588. mindspore/ops/_op_impl/aicpu/fractional_max_pool_with_fixed_ksize.py +49 -0
  589. mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
  590. mindspore/ops/_op_impl/aicpu/fused_sparse_adam.py +46 -0
  591. mindspore/ops/_op_impl/aicpu/fused_sparse_ftrl.py +41 -0
  592. mindspore/ops/_op_impl/aicpu/fused_sparse_lazy_adam.py +46 -0
  593. mindspore/ops/_op_impl/aicpu/fused_sparse_proximal_adagrad.py +39 -0
  594. mindspore/ops/_op_impl/aicpu/gamma.py +38 -0
  595. mindspore/ops/_op_impl/aicpu/gather.py +46 -0
  596. mindspore/ops/_op_impl/aicpu/gather_d.py +79 -0
  597. mindspore/ops/_op_impl/aicpu/gather_d_grad_v2.py +79 -0
  598. mindspore/ops/_op_impl/aicpu/gather_grad.py +54 -0
  599. mindspore/ops/_op_impl/aicpu/gather_nd.py +56 -0
  600. mindspore/ops/_op_impl/aicpu/gcd.py +32 -0
  601. mindspore/ops/_op_impl/aicpu/generate_eod_mask.py +38 -0
  602. mindspore/ops/_op_impl/aicpu/geqrf.py +32 -0
  603. mindspore/ops/_op_impl/aicpu/get_next.py +39 -0
  604. mindspore/ops/_op_impl/aicpu/glu.py +33 -0
  605. mindspore/ops/_op_impl/aicpu/glu_grad.py +34 -0
  606. mindspore/ops/_op_impl/aicpu/greater.py +41 -0
  607. mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
  608. mindspore/ops/_op_impl/aicpu/grid_sampler_2d.py +35 -0
  609. mindspore/ops/_op_impl/aicpu/grid_sampler_2d_grad.py +38 -0
  610. mindspore/ops/_op_impl/aicpu/grid_sampler_3d.py +34 -0
  611. mindspore/ops/_op_impl/aicpu/grid_sampler_3d_grad.py +38 -0
  612. mindspore/ops/_op_impl/aicpu/hamming_window.py +57 -0
  613. mindspore/ops/_op_impl/aicpu/hard_sigmoid.py +32 -0
  614. mindspore/ops/_op_impl/aicpu/hard_sigmoid_grad.py +33 -0
  615. mindspore/ops/_op_impl/aicpu/heaviside.py +40 -0
  616. mindspore/ops/_op_impl/aicpu/histogram.py +35 -0
  617. mindspore/ops/_op_impl/aicpu/hsv_to_rgb.py +32 -0
  618. mindspore/ops/_op_impl/aicpu/hypot.py +32 -0
  619. mindspore/ops/_op_impl/aicpu/identity.py +42 -0
  620. mindspore/ops/_op_impl/aicpu/identity_n.py +41 -0
  621. mindspore/ops/_op_impl/aicpu/igamma.py +30 -0
  622. mindspore/ops/_op_impl/aicpu/igammac.py +30 -0
  623. mindspore/ops/_op_impl/aicpu/igammagrada.py +30 -0
  624. mindspore/ops/_op_impl/aicpu/im2col.py +43 -0
  625. mindspore/ops/_op_impl/aicpu/imag.py +31 -0
  626. mindspore/ops/_op_impl/aicpu/index_fill.py +54 -0
  627. mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
  628. mindspore/ops/_op_impl/aicpu/init_data_set_queue.py +27 -0
  629. mindspore/ops/_op_impl/aicpu/inplace_index_add.py +39 -0
  630. mindspore/ops/_op_impl/aicpu/instance_norm_v2.py +41 -0
  631. mindspore/ops/_op_impl/aicpu/instance_norm_v2_grad.py +44 -0
  632. mindspore/ops/_op_impl/aicpu/is_finite.py +40 -0
  633. mindspore/ops/_op_impl/aicpu/is_inf.py +31 -0
  634. mindspore/ops/_op_impl/aicpu/is_nan.py +31 -0
  635. mindspore/ops/_op_impl/aicpu/kldivloss.py +34 -0
  636. mindspore/ops/_op_impl/aicpu/kldivlossgrad.py +35 -0
  637. mindspore/ops/_op_impl/aicpu/layer_norm_grad_grad.py +47 -0
  638. mindspore/ops/_op_impl/aicpu/lcm.py +32 -0
  639. mindspore/ops/_op_impl/aicpu/left_shift.py +38 -0
  640. mindspore/ops/_op_impl/aicpu/less.py +41 -0
  641. mindspore/ops/_op_impl/aicpu/less_equal.py +41 -0
  642. mindspore/ops/_op_impl/aicpu/lgamma.py +33 -0
  643. mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +57 -0
  644. mindspore/ops/_op_impl/aicpu/linspace.py +33 -0
  645. mindspore/ops/_op_impl/aicpu/list_diff.py +50 -0
  646. mindspore/ops/_op_impl/aicpu/log.py +37 -0
  647. mindspore/ops/_op_impl/aicpu/log1p.py +34 -0
  648. mindspore/ops/_op_impl/aicpu/log_matrix_determinant.py +31 -0
  649. mindspore/ops/_op_impl/aicpu/log_normal_reverse.py +33 -0
  650. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +37 -0
  651. mindspore/ops/_op_impl/aicpu/logical_xor.py +30 -0
  652. mindspore/ops/_op_impl/aicpu/logit.py +33 -0
  653. mindspore/ops/_op_impl/aicpu/logit_grad.py +34 -0
  654. mindspore/ops/_op_impl/aicpu/logspace.py +36 -0
  655. mindspore/ops/_op_impl/aicpu/lower_bound.py +47 -0
  656. mindspore/ops/_op_impl/aicpu/lstsq.py +34 -0
  657. mindspore/ops/_op_impl/aicpu/lu.py +39 -0
  658. mindspore/ops/_op_impl/aicpu/lu_solve.py +32 -0
  659. mindspore/ops/_op_impl/aicpu/lu_unpack.py +114 -0
  660. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +49 -0
  661. mindspore/ops/_op_impl/aicpu/masked_fill.py +42 -0
  662. mindspore/ops/_op_impl/aicpu/masked_scatter.py +40 -0
  663. mindspore/ops/_op_impl/aicpu/masked_select.py +31 -0
  664. mindspore/ops/_op_impl/aicpu/masked_select_grad.py +35 -0
  665. mindspore/ops/_op_impl/aicpu/matmul.py +39 -0
  666. mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
  667. mindspore/ops/_op_impl/aicpu/matrix_determinant.py +30 -0
  668. mindspore/ops/_op_impl/aicpu/matrix_diag_part_v3.py +54 -0
  669. mindspore/ops/_op_impl/aicpu/matrix_diag_v3.py +56 -0
  670. mindspore/ops/_op_impl/aicpu/matrix_exp.py +34 -0
  671. mindspore/ops/_op_impl/aicpu/matrix_inverse.py +31 -0
  672. mindspore/ops/_op_impl/aicpu/matrix_logarithm.py +31 -0
  673. mindspore/ops/_op_impl/aicpu/matrix_power.py +37 -0
  674. mindspore/ops/_op_impl/aicpu/matrix_set_diag_v3.py +54 -0
  675. mindspore/ops/_op_impl/aicpu/matrix_solve.py +35 -0
  676. mindspore/ops/_op_impl/aicpu/matrix_solve_ls.py +36 -0
  677. mindspore/ops/_op_impl/aicpu/matrix_triangular_solve.py +36 -0
  678. mindspore/ops/_op_impl/aicpu/max_pool3d_grad_with_argmax.py +60 -0
  679. mindspore/ops/_op_impl/aicpu/max_pool3d_with_argmax.py +59 -0
  680. mindspore/ops/_op_impl/aicpu/max_unpool2d.py +57 -0
  681. mindspore/ops/_op_impl/aicpu/max_unpool2d_grad.py +58 -0
  682. mindspore/ops/_op_impl/aicpu/max_unpool3d.py +57 -0
  683. mindspore/ops/_op_impl/aicpu/max_unpool3d_grad.py +58 -0
  684. mindspore/ops/_op_impl/aicpu/maximum_grad_grad.py +40 -0
  685. mindspore/ops/_op_impl/aicpu/maxpool_grad_v1.py +46 -0
  686. mindspore/ops/_op_impl/aicpu/maxpool_v1.py +42 -0
  687. mindspore/ops/_op_impl/aicpu/median.py +39 -0
  688. mindspore/ops/_op_impl/aicpu/median_grad.py +45 -0
  689. mindspore/ops/_op_impl/aicpu/meshgrid.py +41 -0
  690. mindspore/ops/_op_impl/aicpu/minimum_grad_grad.py +40 -0
  691. mindspore/ops/_op_impl/aicpu/mirror_pad.py +50 -0
  692. mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +48 -0
  693. mindspore/ops/_op_impl/aicpu/mul.py +43 -0
  694. mindspore/ops/_op_impl/aicpu/mul_no_nan.py +42 -0
  695. mindspore/ops/_op_impl/aicpu/multi_margin_loss.py +37 -0
  696. mindspore/ops/_op_impl/aicpu/multi_margin_loss_grad.py +41 -0
  697. mindspore/ops/_op_impl/aicpu/multilabel_margin_loss_grad.py +37 -0
  698. mindspore/ops/_op_impl/aicpu/multinomial.py +47 -0
  699. mindspore/ops/_op_impl/aicpu/multinomial_with_replacement.py +35 -0
  700. mindspore/ops/_op_impl/aicpu/mvlgamma.py +32 -0
  701. mindspore/ops/_op_impl/aicpu/mvlgamma_grad.py +33 -0
  702. mindspore/ops/_op_impl/aicpu/nan_to_num.py +34 -0
  703. mindspore/ops/_op_impl/aicpu/neg.py +36 -0
  704. mindspore/ops/_op_impl/aicpu/nextafter.py +32 -0
  705. mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
  706. mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
  707. mindspore/ops/_op_impl/aicpu/no_repeat_ngram.py +34 -0
  708. mindspore/ops/_op_impl/aicpu/non_deterministic_ints.py +33 -0
  709. mindspore/ops/_op_impl/aicpu/non_max_suppression.py +36 -0
  710. mindspore/ops/_op_impl/aicpu/non_max_suppression_with_overlaps.py +35 -0
  711. mindspore/ops/_op_impl/aicpu/non_zero.py +43 -0
  712. mindspore/ops/_op_impl/aicpu/not_equal.py +39 -0
  713. mindspore/ops/_op_impl/aicpu/nth_element.py +39 -0
  714. mindspore/ops/_op_impl/aicpu/nuclear_norm.py +33 -0
  715. mindspore/ops/_op_impl/aicpu/one_hot.py +116 -0
  716. mindspore/ops/_op_impl/aicpu/ones_like.py +39 -0
  717. mindspore/ops/_op_impl/aicpu/orgqr.py +34 -0
  718. mindspore/ops/_op_impl/aicpu/pad_and_shift.py +33 -0
  719. mindspore/ops/_op_impl/aicpu/pad_v3.py +61 -0
  720. mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +59 -0
  721. mindspore/ops/_op_impl/aicpu/padding.py +41 -0
  722. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +54 -0
  723. mindspore/ops/_op_impl/aicpu/pdist_grad.py +33 -0
  724. mindspore/ops/_op_impl/aicpu/poisson.py +37 -0
  725. mindspore/ops/_op_impl/aicpu/polar.py +32 -0
  726. mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
  727. mindspore/ops/_op_impl/aicpu/pow.py +39 -0
  728. mindspore/ops/_op_impl/aicpu/print_tensor.py +39 -0
  729. mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +113 -0
  730. mindspore/ops/_op_impl/aicpu/qr.py +36 -0
  731. mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
  732. mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
  733. mindspore/ops/_op_impl/aicpu/ragged_range.py +49 -0
  734. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
  735. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_tensor.py +74 -0
  736. mindspore/ops/_op_impl/aicpu/random_categorical.py +68 -0
  737. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +36 -0
  738. mindspore/ops/_op_impl/aicpu/random_gamma.py +38 -0
  739. mindspore/ops/_op_impl/aicpu/random_poisson.py +134 -0
  740. mindspore/ops/_op_impl/aicpu/random_shuffle.py +47 -0
  741. mindspore/ops/_op_impl/aicpu/randperm.py +38 -0
  742. mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
  743. mindspore/ops/_op_impl/aicpu/range.py +36 -0
  744. mindspore/ops/_op_impl/aicpu/range_v2.py +35 -0
  745. mindspore/ops/_op_impl/aicpu/real.py +31 -0
  746. mindspore/ops/_op_impl/aicpu/real_div.py +40 -0
  747. mindspore/ops/_op_impl/aicpu/reciprocal.py +34 -0
  748. mindspore/ops/_op_impl/aicpu/reciprocal_grad.py +35 -0
  749. mindspore/ops/_op_impl/aicpu/reduce_mean.py +57 -0
  750. mindspore/ops/_op_impl/aicpu/reduce_prod.py +57 -0
  751. mindspore/ops/_op_impl/aicpu/reduce_sum.py +57 -0
  752. mindspore/ops/_op_impl/aicpu/relu_grad_v3.py +41 -0
  753. mindspore/ops/_op_impl/aicpu/relu_v3.py +38 -0
  754. mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +96 -0
  755. mindspore/ops/_op_impl/aicpu/reshape.py +42 -0
  756. mindspore/ops/_op_impl/aicpu/resize_area.py +40 -0
  757. mindspore/ops/_op_impl/aicpu/resize_bicubic.py +20 -0
  758. mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +19 -0
  759. mindspore/ops/_op_impl/aicpu/resize_bilinear.py +32 -0
  760. mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +32 -0
  761. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +36 -0
  762. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +35 -0
  763. mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
  764. mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
  765. mindspore/ops/_op_impl/aicpu/reverse_sequence.py +55 -0
  766. mindspore/ops/_op_impl/aicpu/reversev2.py +54 -0
  767. mindspore/ops/_op_impl/aicpu/rgb_to_hsv.py +32 -0
  768. mindspore/ops/_op_impl/aicpu/right_shift.py +38 -0
  769. mindspore/ops/_op_impl/aicpu/rnnt_loss.py +35 -0
  770. mindspore/ops/_op_impl/aicpu/round.py +34 -0
  771. mindspore/ops/_op_impl/aicpu/rsqrt.py +33 -0
  772. mindspore/ops/_op_impl/aicpu/rsqrt_grad.py +36 -0
  773. mindspore/ops/_op_impl/aicpu/sample_distorted_bounding_box_v2.py +49 -0
  774. mindspore/ops/_op_impl/aicpu/scale_and_translate.py +52 -0
  775. mindspore/ops/_op_impl/aicpu/scale_and_translate_grad.py +36 -0
  776. mindspore/ops/_op_impl/aicpu/scatter.py +79 -0
  777. mindspore/ops/_op_impl/aicpu/scatter_add_with_axis.py +53 -0
  778. mindspore/ops/_op_impl/aicpu/scatter_elements.py +39 -0
  779. mindspore/ops/_op_impl/aicpu/scatter_nd.py +59 -0
  780. mindspore/ops/_op_impl/aicpu/scatter_nd_max.py +54 -0
  781. mindspore/ops/_op_impl/aicpu/scatter_nd_min.py +54 -0
  782. mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +59 -0
  783. mindspore/ops/_op_impl/aicpu/search_sorted.py +44 -0
  784. mindspore/ops/_op_impl/aicpu/segment_max.py +52 -0
  785. mindspore/ops/_op_impl/aicpu/segment_mean.py +56 -0
  786. mindspore/ops/_op_impl/aicpu/segment_min.py +52 -0
  787. mindspore/ops/_op_impl/aicpu/segment_prod.py +56 -0
  788. mindspore/ops/_op_impl/aicpu/segment_sum.py +56 -0
  789. mindspore/ops/_op_impl/aicpu/select.py +45 -0
  790. mindspore/ops/_op_impl/aicpu/self_adjoint_eig.py +34 -0
  791. mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
  792. mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
  793. mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
  794. mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
  795. mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
  796. mindspore/ops/_op_impl/aicpu/set_size.py +38 -0
  797. mindspore/ops/_op_impl/aicpu/sign.py +36 -0
  798. mindspore/ops/_op_impl/aicpu/sin.py +34 -0
  799. mindspore/ops/_op_impl/aicpu/sinc.py +43 -0
  800. mindspore/ops/_op_impl/aicpu/sinh.py +34 -0
  801. mindspore/ops/_op_impl/aicpu/slice.py +59 -0
  802. mindspore/ops/_op_impl/aicpu/slice_grad.py +76 -0
  803. mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
  804. mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
  805. mindspore/ops/_op_impl/aicpu/sort.py +39 -0
  806. mindspore/ops/_op_impl/aicpu/space_to_depth.py +44 -0
  807. mindspore/ops/_op_impl/aicpu/sparse_addmm.py +87 -0
  808. mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +80 -0
  809. mindspore/ops/_op_impl/aicpu/sparse_apply_centered_rms_prop.py +105 -0
  810. mindspore/ops/_op_impl/aicpu/sparse_apply_momentum.py +80 -0
  811. mindspore/ops/_op_impl/aicpu/sparse_apply_proximal_gradient_descent.py +79 -0
  812. mindspore/ops/_op_impl/aicpu/sparse_concat.py +59 -0
  813. mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
  814. mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_add.py +58 -0
  815. mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_div.py +58 -0
  816. mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_mul.py +58 -0
  817. mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows.py +63 -0
  818. mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows_grad.py +45 -0
  819. mindspore/ops/_op_impl/aicpu/sparse_matrix_mat_mul.py +56 -0
  820. mindspore/ops/_op_impl/aicpu/sparse_matrix_nnz.py +81 -0
  821. mindspore/ops/_op_impl/aicpu/sparse_matrix_transpose.py +116 -0
  822. mindspore/ops/_op_impl/aicpu/sparse_reorder.py +56 -0
  823. mindspore/ops/_op_impl/aicpu/sparse_reshape.py +34 -0
  824. mindspore/ops/_op_impl/aicpu/sparse_segment_mean_grad.py +36 -0
  825. mindspore/ops/_op_impl/aicpu/sparse_segment_mean_with_num_segments.py +44 -0
  826. mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n.py +43 -0
  827. mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n_grad.py +38 -0
  828. mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n_with_num_segments.py +44 -0
  829. mindspore/ops/_op_impl/aicpu/sparse_segment_sum.py +49 -0
  830. mindspore/ops/_op_impl/aicpu/sparse_segment_sum_with_num_segments.py +68 -0
  831. mindspore/ops/_op_impl/aicpu/sparse_slice.py +63 -0
  832. mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +61 -0
  833. mindspore/ops/_op_impl/aicpu/sparse_softmax.py +33 -0
  834. mindspore/ops/_op_impl/aicpu/sparse_softmax_cross_entropy_with_logits_v2.py +35 -0
  835. mindspore/ops/_op_impl/aicpu/sparse_sparse_maximum.py +53 -0
  836. mindspore/ops/_op_impl/aicpu/sparse_sparse_minimum.py +53 -0
  837. mindspore/ops/_op_impl/aicpu/sparse_tensor_dense_add.py +84 -0
  838. mindspore/ops/_op_impl/aicpu/sparse_tensor_dense_mat_mul.py +190 -0
  839. mindspore/ops/_op_impl/aicpu/sparse_tensor_to_csr_sparse_matrix.py +51 -0
  840. mindspore/ops/_op_impl/aicpu/sparse_to_dense_v2.py +73 -0
  841. mindspore/ops/_op_impl/aicpu/split.py +45 -0
  842. mindspore/ops/_op_impl/aicpu/sqrt.py +34 -0
  843. mindspore/ops/_op_impl/aicpu/sqrt_grad.py +35 -0
  844. mindspore/ops/_op_impl/aicpu/square.py +35 -0
  845. mindspore/ops/_op_impl/aicpu/squared_difference.py +37 -0
  846. mindspore/ops/_op_impl/aicpu/squeeze.py +42 -0
  847. mindspore/ops/_op_impl/aicpu/sspaddmm.py +97 -0
  848. mindspore/ops/_op_impl/aicpu/stack.py +45 -0
  849. mindspore/ops/_op_impl/aicpu/stack_push_pop.py +87 -0
  850. mindspore/ops/_op_impl/aicpu/standard_laplace.py +34 -0
  851. mindspore/ops/_op_impl/aicpu/standard_normal.py +34 -0
  852. mindspore/ops/_op_impl/aicpu/stateless_dropout_genmask.py +37 -0
  853. mindspore/ops/_op_impl/aicpu/stft.py +70 -0
  854. mindspore/ops/_op_impl/aicpu/strided_slice.py +43 -0
  855. mindspore/ops/_op_impl/aicpu/strided_slice_grad.py +50 -0
  856. mindspore/ops/_op_impl/aicpu/sub.py +41 -0
  857. mindspore/ops/_op_impl/aicpu/sub_and_filter.py +36 -0
  858. mindspore/ops/_op_impl/aicpu/tan.py +34 -0
  859. mindspore/ops/_op_impl/aicpu/tanh.py +34 -0
  860. mindspore/ops/_op_impl/aicpu/tanh_grad.py +35 -0
  861. mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
  862. mindspore/ops/_op_impl/aicpu/tile.py +56 -0
  863. mindspore/ops/_op_impl/aicpu/topk.py +34 -0
  864. mindspore/ops/_op_impl/aicpu/trace.py +40 -0
  865. mindspore/ops/_op_impl/aicpu/tracegrad.py +41 -0
  866. mindspore/ops/_op_impl/aicpu/trans_data.py +35 -0
  867. mindspore/ops/_op_impl/aicpu/transpose.py +58 -0
  868. mindspore/ops/_op_impl/aicpu/tridiagonal_matmul.py +42 -0
  869. mindspore/ops/_op_impl/aicpu/tridiagonal_solve.py +35 -0
  870. mindspore/ops/_op_impl/aicpu/tril.py +42 -0
  871. mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
  872. mindspore/ops/_op_impl/aicpu/triplet_margin_loss.py +62 -0
  873. mindspore/ops/_op_impl/aicpu/triu.py +43 -0
  874. mindspore/ops/_op_impl/aicpu/triu_indices.py +34 -0
  875. mindspore/ops/_op_impl/aicpu/truncated_normal.py +39 -0
  876. mindspore/ops/_op_impl/aicpu/uniform.py +36 -0
  877. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +41 -0
  878. mindspore/ops/_op_impl/aicpu/uniform_int.py +36 -0
  879. mindspore/ops/_op_impl/aicpu/uniform_real.py +33 -0
  880. mindspore/ops/_op_impl/aicpu/unique.py +31 -0
  881. mindspore/ops/_op_impl/aicpu/unique_consecutive.py +47 -0
  882. mindspore/ops/_op_impl/aicpu/unique_with_pad.py +32 -0
  883. mindspore/ops/_op_impl/aicpu/unravel_index.py +32 -0
  884. mindspore/ops/_op_impl/aicpu/unsorted_segment_prod.py +53 -0
  885. mindspore/ops/_op_impl/aicpu/unsorted_segment_sum.py +57 -0
  886. mindspore/ops/_op_impl/aicpu/unstack.py +45 -0
  887. mindspore/ops/_op_impl/aicpu/update_cache.py +44 -0
  888. mindspore/ops/_op_impl/aicpu/upper_bound.py +47 -0
  889. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +42 -0
  890. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +49 -0
  891. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +40 -0
  892. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +50 -0
  893. mindspore/ops/_op_impl/aicpu/xdivy.py +35 -0
  894. mindspore/ops/_op_impl/aicpu/xlogy.py +33 -0
  895. mindspore/ops/_op_impl/aicpu/zeros_like.py +42 -0
  896. mindspore/ops/_op_impl/aicpu/zeta.py +31 -0
  897. mindspore/ops/_op_impl/akg/__init__.py +19 -0
  898. mindspore/ops/_op_impl/akg/ascend/__init__.py +48 -0
  899. mindspore/ops/_op_impl/akg/ascend/abs.py +35 -0
  900. mindspore/ops/_op_impl/akg/ascend/add.py +42 -0
  901. mindspore/ops/_op_impl/akg/ascend/add_n.py +37 -0
  902. mindspore/ops/_op_impl/akg/ascend/batchmatmul.py +33 -0
  903. mindspore/ops/_op_impl/akg/ascend/cast.py +46 -0
  904. mindspore/ops/_op_impl/akg/ascend/equal.py +35 -0
  905. mindspore/ops/_op_impl/akg/ascend/exp.py +35 -0
  906. mindspore/ops/_op_impl/akg/ascend/expand_dims.py +33 -0
  907. mindspore/ops/_op_impl/akg/ascend/greater.py +34 -0
  908. mindspore/ops/_op_impl/akg/ascend/greater_equal.py +35 -0
  909. mindspore/ops/_op_impl/akg/ascend/less.py +31 -0
  910. mindspore/ops/_op_impl/akg/ascend/less_equal.py +35 -0
  911. mindspore/ops/_op_impl/akg/ascend/load_im2col.py +33 -0
  912. mindspore/ops/_op_impl/akg/ascend/log.py +34 -0
  913. mindspore/ops/_op_impl/akg/ascend/maximum.py +36 -0
  914. mindspore/ops/_op_impl/akg/ascend/minimum.py +39 -0
  915. mindspore/ops/_op_impl/akg/ascend/mul.py +41 -0
  916. mindspore/ops/_op_impl/akg/ascend/neg.py +37 -0
  917. mindspore/ops/_op_impl/akg/ascend/pow.py +35 -0
  918. mindspore/ops/_op_impl/akg/ascend/prod_force_se_a.py +33 -0
  919. mindspore/ops/_op_impl/akg/ascend/real_div.py +36 -0
  920. mindspore/ops/_op_impl/akg/ascend/reciprocal.py +32 -0
  921. mindspore/ops/_op_impl/akg/ascend/reduce_max.py +32 -0
  922. mindspore/ops/_op_impl/akg/ascend/reduce_min.py +32 -0
  923. mindspore/ops/_op_impl/akg/ascend/reduce_sum.py +37 -0
  924. mindspore/ops/_op_impl/akg/ascend/rsqrt.py +35 -0
  925. mindspore/ops/_op_impl/akg/ascend/select.py +37 -0
  926. mindspore/ops/_op_impl/akg/ascend/sqrt.py +35 -0
  927. mindspore/ops/_op_impl/akg/ascend/square.py +35 -0
  928. mindspore/ops/_op_impl/akg/ascend/sub.py +42 -0
  929. mindspore/ops/_op_impl/akg/cpu/__init__.py +23 -0
  930. mindspore/ops/_op_impl/akg/cpu/coo2csr.py +29 -0
  931. mindspore/ops/_op_impl/akg/cpu/csr2coo.py +29 -0
  932. mindspore/ops/_op_impl/akg/cpu/csr_gather.py +33 -0
  933. mindspore/ops/_op_impl/akg/cpu/csr_mm.py +34 -0
  934. mindspore/ops/_op_impl/akg/cpu/csr_mul.py +33 -0
  935. mindspore/ops/_op_impl/akg/cpu/csr_mv.py +33 -0
  936. mindspore/ops/_op_impl/akg/cpu/csr_reduce_sum.py +31 -0
  937. mindspore/ops/_op_impl/akg/gpu/__init__.py +24 -0
  938. mindspore/ops/_op_impl/akg/gpu/coo2csr.py +29 -0
  939. mindspore/ops/_op_impl/akg/gpu/csr2coo.py +29 -0
  940. mindspore/ops/_op_impl/akg/gpu/csr_div.py +36 -0
  941. mindspore/ops/_op_impl/akg/gpu/csr_gather.py +33 -0
  942. mindspore/ops/_op_impl/akg/gpu/csr_mm.py +37 -0
  943. mindspore/ops/_op_impl/akg/gpu/csr_mul.py +36 -0
  944. mindspore/ops/_op_impl/akg/gpu/csr_mv.py +36 -0
  945. mindspore/ops/_op_impl/akg/gpu/csr_reduce_sum.py +33 -0
  946. mindspore/ops/_op_impl/cpu/__init__.py +78 -0
  947. mindspore/ops/_op_impl/cpu/adam.py +49 -0
  948. mindspore/ops/_op_impl/cpu/adam_weight_decay.py +47 -0
  949. mindspore/ops/_op_impl/cpu/arg_max.py +30 -0
  950. mindspore/ops/_op_impl/cpu/arg_max_with_value.py +31 -0
  951. mindspore/ops/_op_impl/cpu/arg_min_with_value.py +31 -0
  952. mindspore/ops/_op_impl/cpu/buffer_append.py +28 -0
  953. mindspore/ops/_op_impl/cpu/buffer_get.py +28 -0
  954. mindspore/ops/_op_impl/cpu/buffer_sample.py +28 -0
  955. mindspore/ops/_op_impl/cpu/cast.py +171 -0
  956. mindspore/ops/_op_impl/cpu/concat_offset.py +38 -0
  957. mindspore/ops/_op_impl/cpu/conv2d.py +30 -0
  958. mindspore/ops/_op_impl/cpu/conv3d.py +30 -0
  959. mindspore/ops/_op_impl/cpu/div.py +32 -0
  960. mindspore/ops/_op_impl/cpu/dropout.py +31 -0
  961. mindspore/ops/_op_impl/cpu/dropout_grad.py +30 -0
  962. mindspore/ops/_op_impl/cpu/dynamic_shape.py +42 -0
  963. mindspore/ops/_op_impl/cpu/dynamic_stitch.py +41 -0
  964. mindspore/ops/_op_impl/cpu/equal_count.py +30 -0
  965. mindspore/ops/_op_impl/cpu/gather_d.py +49 -0
  966. mindspore/ops/_op_impl/cpu/gather_d_grad.py +38 -0
  967. mindspore/ops/_op_impl/cpu/gather_d_grad_v2.py +40 -0
  968. mindspore/ops/_op_impl/cpu/gather_v2.py +40 -0
  969. mindspore/ops/_op_impl/cpu/hsigmoid.py +33 -0
  970. mindspore/ops/_op_impl/cpu/hsigmoid_grad.py +34 -0
  971. mindspore/ops/_op_impl/cpu/hswish.py +32 -0
  972. mindspore/ops/_op_impl/cpu/hswish_grad.py +33 -0
  973. mindspore/ops/_op_impl/cpu/identity_n.py +40 -0
  974. mindspore/ops/_op_impl/cpu/is_finite.py +39 -0
  975. mindspore/ops/_op_impl/cpu/l2loss.py +30 -0
  976. mindspore/ops/_op_impl/cpu/layer_norm.py +36 -0
  977. mindspore/ops/_op_impl/cpu/layer_norm_grad.py +38 -0
  978. mindspore/ops/_op_impl/cpu/maximum.py +35 -0
  979. mindspore/ops/_op_impl/cpu/maximum_grad.py +47 -0
  980. mindspore/ops/_op_impl/cpu/minimum.py +40 -0
  981. mindspore/ops/_op_impl/cpu/minimum_grad.py +51 -0
  982. mindspore/ops/_op_impl/cpu/mirror_pad.py +36 -0
  983. mindspore/ops/_op_impl/cpu/mirror_pad_grad.py +36 -0
  984. mindspore/ops/_op_impl/cpu/mul.py +32 -0
  985. mindspore/ops/_op_impl/cpu/one_hot.py +31 -0
  986. mindspore/ops/_op_impl/cpu/pad.py +32 -0
  987. mindspore/ops/_op_impl/cpu/pow.py +32 -0
  988. mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +42 -0
  989. mindspore/ops/_op_impl/cpu/pyexecute.py +29 -0
  990. mindspore/ops/_op_impl/cpu/pyfunc.py +29 -0
  991. mindspore/ops/_op_impl/cpu/range.py +34 -0
  992. mindspore/ops/_op_impl/cpu/real_div.py +33 -0
  993. mindspore/ops/_op_impl/cpu/reduce_all.py +29 -0
  994. mindspore/ops/_op_impl/cpu/reduce_any.py +29 -0
  995. mindspore/ops/_op_impl/cpu/reduce_max.py +32 -0
  996. mindspore/ops/_op_impl/cpu/reduce_mean.py +40 -0
  997. mindspore/ops/_op_impl/cpu/reduce_min.py +32 -0
  998. mindspore/ops/_op_impl/cpu/reduce_prod.py +40 -0
  999. mindspore/ops/_op_impl/cpu/reduce_std.py +31 -0
  1000. mindspore/ops/_op_impl/cpu/reduce_sum.py +41 -0
  1001. mindspore/ops/_op_impl/cpu/space_to_batch_nd.py +38 -0
  1002. mindspore/ops/_op_impl/cpu/sparse_slice.py +62 -0
  1003. mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +60 -0
  1004. mindspore/ops/_op_impl/cpu/split.py +34 -0
  1005. mindspore/ops/_op_impl/cpu/sspaddmm.py +95 -0
  1006. mindspore/ops/_op_impl/cpu/stack.py +38 -0
  1007. mindspore/ops/_op_impl/cpu/sub.py +32 -0
  1008. mindspore/ops/_op_impl/cpu/tensor_copy_slices.py +41 -0
  1009. mindspore/ops/_op_impl/cpu/tile.py +37 -0
  1010. mindspore/ops/_op_impl/cpu/top_k.py +31 -0
  1011. mindspore/ops/_op_impl/cpu/transpose.py +39 -0
  1012. mindspore/ops/_primitive_cache.py +90 -0
  1013. mindspore/ops/_register_for_op.py +73 -0
  1014. mindspore/ops/_utils/__init__.py +20 -0
  1015. mindspore/ops/_utils/utils.py +147 -0
  1016. mindspore/ops/_vmap/__init__.py +25 -0
  1017. mindspore/ops/_vmap/vmap_array_ops.py +2149 -0
  1018. mindspore/ops/_vmap/vmap_base.py +533 -0
  1019. mindspore/ops/_vmap/vmap_convolution_ops.py +441 -0
  1020. mindspore/ops/_vmap/vmap_debug_ops.py +50 -0
  1021. mindspore/ops/_vmap/vmap_grad_math_ops.py +274 -0
  1022. mindspore/ops/_vmap/vmap_grad_nn_ops.py +806 -0
  1023. mindspore/ops/_vmap/vmap_image_ops.py +194 -0
  1024. mindspore/ops/_vmap/vmap_math_ops.py +993 -0
  1025. mindspore/ops/_vmap/vmap_nn_ops.py +2250 -0
  1026. mindspore/ops/_vmap/vmap_other_ops.py +105 -0
  1027. mindspore/ops/_vmap/vmap_random_ops.py +122 -0
  1028. mindspore/ops/_vmap/vmap_sparse_ops.py +89 -0
  1029. mindspore/ops/auto_generate/__init__.py +31 -0
  1030. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +309 -0
  1031. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +252 -0
  1032. mindspore/ops/auto_generate/gen_arg_handler.py +197 -0
  1033. mindspore/ops/auto_generate/gen_extend_func.py +1701 -0
  1034. mindspore/ops/auto_generate/gen_ops_def.py +8482 -0
  1035. mindspore/ops/auto_generate/gen_ops_prim.py +16704 -0
  1036. mindspore/ops/auto_generate/pyboost_inner_prim.py +549 -0
  1037. mindspore/ops/composite/__init__.py +71 -0
  1038. mindspore/ops/composite/base.py +1318 -0
  1039. mindspore/ops/composite/env_ops.py +41 -0
  1040. mindspore/ops/composite/math_ops.py +125 -0
  1041. mindspore/ops/composite/multitype_ops/__init__.py +77 -0
  1042. mindspore/ops/composite/multitype_ops/_compile_utils.py +1459 -0
  1043. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +897 -0
  1044. mindspore/ops/composite/multitype_ops/add_impl.py +606 -0
  1045. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +56 -0
  1046. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +56 -0
  1047. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +56 -0
  1048. mindspore/ops/composite/multitype_ops/div_impl.py +189 -0
  1049. mindspore/ops/composite/multitype_ops/equal_impl.py +335 -0
  1050. mindspore/ops/composite/multitype_ops/floordiv_impl.py +88 -0
  1051. mindspore/ops/composite/multitype_ops/getitem_impl.py +400 -0
  1052. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +109 -0
  1053. mindspore/ops/composite/multitype_ops/greater_impl.py +110 -0
  1054. mindspore/ops/composite/multitype_ops/in_impl.py +196 -0
  1055. mindspore/ops/composite/multitype_ops/left_shift_impl.py +37 -0
  1056. mindspore/ops/composite/multitype_ops/less_equal_impl.py +111 -0
  1057. mindspore/ops/composite/multitype_ops/less_impl.py +112 -0
  1058. mindspore/ops/composite/multitype_ops/logic_not_impl.py +113 -0
  1059. mindspore/ops/composite/multitype_ops/logical_and_impl.py +60 -0
  1060. mindspore/ops/composite/multitype_ops/logical_or_impl.py +61 -0
  1061. mindspore/ops/composite/multitype_ops/mod_impl.py +86 -0
  1062. mindspore/ops/composite/multitype_ops/mul_impl.py +294 -0
  1063. mindspore/ops/composite/multitype_ops/negative_impl.py +79 -0
  1064. mindspore/ops/composite/multitype_ops/not_equal_impl.py +290 -0
  1065. mindspore/ops/composite/multitype_ops/not_in_impl.py +196 -0
  1066. mindspore/ops/composite/multitype_ops/ones_like_impl.py +96 -0
  1067. mindspore/ops/composite/multitype_ops/pow_impl.py +87 -0
  1068. mindspore/ops/composite/multitype_ops/right_shift_impl.py +37 -0
  1069. mindspore/ops/composite/multitype_ops/setitem_impl.py +884 -0
  1070. mindspore/ops/composite/multitype_ops/sub_impl.py +116 -0
  1071. mindspore/ops/composite/multitype_ops/uadd_impl.py +29 -0
  1072. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +228 -0
  1073. mindspore/ops/deprecated.py +315 -0
  1074. mindspore/ops/function/__init__.py +782 -0
  1075. mindspore/ops/function/array_func.py +7226 -0
  1076. mindspore/ops/function/clip_func.py +384 -0
  1077. mindspore/ops/function/debug_func.py +181 -0
  1078. mindspore/ops/function/fft_func.py +44 -0
  1079. mindspore/ops/function/grad/__init__.py +34 -0
  1080. mindspore/ops/function/grad/grad_func.py +1425 -0
  1081. mindspore/ops/function/image_func.py +292 -0
  1082. mindspore/ops/function/linalg_func.py +416 -0
  1083. mindspore/ops/function/math_func.py +12228 -0
  1084. mindspore/ops/function/nn_func.py +8609 -0
  1085. mindspore/ops/function/other_func.py +115 -0
  1086. mindspore/ops/function/parameter_func.py +134 -0
  1087. mindspore/ops/function/random_func.py +1715 -0
  1088. mindspore/ops/function/reshard_func.py +104 -0
  1089. mindspore/ops/function/sparse_func.py +884 -0
  1090. mindspore/ops/function/sparse_unary_func.py +2422 -0
  1091. mindspore/ops/function/spectral_func.py +150 -0
  1092. mindspore/ops/function/vmap_func.py +117 -0
  1093. mindspore/ops/functional.py +464 -0
  1094. mindspore/ops/op_info_register.py +1572 -0
  1095. mindspore/ops/operations/__init__.py +722 -0
  1096. mindspore/ops/operations/_csr_ops.py +403 -0
  1097. mindspore/ops/operations/_custom_grad.py +181 -0
  1098. mindspore/ops/operations/_embedding_cache_ops.py +307 -0
  1099. mindspore/ops/operations/_grad_ops.py +2978 -0
  1100. mindspore/ops/operations/_infer_ops.py +19 -0
  1101. mindspore/ops/operations/_inner_ops.py +2544 -0
  1102. mindspore/ops/operations/_map_tensor_ops.py +112 -0
  1103. mindspore/ops/operations/_ms_kernel.py +601 -0
  1104. mindspore/ops/operations/_ocr_ops.py +379 -0
  1105. mindspore/ops/operations/_opaque_predicate_registry.py +41 -0
  1106. mindspore/ops/operations/_pyfunc_registry.py +58 -0
  1107. mindspore/ops/operations/_quant_ops.py +1844 -0
  1108. mindspore/ops/operations/_rl_inner_ops.py +1231 -0
  1109. mindspore/ops/operations/_scalar_ops.py +106 -0
  1110. mindspore/ops/operations/_sequence_ops.py +1155 -0
  1111. mindspore/ops/operations/_sparse_grad_ops.py +56 -0
  1112. mindspore/ops/operations/_tensor_array.py +359 -0
  1113. mindspore/ops/operations/_thor_ops.py +807 -0
  1114. mindspore/ops/operations/array_ops.py +6124 -0
  1115. mindspore/ops/operations/comm_ops.py +1985 -0
  1116. mindspore/ops/operations/control_ops.py +127 -0
  1117. mindspore/ops/operations/custom_ops.py +1129 -0
  1118. mindspore/ops/operations/debug_ops.py +678 -0
  1119. mindspore/ops/operations/image_ops.py +1041 -0
  1120. mindspore/ops/operations/inner_ops.py +697 -0
  1121. mindspore/ops/operations/linalg_ops.py +95 -0
  1122. mindspore/ops/operations/manually_defined/__init__.py +24 -0
  1123. mindspore/ops/operations/manually_defined/_inner.py +73 -0
  1124. mindspore/ops/operations/manually_defined/ops_def.py +2271 -0
  1125. mindspore/ops/operations/math_ops.py +5095 -0
  1126. mindspore/ops/operations/nn_ops.py +9575 -0
  1127. mindspore/ops/operations/other_ops.py +874 -0
  1128. mindspore/ops/operations/random_ops.py +1288 -0
  1129. mindspore/ops/operations/reshard_ops.py +53 -0
  1130. mindspore/ops/operations/rl_ops.py +288 -0
  1131. mindspore/ops/operations/sparse_ops.py +2753 -0
  1132. mindspore/ops/operations/spectral_ops.py +111 -0
  1133. mindspore/ops/primitive.py +1046 -0
  1134. mindspore/ops/signature.py +54 -0
  1135. mindspore/ops/vm_impl_registry.py +91 -0
  1136. mindspore/ops_generate/__init__.py +27 -0
  1137. mindspore/ops_generate/arg_dtype_cast.py +252 -0
  1138. mindspore/ops_generate/arg_handler.py +197 -0
  1139. mindspore/ops_generate/gen_aclnn_implement.py +263 -0
  1140. mindspore/ops_generate/gen_constants.py +36 -0
  1141. mindspore/ops_generate/gen_ops.py +1099 -0
  1142. mindspore/ops_generate/gen_ops_inner_prim.py +131 -0
  1143. mindspore/ops_generate/gen_pyboost_func.py +1052 -0
  1144. mindspore/ops_generate/gen_utils.py +209 -0
  1145. mindspore/ops_generate/op_proto.py +145 -0
  1146. mindspore/ops_generate/pyboost_utils.py +367 -0
  1147. mindspore/ops_generate/template.py +261 -0
  1148. mindspore/parallel/__init__.py +30 -0
  1149. mindspore/parallel/_auto_parallel_context.py +1486 -0
  1150. mindspore/parallel/_cell_wrapper.py +174 -0
  1151. mindspore/parallel/_cost_model_context.py +700 -0
  1152. mindspore/parallel/_dp_allreduce_fusion.py +159 -0
  1153. mindspore/parallel/_offload_context.py +275 -0
  1154. mindspore/parallel/_parallel_serialization.py +561 -0
  1155. mindspore/parallel/_ps_context.py +242 -0
  1156. mindspore/parallel/_recovery_context.py +110 -0
  1157. mindspore/parallel/_tensor.py +730 -0
  1158. mindspore/parallel/_transformer/__init__.py +35 -0
  1159. mindspore/parallel/_transformer/layers.py +765 -0
  1160. mindspore/parallel/_transformer/loss.py +251 -0
  1161. mindspore/parallel/_transformer/moe.py +693 -0
  1162. mindspore/parallel/_transformer/op_parallel_config.py +222 -0
  1163. mindspore/parallel/_transformer/transformer.py +3119 -0
  1164. mindspore/parallel/_utils.py +612 -0
  1165. mindspore/parallel/algo_parameter_config.py +400 -0
  1166. mindspore/parallel/checkpoint_transform.py +650 -0
  1167. mindspore/parallel/cluster/__init__.py +15 -0
  1168. mindspore/parallel/cluster/process_entity/__init__.py +18 -0
  1169. mindspore/parallel/cluster/process_entity/_api.py +352 -0
  1170. mindspore/parallel/cluster/process_entity/_utils.py +101 -0
  1171. mindspore/parallel/cluster/run.py +136 -0
  1172. mindspore/parallel/mpi/__init__.py +14 -0
  1173. mindspore/parallel/mpi/_mpi_config.py +116 -0
  1174. mindspore/parallel/parameter_broadcast.py +151 -0
  1175. mindspore/parallel/shard.py +481 -0
  1176. mindspore/parallel/transform_safetensors.py +993 -0
  1177. mindspore/profiler/__init__.py +28 -0
  1178. mindspore/profiler/common/__init__.py +14 -0
  1179. mindspore/profiler/common/constant.py +29 -0
  1180. mindspore/profiler/common/exceptions/__init__.py +14 -0
  1181. mindspore/profiler/common/exceptions/error_code.py +83 -0
  1182. mindspore/profiler/common/exceptions/exceptions.py +286 -0
  1183. mindspore/profiler/common/process_pool.py +41 -0
  1184. mindspore/profiler/common/registry.py +47 -0
  1185. mindspore/profiler/common/singleton.py +28 -0
  1186. mindspore/profiler/common/struct_type.py +118 -0
  1187. mindspore/profiler/common/util.py +472 -0
  1188. mindspore/profiler/common/validator/__init__.py +14 -0
  1189. mindspore/profiler/common/validator/validate_path.py +84 -0
  1190. mindspore/profiler/dynamic_profiler.py +694 -0
  1191. mindspore/profiler/envprofiling.py +254 -0
  1192. mindspore/profiler/parser/__init__.py +14 -0
  1193. mindspore/profiler/parser/aicpu_data_parser.py +272 -0
  1194. mindspore/profiler/parser/ascend_analysis/__init__.py +14 -0
  1195. mindspore/profiler/parser/ascend_analysis/constant.py +71 -0
  1196. mindspore/profiler/parser/ascend_analysis/file_manager.py +180 -0
  1197. mindspore/profiler/parser/ascend_analysis/function_event.py +185 -0
  1198. mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +136 -0
  1199. mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +131 -0
  1200. mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +104 -0
  1201. mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
  1202. mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +123 -0
  1203. mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +86 -0
  1204. mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +75 -0
  1205. mindspore/profiler/parser/ascend_cluster_generator.py +116 -0
  1206. mindspore/profiler/parser/ascend_communicate_generator.py +314 -0
  1207. mindspore/profiler/parser/ascend_flops_generator.py +116 -0
  1208. mindspore/profiler/parser/ascend_fpbp_generator.py +82 -0
  1209. mindspore/profiler/parser/ascend_hccl_generator.py +271 -0
  1210. mindspore/profiler/parser/ascend_integrate_generator.py +42 -0
  1211. mindspore/profiler/parser/ascend_memory_generator.py +185 -0
  1212. mindspore/profiler/parser/ascend_msprof_exporter.py +282 -0
  1213. mindspore/profiler/parser/ascend_msprof_generator.py +187 -0
  1214. mindspore/profiler/parser/ascend_op_generator.py +334 -0
  1215. mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
  1216. mindspore/profiler/parser/ascend_timeline_generator.py +545 -0
  1217. mindspore/profiler/parser/base_timeline_generator.py +483 -0
  1218. mindspore/profiler/parser/container.py +229 -0
  1219. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +697 -0
  1220. mindspore/profiler/parser/flops_parser.py +531 -0
  1221. mindspore/profiler/parser/framework_enum.py +111 -0
  1222. mindspore/profiler/parser/framework_parser.py +464 -0
  1223. mindspore/profiler/parser/framework_struct.py +61 -0
  1224. mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
  1225. mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
  1226. mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
  1227. mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
  1228. mindspore/profiler/parser/hccl_parser.py +573 -0
  1229. mindspore/profiler/parser/hwts_log_parser.py +122 -0
  1230. mindspore/profiler/parser/integrator.py +526 -0
  1231. mindspore/profiler/parser/memory_usage_parser.py +277 -0
  1232. mindspore/profiler/parser/minddata_analyzer.py +800 -0
  1233. mindspore/profiler/parser/minddata_parser.py +186 -0
  1234. mindspore/profiler/parser/minddata_pipeline_parser.py +299 -0
  1235. mindspore/profiler/parser/op_intermediate_parser.py +149 -0
  1236. mindspore/profiler/parser/optime_parser.py +250 -0
  1237. mindspore/profiler/parser/profiler_info.py +213 -0
  1238. mindspore/profiler/parser/step_trace_parser.py +666 -0
  1239. mindspore/profiler/profiler.py +153 -0
  1240. mindspore/profiler/profiling.py +1922 -0
  1241. mindspore/rewrite/__init__.py +28 -0
  1242. mindspore/rewrite/api/__init__.py +17 -0
  1243. mindspore/rewrite/api/node.py +519 -0
  1244. mindspore/rewrite/api/node_type.py +53 -0
  1245. mindspore/rewrite/api/pattern_engine.py +490 -0
  1246. mindspore/rewrite/api/scoped_value.py +181 -0
  1247. mindspore/rewrite/api/symbol_tree.py +497 -0
  1248. mindspore/rewrite/ast_helpers/__init__.py +25 -0
  1249. mindspore/rewrite/ast_helpers/ast_converter.py +143 -0
  1250. mindspore/rewrite/ast_helpers/ast_finder.py +404 -0
  1251. mindspore/rewrite/ast_helpers/ast_flattener.py +268 -0
  1252. mindspore/rewrite/ast_helpers/ast_modifier.py +605 -0
  1253. mindspore/rewrite/ast_helpers/ast_replacer.py +79 -0
  1254. mindspore/rewrite/common/__init__.py +19 -0
  1255. mindspore/rewrite/common/config.py +24 -0
  1256. mindspore/rewrite/common/error_log.py +39 -0
  1257. mindspore/rewrite/common/event.py +28 -0
  1258. mindspore/rewrite/common/namer.py +271 -0
  1259. mindspore/rewrite/common/namespace.py +118 -0
  1260. mindspore/rewrite/common/observable.py +44 -0
  1261. mindspore/rewrite/common/observer.py +54 -0
  1262. mindspore/rewrite/node/__init__.py +22 -0
  1263. mindspore/rewrite/node/call_function.py +95 -0
  1264. mindspore/rewrite/node/cell_container.py +139 -0
  1265. mindspore/rewrite/node/control_flow.py +113 -0
  1266. mindspore/rewrite/node/node.py +1428 -0
  1267. mindspore/rewrite/node/node_manager.py +283 -0
  1268. mindspore/rewrite/node/node_topological_manager.py +223 -0
  1269. mindspore/rewrite/parsers/__init__.py +29 -0
  1270. mindspore/rewrite/parsers/arguments_parser.py +63 -0
  1271. mindspore/rewrite/parsers/assign_parser.py +852 -0
  1272. mindspore/rewrite/parsers/attribute_parser.py +57 -0
  1273. mindspore/rewrite/parsers/class_def_parser.py +289 -0
  1274. mindspore/rewrite/parsers/constant_parser.py +104 -0
  1275. mindspore/rewrite/parsers/container_parser.py +88 -0
  1276. mindspore/rewrite/parsers/expr_parser.py +55 -0
  1277. mindspore/rewrite/parsers/for_parser.py +61 -0
  1278. mindspore/rewrite/parsers/function_def_parser.py +84 -0
  1279. mindspore/rewrite/parsers/if_parser.py +85 -0
  1280. mindspore/rewrite/parsers/module_parser.py +117 -0
  1281. mindspore/rewrite/parsers/parser.py +43 -0
  1282. mindspore/rewrite/parsers/parser_register.py +86 -0
  1283. mindspore/rewrite/parsers/return_parser.py +37 -0
  1284. mindspore/rewrite/parsers/while_parser.py +59 -0
  1285. mindspore/rewrite/sparsify/__init__.py +0 -0
  1286. mindspore/rewrite/sparsify/sparse_transformer.py +457 -0
  1287. mindspore/rewrite/sparsify/sparsify.py +112 -0
  1288. mindspore/rewrite/sparsify/utils.py +179 -0
  1289. mindspore/rewrite/symbol_tree/__init__.py +20 -0
  1290. mindspore/rewrite/symbol_tree/symbol_tree.py +1819 -0
  1291. mindspore/rewrite/symbol_tree/symbol_tree_builder.py +76 -0
  1292. mindspore/rewrite/symbol_tree/symbol_tree_dumper.py +142 -0
  1293. mindspore/run_check/__init__.py +20 -0
  1294. mindspore/run_check/_check_version.py +507 -0
  1295. mindspore/run_check/run_check.py +66 -0
  1296. mindspore/safeguard/__init__.py +18 -0
  1297. mindspore/safeguard/rewrite_obfuscation.py +875 -0
  1298. mindspore/scipy/__init__.py +18 -0
  1299. mindspore/scipy/fft.py +264 -0
  1300. mindspore/scipy/linalg.py +919 -0
  1301. mindspore/scipy/ops.py +165 -0
  1302. mindspore/scipy/ops_grad.py +115 -0
  1303. mindspore/scipy/ops_wrapper.py +74 -0
  1304. mindspore/scipy/optimize/__init__.py +20 -0
  1305. mindspore/scipy/optimize/_bfgs.py +230 -0
  1306. mindspore/scipy/optimize/_lagrange.py +201 -0
  1307. mindspore/scipy/optimize/_lbfgs.py +146 -0
  1308. mindspore/scipy/optimize/gradient_optimization_algorithm.py +168 -0
  1309. mindspore/scipy/optimize/line_search.py +370 -0
  1310. mindspore/scipy/optimize/linear_sum_assignment.py +78 -0
  1311. mindspore/scipy/optimize/minimize.py +200 -0
  1312. mindspore/scipy/utils.py +156 -0
  1313. mindspore/scipy/utils_const.py +246 -0
  1314. mindspore/train/__init__.py +48 -0
  1315. mindspore/train/_utils.py +465 -0
  1316. mindspore/train/amp.py +935 -0
  1317. mindspore/train/anf_ir_pb2.py +1517 -0
  1318. mindspore/train/callback/__init__.py +44 -0
  1319. mindspore/train/callback/_backup_and_restore.py +117 -0
  1320. mindspore/train/callback/_callback.py +613 -0
  1321. mindspore/train/callback/_checkpoint.py +814 -0
  1322. mindspore/train/callback/_cluster_monitor.py +201 -0
  1323. mindspore/train/callback/_dataset_graph.py +150 -0
  1324. mindspore/train/callback/_early_stop.py +239 -0
  1325. mindspore/train/callback/_flops_collector.py +239 -0
  1326. mindspore/train/callback/_history.py +92 -0
  1327. mindspore/train/callback/_lambda_callback.py +80 -0
  1328. mindspore/train/callback/_landscape.py +1049 -0
  1329. mindspore/train/callback/_loss_monitor.py +107 -0
  1330. mindspore/train/callback/_lr_scheduler_callback.py +76 -0
  1331. mindspore/train/callback/_on_request_exit.py +298 -0
  1332. mindspore/train/callback/_reduce_lr_on_plateau.py +226 -0
  1333. mindspore/train/callback/_summary_collector.py +1184 -0
  1334. mindspore/train/callback/_tft_register.py +352 -0
  1335. mindspore/train/callback/_time_monitor.py +141 -0
  1336. mindspore/train/checkpoint_pb2.py +233 -0
  1337. mindspore/train/data_sink.py +219 -0
  1338. mindspore/train/dataset_helper.py +692 -0
  1339. mindspore/train/lineage_pb2.py +1260 -0
  1340. mindspore/train/loss_scale_manager.py +213 -0
  1341. mindspore/train/memory_profiling_pb2.py +298 -0
  1342. mindspore/train/metrics/__init__.py +175 -0
  1343. mindspore/train/metrics/accuracy.py +133 -0
  1344. mindspore/train/metrics/auc.py +129 -0
  1345. mindspore/train/metrics/bleu_score.py +170 -0
  1346. mindspore/train/metrics/confusion_matrix.py +700 -0
  1347. mindspore/train/metrics/cosine_similarity.py +109 -0
  1348. mindspore/train/metrics/dice.py +116 -0
  1349. mindspore/train/metrics/error.py +175 -0
  1350. mindspore/train/metrics/fbeta.py +167 -0
  1351. mindspore/train/metrics/hausdorff_distance.py +333 -0
  1352. mindspore/train/metrics/loss.py +97 -0
  1353. mindspore/train/metrics/mean_surface_distance.py +189 -0
  1354. mindspore/train/metrics/metric.py +373 -0
  1355. mindspore/train/metrics/occlusion_sensitivity.py +225 -0
  1356. mindspore/train/metrics/perplexity.py +133 -0
  1357. mindspore/train/metrics/precision.py +160 -0
  1358. mindspore/train/metrics/recall.py +159 -0
  1359. mindspore/train/metrics/roc.py +223 -0
  1360. mindspore/train/metrics/root_mean_square_surface_distance.py +191 -0
  1361. mindspore/train/metrics/topk.py +167 -0
  1362. mindspore/train/mind_ir_pb2.py +1908 -0
  1363. mindspore/train/model.py +2252 -0
  1364. mindspore/train/node_strategy_pb2.py +653 -0
  1365. mindspore/train/print_pb2.py +184 -0
  1366. mindspore/train/profiling_parallel_pb2.py +151 -0
  1367. mindspore/train/serialization.py +3325 -0
  1368. mindspore/train/summary/__init__.py +23 -0
  1369. mindspore/train/summary/_lineage_adapter.py +41 -0
  1370. mindspore/train/summary/_summary_adapter.py +496 -0
  1371. mindspore/train/summary/_writer_pool.py +207 -0
  1372. mindspore/train/summary/enums.py +56 -0
  1373. mindspore/train/summary/summary_record.py +581 -0
  1374. mindspore/train/summary/writer.py +167 -0
  1375. mindspore/train/summary_pb2.py +1165 -0
  1376. mindspore/train/train_thor/__init__.py +20 -0
  1377. mindspore/train/train_thor/convert_utils.py +268 -0
  1378. mindspore/train/train_thor/dataset_helper.py +192 -0
  1379. mindspore/train/train_thor/model_thor.py +257 -0
  1380. mindspore/utils/__init__.py +21 -0
  1381. mindspore/utils/utils.py +60 -0
  1382. mindspore/version.py +1 -0
  1383. mindspore-2.4.0.dist-info/METADATA +352 -0
  1384. mindspore-2.4.0.dist-info/RECORD +1387 -0
  1385. mindspore-2.4.0.dist-info/WHEEL +5 -0
  1386. mindspore-2.4.0.dist-info/entry_points.txt +3 -0
  1387. mindspore-2.4.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,2544 @@
1
+ # Copyright 2020-2022 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+
16
+ """Inner operators."""
17
+ from types import FunctionType, MethodType
18
+ from collections.abc import Iterable
19
+ import os
20
+ import weakref
21
+ import numpy as np
22
+
23
+ from mindspore.common import Tensor
24
+ from mindspore.common._stub_tensor import StubTensor
25
+ from mindspore.ops import composite as C
26
+ from mindspore.ops.operations.array_ops import Cast
27
+ from mindspore.ops.operations._scalar_ops import bit_or, bit_and
28
+ from mindspore.ops import signature as sig
29
+ from mindspore.ops.operations.math_ops import _infer_shape_reduce
30
+ from mindspore.ops.primitive import PrimitiveWithCheck, PrimitiveWithInfer, prim_attr_register, Primitive, \
31
+ _run_op, _check_contains_variable
32
+ from mindspore._c_expression import Tensor as Tensor_
33
+ from mindspore._c_expression import typing, HookType
34
+ from mindspore import _checkparam as validator
35
+ from mindspore.common import dtype as mstype
36
+ from mindspore.common.parameter import Parameter
37
+ from mindspore.communication.management import GlobalComm, get_rank, _get_group, get_group_size
38
+ from mindspore.common.api import _pynative_executor
39
+ from mindspore.common._register_for_adapter import ms_adapter_registry
40
+ from mindspore import ops
41
+ from ..auto_generate import TensorCopySlices, SiLU, Cummin, TopKRouter, ExtractImagePatches, DecoderKVCache, \
42
+ PromptKVCache, ApplyCamePart1, ApplyCamePart2, ApplyCamePart3, ApplyCamePart4
43
+
44
+ # Bit operation
45
+ bit_and = bit_and()
46
+ bit_or = bit_or()
47
+ bit_xor = Primitive("bit_xor")
48
+ bit_left_shift = Primitive("bit_left_shift")
49
+ bit_right_shift = Primitive("bit_right_shift")
50
+ # String operation
51
+ string_lt = Primitive("string_lt")
52
+ string_gt = Primitive("string_gt")
53
+ string_le = Primitive("string_le")
54
+ string_ge = Primitive("string_ge")
55
+ string_not = Primitive("string_not")
56
+ string_in = Primitive("string_in")
57
+ string_mul = Primitive("string_mul")
58
+ string_getitem = Primitive("string_getitem")
59
+
60
+
61
+ class Generator(Primitive):
62
+ r"""
63
+ Manage the state of random number generation.
64
+
65
+ Inputs:
66
+ - **cmd** (int) : operation to be executed.
67
+ - **inputs** (tuple[tensor]) : inputs for the operation.
68
+
69
+ Outputs:
70
+ - **seed** (Tensor): Seed for the random number generation algorithm.
71
+ - **offset** (Tensor): Offset of the random number sequence.
72
+ - **state** (Tensor): State tensor, can be used to restore current state.
73
+ """
74
+
75
+ @prim_attr_register
76
+ def __init__(self):
77
+ self.add_prim_attr("side_effect_mem", True)
78
+
79
+ def __call__(self, cmd, inputs):
80
+ if cmd == 0: # step cmd
81
+ return inputs[0], inputs[1]
82
+ return super().__call__(cmd, inputs)
83
+
84
+
85
+ class Quant(PrimitiveWithInfer):
86
+ r"""
87
+ Returns the quantized value of input_x.
88
+
89
+ If `sqrt_mode` is False:
90
+
91
+ .. math::
92
+ y = round(scale * x + offset)
93
+
94
+ If `sqrt_mode` is True:
95
+
96
+ .. math::
97
+ y = round(scale * x * scale + offset)
98
+
99
+ Note:
100
+ This operation only support Atlas 200/300/500 inference product.
101
+
102
+ Args:
103
+ scale (float) : Specifies the scaling ratio.
104
+ offset (float): Specifies the offset.
105
+ sqrt_mode (bool) : Specifies whether to perform square root on `scale`. Default: ``False``.
106
+ round_mode (str): Specifies the way to round. Must be one of ["Round", "Floor", "Ceil", "Trunc"].
107
+ Default: "Round".
108
+
109
+ Inputs:
110
+ - **input_x** (Tensor) : Input tensor. Its data type must be mindspore.float16 or mindspore.float32.
111
+
112
+ Outputs:
113
+ - Tensor: The quantized output tensor of type mindspore.int8.
114
+
115
+ Examples:
116
+ >>> input_x = Tensor([100.0, 150.0], mstype.float32)
117
+ >>> quant = ops.Quant(80.0, 0.0, False, "Round")
118
+ >>> y = quant(input_x)
119
+ """
120
+
121
+ @prim_attr_register
122
+ def __init__(self, scale, offset, sqrt_mode=False, round_mode="Round"):
123
+ self.scale = validator.check_value_type("scale", scale, [float], self.name)
124
+ self.offset = validator.check_value_type("offset", offset, [float], self.name)
125
+ self.sqrt_mode = validator.check_value_type("sqrt_mode", sqrt_mode, [bool], self.name)
126
+ self.round_mode = validator.check_string(round_mode, ["Round", "Floor", "Ceil", "Trunc"],
127
+ "round_mode", self.name)
128
+ self.add_prim_attr("dst_type", mstype.int8)
129
+
130
+ def infer_shape(self, x_shape):
131
+ return x_shape
132
+
133
+ def infer_dtype(self, x_type):
134
+ validator.check_subclass("input_x", x_type, mstype.tensor_type, self.name)
135
+ validator.check_type_name("input_x", x_type, [mstype.float16, mstype.float32], self.name)
136
+ return self.get_attr_dict()['dst_type']
137
+
138
+
139
+ class Lamb(PrimitiveWithInfer):
140
+ r"""
141
+ LAMB optimizer algorithm.
142
+
143
+ The Lamb optimizer is proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes
144
+ <https://arxiv.org/abs/1904.00962>`_.
145
+
146
+ Inputs:
147
+ - **var** (Tensor) - Weights to be updated. The shape is :math:`(N, *)` where :math:`*` means,
148
+ any number of additional dimensions. The data type can be float16 or float32.
149
+ - **m** (Tensor) - The 1st moment vector in the updating formula,
150
+ the shape and data type value should be the same as `var`.
151
+ - **v** (Tensor) - the 2nd moment vector in the updating formula,
152
+ the shape and data type value should be the same as `var`. Mean square gradients with the same type as `var`.
153
+ - **lr** (float) - :math:`l` in the updating formula. The paper suggested value is :math:`10^{-8}`,
154
+ the data type value should be the same as `var`.
155
+ - **beta1** (float) - The exponential decay rate for the 1st moment estimations,
156
+ the data type value should be the same as `var`. The paper suggested value is :math:`0.9`
157
+ - **beta2** (float) - The exponential decay rate for the 2nd moment estimations,
158
+ the data type value should be the same as `var`. The paper suggested value is :math:`0.999`
159
+ - **epsilon** (float) - Term added to the denominator to improve numerical stability.
160
+ - **decay** (float) - The weight decay value, must be a scalar tensor with float data type.
161
+ Default: 0.0.
162
+ - **global_step** (Tensor) - Tensor to record current global step.
163
+ - **gradient** (Tensor) - Gradient, has the same shape and data type as `var`.
164
+
165
+ Outputs:
166
+ Tensor, the updated parameters.
167
+
168
+ - **var** (Tensor) - The same shape and data type as `var`.
169
+
170
+ Supported Platforms:
171
+ ``Ascend````GPU``
172
+ """
173
+
174
+ @prim_attr_register
175
+ def __init__(self):
176
+ """Initialize Lamb."""
177
+ self.add_prim_attr('side_effect_mem', True)
178
+
179
+ def infer_shape(self, var_shape, m_shape, v_shape, lr_shape, beta1_shape, beta2_shape,
180
+ epsilon_shape, decay_shape, global_step_shape, gradient_shape):
181
+ validator.check("var_shape", var_shape, "m_shape", m_shape, validator.EQ, self.name)
182
+ validator.check("var_shape", var_shape, "v_shape", v_shape, validator.EQ, self.name)
183
+ validator.check("var_shape", var_shape, "gradient_shape", gradient_shape, validator.EQ, self.name)
184
+ return var_shape
185
+
186
+ def infer_dtype(self, var_dtype, m_dtype, v_dtype, lr_dtype, beta1_dtype, beta2_dtype,
187
+ epsilon_dtype, decay_dtype, global_step_dtype, gradient_dtype):
188
+ args = {"var": var_dtype, "m": m_dtype, "v": v_dtype, "grad": gradient_dtype}
189
+ validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name)
190
+
191
+ args = {"lr": lr_dtype, "decay": decay_dtype, "beta1": beta1_dtype, "beta2": beta2_dtype,
192
+ "epsilon": epsilon_dtype}
193
+ validator.check_scalar_or_tensor_types_same(args, [mstype.float32], self.name, True)
194
+ return var_dtype
195
+
196
+
197
+ class Dequant(PrimitiveWithInfer):
198
+ r"""
199
+ Returns the dequantized value of input_x.
200
+ This operation will do ReLU to the dequantized value if `relu_flag` is True.
201
+
202
+ If `sqrt_mode` is False:
203
+
204
+ .. math::
205
+ y = x * deq\_scale
206
+
207
+ If `sqrt_mode` is True:
208
+
209
+ .. math::
210
+ y = x * deq\_scale * deq\_scale
211
+
212
+ Note:
213
+ This operation only support Atlas 200/300/500 inference product.
214
+
215
+ Args:
216
+ sqrt_mode (bool) : Specifies whether to perform square root on `scale`. Default: ``False``.
217
+ relu_flag (bool): Specifies whether to perform ReLU. Default: ``False``.
218
+
219
+ Inputs:
220
+ - **input_x** (Tensor) : Input tensor. Must be mindspore.int32.
221
+ - **deq_scale** (Tensor) : Specifies the scaling ratio.
222
+ Data type must be mindspore.float16 or mindspore.uint64
223
+
224
+ Outputs:
225
+ - Tensor: The quantized output tensor of type mindspore.float16.
226
+
227
+ Examples:
228
+ >>> input_x = Tensor([100.0, 150.0], mstype.float32)
229
+ >>> dequant = ops.Dequant(False, False)
230
+ >>> y = dequant(input_x)
231
+ """
232
+
233
+ @prim_attr_register
234
+ def __init__(self, sqrt_mode=False, relu_flag=False, dtype=mstype.float16):
235
+ self.sqrt_mode = validator.check_value_type("sqrt_mode", sqrt_mode, [bool], self.name)
236
+ self.relu_flag = validator.check_value_type("relu_flag", relu_flag, [bool], self.name)
237
+ self.dtype = dtype
238
+
239
+ def infer_shape(self, x_shape, deq_scale_shape):
240
+ return x_shape
241
+
242
+ def infer_dtype(self, x_type, deq_scale_type):
243
+ validator.check_subclass("x", x_type, mstype.tensor_type, self.name)
244
+ validator.check_type_name("x", x_type, [mstype.int32], self.name)
245
+ validator.check_type_name("deq_scale", deq_scale_type, [mstype.float16, mstype.uint64], self.name)
246
+ return mstype.float16
247
+
248
+
249
+ class AntiQuant(Primitive):
250
+ r"""
251
+ Returns the antiquantized value of input_x.
252
+
253
+ If `sqrt_mode` is False:
254
+
255
+ .. math::
256
+ y = scale * (x + offset)
257
+
258
+ If `sqrt_mode` is True:
259
+
260
+ .. math::
261
+ y = scale * scale * (x + offset)
262
+
263
+ Note:
264
+ This operation only support Atlas 200/300/500 inference product.
265
+
266
+ Args:
267
+ scale (float) : Specifies the scaling ratio.
268
+ offset (float): Specifies the offset.
269
+ sqrt_mode (bool) : Specifies whether to perform square root on `scale`. Default: ``False``.
270
+
271
+ Inputs:
272
+ - **input_x** (Tensor) : Input tensor. Must be mindspore.int8.
273
+
274
+ Outputs:
275
+ - Tensor: The antiquantized output tensor of type mindspore.float32.
276
+
277
+ Examples:
278
+ >>> from mindspore.ops.operations._inner_ops import AntiQuant
279
+ >>> input_x = Tensor([50.0, 20.0], mstype.int8)
280
+ >>> antiquant = AntiQuant(2.0, 1.0, False)
281
+ >>> y = antiquant(input_x)
282
+ >>> print(y)
283
+ [102. 42.]
284
+ """
285
+
286
+ @prim_attr_register
287
+ def __init__(self, sqrt_mode=False, dtype=mstype.float16):
288
+ super().__init__("AntiQuant")
289
+ self.sqrt_mode = validator.check_value_type("sqrt_mode", sqrt_mode, [bool], self.name)
290
+ self.dtype = dtype
291
+
292
+ self.init_prim_io_names(inputs=['x', 'scale', 'offset'],
293
+ outputs=['y'])
294
+
295
+
296
+ class MatrixDiag(PrimitiveWithInfer):
297
+ """
298
+ Returns a batched diagonal tensor with a given batched diagonal values.
299
+
300
+ Inputs:
301
+ - **x** (Tensor) - A tensor which to be element-wise multi by `assist`. It can be one of the following data
302
+ types: float32, float16, int32, int8, and uint8.
303
+ - **assist** (Tensor) - A eye tensor of the same type as `x`. It's rank must be greater than or equal to 2 and
304
+ it's last dimension must be equal to the second to last dimension.
305
+
306
+ Outputs:
307
+ Tensor, has the same type and shape as input `assist`.
308
+
309
+ Examples:
310
+ >>> x = Tensor(np.array([1, -1]), mstype.float32)
311
+ >>> assist = Tensor(np.arange(-12, 0).reshape(3, 2, 2), mindspore.float32)
312
+ >>> matrix_diag = ops.MatrixDiag()
313
+ >>> result = matrix_diag(x, assist)
314
+ >>> print(result)
315
+ [[[-12. 11.]
316
+ [-10. 9.]]
317
+ [[ -8. 7.]
318
+ [ -6. 5.]]
319
+ [[ -4. 3.]
320
+ [ -2. 1.]]]
321
+ """
322
+
323
+ @prim_attr_register
324
+ def __init__(self):
325
+ """Initialize MatrixDiag"""
326
+
327
+ def infer_dtype(self, x_dtype, assist_dtype):
328
+ valid_type = [mstype.float16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8]
329
+ args = {"x": x_dtype, "assist": assist_dtype}
330
+ validator.check_tensors_dtypes_same_and_valid(args, valid_type, self.name)
331
+ return x_dtype
332
+
333
+ def infer_shape(self, x_shape, assist_shape):
334
+ validator.check_int(len(assist_shape), 2, validator.GE, "assist rank", self.name)
335
+ validator.check('rank of x', len(x_shape) + 1,
336
+ 'rank of assist', len(assist_shape), validator.LE, self.name)
337
+ validator.check('assist\'s penultimate dimension', assist_shape[-2], 'assist\'s last dimension',
338
+ assist_shape[-1], validator.EQ, self.name)
339
+
340
+ r_end_dim = -len(x_shape)
341
+ r_idx = -1
342
+ while r_idx >= r_end_dim:
343
+ if x_shape[r_idx] != 1:
344
+ validator.check("reverse x dim %d" % r_idx, x_shape[r_idx], "reverse assist dim %d" %
345
+ assist_shape[r_idx - 1], assist_shape[r_idx - 1], validator.EQ, self.name)
346
+ r_idx = r_idx - 1
347
+
348
+ return assist_shape
349
+
350
+
351
+ class MatrixDiagPart(PrimitiveWithInfer):
352
+ r"""
353
+ Returns the batched diagonal part of a batched tensor.
354
+
355
+ Inputs:
356
+ - **x** (Tensor) - The batched tensor. It can be one of the following data types:
357
+ float32, float16, int32, int8, uint8.
358
+ - **assist** (Tensor) - A eye tensor of the same type as `x`. With shape same as `x`.
359
+
360
+ Outputs:
361
+ Tensor, data type same as input `x`. The shape must be x.shape[:-2] + [min(x.shape[-2:])].
362
+
363
+ Examples:
364
+ >>> x = Tensor([[[-1, 0], [0, 1]], [[-1, 0], [0, 1]], [[-1, 0], [0, 1]]], mindspore.float32)
365
+ >>> assist = Tensor(np.arange(-12, 0).reshape(3, 2, 2), mindspore.float32)
366
+ >>> matrix_diag_part = ops.MatrixDiagPart()
367
+ >>> result = matrix_diag_part(x, assist)
368
+ >>> print(result)
369
+ [[12., -9.], [8., -5.], [4., -1.]]
370
+ """
371
+
372
+ @prim_attr_register
373
+ def __init__(self):
374
+ """Initialize MatrixDiagPart"""
375
+
376
+ def infer_dtype(self, x_dtype, assist_dtype):
377
+ valid_type = [mstype.float16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8]
378
+ args = {"x": x_dtype, "assist": assist_dtype}
379
+ validator.check_tensors_dtypes_same_and_valid(args, valid_type, self.name)
380
+ return x_dtype
381
+
382
+ def infer_shape(self, x_shape, assist_shape):
383
+ validator.check_int(len(x_shape), 2, validator.GE, "x rank", self.name)
384
+ validator.check("x shape", x_shape, "assist shape", assist_shape, validator.EQ, self.name)
385
+
386
+ if assist_shape[-2] < assist_shape[-1]:
387
+ out_shape = assist_shape[:-1]
388
+ else:
389
+ out_shape = assist_shape[:-2] + assist_shape[-1:]
390
+ return out_shape
391
+
392
+
393
+ class MatrixSetDiag(PrimitiveWithInfer):
394
+ r"""
395
+ Modifies the batched diagonal part of a batched tensor.
396
+
397
+ Inputs:
398
+ - **x** (Tensor) - The batched tensor. Rank k+1, where k >= 1. It can be one of the following data types:
399
+ float32, float16, int32, int8, uint8.
400
+ - **diagonal** (Tensor) - The diagonal values. Must have the same type as input `x`. Rank k, where k >= 1.
401
+ - **assist** (Tensor) - A eye tensor of the same type as `x`. With shape same as `x`.
402
+
403
+ Outputs:
404
+ Tensor, data type same as input `x`. The shape same as `x`.
405
+
406
+ Examples:
407
+ >>> x = Tensor([[[-1, 0], [0, 1]], [[-1, 0], [0, 1]], [[-1, 0], [0, 1]]], mindspore.float32)
408
+ >>> diagonal = Tensor([[-1., 2.], [-1., 1.], [-1., 1.]], mindspore.float32)
409
+ >>> matrix_set_diag = ops.MatrixSetDiag()
410
+ >>> result = matrix_set_diag(x, diagonal)
411
+ >>> print(result)
412
+ [[[-1, 0], [0, 2]], [[-1, 0], [0, 1]], [[-1, 0], [0, 1]]]
413
+
414
+ """
415
+
416
+ @prim_attr_register
417
+ def __init__(self):
418
+ """Initialize MatrixSetDiag"""
419
+
420
+ def infer_dtype(self, x_dtype, diagonal_dtype, assist_dtype):
421
+ valid_type = [mstype.float16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8]
422
+ args = {"x": x_dtype, "diagonal": diagonal_dtype, "assist": assist_dtype}
423
+ validator.check_tensors_dtypes_same_and_valid(args, valid_type, self.name)
424
+ return x_dtype
425
+
426
+ def infer_shape(self, x_shape, diagonal_shape, assist_shape):
427
+ validator.check_int(len(x_shape), 2, validator.GE, "x rank", self.name)
428
+ validator.check("x shape", x_shape, "assist shape", assist_shape, validator.EQ, self.name)
429
+
430
+ if x_shape[-2] < x_shape[-1]:
431
+ validator.check("diagonal shape", diagonal_shape, "x shape excluding the last dimension",
432
+ x_shape[:-1], validator.EQ, self.name)
433
+ else:
434
+ validator.check("diagonal shape", diagonal_shape, "x shape excluding the second last dimension",
435
+ x_shape[:-2] + x_shape[-1:], validator.EQ, self.name)
436
+
437
+ return assist_shape
438
+
439
+
440
+ class ConfusionMulGrad(PrimitiveWithInfer):
441
+ """
442
+ `output0` is the dot product result of input0 and input1.
443
+
444
+ `output1` is the dot product result of input0 and input1, then apply the reducesum operation on it.
445
+
446
+ Args:
447
+ axis (Union[int, tuple[int], list[int]]): The dimensions to reduce.
448
+ Default:(), reduce all dimensions. Only constant value is allowed.
449
+ keep_dims (bool):
450
+
451
+ - If true, keep these reduced dimensions and the length as 1.
452
+ - If false, don't keep these dimensions. Default:False.
453
+
454
+ Inputs:
455
+ - **input_0** (Tensor) - The input Tensor.
456
+ - **input_1** (Tensor) - The input Tensor.
457
+ - **input_2** (Tensor) - The input Tensor.
458
+
459
+ Outputs:
460
+ - **output_0** (Tensor) - The same shape as `input0`.
461
+ - **output_1** (Tensor)
462
+
463
+ - If axis is (), and keep_dims is false, the output is a 0-D array representing
464
+ the sum of all elements in the input array.
465
+ - If axis is int, set as 2, and keep_dims is false,
466
+ the shape of output is :math:`(x_1,x_3,...,x_R)`.
467
+ - If axis is tuple(int), set as (2,3), and keep_dims is false,
468
+ the shape of output is :math:`(x_1,x_4,...x_R)`.
469
+
470
+ Examples:
471
+ >>> confusion_mul_grad = ops.ConfusionMulGrad()
472
+ >>> input_0 = Tensor(np.random.randint(-2, 2, (2, 3)), mindspore.float32)
473
+ >>> input_1 = Tensor(np.random.randint(0, 4, (2, 3)), mindspore.float32)
474
+ >>> input_2 = Tensor(np.random.randint(-4, 0, (2, 3)), mindspore.float32)
475
+ >>> output_0, output_1 = confusion_mul_grad(input_0, input_1, input_2)
476
+ output_0:
477
+ [[ 3. 1. 0.]
478
+ [-6. 2. -2.]]
479
+ output_1:
480
+ -3.0
481
+ """
482
+
483
+ @prim_attr_register
484
+ def __init__(self, axis=(), keep_dims=False):
485
+ self.init_prim_io_names(inputs=["input0", "input1", "input2"], outputs=["output0", "output1"])
486
+ self.axis_ = validator.check_value_type("axis", axis, [int, tuple, list], self.name)
487
+ self.keep_dims_ = validator.check_value_type("keep_dims", keep_dims, [bool], self.name)
488
+
489
+ def infer_shape(self, input0_shape, input1_shape, input2_shape):
490
+ outshape0 = input0_shape
491
+ outshape1 = _infer_shape_reduce(input1_shape, self.axis_, self.keep_dims_, self.name)
492
+ return outshape0, outshape1
493
+
494
+ def infer_dtype(self, input0_dtype, input1_dtype, input2_dtype):
495
+ validator.check_subclass("input0_dtype", input0_dtype, mstype.tensor_type, self.name)
496
+ validator.check_subclass("input1_dtype", input1_dtype, mstype.tensor_type, self.name)
497
+ validator.check_subclass("input2_dtype", input2_dtype, mstype.tensor_type, self.name)
498
+ return input0_dtype, input1_dtype
499
+
500
+
501
+ class ConvertToDynamic(PrimitiveWithCheck):
502
+ """
503
+ This op is used for dynamic rank testing. Its inferred shape will be unknown
504
+ during compile time, so that its output will appear to be dynamically ranked.
505
+ The input will not be altered in any way. Put this operator before the operator
506
+ being tested for dynamic rank support.
507
+
508
+ Args:
509
+ is_dynamic_rank (bool): If true, convert to dynamic rank.
510
+ If false, convert to dynamic shape. Default: ``False``.
511
+
512
+ Inputs:
513
+ - **input** (Tensor) - The tensor used for testing.
514
+
515
+ Outputs:
516
+ - **output** (Tensor) - Same shape, type and value as `input`.
517
+
518
+ Supported Platforms:
519
+ ``CPU``
520
+
521
+ Examples:
522
+ >>> import mindspore as ms
523
+ >>> import mindspore.nn as nn
524
+ >>> from mindspore.ops.operations import _inner_ops as inner
525
+ >>> from mindspore.ops import operations as P
526
+ >>> class TestDynamicNet(nn.Cell):
527
+ >>> def __init__(self):
528
+ >>> super(TestDynamicNet, self).__init__()
529
+ >>> self.convert_to_dynamic = inner.ConvertToDynamic()
530
+ >>> # suppose we are testing Reshape op
531
+ >>> self.reshape = P.Reshape()
532
+ >>>
533
+ >>> def construct(self, input, new_shape):
534
+ >>> dynamic_input = self.convert_to_dynamic(input)
535
+ >>> reshaped_input = self.reshape(dynamic_input, new_shape)
536
+ >>>
537
+ >>> ms.set_context(mode=ms.GRAPH_MODE, device_target="CPU")
538
+ >>> input = Tensor(np.array([0, 1, 2, 3])
539
+ >>> new_shape = (2, 2)
540
+ >>> net = TestDynamicNet()
541
+ >>> output = net(input, new_shape)
542
+ >>> print(output)
543
+ [[0, 1], [2, 3]
544
+ """
545
+
546
+ @prim_attr_register
547
+ def __init__(self, is_dynamic_rank=False):
548
+ validator.check_value_type('is_dynamic_rank', is_dynamic_rank, [bool], self.name)
549
+ self.init_prim_io_names(inputs=["input"], outputs=["output"])
550
+
551
+ def check_shape(self, input_shape):
552
+ validator.check("input_shape rank", len(input_shape), "", 0, validator.GT, self.name)
553
+
554
+ def check_dtype(self, input_dtype):
555
+ validator.check_subclass("input_dtype", input_dtype, mstype.tensor_type, self.name)
556
+
557
+
558
+ class GpuConvertToDynamicShape(PrimitiveWithCheck):
559
+ """
560
+ This op is used for dynamic shape testing. Its inferred shape will be unknown
561
+ during compile time, so that its output will appear to be dynamically shaped.
562
+ The input will not be altered in any way. Put this operator before the operator
563
+ being tested for dynamic shape support.
564
+
565
+ Inputs:
566
+ - **input** (Tensor) - The tensor used for testing.
567
+
568
+ Outputs:
569
+ - **output** (Tensor) - Same shape, type and value as `input`.
570
+
571
+ Examples:
572
+ >>> # make a model, since dynamic shape operators must be in GRAPH_MODE
573
+ >>> import mindspore as ms
574
+ >>> import mindspore.nn as nn
575
+ >>> from mindspore.ops.operations import _inner_ops as inner
576
+ >>> from mindspore.ops import operations as P
577
+ >>> class TestDynamicShapeReshapeNet(nn.Cell):
578
+ >>> def __init__(self):
579
+ >>> super(TestDynamicShapeReshapeNet, self).__init__()
580
+ >>> self.convert_to_dynamic_shape = inner.GpuConvertToDynamicShape()
581
+ >>> # suppose we are testing Reshape op
582
+ >>> self.reshape = P.Reshape()
583
+ >>>
584
+ >>> def construct(self, input, new_shape):
585
+ >>> dynamic_shape_input = self.convert_to_dynamic_shape(input)
586
+ >>> reshaped_input = self.reshape(input, new_shape)
587
+ >>>
588
+ >>> ms.set_context(mode=ms.GRAPH_MODE, device_target="GPU")
589
+ >>> input = Tensor(np.array([0, 1, 2, 3])
590
+ >>> new_shape = (2, 2)
591
+ >>> net = TestDynamicShapeReshapeNet()
592
+ >>> output = net(input, new_shape)
593
+ >>> print(output)
594
+ [[0, 1], [2, 3]
595
+ """
596
+
597
+ @prim_attr_register
598
+ def __init__(self):
599
+ self.init_prim_io_names(inputs=["input"], outputs=["output"])
600
+
601
+ def check_shape(self, input_shape):
602
+ validator.check("input_shape rank", len(input_shape), "", 0, validator.GT, self.name)
603
+
604
+ def check_dtype(self, input_dtype):
605
+ validator.check_subclass("input_dtype", input_dtype, mstype.tensor_type, self.name)
606
+
607
+
608
+ class ErrorOnDynamicShapeInput(PrimitiveWithInfer):
609
+ """
610
+ This op is used for dynamic shape testing. The only purpose of this operator is
611
+ that it will throw a value error if the input is dynamically shaped.
612
+
613
+ Inputs:
614
+ - **input** (Tensor) - The tensor used for testing.
615
+
616
+ Outputs:
617
+ - **output** (Tensor) - Same shape, type and value as `input`.
618
+
619
+ Examples:
620
+ >>> # make a model, since dynamic shape operators must be in GRAPH_MODE
621
+ >>> import mindspore as ms
622
+ >>> import mindspore.nn as nn
623
+ >>> from mindspore.ops.operations import _inner_ops as inner
624
+ >>> from mindspore.ops import operations as P
625
+ >>> class AssertDynamicShapeNet(nn.Cell):
626
+ >>> def __init__(self):
627
+ >>> super(AssertDynamicShapeNet, self).__init__()
628
+ >>> self.convert_to_dynamic_shape = inner.GpuConvertToDynamicShape()
629
+ >>> self.error_on_dynamic_shape_input = inner.ErrorOnDynamicShapeInput()
630
+ >>>
631
+ >>> def construct(self, input, new_shape):
632
+ >>> dynamic_shape_input = self.convert_to_dynamic_shape(input)
633
+ >>> self.error_on_dynamic_shape_input(dynamic_shape_input)
634
+ >>>
635
+ >>> ms.set_context(mode=ms.GRAPH_MODE, device_target="GPU")
636
+ >>> input = Tensor(np.array([0])
637
+ >>> net = TestDynamicShapeReshapeNet()
638
+ >>> output = net(input, new_shape)
639
+ ValueError: Input is dynamically shaped.
640
+ """
641
+
642
+ @prim_attr_register
643
+ def __init__(self):
644
+ self.init_prim_io_names(inputs=["input"], outputs=["output"])
645
+
646
+ def infer_shape(self, input_shape):
647
+ shape = list(input_shape)
648
+
649
+ for dim in shape:
650
+ if dim == -1:
651
+ raise ValueError("Input is dynamically shaped.")
652
+
653
+ return input_shape
654
+
655
+ def infer_type(self, input_dtype):
656
+ """Infer the dtype of input for ErrorOnDynamicShapeInput."""
657
+ validator.check_subclass("input_dtype", input_dtype, mstype.tensor_type, self.name)
658
+ return input_dtype
659
+
660
+ def infer_value(self, input_tensor):
661
+ return input_tensor
662
+
663
+
664
+ class SequenceMask(PrimitiveWithCheck):
665
+ """
666
+ Returns a mask tensor representing the first N positions of each cell.
667
+
668
+ If lengths has shape [d_1, d_2, ..., d_n], then the resulting tensor mask has type and shape
669
+ [d_1, d_2, ..., d_n, maxlen], with mask[i_1, i_2, ..., i_n, j] = (j < lengths[i_1, i_2, ..., i_n])
670
+
671
+ Inputs:
672
+ - **lengths** (Tensor) - Tensor to calculate the mask for. All values in this tensor should be
673
+ less than or equal to `maxlen`. Values greater than `maxlen` will be treated as `maxlen`.
674
+ Must be type int32 or int64.
675
+
676
+ - **maxlen** (int) - size of the last dimension of returned tensor. Must be positive and same
677
+ type as elements in `lengths`.
678
+
679
+ Outputs:
680
+ One mask tensor of shape lengths.shape + (maxlen,).
681
+
682
+ Supported Platforms:
683
+ ``GPU`` ``CPU``
684
+
685
+ Examples:
686
+ >>> from mindspore import ops
687
+ >>> import numpy as np
688
+ >>> x = Tensor(np.array([[1, 3], [2, 0]]))
689
+ >>> sequence_mask = ops.SequenceMask()
690
+ >>> output = sequence_mask(x, 3)
691
+ >>> print(output)
692
+ [[[True False False]
693
+ [True True True]]
694
+ [[True True False]
695
+ [False False False]]]
696
+ """
697
+
698
+ @prim_attr_register
699
+ def __init__(self):
700
+ self.init_prim_io_names(inputs=["lengths", "maxlen"], outputs=["mask"])
701
+
702
+ def check_shape(self, lengths_shape, maxlen_shape):
703
+ validator.check("lengths_shape", len(lengths_shape), "", 0, validator.GT, self.name)
704
+ validator.check("maxlen_shape", len(maxlen_shape), "", 0, validator.EQ, self.name)
705
+
706
+ def check_dtype(self, lengths_dtype, maxlen_dtype):
707
+ validator.check_subclass("lengths_dtype", lengths_dtype, mstype.tensor_type, self.name)
708
+ validator.check_subclass("maxlen", maxlen_dtype, mstype.number, self.name)
709
+
710
+
711
+ class SyncBatchNorm(Primitive):
712
+ r"""
713
+ Sync Batch Normalization for input data and updated parameters.
714
+
715
+ Sync Batch Normalization is cross device synchronized Batch Normalization. Batch Normalization is
716
+ widely used in convolutional neural networks. This operation applies Batch Normalization over input
717
+ to avoid internal covariate shift as described in the paper `Batch Normalization: Accelerating
718
+ Deep Network Training by Reducing Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`_.
719
+ It rescales and recenters the features using a mini-batch of data and the learned parameters which
720
+ can be described in the following formula,
721
+
722
+ .. math::
723
+ y = \frac{x - mean}{\sqrt{variance + \epsilon}} * \gamma + \beta
724
+
725
+ where :math:`\gamma` is scale, :math:`\beta` is bias, :math:`\epsilon` is epsilon.
726
+
727
+ Args:
728
+ epsilon (float): A small value added for numerical stability. Default: 1e-5.
729
+ momentum (float): The hyper parameter to compute moving average for running_mean and running_var
730
+ (e.g. :math:`new\_running\_mean = (1 - momentum) * running\_mean + momentum * current\_mean`).
731
+ Momentum value must be [0, 1]. Default: 0.1.
732
+ group (str): The communication group to work on. Default: "sync_bn_group0".
733
+ device_num (int): The number of devices in each group. Default: 2.
734
+
735
+ Inputs:
736
+ - **input_x** (Tensor) - Tensor of shape :math:`(N, C)`, with float16 or float32 data type.
737
+ - **scale** (Tensor) - Tensor of shape :math:`(C,)`, with float16 or float32 data type.
738
+ - **bias** (Tensor) - Tensor of shape :math:`(C,)`, has the same data type with `scale`.
739
+ - **mean** (Tensor) - Tensor of shape :math:`(C,)`, with float16 or float32 data type.
740
+ - **variance** (Tensor) - Tensor of shape :math:`(C,)`, has the same data type with `mean`.
741
+
742
+ Outputs:
743
+ Tuple of 5 Tensor, the normalized inputs and the updated parameters.
744
+
745
+ - **output_x** (Tensor) - The same type and shape as the input_x. The shape is :math:`(N, C)`.
746
+ - **updated_scale** (Tensor) - Tensor of shape :math:`(C,)`.
747
+ - **updated_bias** (Tensor) - Tensor of shape :math:`(C,)`.
748
+ - **updated_moving_mean** (Tensor) - Tensor of shape :math:`(C,)`.
749
+ - **updated_moving_variance** (Tensor) - Tensor of shape :math:`(C,)`.
750
+
751
+ Supported Platforms:
752
+ ``Ascend``
753
+
754
+ Examples:
755
+ >>> # This example should be run with multiple processes.
756
+ >>> # Please refer to nn.SyncBatchNorm for direct use.
757
+ >>> input_x = Tensor(np.ones([2, 2]), mindspore.float32)
758
+ >>> scale = Tensor(np.ones([2]), mindspore.float32)
759
+ >>> bias = Tensor(np.ones([2]), mindspore.float32)
760
+ >>> mean = Tensor(np.ones([2]), mindspore.float32)
761
+ >>> variance = Tensor(np.ones([2]), mindspore.float32)
762
+ >>> sync_batch_norm = ops._inner_ops.SyncBatchNorm()
763
+ >>> output = sync_batch_norm(input_x, scale, bias, mean, variance)
764
+ >>> print(output)
765
+ (Tensor(shape=[2, 2], dtype=Float32, value=
766
+ [[ 1.00000000e+00, 1.00000000e+00],
767
+ [ 1.00000000e+00, 1.00000000e+00]]), Tensor(shape=[2], dtype=Float32, value=
768
+ [ 1.00000000e+00, 1.00000000e+00]), Tensor(shape=[2], dtype=Float32, value=
769
+ [ 1.00000000e+00, 1.00000000e+00]), Tensor(shape=[2], dtype=Float32, value=
770
+ [ 1.00000000e+00, 1.00000000e+00]), Tensor(shape=[2], dtype=Float32, value=
771
+ [ 1.00000000e+00, 1.00000000e+00]))
772
+ """
773
+
774
+ @prim_attr_register
775
+ def __init__(self, epsilon=1e-5, momentum=0.1, group="sync_bn_group0", device_num=2):
776
+ validator.check_float_range(epsilon, 0, 1, validator.INC_RIGHT, 'epsilon', self.name)
777
+ validator.check_float_range(momentum, 0, 1, validator.INC_BOTH, 'momentum', self.name)
778
+ validator.check_isinstance("group", group, str)
779
+ validator.check_int(device_num, 2, validator.GE, "device_num", self.name)
780
+ self.init_prim_io_names(inputs=['x', 'scale', 'offset', 'mean', 'variance'],
781
+ outputs=['y', 'batch_mean', 'batch_variance', 'reserve_space_1', 'reserve_space_2'])
782
+ self.add_prim_attr('side_effect_mem', True)
783
+ self.add_prim_attr('format', 'NCHW')
784
+
785
+
786
+ class Centralization(PrimitiveWithInfer):
787
+ """
788
+ Computes centralization. y = x - mean(x, axis).
789
+
790
+ Note:
791
+ The dimension index starts at 0 and must be in the range `[-input.ndim, input.ndim)`.
792
+
793
+ Inputs:
794
+ - **input_x** (Tensor) - The input tensor. The data type mast be float16 or float32.
795
+ - **axis** (Union[int, Tuple(int), List(int)]) - The dimensions to reduce. Default: (), reduce all dimensions.
796
+ Only constant value is allowed. Must be in the range [-rank(input_x), rank(input_x)).
797
+
798
+ Outputs:
799
+ Tensor, has the same shape and dtype as the `input_x`.
800
+
801
+ Raises:
802
+ TypeError: If `axis` is not one of the following types: int, list, tuple, NoneType.
803
+ TypeError: If `axis` has non-Int elements.
804
+
805
+ Supported Platforms:
806
+ ``Ascend``
807
+
808
+ Examples:
809
+ >>> mindspore.set_seed(1)
810
+ >>> input_x = Tensor(np.random.randn(2, 2).astype(np.float32))
811
+ >>> centralization = ops.Centralization()
812
+ >>> output = centralization(input_x, -1)
813
+ >>> print(output)
814
+ [[ 1.1180509 -1.1180508]
815
+ [ 0.2723984 -0.2723984]]
816
+ """
817
+
818
+ __mindspore_signature__ = (
819
+ sig.make_sig('input_x'),
820
+ sig.make_sig('axis', default=())
821
+ )
822
+
823
+ @prim_attr_register
824
+ def __init__(self):
825
+ """Initialize Centralization"""
826
+ self.init_prim_io_names(inputs=['input_x', 'axis'], outputs=['output'])
827
+
828
+ def __infer__(self, input_x, axis):
829
+ x_shape = list(input_x['shape'])
830
+ x_dtype = input_x['dtype']
831
+ axis_v = axis['value']
832
+ rank = len(x_shape)
833
+
834
+ args = {'input_x': input_x['dtype']}
835
+ validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name)
836
+
837
+ if axis_v is None:
838
+ raise ValueError(f"For {self.name}, axis must be const.")
839
+ validator.check_value_type('axis', axis_v, [int, list, tuple], self.name)
840
+
841
+ if isinstance(axis_v, int):
842
+ validator.check_int_range(axis_v, -rank, rank, validator.INC_LEFT, 'axis', self.name)
843
+ elif axis:
844
+ for index, one_axis in enumerate(axis_v):
845
+ validator.check_value_type('axis[%d]' % index, one_axis, [int], self.name)
846
+
847
+ out = {'shape': x_shape,
848
+ 'dtype': x_dtype,
849
+ 'value': None}
850
+ return out
851
+
852
+
853
+ class StackInit(PrimitiveWithInfer):
854
+ """
855
+ Create a stack that produces tensors in first-in last-out order.
856
+
857
+ After `StackInit`, a tensor can be pushed onto the stack using `StackPush`, and popped
858
+ at the top of the stack using `StackPop`. Finally, the stack should be destroyed with `StackDestroy`.
859
+
860
+ Args:
861
+ index (int): The index of the stack. Default: 1.
862
+
863
+ Supported Platforms:
864
+ ``Ascend``
865
+
866
+ Examples:
867
+ >>> x = Tensor(np.array([[1, 3], [2, 0]]))
868
+ >>> index = 0
869
+ >>> stack = ops.StackInit(index)
870
+ >>> push = ops.StackPush(index)
871
+ >>> pop = ops.StackPop(index, x.shape, x.dtype)
872
+ >>> destroy = ops.StackDestroy(index)
873
+ >>> stack()
874
+ >>> push(x)
875
+ >>> y = pop()
876
+ >>> destroy()
877
+ >>> print(y)
878
+ [[1 3]
879
+ [2 0]]
880
+ """
881
+
882
+ @prim_attr_register
883
+ def __init__(self, index=1):
884
+ """StackInit"""
885
+ validator.check_value_type("index", index, [int], self.name)
886
+
887
+
888
+ class StackPush(PrimitiveWithInfer):
889
+ """
890
+ Push a tensor onto the stack.
891
+
892
+ Before `StackPush`, the stack should be created using `StackInit`.
893
+ Please refer to the usage in source code of `StackInit`.
894
+
895
+ Args:
896
+ index (int): The index of the stack. Default: 1.
897
+
898
+ Inputs:
899
+ - **input** (Tensor) - A tensor to be pushed onto the stack.
900
+
901
+ Supported Platforms:
902
+ ``Ascend``
903
+
904
+ Examples:
905
+ Please refer to the usage of `StackInit`.
906
+ """
907
+
908
+ @prim_attr_register
909
+ def __init__(self, index=1):
910
+ """StackPush"""
911
+ validator.check_value_type("index", index, [int], self.name)
912
+ self.init_prim_io_names(inputs=['input'], outputs=[])
913
+
914
+
915
+ class StackPop(PrimitiveWithInfer):
916
+ """
917
+ Pop the tensor at the top of the stack.
918
+
919
+ Before `StackPop`, the stack should be created using `StackInit`.
920
+ Please refer to the usage in source code of `StackInit`.
921
+
922
+ Args:
923
+ index (int): The index of the stack. Default: 1.
924
+ shape (tuple): The shape of the tensor at the top of the stack. Default: (1,).
925
+ dtype (mindspore.dtype): The type of the tensor at the top of the stack. Default: mindspore.float32.
926
+
927
+ Outputs:
928
+ - **output** (Tensor) - The tensor at the top of the stack.
929
+
930
+ Supported Platforms:
931
+ ``Ascend``
932
+
933
+ Examples:
934
+ Please refer to the usage of `StackInit`.
935
+ """
936
+
937
+ @prim_attr_register
938
+ def __init__(self, index=1, shape=(1,), dtype=mstype.float32):
939
+ """StackPop"""
940
+ validator.check_value_type("index", index, [int], self.name)
941
+
942
+ validator.check_value_type('shape type', shape, [list, tuple], self.name)
943
+ validator.check_int(len(np.array(shape).shape), 1, validator.EQ, "dim of shape", self.name)
944
+ for elem in shape:
945
+ validator.check_int(elem, 1, validator.GE, 'shape element', self.name)
946
+ validator.check_value_type('type of shape element', elem, [int], self.name)
947
+
948
+ validator.check_type_name("dtype", dtype, (mstype.bool_,) + mstype.number_type, self.name)
949
+ self.shape = shape
950
+ self.dtype = dtype
951
+
952
+ self.init_prim_io_names(inputs=[], outputs=['output'])
953
+
954
+ def __infer__(self):
955
+ return {'shape': (list(self.shape)),
956
+ 'dtype': (self.dtype),
957
+ 'value': None}
958
+
959
+
960
+ class StackDestroy(PrimitiveWithInfer):
961
+ """
962
+ Destroy the stack.
963
+
964
+ Before `StackDestroy`, the stack should be created using `StackInit`.
965
+ Please refer to the usage in source code of `StackInit`.
966
+
967
+ Args:
968
+ index (int): The index of the stack. Default: 1.
969
+
970
+ Supported Platforms:
971
+ ``Ascend``
972
+
973
+ Examples:
974
+ Please refer to the usage of `StackInit`.
975
+ """
976
+
977
+ @prim_attr_register
978
+ def __init__(self, index=1):
979
+ """StackDestroy"""
980
+ validator.check_value_type("index", index, [int], self.name)
981
+
982
+
983
+ class DynamicStitch(PrimitiveWithCheck):
984
+ r"""
985
+ Interleave the values from the data tensors into a single tensor.
986
+
987
+ Inputs:
988
+ - **indices** (Union[tuple, list]) - A Tuple or list of Tensor objects with the same shape and type.
989
+ - **data** (Union[tuple, list]) - A Tuple or list of Tensor objects with the same shape and type.
990
+
991
+ Outputs:
992
+ Tensor. A stacked Tensor with the same type as `data`.
993
+
994
+ Raises:
995
+ TypeError: If the data types of elements in `data` or `indices` are not the same.
996
+ ValueError: If the length of `data` or `indices` is not greater than 1.
997
+
998
+ Supported Platforms:
999
+ ``Ascend``
1000
+
1001
+ Examples:
1002
+ >>> x1 = Tensor([6], mstype.int32)
1003
+ >>> x2 = Tensor(np.array([4, 1]), mstype.int32)
1004
+ >>> x3 = Tensor(np.array([[5, 2], [0, 3]]), mstype.int32)
1005
+ >>> y1 = Tensor(np.array([[6, 1]]), mstype.int32)
1006
+ >>> y2 = Tensor(np.array([[41, 42], [11, 12]]), mstype.int32)
1007
+ >>> y3 = Tensor(np.array([[[51, 52], [21, 22]], [[1, 2], [31, 32]]]), mstype.int32)
1008
+ >>> stitch = ops.DynamicStitch()
1009
+ >>> output = stitch([x1, x2, x3], [y1, y2, y3])
1010
+ >>> print(output)
1011
+ [[ 1 2]
1012
+ [11 12]
1013
+ [21 22]
1014
+ [31 32]
1015
+ [41 42]
1016
+ [51 52]
1017
+ [61 62]]
1018
+ """
1019
+
1020
+ @prim_attr_register
1021
+ def __init__(self):
1022
+ """Initialize DynamicStitch"""
1023
+
1024
+ def check_shape(self, indices_shape, data_shape):
1025
+ validator.check_value_type("shape of indices", indices_shape, [tuple, list], self.name)
1026
+ validator.check_int(len(indices_shape), 1, validator.GE, "len of indices_shape", self.name)
1027
+ indices_dim0 = len(indices_shape[0])
1028
+ indices_num = len(indices_shape)
1029
+
1030
+ validator.check_value_type("shape of data", data_shape, [tuple, list], self.name)
1031
+ validator.check_int(len(data_shape), 1, validator.GE, "len of data_shape", self.name)
1032
+ data_dim0 = len(data_shape[0])
1033
+ data_num = len(indices_shape)
1034
+
1035
+ validator.check("size of indices", indices_num, 'size of data', data_num, validator.EQ, self.name)
1036
+
1037
+ # shape of `data` must start with shape of `indices`
1038
+ for i in range(0, indices_num):
1039
+ indices_dim = len(indices_shape[i])
1040
+ data_dim = len(data_shape[i])
1041
+ validator.check(f"dim of indices[{i}]", indices_dim, f"dim of data[{i}]", data_dim, validator.LE, self.name)
1042
+ if data_shape[i][:indices_dim] != data_shape[i][:indices_dim]:
1043
+ raise ValueError(f"data[{i}].shape: {data_shape} does not start with indices[{i}].shape: {data_shape}")
1044
+
1045
+ # the last-(data_dim0-indices_dim0)-dim of data shape must end with same shape.
1046
+ base_extra = data_dim0 - indices_dim0
1047
+ for i in range(0, data_num):
1048
+ indices_dim = len(indices_shape[i])
1049
+ data_dim = len(data_shape[i])
1050
+ extra = data_dim - indices_dim
1051
+ validator.check(f"extra dim of data[{i}]", extra,
1052
+ f"extra dim of data[0]", base_extra, validator.EQ, self.name)
1053
+ validator.check(f"data[0].shape[{indices_dim0}:]", data_shape[0][indices_dim0:],
1054
+ f"data[{i}].shape[{len(indices_shape[i])}:]",
1055
+ data_shape[i][indices_dim:], validator.EQ, self.name)
1056
+
1057
+ out_shape = [-1] + data_shape[0][indices_dim0:]
1058
+ return out_shape
1059
+
1060
+ def check_dtype(self, indices_type, data_type):
1061
+ validator.check_subclass("indices[0]", indices_type[0], mstype.tensor_type, self.name)
1062
+ validator.check_subclass("data[0]", data_type[0], mstype.tensor_type, self.name)
1063
+ indices_num = len(indices_type)
1064
+ for i in range(0, indices_num):
1065
+ validator.check_tensor_dtype_valid(f'indices[{i}]', indices_type[i], mstype.int32, self.name)
1066
+ validator.check_tensor_dtype_valid(f'data[{i}]', data_type[i],
1067
+ mstype.number_type + (mstype.bool_,), self.name)
1068
+ validator.check(f"type of data[{i}]", data_type[i], f"type of data[0]",
1069
+ data_type[0], validator.EQ, self.name)
1070
+ return data_type[0]
1071
+
1072
+
1073
+ class DynamicBroadcastGradientArgs(Primitive):
1074
+ """
1075
+ Broadcast the two input shapes, return the dimensions that each need to be broadcast.
1076
+
1077
+ Input shape `s0` and shape `s1` can be broadcast to a common shape if for each dimension pair they are either equal
1078
+ or input is one or the target dimension is -1. In case of -1 in target shape, it will be replaced by the input
1079
+ shape's value in that dimension.
1080
+
1081
+ Inputs:
1082
+ - **s0** (Tensor) - A `1-D` tensor. The data type should be one of the following types: int32, int64,
1083
+ uint32, uint64.
1084
+ - **s1** (Tensor) - A `1-D` tensor with the same type as `s0`.
1085
+
1086
+ Outputs:
1087
+ Tuple(Tensor), tuple of 2 tensors, r0 and r1. The first one is the index tensor and the other one is the mask
1088
+ tensor.
1089
+
1090
+ - **r0** (Tensor) - The output shape is 1-D with the same type as s0.
1091
+ - **r1** (Tensor) - The output shape is 1-D with the same type as s0.
1092
+
1093
+ Raises:
1094
+ ValueError: if the `s0` and `s1` are incompatible, or if a - 1 in the target shape is in an invalid
1095
+ location.
1096
+
1097
+ Supported Platforms:
1098
+ ``Ascend``
1099
+
1100
+ Examples:
1101
+ >>> shape0 = (4, 2, 1)
1102
+ >>> shape1 = (2, 7)
1103
+ >>> from mindspore.ops.operations import _inner_ops
1104
+ >>> args = _inner_ops.DynamicBroadcastGradientArgs()
1105
+ >>> r0, r1 = args(Tensor(shape0), Tensor(shape1))
1106
+ >>> print(r0, r1)
1107
+ [2], [0]
1108
+ """
1109
+
1110
+ @prim_attr_register
1111
+ def __init__(self):
1112
+ """Init BroadcastGradientArgs"""
1113
+
1114
+
1115
+ class DSDMatmul(PrimitiveWithInfer):
1116
+ """
1117
+ The definition of the CusSquare primitive.
1118
+ """
1119
+
1120
+ @prim_attr_register
1121
+ def __init__(self):
1122
+ self.init_prim_io_names(inputs=['input_w1', 'input_w2', 'input_v'], outputs=['output_y'])
1123
+
1124
+ def infer_shape(self, input_w1_shape, input_w2_shape, input_v_shape):
1125
+ batch_size = input_w1_shape[0]
1126
+ head = input_w1_shape[1]
1127
+ v_embedding = input_v_shape[1] * 16 // head
1128
+ seq_len = input_v_shape[0] * 16 // batch_size
1129
+ return (batch_size, head, v_embedding // 16, seq_len // 16, 16, 16)
1130
+
1131
+ def infer_dtype(self, data_dtype1, data_dtype2, data_dtype3):
1132
+ return data_dtype1
1133
+
1134
+
1135
+ class MatmulDDS(PrimitiveWithInfer):
1136
+ """MatmulDDS definition"""
1137
+
1138
+ @prim_attr_register
1139
+ def __init__(self, bs, heads):
1140
+ """init MatmulDDS"""
1141
+ self.init_prim_io_names(inputs=['q', 'k', 'local_mask', 'global_mask'],
1142
+ outputs=['local_prob', 'global_prob'])
1143
+
1144
+ self.heads = heads
1145
+
1146
+ def infer_shape(self, q, k, local_mask, global_mask):
1147
+ seq_len = local_mask[0] * local_mask[-1]
1148
+ bs = q[1] * q[2] // seq_len
1149
+ global_size = seq_len // 4
1150
+ size_per_head = q[0] * q[-1] // self.heads
1151
+ heads = q[0] * q[-1] // size_per_head
1152
+ block_size = local_mask[1] * local_mask[2] // bs
1153
+ block_num = seq_len // block_size
1154
+ l_size = (bs, heads, block_num, block_size // 16, block_size // 16, 16, 16)
1155
+ g_size = (bs, heads, block_num, global_size // 16, block_size // 16, 16, 16)
1156
+
1157
+ return l_size, g_size
1158
+
1159
+ def infer_dtype(self, q, k, local_mask, global_mask):
1160
+ return q, q
1161
+
1162
+
1163
+ class DSDGrad(PrimitiveWithInfer):
1164
+ """
1165
+ The definition of the CusSquare primitive.
1166
+ """
1167
+
1168
+ @prim_attr_register
1169
+ def __init__(self):
1170
+ self.init_prim_io_names(inputs=['w1_gm', 'w2_gm', 'v_gm', 'a_gm', 'd_a_gm'],
1171
+ outputs=['d_w1_gm', 'd_w2_gm', 'd_v_gm'])
1172
+
1173
+ def infer_shape(self, input_w1_shape, input_w2_shape, input_v_shape, input_a_shape, input_da_shape):
1174
+ return input_w1_shape, input_w2_shape, input_v_shape
1175
+
1176
+ def infer_dtype(self, data_dtype1, data_dtype2, data_dtype3, data_dtype4, data_dtype5):
1177
+ return data_dtype1, data_dtype1, data_dtype1
1178
+
1179
+
1180
+ class MatmulDDSGrad(PrimitiveWithInfer):
1181
+ """MatmulDDS definition"""
1182
+
1183
+ @prim_attr_register
1184
+ def __init__(self):
1185
+ """init MatmulDDS"""
1186
+ self.init_prim_io_names(inputs=['q', 'k', 'local_prob', 'global_prob', 'local_prob_grad', 'global_prob_grad'],
1187
+ outputs=['dq', 'dk'])
1188
+
1189
+ def infer_shape(self, q, k, local_prob, global_prob, local_prob_grad, global_prob_grad):
1190
+ k_size = (q[1], q[0], q[3], q[2])
1191
+
1192
+ return q, k_size
1193
+
1194
+ def infer_dtype(self, q, k, local_prob, global_prob, local_prob_grad, global_prob_grad):
1195
+ return q, k
1196
+
1197
+
1198
+ class NonZeroWithValue(Primitive):
1199
+ """
1200
+ Returns the value of elements that are non-zero (in row-major order - by dimension).
1201
+
1202
+ Inputs:
1203
+ - **x** (Tensor), input array of rank >= 2.
1204
+
1205
+ Outputs:
1206
+ elements that are non-zero.
1207
+
1208
+ Supported Platforms:
1209
+ ``Ascend``
1210
+
1211
+ Examples:
1212
+ >>> op = NonZeroWithValue()
1213
+ >>> data = Tensor(np.array([[1, 0, 0], [0, 0, 1]]), mindspore.float32)
1214
+ >>> value, index, count = op(data)
1215
+ >>> print(value)
1216
+ [1.0, 1.0]
1217
+ """
1218
+
1219
+ @prim_attr_register
1220
+ def __init__(self, transpose=False):
1221
+ """Initialize NonZeroWithValue"""
1222
+ validator.check_value_type("transpose", transpose, [bool], self.name)
1223
+ self.init_prim_io_names(inputs=['x'], outputs=['value', 'index', 'count'])
1224
+
1225
+
1226
+ class NonZeroWithValueShape(Primitive):
1227
+ """
1228
+ Returns the value and index of elements that are non-zero (in row-major order - by dimension).
1229
+
1230
+ Inputs:
1231
+ - **x** (Tensor), input array of rank >= 2.
1232
+
1233
+ Outputs:
1234
+ elements that are non-zero.
1235
+
1236
+ Supported Platforms:
1237
+ ``Ascend``
1238
+
1239
+ Examples:
1240
+ >>> non_zero = NonZeroWithValue()
1241
+ >>> op = NonZeroWithValueShape()
1242
+ >>> data = Tensor(np.array([[1, 0, 0], [0, 0, 1]]), mindspore.float32)
1243
+ >>> value, index, count = non_zero(data)
1244
+ >>> out_value, out_index = op(value, index, count)
1245
+ >>> print(out_index)
1246
+ [[0, 1], [0, 2]]
1247
+ """
1248
+
1249
+ @prim_attr_register
1250
+ def __init__(self):
1251
+ """Initialize NonZeroWithValueShape"""
1252
+ self.init_prim_io_names(inputs=['value', 'index', 'count'], outputs=['out_value', 'out_index'])
1253
+
1254
+
1255
+ class DecodeImage(PrimitiveWithInfer):
1256
+ """
1257
+ Returns image data that parse from string Tensor.
1258
+
1259
+ Inputs:
1260
+ - **x** (Tensor), a Tensor of type string. 0-D. The jPEG, GIF, PNG, BMP-encoded image.
1261
+
1262
+ Outputs:
1263
+ A Tensor of type uint8, uint16, float.
1264
+
1265
+ Supported Platforms:
1266
+ ``Ascend``
1267
+
1268
+ Examples:
1269
+ """
1270
+
1271
+ @prim_attr_register
1272
+ def __init__(self, channels=0, dtype=mstype.uint8, expand_animations=False, _op_max_shape="8192,8192,3",
1273
+ _op_max_size=[8000000]):
1274
+ self.init_prim_io_names(inputs=["contents"], outputs=["image"])
1275
+ self.res_type = dtype
1276
+
1277
+ def infer_shape(self, x):
1278
+ return (-1, -1, 3)
1279
+
1280
+ def infer_dtype(self, x):
1281
+ return self.res_type
1282
+
1283
+
1284
+ class SliceGetItem(Primitive):
1285
+ """
1286
+ using SliceGetItem to get slice's attribute of 'start' 'stop' 'step'
1287
+ """
1288
+
1289
+ @prim_attr_register
1290
+ def __init__(self):
1291
+ """Initialize ScatterElements"""
1292
+ self.init_prim_io_names(inputs=['slice', 'attr'], outputs=['slice_item'])
1293
+
1294
+ def __call__(self, slice_value, value):
1295
+ if not isinstance(slice_value, slice):
1296
+ raise TypeError(
1297
+ "Primitive[SliceGetItem] only support to get a slice type element but got {}".format(slice_value))
1298
+ if value == "start":
1299
+ if hasattr(slice_value.start, "ndim") and slice_value.start.ndim == 1:
1300
+ return slice_value.start.item()
1301
+ return slice_value.start
1302
+ if value == "stop":
1303
+ if hasattr(slice_value.stop, "ndim") and slice_value.stop.ndim == 1:
1304
+ return slice_value.stop.item()
1305
+ return slice_value.stop
1306
+ if value == "step":
1307
+ if hasattr(slice_value.step, "ndim") and slice_value.step.ndim == 1:
1308
+ return slice_value.step.item()
1309
+ return slice_value.step
1310
+ raise AttributeError("\'slice\' object has no attribute {}".format(value))
1311
+
1312
+
1313
+ class DynamicBroadcastTo(Primitive):
1314
+ """
1315
+ Broadcasts input tensor to a given shape.
1316
+
1317
+ Inputs:
1318
+ - **input_x** (Tensor) - The input tensor. The data type should be one of the following types:
1319
+ float16, float32, int32, int8, uint8.
1320
+ The shape is :math:`(N,*)` where :math:`*` means any number of additional dimensions.
1321
+ - **shape** (Tensor): The target shape to broadcast.
1322
+
1323
+ Outputs:
1324
+ Tensor, with the given `shape` and the same data type as `input_x`.
1325
+
1326
+ Raises:
1327
+ ValueError: if the target and input shapes are incompatible.
1328
+
1329
+ Supported Platforms:
1330
+ ``Ascend`` ``GPU`` ``CPU``
1331
+ """
1332
+
1333
+ @prim_attr_register
1334
+ def __init__(self):
1335
+ """Initialize DynamicBroadcastTo"""
1336
+ self.init_prim_io_names(inputs=['x', 'shape'], outputs=['y'])
1337
+
1338
+
1339
+ class DynamicResizeNearestNeighbor(Primitive):
1340
+ r"""
1341
+ Resizes the input tensor by using the nearest neighbor algorithm.
1342
+
1343
+ Resizes the input tensor to a given size by using the nearest neighbor algorithm. The nearest
1344
+ neighbor algorithm selects the value of the nearest point and does not consider the
1345
+ values of neighboring points at all, yielding a piecewise-constant interpolant.
1346
+
1347
+ Note:
1348
+ The operator supports dynamic shape.
1349
+
1350
+ Args:
1351
+ align_corners (bool): Whether the centers of the 4 corner pixels of the input
1352
+ and output tensors are aligned. Default: ``False``.
1353
+
1354
+ Inputs:
1355
+ - **input_x** (Tensor) - The input tensor. The shape of the tensor is :math:`(N, C, H, W)`.
1356
+ - **size** (Union[tuple, list]): The target size. The dimension of size must be 2.
1357
+
1358
+ Outputs:
1359
+ Tensor, the shape of the output tensor is :math:`(N, C, NEW\_H, NEW\_W)`.
1360
+ The data type is the same as the `input_x`.
1361
+ """
1362
+
1363
+ @prim_attr_register
1364
+ def __init__(self, align_corners=False):
1365
+ """Initialize ResizeNearestNeighbor"""
1366
+ validator.check_value_type("align_corners", align_corners, [bool], self.name)
1367
+ self.init_prim_io_names(inputs=['image_in'], outputs=['image_out'])
1368
+
1369
+
1370
+ class PsROIPooling(PrimitiveWithInfer):
1371
+ r"""
1372
+ Position Sensitive ROI-Pooling
1373
+ Inputs:
1374
+ - feature(Tensor)
1375
+ - rois(Tensor)
1376
+
1377
+ - **features** (Tensor) - The input features, whose shape must be :math:`(N, C, H, W)`.
1378
+ - **rois** (Tensor) - The shape is :math:`(rois\_n, 5)`. With data type of float16 or float32.
1379
+ `rois_n` represents the number of RoI. The size of the second dimension must be `5` and the `5` colunms
1380
+ are :math:`(image\_index, top\_left\_x, top\_left\_y, bottom\_right\_x, bottom\_right\_y)`.
1381
+ `image_index` represents the index of image. `top_left_x` and `top_left_y` represent the `x, y`
1382
+ coordinates of the top left corner of corresponding RoI, respectively. `bottom_right_x` and `bottom_right_y`
1383
+ represent the `x, y` coordinates of the bottom right corner of corresponding RoI, respectively.
1384
+
1385
+ Outputs:
1386
+ - out shape(rois_num, out_channel, pool_height, pool_width), the result after pooling.
1387
+ - channel_map shape(rois_num, out_channel, pool_height, pool_width), use for back forward to compute grad
1388
+ Supported Platforms:
1389
+ ``GPU``
1390
+
1391
+ Examples:
1392
+ >>> import mindspore
1393
+ >>> import numpy as np
1394
+ >>> from mindspore import Tensor
1395
+ >>> from mindspore.ops.operations import _inner_ops as inner
1396
+ >>> features = np.random.randn(4, 21 * 7 * 7, 80, 48)
1397
+ >>> features = Tensor.from_numpy(features).astype(mindspore.float32)
1398
+ >>> rois = Tensor.from_numpy(
1399
+ >>> np.array([
1400
+ >>> [0.0000, 150.3563, 200.1320, 579.3563, 602.3452],
1401
+ >>> [1.0000, 657.1263, 302.8564, 762.4214, 567.9854],
1402
+ >>> [2.0000, 321.3122, 232.2410, 679.0281, 587.6346],
1403
+ >>> [3.0000, 664.1630, 387.4919, 778.7322, 562.7321],
1404
+ >>> ])).astype(mindspore.float32)
1405
+ >>> psRoIPooling = inner.PsROIPooling(pooled_height=7, pooled_width=7, num_rois=4,
1406
+ >>> spatial_scale=1.0/16, out_dim=21,
1407
+ >>> group_size=7)
1408
+ >>> out, channel_map = psRoIPooling(features, rois)
1409
+ >>> print(out.shape)
1410
+ [4, 21, 7, 7]
1411
+ >>> print(channel_map.shape)
1412
+ [4, 21, 7, 7]
1413
+ """
1414
+
1415
+ @prim_attr_register
1416
+ def __init__(self, pooled_height, pooled_width, num_rois, spatial_scale, out_dim, group_size):
1417
+ """Initialize PsROIPooling"""
1418
+ validator.check_value_type("pooled_height", pooled_height, [int], self.name)
1419
+ validator.check_value_type("pooled_width", pooled_width, [int], self.name)
1420
+ validator.check_value_type("num_rois", pooled_width, [int], self.name)
1421
+ validator.check_value_type("spatial_scale", spatial_scale, [float], self.name)
1422
+ validator.check_value_type("out_dim", out_dim, [int], self.name)
1423
+ validator.check_value_type("group_size", group_size, [int], self.name)
1424
+ self.pooled_height = pooled_height
1425
+ self.pooled_width = pooled_width
1426
+ self.num_rois = num_rois
1427
+ self.spatial_scale = spatial_scale
1428
+ self.out_dim = out_dim
1429
+ self.group_size = group_size
1430
+
1431
+ def infer_shape(self, inputs_shape, rois_shape):
1432
+ output_shape = [self.num_rois, self.out_dim, self.pooled_height, self.pooled_width]
1433
+ output_map_shape = [self.num_rois, self.out_dim, self.pooled_height, self.pooled_width]
1434
+ return output_shape, output_map_shape
1435
+
1436
+ def infer_dtype(self, inputs_type, rois_type):
1437
+ map_type = mstype.TensorType(mstype.int32)
1438
+ return inputs_type, map_type
1439
+
1440
+
1441
+ class ParallelResizeBilinear(PrimitiveWithInfer):
1442
+ """ParallelResizeBilinear ops"""
1443
+
1444
+ @prim_attr_register
1445
+ def __init__(self, ori_image_size, split_size, src_start_w, dst_start_w, align_corners):
1446
+ """Initialize ParallelResizeBilinear."""
1447
+ validator.check_value_type("ori_image_size", ori_image_size, [list, tuple], self.name)
1448
+ validator.check_value_type("split_size", split_size, [list, tuple], self.name)
1449
+ validator.check_int(len(split_size), 2, validator.EQ, "len of split_size", self.name)
1450
+ validator.check_value_type("src_start_w", src_start_w, [int], self.name)
1451
+ validator.check_value_type("dst_start_w", dst_start_w, [int], self.name)
1452
+ validator.check_value_type("align_corners", align_corners, [bool], self.name)
1453
+ self.ori_image_size = list(ori_image_size)
1454
+ self.split_size = list(split_size)
1455
+ self.src_start_w = src_start_w
1456
+ self.dst_start_w = dst_start_w
1457
+ self.align_corners = align_corners
1458
+ self.half_pixel_centers = False
1459
+ self.add_prim_attr('ori_image_size', self.ori_image_size)
1460
+ self.add_prim_attr('split_size', self.split_size)
1461
+ self.add_prim_attr('src_start_w', self.src_start_w)
1462
+ self.add_prim_attr('dst_start_w', self.dst_start_w)
1463
+ self.add_prim_attr('align_corners', self.align_corners)
1464
+ self.add_prim_attr('half_pixel_centers', self.half_pixel_centers)
1465
+
1466
+ def __infer__(self, x, size):
1467
+ size_val = size['value']
1468
+ x_shape = x['shape']
1469
+ x_dtype = x['dtype']
1470
+ validator.check_tensor_dtype_valid("x_dtype", x_dtype, [mstype.float16, mstype.float32], self.name)
1471
+ if size_val is None:
1472
+ raise ValueError("size must be const input")
1473
+ output_shape = [x_shape[0], x_shape[1], self.split_size[0], self.split_size[1]]
1474
+
1475
+ return {'shape': output_shape,
1476
+ 'dtype': x_dtype,
1477
+ 'value': None}
1478
+
1479
+
1480
+ class PartitionedCall(PrimitiveWithInfer):
1481
+ """
1482
+ Pass the input tensors to the subgraph and return the output tensors.
1483
+
1484
+ Inputs:
1485
+ - **inputs** (Tuple), the input tensors, which will be passed to subgraph.
1486
+
1487
+ Outputs:
1488
+ - outputs(Tuple), the output tensor returned by subgraph.
1489
+
1490
+ Supported Platforms:
1491
+ ``Ascend``
1492
+
1493
+ Examples:
1494
+ """
1495
+
1496
+ @prim_attr_register
1497
+ def __init__(self, graph, executor_type=""):
1498
+ super(PartitionedCall, self).__init__(self.__class__.__name__)
1499
+ self.add_prim_attr("executor_type", executor_type)
1500
+ self.graph = graph
1501
+
1502
+ def infer_shape(self, *inputs):
1503
+ return NotImplementedError
1504
+
1505
+ def infer_dtype(self, *inputs):
1506
+ return NotImplementedError
1507
+
1508
+
1509
+ class CellBackwardHook(PrimitiveWithInfer):
1510
+ r"""
1511
+ This operator is used to hook input gradient and output gradient of Cell object.
1512
+
1513
+ Note:
1514
+ This operator is only used in backward hook function of Cell object in pynative mode.
1515
+
1516
+ Args:
1517
+ cell_id (str): Used to identify which cell obj the hook function registered on. For example, 'nn.Add()' is a
1518
+ cell object.
1519
+
1520
+ Inputs:
1521
+ - **input** - The variable to hook.
1522
+
1523
+ Outputs:
1524
+ - **output** - Returns `input` directly. `CellBackwardHook` does not affect the forward result.
1525
+
1526
+ Supported Platforms:
1527
+ ``Ascend`` ``GPU`` ``CPU``
1528
+
1529
+ Examples:
1530
+ >>> import mindspore as ms
1531
+ >>> from mindspore import Tensor
1532
+ >>> from mindspore.ops import GradOperation
1533
+ >>> from mindspore.ops.operations import _inner_ops as inner
1534
+ >>> ms.set_context(mode=ms.PYNATIVE_MODE)
1535
+ >>> def hook_fn(grad):
1536
+ ... print(grad)
1537
+ ...
1538
+ >>> hook = inner.CellBackwardHook()
1539
+ >>> hook_fn_key = hook.register_backward_hook()
1540
+ >>> def hook_test(x, y):
1541
+ ... z = x * y
1542
+ ... z = hook(z)
1543
+ ... z = z * y
1544
+ ... return z
1545
+ ...
1546
+ >>> grad_all = GradOperation(get_all=True)
1547
+ >>> def backward(x, y):
1548
+ ... return grad_all(hook_test)(x, y)
1549
+ ...
1550
+ >>> output = backward(Tensor(1, mindspore.float32), Tensor(2, mindspore.float32))
1551
+ (Tensor(shape=[], dtype=Float32, value= 2),)
1552
+ >>> print(output)
1553
+ (Tensor(shape=[], dtype=Float32, value= 4), Tensor(shape=[], dtype=Float32, value= 4))
1554
+ >>> hook.remove_backward_hook(hook_fn_key)
1555
+ >>> output = backward(Tensor(1, mindspore.float32), Tensor(2, mindspore.float32))
1556
+ >>> print(output)
1557
+ (Tensor(shape=[], dtype=Float32, value= 4), Tensor(shape=[], dtype=Float32, value= 4))
1558
+ """
1559
+
1560
+ def __init__(self, cell_id="", cell=None, hook_dict=None):
1561
+ """Initialize CellBackwardHook"""
1562
+ super(CellBackwardHook, self).__init__(self.__class__.__name__)
1563
+ self.cell_id = cell_id
1564
+ self.cell = cell
1565
+ self.hook_dict = weakref.ref(hook_dict)
1566
+ self.add_prim_attr("cell_id", cell_id)
1567
+ self.grad_output = None
1568
+
1569
+ def __call__(self, *args):
1570
+ # If args is empty, just return.
1571
+ if not args:
1572
+ return args
1573
+ return _run_op(self, self.name, args)
1574
+
1575
+ def infer_shape(self, *inputs_shape):
1576
+ if len(inputs_shape) == 1:
1577
+ return inputs_shape[0]
1578
+ return inputs_shape
1579
+
1580
+ def infer_dtype(self, *inputs_type):
1581
+ if len(inputs_type) == 1:
1582
+ return inputs_type[0]
1583
+ return inputs_type
1584
+
1585
+ def register_backward_hook(self):
1586
+ """
1587
+ Register the backward hook function.
1588
+
1589
+ Args:
1590
+ None
1591
+
1592
+ Returns:
1593
+ None
1594
+
1595
+ Supported Platforms:
1596
+ ``Ascend`` ``GPU`` ``CPU``
1597
+ """
1598
+
1599
+ def hook_backward_grad(grad):
1600
+ if self.grad_output is None:
1601
+ self.grad_output = grad
1602
+ # Indicates the first time of call backward hook, and need to wait for the second time call
1603
+ return self.cell_id
1604
+ backward_hook_grad_input = grad
1605
+ if self.hook_dict():
1606
+ backward_hooks = self.hook_dict().values()
1607
+ for hook in backward_hooks:
1608
+ res = hook(self.cell, backward_hook_grad_input, self.grad_output)
1609
+ if res is None:
1610
+ continue
1611
+ if not isinstance(res, tuple):
1612
+ res = (res,)
1613
+ if len(res) != len(grad):
1614
+ raise TypeError(
1615
+ "The backward hook return value size is {} not equal to expect grad input size {}".format(
1616
+ len(res), len(grad)))
1617
+ backward_hook_grad_input = res
1618
+ self.grad_output = None
1619
+ return backward_hook_grad_input
1620
+
1621
+ self.set_hook_fn(hook_backward_grad, HookType.BackwardHook)
1622
+
1623
+ def register_backward_pre_hook(self):
1624
+ """
1625
+ Register the backward pre hook function.
1626
+
1627
+ Args:
1628
+ None
1629
+
1630
+ Returns:
1631
+ None
1632
+
1633
+ Supported Platforms:
1634
+ ``Ascend`` ``GPU`` ``CPU``
1635
+ """
1636
+
1637
+ def hook_backward_pre_grad(grad):
1638
+ backward_pre_hook_grad = grad
1639
+ if self.hook_dict():
1640
+ backward_pre_hooks = self.hook_dict().values()
1641
+ for hook in backward_pre_hooks:
1642
+ res = hook(self.cell, backward_pre_hook_grad)
1643
+ if res is None:
1644
+ continue
1645
+ if not isinstance(res, tuple):
1646
+ res = (res,)
1647
+ if len(res) != len(grad):
1648
+ raise TypeError(
1649
+ "The backward pre hook return value size is {} not equal to expect output size {}".format(
1650
+ len(res), len(grad)))
1651
+ backward_pre_hook_grad = res
1652
+ return backward_pre_hook_grad
1653
+
1654
+ self.set_hook_fn(hook_backward_pre_grad, HookType.BackwardPreHook)
1655
+
1656
+
1657
+ class Format(PrimitiveWithInfer):
1658
+ r"""
1659
+ This operator is used to format a string.
1660
+
1661
+ Note:
1662
+ Current not supported to using by customer.
1663
+ Only support convert str.format() in user code and it will be converted to be Format
1664
+ operation by ME-Compiler automatically.
1665
+
1666
+
1667
+ Inputs:
1668
+ - **input** -
1669
+ string : the string to be formatted.
1670
+ args : the format args.
1671
+
1672
+ Outputs:
1673
+ - **output** - Returns formatted string.
1674
+
1675
+ Supported Platforms:
1676
+ ``Ascend`` ``GPU`` ``CPU``
1677
+ """
1678
+
1679
+ @prim_attr_register
1680
+ def __init__(self):
1681
+ self.init_prim_io_names(inputs=['string', 'args'], outputs=['string'])
1682
+
1683
+ def __infer__(self, str_, *var):
1684
+ def check_variable(str_, var):
1685
+ if _check_contains_variable(str_['dtype'], str_['value']):
1686
+ return True
1687
+
1688
+ for item in var:
1689
+ if _check_contains_variable(item['dtype'], item['value']):
1690
+ return True
1691
+ return False
1692
+
1693
+ if check_variable(str_, var):
1694
+ return {'dtype': mstype.string, 'shape': [], 'value': None}
1695
+
1696
+ str_value = str_['value']
1697
+ kwargs = dict()
1698
+ var_value = list()
1699
+
1700
+ for item in var:
1701
+ if isinstance(item["dtype"], typing.Keyword):
1702
+ kwargs.update(item["value"])
1703
+ var_value.append(item["value"])
1704
+
1705
+ value = str_value.format(*var_value, **kwargs)
1706
+ return {'dtype': mstype.string, 'shape': [], 'value': value}
1707
+
1708
+
1709
+ class FlattenConcat(Primitive):
1710
+ """
1711
+ Flatten input tensors and concatenate them into several chunk tensors grouped by data types.
1712
+
1713
+ Args:
1714
+ fusion_size (int): Maximum memory chunk size in bytes, 0 for unlimited. Default: 0.
1715
+
1716
+ Inputs:
1717
+ - **tensors** (tuple[Tensor], list[Tensor]) - The input Tensors to be flattened and concatenated.
1718
+
1719
+ Outputs:
1720
+ tuple[Tensor], result chunk tensors.
1721
+
1722
+ Supported Platforms:
1723
+ ``Ascend`` ``GPU`` ``CPU``
1724
+
1725
+ Examples:
1726
+ >>> from mindspore.ops.operations import _inner_ops as inner
1727
+ >>> t1 = Tensor(np.array([1]).astype(np.float32))
1728
+ >>> t2 = Tensor(np.array([2]).astype(np.float32))
1729
+ >>> t3 = Tensor(np.array([3]).astype(np.float64))
1730
+ >>> t4 = Tensor(np.array([4]).astype(np.float32))
1731
+ >>> t5 = Tensor(np.array([5]).astype(np.float64))
1732
+ >>> chunks = inner.FlattenConcat()([t1, t2, t2, t3, t4, t5])
1733
+ >>> print(chunks[0].asnumpy())
1734
+ >>> print(chunks[1].asnumpy())
1735
+ [1. 2. 4.]
1736
+ [3. 5.]
1737
+ """
1738
+
1739
+ @prim_attr_register
1740
+ def __init__(self, fusion_size=0):
1741
+ """Initialize FlattenConcat"""
1742
+ validator.check_non_negative_int(fusion_size, 'fusion_size', self.name)
1743
+ self.fusion_size = fusion_size
1744
+ self.add_prim_attr('fusion_size', fusion_size)
1745
+
1746
+
1747
+ class KMeansCentroids(PrimitiveWithInfer):
1748
+ """
1749
+ Calculate the segment_sum, segment_count, kmean_total_sum that are clustering results
1750
+
1751
+ Args:
1752
+ use_actual_distance (bool): A bool value to decide whether do complete calculation of distance.
1753
+
1754
+ Inputs:
1755
+ - **x** (Tensor(float32)) - Input data used for clustering
1756
+ - **y** (Tensor(float32)) - Initial centroids of clutering
1757
+ - **sum_square_y** (Tensor(float32)) - The result of preprocessing such as square, reduce and transpose of y
1758
+ - **sum_square_x** (Tensor(float32)) - The result of preprocessing such as square and reduce of x
1759
+
1760
+ Outputs:
1761
+ - **segment_sum** (Tensor(float32)) - Clustering result w.r.t. each centroid
1762
+ - **segment_count** (Tensor(float32)) - Clustering count w.r.t. each centroid
1763
+ - **kmean_total_sum** (Tensor(float32)) - The sum of the distances from all vectors to ther nearest centroid
1764
+
1765
+ Supported Platforms:
1766
+ ''Ascend''
1767
+
1768
+ Examples:
1769
+ >>> import numpy as np
1770
+ >>> import mindspore as ms
1771
+ >>> import mindspore.common.dtype as mstype
1772
+ >>> import mindspore.nn as nn
1773
+ >>> from mindspore import Tensor
1774
+ >>> from mindspore.ops import operations as P
1775
+ >>> ms.set_context(mode=ms.GRAPH_MODE, device_target="Ascend")
1776
+
1777
+ >>> class Net(nn.Cell):
1778
+ >>> def __init__(self):
1779
+ >>> super(Net, self).__init__()
1780
+ >>> self.reduce_sum = P.ReduceSUm(keep_dims=True)
1781
+ >>> self.square = P.Square()
1782
+ >>> self.transpose = P.Transpose()
1783
+ >>> self.k_means_centroids = P.KMeansCentroids(True)
1784
+
1785
+ >>> def construct(self, x, y):
1786
+ >>> p1 = self.reduce_sum(self.square(x), -1)
1787
+ >>> p2 = self.transpose(self.reduce_sum(self.square(y), -1), (1, 0))
1788
+ >>> return self.k_means_centroids(x, y, p2, p1)
1789
+
1790
+ >>> def test_net():
1791
+ >>> data_type = np.float32
1792
+ >>> x = Tensor(np.random.uniform(-10, 10, (65536, 128)).astype(data_type))
1793
+ >>> y = P.Ones()((1048576, 128), mstype.float32)
1794
+ >>> net = Net()
1795
+ >>> local_sum, local_count, local_avg_distance = net(x, y)
1796
+ """
1797
+
1798
+ @prim_attr_register
1799
+ def __init__(self, use_actual_distance):
1800
+ validator.check_value_type('use_actual_distance', use_actual_distance, [bool], self.name)
1801
+ self.init_prim_io_names(inputs=['x', 'y', 'sum_square_y', 'sum_square_x'],
1802
+ outputs=['segment_sum', 'segment_count', 'kmean_total_sum'])
1803
+
1804
+ def infer_shape(self, x_shape, y_shape, sum_square_y_shape, sum_square_x_shape):
1805
+ """infer shape of primitive"""
1806
+ expected_shape_size = 2
1807
+ validator.check_int(len(x_shape), expected_shape_size, validator.EQ, "dims of x", self.name)
1808
+ validator.check_int(len(y_shape), expected_shape_size, validator.EQ, "dims of y", self.name)
1809
+ validator.check_int(len(sum_square_y_shape), expected_shape_size, validator.EQ,
1810
+ "dims of sum_square_y", self.name)
1811
+ validator.check_int(len(sum_square_x_shape), expected_shape_size, validator.EQ,
1812
+ "dims of sum_square_x", self.name)
1813
+
1814
+ validator.check_int(x_shape[1], y_shape[1], validator.EQ,
1815
+ "the second dim of x and the second dim of y", self.name)
1816
+ validator.check_int(y_shape[0], sum_square_y_shape[1], validator.EQ,
1817
+ "the first dim of y and the second dim of sum_square_y", self.name)
1818
+ validator.check_int(x_shape[0], sum_square_x_shape[0], validator.EQ,
1819
+ "the first dim of x and the first dim of sum_square_x", self.name)
1820
+ validator.check_int(sum_square_y_shape[0], sum_square_x_shape[1], validator.EQ,
1821
+ "the first dim of sum_square_y and the first dim of sum_square_x",
1822
+ self.name)
1823
+ validator.check_int(sum_square_y_shape[0], 1, validator.EQ,
1824
+ "the first dim of sum_square_y", self.name)
1825
+
1826
+ k = y_shape[0]
1827
+ em_size = x_shape[1]
1828
+ return (k, em_size), (k, 1), (1)
1829
+
1830
+
1831
+ class ClipByNorm(PrimitiveWithInfer):
1832
+ r"""
1833
+ Clips tensor values to a maximum :math:`L_2`-norm.
1834
+
1835
+ Note:
1836
+ The output tensor of this operator remains the same with input tensor if the :math:`L_2`-norm of the input
1837
+ tensor is not greater than the argument `clip_norm`. Otherwise the output tensor will be normalized as:
1838
+
1839
+ .. math::
1840
+ \text{output}(X) = \frac{\text{clip_norm} * X}{L_2(X)},
1841
+
1842
+ where :math:`L_2(X)` is the :math:`L_2`-norm of :math:`X`.
1843
+
1844
+ Args:
1845
+ axis (Union[None, int, tuple(int), list(int)]): Compute the `L_2`-norm along the specific dimension.
1846
+ Default: ``None``, all dimensions to calculate.
1847
+
1848
+ Inputs:
1849
+ - **x** (Tensor) - Tensor of shape N-D. The type must be float16 or float32.
1850
+ - **clip_norm** (Tensor) - A scalar Tensor of shape :math:`()` or :math:`(1)`.
1851
+ Or a Tensor which shape can be broadcast to the shape of `x`. The type must be float16 or float32.
1852
+
1853
+ Outputs:
1854
+ Tensor, clipped Tensor with the same shape as the `x`, whose type is float32.
1855
+
1856
+ Raises:
1857
+ TypeError: If `axis` is not one of None, int, tuple(int) and list(int).
1858
+ TypeError: If dtype of `x` is neither float16 nor float32.
1859
+ TypeError: If dtype of `clip_norm` is neither float16 nor float32.
1860
+
1861
+ Supported Platforms:
1862
+ ``Ascend`` ``GPU`` ``CPU``
1863
+
1864
+ Examples:
1865
+ >>> import numpy as np
1866
+ >>> import mindspore
1867
+ >>> from mindspore import Tensor
1868
+ >>> from mindspore.ops.operations import _inner_ops as inner
1869
+ >>> clip_by_norm = inner.ClipByNorm()
1870
+ >>> x = Tensor(np.random.randint(0, 10, [4, 16]), mindspore.float32)
1871
+ >>> clip_norm = Tensor(np.array([100]).astype(np.float32))
1872
+ >>> output = clip_by_norm(x, clip_norm)
1873
+ >>> print(output.shape)
1874
+ (4, 16)
1875
+ """
1876
+
1877
+ @prim_attr_register
1878
+ def __init__(self, axis=None):
1879
+ """Initialize ClipByNorm"""
1880
+ self.axis = () if axis is None else axis
1881
+ validator.check_value_type('axis', self.axis, [int, tuple, list], self.name)
1882
+ axis_check = self.axis if isinstance(self.axis, Iterable) else (self.axis,)
1883
+ for i, value in enumerate(axis_check):
1884
+ validator.check_value_type('axis[%d]' % i, value, [int], self.name)
1885
+ self.init_attrs['axis'] = self.axis
1886
+ self.add_prim_attr('axis', self.axis)
1887
+ self.init_prim_io_names(inputs=['x', 'clip_norm'], outputs=['output'])
1888
+
1889
+ def infer_shape(self, x_shape, clip_norm_shape):
1890
+ """Infer shape for ClipByNorm"""
1891
+ x_dim = len(x_shape)
1892
+ axis = self.axis if isinstance(self.axis, Iterable) else (self.axis,)
1893
+ for _, value in enumerate(axis):
1894
+ validator.check_int_range(value, -x_dim, x_dim, validator.INC_LEFT, 'axis', self.name)
1895
+ return x_shape
1896
+
1897
+ def infer_dtype(self, x_type, clip_norm_type):
1898
+ """Infer data type for ClipByNorm"""
1899
+ validator.check_tensor_dtype_valid("x_type", x_type, [mstype.float16, mstype.float32], self.name)
1900
+ validator.check_tensor_dtype_valid("clip_norm_type", clip_norm_type,
1901
+ [mstype.float16, mstype.float32], self.name)
1902
+ return mstype.float32
1903
+
1904
+
1905
+ class TopTypeof(Primitive):
1906
+ """
1907
+ Internal primitive method, to speed up mindspore.ops.typeof.
1908
+
1909
+ Returns the top type of the input data.
1910
+
1911
+ In Pynative mode, returns the top type in cache.
1912
+
1913
+ Supported Platforms:
1914
+ ``Ascend`` ``GPU`` ``CPU``
1915
+ """
1916
+
1917
+ @prim_attr_register
1918
+ def __init__(self):
1919
+ self.prim = Primitive('TopTypeof')
1920
+ self.typeof_cache = {
1921
+ 'slice': mstype.Slice(),
1922
+ 'list': mstype.List(),
1923
+ 'tuple': mstype.Tuple(),
1924
+ 'Tensor': mstype.tensor_type,
1925
+ 'NoneType': mstype.NoneType(),
1926
+ 'int': mstype.Int(),
1927
+ 'bool': mstype.Bool(),
1928
+ 'ellipsis': mstype.Ellipsis_(),
1929
+ 'dict': mstype.Dict()
1930
+ }
1931
+
1932
+ def __call__(self, x):
1933
+ index_type = type(x).__name__
1934
+ if 'Tensor' in index_type:
1935
+ index_type = 'Tensor'
1936
+ if index_type in self.typeof_cache:
1937
+ return self.typeof_cache.get(index_type)
1938
+ return _pynative_executor.constant_folding(self.prim, x)
1939
+
1940
+
1941
+ class MixedPrecisionCast(Primitive):
1942
+ r"""
1943
+ Internal primitive method, to achieve mindspore.functional.mixed_precision_cast.
1944
+
1945
+ Note:
1946
+ This internal primitive method used to do mixed precision conversion.
1947
+ Only the input object with float dtype will be cast.
1948
+
1949
+ Inputs:
1950
+ - **dtype** (Union[Float16, Float32]) - The data type of the output object.
1951
+ - **input** (Union[Tensor, Tuple, Dictionary, KeywordArg]) - The object to be cast.
1952
+
1953
+ Outputs:
1954
+ Object, its dtype is the same as `dtype` and shape is the same as 'input'.
1955
+
1956
+ Supported Platforms:
1957
+ ``Ascend`` ``GPU`` ``CPU``
1958
+
1959
+ Examples:
1960
+ >>> import numpy as np
1961
+ >>> from mindspore import Tensor
1962
+ >>> from mindspore import dtype as mstype
1963
+ >>> from mindspore.ops.operations import _inner_ops as inner
1964
+ >>> x = Tensor(np.ones([2, 3], dtype=np.float32))
1965
+ >>> out = inner.MixedPrecisionCast(mstype.float16, x)
1966
+ >>> print(out.dtype)
1967
+ Float16
1968
+ """
1969
+
1970
+ @prim_attr_register
1971
+ def __init__(self):
1972
+ """Initialize MixedPrecisionCast"""
1973
+ self.init_prim_io_names(inputs=['dst_dtype', 'input_x'], outputs=['output'])
1974
+ self.cast = Cast()
1975
+ self.hyper_map = C.HyperMap()
1976
+
1977
+ def __call__(self, dst_dtype, x):
1978
+ def cast_inner(data):
1979
+ if isinstance(data, Tensor) and data.dtype in (mstype.float16, mstype.float32,
1980
+ mstype.float64, mstype.bfloat16):
1981
+ return self.cast(data, dst_dtype)
1982
+ return data
1983
+
1984
+ return self.hyper_map(cast_inner, x)
1985
+
1986
+
1987
+ class CheckBprop(PrimitiveWithInfer):
1988
+ """
1989
+ Checks whether the data type and the shape of corresponding elements from tuples x and y are the same.
1990
+
1991
+ Args:
1992
+ prim_to_check (str): The name of the primitive being checked. Default: ''.
1993
+
1994
+ Inputs:
1995
+ - **input_x** (tuple[Tensor]) - The `input_x` contains the outputs of bprop to be checked.
1996
+ - **input_y** (tuple[Tensor]) - The `input_y` contains the inputs of bprop to check against.
1997
+
1998
+ Outputs:
1999
+ Tuple[Tensor], the `input_x`,
2000
+ if data type and shape of corresponding elements from `input_x` and `input_y` are the same.
2001
+
2002
+ Raises:
2003
+ TypeError: If `input_x` or `input_y` is not a Tensor.
2004
+
2005
+ Supported Platforms:
2006
+ ``Ascend`` ``GPU`` ``CPU``
2007
+
2008
+ Examples:
2009
+ >>> class Net(nn.Cell):
2010
+ ... def __init__(self):
2011
+ ... super(Net, self).__init__()
2012
+ ... self.op = ops.CheckBprop()
2013
+ ... def construct(self, x, y):
2014
+ ... return self.op(x, y)
2015
+ ...
2016
+ >>> net = Net()
2017
+ >>> input_x = (Tensor(np.array([[2, 2], [2, 2]]), mindspore.float32),)
2018
+ >>> input_y = (Tensor(np.array([[2, 2], [2, 2]]), mindspore.float32),)
2019
+ >>> output = net(input_x, input_y)
2020
+ >>> print(output)
2021
+ (Tensor(shape=[2, 2], dtype=Float32, value=
2022
+ [[ 2.00000000e+00, 2.00000000e+00],
2023
+ [ 2.00000000e+00, 2.00000000e+00]]),)
2024
+ """
2025
+
2026
+ @prim_attr_register
2027
+ def __init__(self, prim_to_check=""):
2028
+ """Initialize CheckBprop"""
2029
+ self.prim_to_check = prim_to_check
2030
+
2031
+ def infer_shape(self, xshapes, yshapes):
2032
+ """infer shape"""
2033
+ tips = f"user defined method 'bprop'"
2034
+ validator.check_value_type('grads', xshapes, (tuple,), tips)
2035
+ validator.check_value_type('params', yshapes, (tuple,), tips)
2036
+ if not len(xshapes) == len(yshapes):
2037
+ raise ValueError(f"For {tips} the number of return values(gradients) must be equal to "
2038
+ f"the number of input arguments except 'out' and 'dout', "
2039
+ f"which is:{len(yshapes)} but got {len(xshapes)}.")
2040
+
2041
+ def shape_equal(shape1, shape2):
2042
+ if len(shape1) != len(shape2):
2043
+ return False
2044
+ for shape_axis1, shape_axis2 in zip(shape1, shape2):
2045
+ if shape_axis1 == -1 or shape_axis2 == -1:
2046
+ continue
2047
+ if shape_axis1 != shape_axis2:
2048
+ return False
2049
+ return True
2050
+
2051
+ for i, (xshape, yshape) in enumerate(zip(xshapes, yshapes)):
2052
+ if not xshape or not yshape:
2053
+ continue
2054
+
2055
+ if not shape_equal(xshape, yshape):
2056
+ raise ValueError(f"For {tips}, the {i}th return value(gradient of the {i}th argument) "
2057
+ f"should have the same shape as the {i}th argument, "
2058
+ f"which is:{yshape}, but got: {xshape}.")
2059
+ return xshapes
2060
+
2061
+ def infer_dtype(self, xdtypes, ydtypes):
2062
+ """infer dtype"""
2063
+ tips = f"user defined method 'bprop'"
2064
+ validator.check_value_type('grads', xdtypes, (tuple,), tips)
2065
+ validator.check_value_type('params', ydtypes, (tuple,), tips)
2066
+ if not len(xdtypes) == len(ydtypes):
2067
+ raise ValueError(f"For {tips}, the number of return values(gradients) must be equal to "
2068
+ f"the number of input arguments except 'out' and 'dout', "
2069
+ f"which is:{len(ydtypes)} but got {len(xdtypes)}.")
2070
+ checking_range = len(ydtypes)
2071
+ for i in range(checking_range):
2072
+ xdtype = xdtypes[i]
2073
+ ydtype = ydtypes[i]
2074
+ if isinstance(xdtype, mstype.AnythingType) or isinstance(ydtype, mstype.AnythingType):
2075
+ continue
2076
+ if isinstance(ydtype, mstype.FunctionType):
2077
+ if not isinstance(xdtype, mstype.EnvType):
2078
+ raise TypeError(f"For {tips}, the {i}th return value(gradient of the {i}th argument) type "
2079
+ f"should be {mstype.EnvType}, but got {xdtype}.")
2080
+ if xdtype != ydtype:
2081
+ raise TypeError(f"For {tips}, the {i}th return value(gradient of the {i}th argument) "
2082
+ f"should have the same dtype as the {i}th argument, "
2083
+ f"which is:{ydtype}, but got: {xdtype}.")
2084
+ return xdtypes
2085
+
2086
+
2087
+ check_bprop = CheckBprop()
2088
+
2089
+
2090
+ class SameTypeShape(PrimitiveWithInfer):
2091
+ """
2092
+ Checks whether the data type and shape of two tensors are the same.
2093
+
2094
+ Refer to :func:`mindspore.ops.same_type_shape` for more detail.
2095
+
2096
+ Supported Platforms:
2097
+ ``Ascend`` ``GPU`` ``CPU``
2098
+
2099
+ Examples:
2100
+ >>> input_x = Tensor(np.array([[2, 2], [2, 2]]), mindspore.float32)
2101
+ >>> input_y = Tensor(np.array([[2, 2], [2, 2]]), mindspore.float32)
2102
+ >>> output = ops.SameTypeShape()(input_x, input_y)
2103
+ >>> print(output)
2104
+ [[2. 2.]
2105
+ [2. 2.]]
2106
+ """
2107
+
2108
+ @prim_attr_register
2109
+ def __init__(self):
2110
+ """Initialize Same"""
2111
+
2112
+ def __call__(self, x, y):
2113
+ """run in PyNative mode"""
2114
+ validator.check_value_type('x', x, Tensor, self.name)
2115
+ validator.check_value_type('y', y, Tensor, self.name)
2116
+ validator.check('x dtype', x.dtype, 'y dtype', y.dtype, validator.EQ, self.name, TypeError)
2117
+ validator.check('x shape', x.shape, 'y shape', y.shape, validator.EQ, self.name)
2118
+ return x
2119
+
2120
+ def __infer__(self, x, y):
2121
+ validator.check_subclass('x', x['dtype'], mstype.tensor_type, self.name)
2122
+ validator.check_subclass('y', y['dtype'], mstype.tensor_type, self.name)
2123
+ validator.check('x dtype', x['dtype'], 'y dtype', y['dtype'], validator.EQ, self.name, TypeError)
2124
+ validator.check('x shape', x['shape'], 'y shape', y['shape'], validator.EQ, self.name)
2125
+ return x
2126
+
2127
+
2128
+ same_type_shape_ = SameTypeShape()
2129
+
2130
+
2131
+ def _is_subclass_(type_, dtype):
2132
+ if not isinstance(type_, typing.Type):
2133
+ return False
2134
+ return typing.is_subclass(type_, dtype)
2135
+
2136
+
2137
+ class IsSubClass(PrimitiveWithInfer):
2138
+ """
2139
+ Checks whether this type is a sub-class of another type.
2140
+
2141
+ Inputs:
2142
+ - **sub_type** (mindspore.dtype) - The type to be checked. Only constant value is allowed.
2143
+ - **type_** (mindspore.dtype) - The target type. Only constant value is allowed.
2144
+
2145
+ Outputs:
2146
+ bool, the check result.
2147
+
2148
+ Raises:
2149
+ TypeError: If `sub_type` or `type_` is not a Type.
2150
+
2151
+ Supported Platforms:
2152
+ ``Ascend`` ``GPU`` ``CPU``
2153
+
2154
+ Examples:
2155
+ >>> output = ops.IsSubClass()(mindspore.int32, mindspore.intc)
2156
+ >>> print(output)
2157
+ True
2158
+ """
2159
+
2160
+ @prim_attr_register
2161
+ def __init__(self):
2162
+ pass
2163
+
2164
+ def __infer__(self, sub_type, type_):
2165
+ sub_type_t = sub_type['value']
2166
+ type_v = type_['value']
2167
+
2168
+ validator.check_value_type("sub_type", sub_type_t, [mstype.Type], self.name)
2169
+ validator.check_value_type("type_", type_v, [mstype.Type], self.name)
2170
+
2171
+ value = _is_subclass_(sub_type_t, type_v)
2172
+
2173
+ out = {'shape': (),
2174
+ 'dtype': mstype.type_type,
2175
+ 'value': value}
2176
+ return out
2177
+
2178
+
2179
+ issubclass_ = IsSubClass()
2180
+
2181
+
2182
+ class IsInstance(PrimitiveWithInfer):
2183
+ """
2184
+ Checks whether an object is an instance of a target type.
2185
+
2186
+ Inputs:
2187
+ - **inst** (Any Object) - The instance to be checked. Only constant value is allowed.
2188
+ - **type_** (mindspore.dtype) - The target type. Only constant value is allowed.
2189
+
2190
+ Outputs:
2191
+ bool, the check result.
2192
+
2193
+ Raises:
2194
+ TypeError: If `type_` is not a Type.
2195
+
2196
+ Supported Platforms:
2197
+ ``Ascend`` ``GPU`` ``CPU``
2198
+
2199
+ Examples:
2200
+ >>> inst = 1
2201
+ >>> output = ops.IsInstance()(inst, mindspore.int32)
2202
+ >>> print(output)
2203
+ False
2204
+ """
2205
+
2206
+ @prim_attr_register
2207
+ def __init__(self):
2208
+ pass
2209
+
2210
+ def __infer__(self, inst, type_):
2211
+ sub_type_t = inst['dtype']
2212
+ type_v = type_['value']
2213
+
2214
+ validator.check_value_type("type_", type_v, [mstype.Type], self.name)
2215
+
2216
+ if type_v == mstype.list_:
2217
+ value = isinstance(sub_type_t, list)
2218
+ elif type_v == mstype.tuple_:
2219
+ value = isinstance(sub_type_t, tuple)
2220
+ else:
2221
+ value = _is_subclass_(sub_type_t, type_v)
2222
+
2223
+ out = {'shape': (),
2224
+ 'dtype': mstype.type_type,
2225
+ 'value': value}
2226
+ return out
2227
+
2228
+
2229
+ class ConvertToAdapterTensor(Primitive):
2230
+ """
2231
+ Convert a tensor from MindSpore's Tensor type to MSAdapter's Tensor type,
2232
+ where MSAdapter's Tensor is a subclass of MindSpore's Tensor.
2233
+
2234
+ Inputs:
2235
+ - **x** (Tensor) - The input tensor.
2236
+
2237
+ Outputs:
2238
+ A tensor, whose type is MSAdapter's Tensor.
2239
+
2240
+ Supported Platforms:
2241
+ ``Ascend`` ``GPU`` ``CPU``
2242
+
2243
+ Examples:
2244
+ >>> x = Tensor([1, 2 ,3])
2245
+ >>> x = ops.ConvertToAdapterTensor()(x)
2246
+ >>> print(x)
2247
+ [1 2 3]
2248
+ """
2249
+
2250
+ @prim_attr_register
2251
+ def __init__(self):
2252
+ """Initialize"""
2253
+
2254
+ def __call__(self, x):
2255
+ """Run in PyNative mode"""
2256
+ return ms_adapter_registry.tensor(x, cast_tensor=True)
2257
+
2258
+
2259
+ convert_to_adapter_tensor = ConvertToAdapterTensor()
2260
+
2261
+
2262
+ class ConvertToMsTensor(Primitive):
2263
+ """
2264
+ Convert a tensor from MSAdapter's Tensor type to MindSpore's Tensor type,
2265
+ where MSAdapter's Tensor is a subclass of MindSpore's Tensor.
2266
+
2267
+ Inputs:
2268
+ - **x** (Tensor) - The input tensor.
2269
+
2270
+ Outputs:
2271
+ A tensor, whose type is MindSpore's Tensor.
2272
+
2273
+ Supported Platforms:
2274
+ ``Ascend`` ``GPU`` ``CPU``
2275
+
2276
+ Examples:
2277
+ >>> x = Tensor([1, 2 ,3])
2278
+ >>> x = ops.ConvertToMsTensor()(x)
2279
+ >>> print(x)
2280
+ [1 2 3]
2281
+ """
2282
+
2283
+ @prim_attr_register
2284
+ def __init__(self):
2285
+ """Initialize"""
2286
+
2287
+ def __call__(self, x):
2288
+ """Run in PyNative mode"""
2289
+ if isinstance(x, StubTensor):
2290
+ return StubTensor(stub=x.stub, tensor=x.tensor)
2291
+ return ops.auto_generate.deepcopy(x)
2292
+
2293
+
2294
+ convert_to_ms_tensor = ConvertToMsTensor()
2295
+
2296
+
2297
+ class GetGrad(Primitive):
2298
+ """
2299
+ Use the position id or Parameter object to get the gradient from the output
2300
+ which returned by the :func:`mindspore.ops.grad`.
2301
+ """
2302
+
2303
+ @prim_attr_register
2304
+ def __init__(self):
2305
+ """Initialize ScatterElements"""
2306
+ self.init_prim_io_names(
2307
+ inputs=['gradients', 'x'], outputs=['gradient'])
2308
+
2309
+ def __call__(self, gradients, x):
2310
+ if not isinstance(x, int) and not isinstance(x, Parameter):
2311
+ raise TypeError(
2312
+ f"For `get_grad`, the `x` should be an integer or a Parameter, but got {x}")
2313
+ hash_id = x
2314
+ if isinstance(x, Parameter):
2315
+ hash_id = x.name
2316
+ output = None
2317
+
2318
+ def _get_grad(grads, identifier):
2319
+ if isinstance(grads, tuple):
2320
+ if len(grads) != 2 or identifier != grads[0]:
2321
+ for gradient in grads:
2322
+ _get_grad(gradient, identifier)
2323
+ else:
2324
+ nonlocal output
2325
+ output = grads[1]
2326
+ return
2327
+
2328
+ _get_grad(gradients, hash_id)
2329
+ if output is None:
2330
+ raise RuntimeError(
2331
+ f"Can not find the gradient for position or Parameter {x}")
2332
+ return output
2333
+
2334
+
2335
+ class IsParameter(PrimitiveWithInfer):
2336
+ """
2337
+ Check if input is `Parameter`
2338
+ """
2339
+
2340
+ @prim_attr_register
2341
+ def __init__(self):
2342
+ """Initialize IsParameter"""
2343
+
2344
+ def __call__(self, x):
2345
+ return isinstance(x, Parameter)
2346
+
2347
+ def __infer__(self, x):
2348
+ return {'shape': [],
2349
+ 'dtype': mstype.bool_,
2350
+ 'value': isinstance(x['dtype'], mstype.RefType)}
2351
+
2352
+
2353
+ class TileSize(Primitive):
2354
+ r"""
2355
+ Tile size for matmul
2356
+ """
2357
+
2358
+ @prim_attr_register
2359
+ def __init__(self):
2360
+ """Initialize TileSize"""
2361
+ self.init_prim_io_names(inputs=['shape', 'out_shape', 'ndim'], outputs=['output'])
2362
+
2363
+ def __call__(self, shape, out_shape, ndim):
2364
+ size = [1] * ndim
2365
+ for idx, (i, j) in enumerate(zip(shape, out_shape)):
2366
+ if i != j:
2367
+ size[idx] = j
2368
+ return tuple(size)
2369
+
2370
+
2371
+ class GetitemTensorIndexInfo(Primitive):
2372
+ r"""
2373
+ Get getitem tensor index info
2374
+ """
2375
+
2376
+ @prim_attr_register
2377
+ def __init__(self, is_ascend):
2378
+ """Initialize GetitemTensorIndexInfo"""
2379
+ self.init_prim_io_names(inputs=['data', 'index'],
2380
+ outputs=["new_index", "tensor_update_types", "tensor_update_args"])
2381
+ validator.check_value_type('is_ascend', is_ascend, [bool], self.name)
2382
+ self.is_ascend = is_ascend
2383
+
2384
+ def __call__(self, data, index):
2385
+ return Tensor_.getitem_index_info(data, index, self.is_ascend)
2386
+
2387
+
2388
+ class SetitemTensorIndexInfo(Primitive):
2389
+ r"""
2390
+ Get setitem tensor index info
2391
+ """
2392
+
2393
+ @prim_attr_register
2394
+ def __init__(self, is_ascend):
2395
+ """Initialize GetitemTensorIndexInfo"""
2396
+ self.init_prim_io_names(
2397
+ inputs=['data', 'index', 'value'], outputs=['new_index',
2398
+ 'v_transfer_types',
2399
+ 'v_transfer_args',
2400
+ 'tensor_update_types',
2401
+ 'tensor_update_args'])
2402
+ validator.check_value_type('is_ascend', is_ascend, [bool], self.name)
2403
+ self.is_ascend = is_ascend
2404
+
2405
+ def __call__(self, data, index, value):
2406
+ return Tensor_.setitem_index_info(data, index, value, self.is_ascend)
2407
+
2408
+
2409
+ class IsConstant(Primitive):
2410
+ r"""
2411
+ Check if the input is constant
2412
+ """
2413
+
2414
+ @prim_attr_register
2415
+ def __init__(self):
2416
+ """Initialize IsConstant"""
2417
+
2418
+ def __call__(self, x):
2419
+ return True
2420
+
2421
+
2422
+ class SelectView(Primitive):
2423
+ r"""
2424
+ Select tensor of view
2425
+ """
2426
+
2427
+ @prim_attr_register
2428
+ def __init__(self):
2429
+ self.init_prim_io_names(inputs=['input_tensor', 'input_indices', 'axis'], outputs=['output'])
2430
+
2431
+
2432
+ class CopyWithSlice(Primitive):
2433
+ r"""
2434
+ Copy data to discontinuous tensor
2435
+ """
2436
+
2437
+ @prim_attr_register
2438
+ def __init__(self):
2439
+ self.add_prim_attr('side_effect_mem', True)
2440
+ self.init_prim_io_names(inputs=['x', 'y'], outputs=['x'])
2441
+
2442
+
2443
+ class FFN(Primitive):
2444
+ r"""
2445
+ The FFN computation is similar to Feed-Forward Network, it contains matmul + gelu + matmul.
2446
+
2447
+ Args:
2448
+ activation (string): The activation type, set to 'fastgelu' or 'gelu'.
2449
+ Only support 'fastgelu' for now. Default: "fastgelu".
2450
+ inner_precise (int): The precise mode, set to 0 for high precision or 1 for high performance.
2451
+ Only support 1 for now. Default: 0.
2452
+
2453
+ Inputs:
2454
+ - **x** (Tensor) - The input tensor with data type of int8, float16.
2455
+ Input tensor of shape :math:`(batch\_size * seq\_length, hidden\_size)`.
2456
+ - **weight1** (Tensor) - The weight1 tensor with data type of float16.
2457
+ Weight1 tensor of shape :math:`(expert\_num, hidden\_size, ffn\_hidden\_size)`.
2458
+ - **weight2** (Tensor) - The weight2 tensor with data type of float16.
2459
+ Weight2 tensor of shape :math:`(expert\_num, ffn\_hidden\_size, hidden\_size)`.
2460
+ - **expert_tokens** (Tensor]) - The expert tokens tensor with data type of int64.
2461
+ Expert tokens tensor of shape :math:`(16,)`. For example, `(2, 1, 0, .., 9)`
2462
+ indicate that the 0th expert deals with 2 tokens, the 1th expert deals with 1 tokens,
2463
+ the 2th expert do noting and so on.
2464
+ - **bias1** (Tensor) - The bias1 tensor with data type of float16.
2465
+ Bias1 tensor of shape :math:`(expert\_num, ffn\_hidden\_size)`.
2466
+ - **bias2** (Tensor) - The bias2 tensor with data type of float16.
2467
+ Bias2 tensor of shape :math:`(expert\_num, hidden\_size)`.
2468
+ - **scale** (Tensor) - The scale tensor with data type of float16. Not enable now.
2469
+ - **offset** (Tensor) - The offset tensor with data type of float16. Not enable now.
2470
+ - **deq_scale1** (Tensor) - The deq_scale1 tensor with data type of float16. Not enable now.
2471
+ - **deq_scale2** (Tensor) - The deq_scale2 tensor with data type of float16. Not enable now.
2472
+
2473
+ Outputs:
2474
+ Tensor of shape :math:`(batch\_size * seq\_length, hidden\_size)`. With data type of float16.
2475
+
2476
+ Supported Platforms:
2477
+ ``Ascend``
2478
+
2479
+ Examples:
2480
+ >>> from mindspore.ops.operations import _inner_ops
2481
+ >>> b = 4
2482
+ >>> s = 128
2483
+ >>> h = 1024
2484
+ >>> h_f = 4 * h
2485
+ >>> e = 16
2486
+ >>> x = Tensor(np.random.randn(s, h).astype(np.float16))
2487
+ >>> w1 = Tensor(np.random.randn(e, h, h_f).astype(np.float16))
2488
+ >>> w2 = Tensor(np.random.randn(e, h_f, h).astype(np.float16))
2489
+ >>> expert_tokens = Tensor(np.full(e, 8))
2490
+ >>> bias1 = Tensor(np.random.randn(e, h_f).astype(np.float16))
2491
+ >>> bias2 = Tensor(np.random.randn(e, h).astype(np.float16))
2492
+ >>> ffn = _inner_ops.FFN("fastgelu", 1)
2493
+ >>> output = ffn(x, w1, w2, expert_tokens, bias1, bias2)
2494
+ >>> print(output)
2495
+ """
2496
+
2497
+ @prim_attr_register
2498
+ def __init__(self, activation, inner_precise):
2499
+ """Initialize FFN."""
2500
+ self.init_prim_io_names(inputs=["x", "weight1", "weight2", "expert_tokens", "bias1",
2501
+ "bias2", "scale", "offset", "deq_scale1", "deq_scale2",
2502
+ "antiquant_scale1", "antiquant_scale2",
2503
+ "antiquant_offset1", "antiquant_offset2"],
2504
+ outputs=["y"])
2505
+ cls_name = self.name
2506
+ validator.check_value_type("activation", activation, [str], cls_name)
2507
+ validator.check_value_type("inner_precise", inner_precise, [int], cls_name)
2508
+
2509
+
2510
+ class _VirtualConverterEnd(PrimitiveWithInfer):
2511
+ """
2512
+ Auto parallel virtual operator.
2513
+ """
2514
+
2515
+ @prim_attr_register
2516
+ def __init__(self, input_nums):
2517
+ """Initialize _VirtualConverterEnd."""
2518
+ self.input_nums = input_nums
2519
+
2520
+ def infer_shape(self, *args):
2521
+ return (args[0][0] * self.input_nums,) + tuple(args[0][1:])
2522
+
2523
+ def infer_dtype(self, *args):
2524
+ return args[0]
2525
+
2526
+
2527
+ class _VirtualConverterBegin(PrimitiveWithInfer):
2528
+ """
2529
+ Auto parallel virtual operator.
2530
+ """
2531
+
2532
+ @prim_attr_register
2533
+ def __init__(self, output_nums):
2534
+ """Initialize _VirtualConverterBegin."""
2535
+ self.output_nums = output_nums
2536
+
2537
+ def infer_shape(self, arg):
2538
+ if self.output_nums == 0:
2539
+ return ValueError("output_nums can\'t be zero.")
2540
+ new_arg = (arg[0] / self.output_nums,) + tuple(arg[1:])
2541
+ return (new_arg,) * self.output_nums
2542
+
2543
+ def infer_dtype(self, arg):
2544
+ return (arg,) * self.output_nums