mindspore 2.4.0__cp311-cp311-macosx_10_15_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (1387) hide show
  1. mindspore/.commit_id +1 -0
  2. mindspore/__init__.py +53 -0
  3. mindspore/_c_dataengine.cpython-311-darwin.so +0 -0
  4. mindspore/_c_expression.cpython-311-darwin.so +0 -0
  5. mindspore/_c_mindrecord.cpython-311-darwin.so +0 -0
  6. mindspore/_check_jit_forbidden_api.py +106 -0
  7. mindspore/_checkparam.py +1419 -0
  8. mindspore/_extends/__init__.py +23 -0
  9. mindspore/_extends/builtin_operations.py +224 -0
  10. mindspore/_extends/graph_kernel/__init__.py +17 -0
  11. mindspore/_extends/graph_kernel/model/__init__.py +19 -0
  12. mindspore/_extends/graph_kernel/model/graph_parallel.py +311 -0
  13. mindspore/_extends/graph_kernel/model/graph_split.py +1348 -0
  14. mindspore/_extends/graph_kernel/model/model.py +553 -0
  15. mindspore/_extends/graph_kernel/model/model_builder.py +216 -0
  16. mindspore/_extends/graph_kernel/parallel_estimate.py +60 -0
  17. mindspore/_extends/graph_kernel/splitter.py +140 -0
  18. mindspore/_extends/graph_kernel/utils.py +28 -0
  19. mindspore/_extends/parallel_compile/__init__.py +19 -0
  20. mindspore/_extends/parallel_compile/akg_compiler/__init__.py +19 -0
  21. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +269 -0
  22. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +529 -0
  23. mindspore/_extends/parallel_compile/akg_compiler/compiler.py +56 -0
  24. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +96 -0
  25. mindspore/_extends/parallel_compile/akg_compiler/get_file_path.py +36 -0
  26. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +556 -0
  27. mindspore/_extends/parallel_compile/akg_compiler/util.py +159 -0
  28. mindspore/_extends/parse/__init__.py +49 -0
  29. mindspore/_extends/parse/compile_config.py +299 -0
  30. mindspore/_extends/parse/namespace.py +136 -0
  31. mindspore/_extends/parse/parser.py +1448 -0
  32. mindspore/_extends/parse/resources.py +213 -0
  33. mindspore/_extends/parse/standard_method.py +4475 -0
  34. mindspore/_extends/parse/trope.py +97 -0
  35. mindspore/_extends/pijit/__init__.py +23 -0
  36. mindspore/_extends/pijit/pijit_func_white_list.py +669 -0
  37. mindspore/_extends/remote/__init__.py +19 -0
  38. mindspore/_extends/remote/kernel_build_server.py +199 -0
  39. mindspore/_extends/remote/kernel_build_server_akg.py +55 -0
  40. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  41. mindspore/_extends/remote/kernel_build_server_ascend.py +75 -0
  42. mindspore/_extends/utils.py +68 -0
  43. mindspore/_install_custom.py +43 -0
  44. mindspore/_profiler.py +30 -0
  45. mindspore/amp.py +433 -0
  46. mindspore/boost/__init__.py +42 -0
  47. mindspore/boost/adasum.py +319 -0
  48. mindspore/boost/base.py +535 -0
  49. mindspore/boost/boost.py +400 -0
  50. mindspore/boost/boost_cell_wrapper.py +790 -0
  51. mindspore/boost/dim_reduce.py +323 -0
  52. mindspore/boost/grad_accumulation.py +79 -0
  53. mindspore/boost/grad_freeze.py +382 -0
  54. mindspore/boost/group_loss_scale_manager.py +166 -0
  55. mindspore/boost/less_batch_normalization.py +174 -0
  56. mindspore/common/__init__.py +86 -0
  57. mindspore/common/_auto_dynamic.py +68 -0
  58. mindspore/common/_decorator.py +50 -0
  59. mindspore/common/_jit_fallback_utils.py +110 -0
  60. mindspore/common/_monad.py +25 -0
  61. mindspore/common/_pijit_context.py +190 -0
  62. mindspore/common/_register_for_adapter.py +74 -0
  63. mindspore/common/_register_for_recompute.py +48 -0
  64. mindspore/common/_register_for_tensor.py +46 -0
  65. mindspore/common/_stub_tensor.py +210 -0
  66. mindspore/common/_tensor_overload.py +139 -0
  67. mindspore/common/_utils.py +122 -0
  68. mindspore/common/api.py +2064 -0
  69. mindspore/common/auto_dynamic_shape.py +507 -0
  70. mindspore/common/dtype.py +422 -0
  71. mindspore/common/dump.py +130 -0
  72. mindspore/common/file_system.py +48 -0
  73. mindspore/common/generator.py +254 -0
  74. mindspore/common/hook_handle.py +143 -0
  75. mindspore/common/initializer.py +880 -0
  76. mindspore/common/jit_config.py +98 -0
  77. mindspore/common/lazy_inline.py +240 -0
  78. mindspore/common/mindir_util.py +111 -0
  79. mindspore/common/mutable.py +234 -0
  80. mindspore/common/no_inline.py +54 -0
  81. mindspore/common/np_dtype.py +25 -0
  82. mindspore/common/parameter.py +1081 -0
  83. mindspore/common/recompute.py +292 -0
  84. mindspore/common/seed.py +260 -0
  85. mindspore/common/sparse_tensor.py +1175 -0
  86. mindspore/common/symbol.py +122 -0
  87. mindspore/common/tensor.py +5039 -0
  88. mindspore/communication/__init__.py +37 -0
  89. mindspore/communication/_comm_helper.py +501 -0
  90. mindspore/communication/_hccl_management.py +297 -0
  91. mindspore/communication/comm_func.py +1395 -0
  92. mindspore/communication/management.py +673 -0
  93. mindspore/config/op_info.config +533 -0
  94. mindspore/context.py +2077 -0
  95. mindspore/dataset/__init__.py +90 -0
  96. mindspore/dataset/audio/__init__.py +61 -0
  97. mindspore/dataset/audio/transforms.py +3690 -0
  98. mindspore/dataset/audio/utils.py +386 -0
  99. mindspore/dataset/audio/validators.py +1172 -0
  100. mindspore/dataset/callback/__init__.py +20 -0
  101. mindspore/dataset/callback/ds_callback.py +368 -0
  102. mindspore/dataset/callback/validators.py +32 -0
  103. mindspore/dataset/core/__init__.py +13 -0
  104. mindspore/dataset/core/config.py +1095 -0
  105. mindspore/dataset/core/datatypes.py +101 -0
  106. mindspore/dataset/core/py_util_helpers.py +65 -0
  107. mindspore/dataset/core/validator_helpers.py +781 -0
  108. mindspore/dataset/debug/__init__.py +21 -0
  109. mindspore/dataset/debug/debug_hook.py +97 -0
  110. mindspore/dataset/debug/pre_defined_hook.py +67 -0
  111. mindspore/dataset/engine/__init__.py +124 -0
  112. mindspore/dataset/engine/cache_admin.py +47 -0
  113. mindspore/dataset/engine/cache_client.py +129 -0
  114. mindspore/dataset/engine/datasets.py +4582 -0
  115. mindspore/dataset/engine/datasets_audio.py +911 -0
  116. mindspore/dataset/engine/datasets_standard_format.py +543 -0
  117. mindspore/dataset/engine/datasets_text.py +2161 -0
  118. mindspore/dataset/engine/datasets_user_defined.py +1184 -0
  119. mindspore/dataset/engine/datasets_vision.py +4816 -0
  120. mindspore/dataset/engine/iterators.py +371 -0
  121. mindspore/dataset/engine/obs/__init__.py +23 -0
  122. mindspore/dataset/engine/obs/config_loader.py +68 -0
  123. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +508 -0
  124. mindspore/dataset/engine/obs/util.py +482 -0
  125. mindspore/dataset/engine/offload.py +596 -0
  126. mindspore/dataset/engine/queue.py +304 -0
  127. mindspore/dataset/engine/samplers.py +895 -0
  128. mindspore/dataset/engine/serializer_deserializer.py +159 -0
  129. mindspore/dataset/engine/validators.py +2895 -0
  130. mindspore/dataset/text/__init__.py +51 -0
  131. mindspore/dataset/text/transforms.py +1703 -0
  132. mindspore/dataset/text/utils.py +715 -0
  133. mindspore/dataset/text/validators.py +642 -0
  134. mindspore/dataset/transforms/__init__.py +45 -0
  135. mindspore/dataset/transforms/c_transforms.py +638 -0
  136. mindspore/dataset/transforms/py_transforms.py +393 -0
  137. mindspore/dataset/transforms/py_transforms_util.py +255 -0
  138. mindspore/dataset/transforms/transforms.py +1260 -0
  139. mindspore/dataset/transforms/validators.py +410 -0
  140. mindspore/dataset/utils/__init__.py +19 -0
  141. mindspore/dataset/utils/browse_dataset.py +190 -0
  142. mindspore/dataset/utils/line_reader.py +126 -0
  143. mindspore/dataset/vision/__init__.py +65 -0
  144. mindspore/dataset/vision/c_transforms.py +2641 -0
  145. mindspore/dataset/vision/py_transforms.py +2120 -0
  146. mindspore/dataset/vision/py_transforms_util.py +1660 -0
  147. mindspore/dataset/vision/transforms.py +7295 -0
  148. mindspore/dataset/vision/utils.py +863 -0
  149. mindspore/dataset/vision/validators.py +1483 -0
  150. mindspore/default_config.py +2 -0
  151. mindspore/experimental/__init__.py +20 -0
  152. mindspore/experimental/es/__init__.py +22 -0
  153. mindspore/experimental/es/embedding_service.py +883 -0
  154. mindspore/experimental/es/embedding_service_layer.py +581 -0
  155. mindspore/experimental/llm_boost/__init__.py +21 -0
  156. mindspore/experimental/llm_boost/atb/__init__.py +23 -0
  157. mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
  158. mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
  159. mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
  160. mindspore/experimental/llm_boost/register.py +129 -0
  161. mindspore/experimental/llm_boost/utils.py +31 -0
  162. mindspore/experimental/map_parameter.py +309 -0
  163. mindspore/experimental/optim/__init__.py +40 -0
  164. mindspore/experimental/optim/adadelta.py +161 -0
  165. mindspore/experimental/optim/adagrad.py +168 -0
  166. mindspore/experimental/optim/adam.py +193 -0
  167. mindspore/experimental/optim/adamax.py +170 -0
  168. mindspore/experimental/optim/adamw.py +290 -0
  169. mindspore/experimental/optim/asgd.py +153 -0
  170. mindspore/experimental/optim/lr_scheduler.py +1371 -0
  171. mindspore/experimental/optim/nadam.py +157 -0
  172. mindspore/experimental/optim/optimizer.py +262 -0
  173. mindspore/experimental/optim/radam.py +194 -0
  174. mindspore/experimental/optim/rmsprop.py +154 -0
  175. mindspore/experimental/optim/rprop.py +164 -0
  176. mindspore/experimental/optim/sgd.py +156 -0
  177. mindspore/hal/__init__.py +40 -0
  178. mindspore/hal/_ascend.py +57 -0
  179. mindspore/hal/_base.py +57 -0
  180. mindspore/hal/_cpu.py +56 -0
  181. mindspore/hal/_gpu.py +57 -0
  182. mindspore/hal/contiguous_tensors_handle.py +175 -0
  183. mindspore/hal/device.py +356 -0
  184. mindspore/hal/event.py +179 -0
  185. mindspore/hal/memory.py +326 -0
  186. mindspore/hal/stream.py +357 -0
  187. mindspore/include/OWNERS +7 -0
  188. mindspore/include/api/allocator.h +97 -0
  189. mindspore/include/api/callback/callback.h +93 -0
  190. mindspore/include/api/callback/ckpt_saver.h +41 -0
  191. mindspore/include/api/callback/loss_monitor.h +33 -0
  192. mindspore/include/api/callback/lr_scheduler.h +51 -0
  193. mindspore/include/api/callback/time_monitor.h +34 -0
  194. mindspore/include/api/callback/train_accuracy.h +37 -0
  195. mindspore/include/api/cell.h +90 -0
  196. mindspore/include/api/cfg.h +82 -0
  197. mindspore/include/api/context.h +602 -0
  198. mindspore/include/api/data_type.h +47 -0
  199. mindspore/include/api/delegate.h +178 -0
  200. mindspore/include/api/delegate_api.h +75 -0
  201. mindspore/include/api/dual_abi_helper.h +208 -0
  202. mindspore/include/api/format.h +28 -0
  203. mindspore/include/api/graph.h +46 -0
  204. mindspore/include/api/kernel.h +58 -0
  205. mindspore/include/api/kernel_api.h +168 -0
  206. mindspore/include/api/metrics/accuracy.h +36 -0
  207. mindspore/include/api/metrics/metrics.h +41 -0
  208. mindspore/include/api/model.h +438 -0
  209. mindspore/include/api/model_group.h +91 -0
  210. mindspore/include/api/model_parallel_runner.h +168 -0
  211. mindspore/include/api/serialization.h +185 -0
  212. mindspore/include/api/status.h +192 -0
  213. mindspore/include/api/types.h +431 -0
  214. mindspore/include/api/visible.h +41 -0
  215. mindspore/include/c_api/context_c.h +179 -0
  216. mindspore/include/c_api/data_type_c.h +52 -0
  217. mindspore/include/c_api/format_c.h +46 -0
  218. mindspore/include/c_api/model_c.h +347 -0
  219. mindspore/include/c_api/status_c.h +79 -0
  220. mindspore/include/c_api/tensor_c.h +146 -0
  221. mindspore/include/c_api/types_c.h +67 -0
  222. mindspore/include/dataset/config.h +163 -0
  223. mindspore/include/dataset/constants.h +363 -0
  224. mindspore/include/dataset/execute.h +196 -0
  225. mindspore/include/dataset/text.h +1092 -0
  226. mindspore/include/dataset/transforms.h +638 -0
  227. mindspore/include/dataset/vision.h +2129 -0
  228. mindspore/include/dataset/vision_ascend.h +206 -0
  229. mindspore/include/dataset/vision_lite.h +625 -0
  230. mindspore/lib/libavcodec.59.dylib +0 -0
  231. mindspore/lib/libavdevice.59.dylib +0 -0
  232. mindspore/lib/libavfilter.8.dylib +0 -0
  233. mindspore/lib/libavformat.59.dylib +0 -0
  234. mindspore/lib/libavutil.57.dylib +0 -0
  235. mindspore/lib/libdnnl.2.dylib +0 -0
  236. mindspore/lib/libicudata.69.dylib +0 -0
  237. mindspore/lib/libicui18n.69.dylib +0 -0
  238. mindspore/lib/libicuuc.69.dylib +0 -0
  239. mindspore/lib/libmindspore_address_sorting.15.dylib +0 -0
  240. mindspore/lib/libmindspore_backend.dylib +0 -0
  241. mindspore/lib/libmindspore_common.dylib +0 -0
  242. mindspore/lib/libmindspore_core.dylib +0 -0
  243. mindspore/lib/libmindspore_glog.0.dylib +0 -0
  244. mindspore/lib/libmindspore_gpr.15.dylib +0 -0
  245. mindspore/lib/libmindspore_grpc++.1.dylib +0 -0
  246. mindspore/lib/libmindspore_grpc.15.dylib +0 -0
  247. mindspore/lib/libmindspore_np_dtype.dylib +0 -0
  248. mindspore/lib/libmindspore_ops.dylib +0 -0
  249. mindspore/lib/libmindspore_upb.15.dylib +0 -0
  250. mindspore/lib/libnnacl.dylib +0 -0
  251. mindspore/lib/libopencv_core.4.5.dylib +0 -0
  252. mindspore/lib/libopencv_imgcodecs.4.5.dylib +0 -0
  253. mindspore/lib/libopencv_imgproc.4.5.dylib +0 -0
  254. mindspore/lib/libps_cache.dylib +0 -0
  255. mindspore/lib/libswresample.4.dylib +0 -0
  256. mindspore/lib/libswscale.6.dylib +0 -0
  257. mindspore/lib/libtinyxml2.8.dylib +0 -0
  258. mindspore/log.py +633 -0
  259. mindspore/mindrecord/__init__.py +43 -0
  260. mindspore/mindrecord/common/__init__.py +17 -0
  261. mindspore/mindrecord/common/constant.py +20 -0
  262. mindspore/mindrecord/common/enums.py +44 -0
  263. mindspore/mindrecord/common/exceptions.py +311 -0
  264. mindspore/mindrecord/config.py +809 -0
  265. mindspore/mindrecord/filereader.py +174 -0
  266. mindspore/mindrecord/filewriter.py +722 -0
  267. mindspore/mindrecord/mindpage.py +210 -0
  268. mindspore/mindrecord/shardheader.py +141 -0
  269. mindspore/mindrecord/shardindexgenerator.py +74 -0
  270. mindspore/mindrecord/shardreader.py +117 -0
  271. mindspore/mindrecord/shardsegment.py +128 -0
  272. mindspore/mindrecord/shardutils.py +185 -0
  273. mindspore/mindrecord/shardwriter.py +237 -0
  274. mindspore/mindrecord/tools/__init__.py +17 -0
  275. mindspore/mindrecord/tools/cifar10.py +140 -0
  276. mindspore/mindrecord/tools/cifar100.py +153 -0
  277. mindspore/mindrecord/tools/cifar100_to_mr.py +185 -0
  278. mindspore/mindrecord/tools/cifar10_to_mr.py +177 -0
  279. mindspore/mindrecord/tools/csv_to_mr.py +200 -0
  280. mindspore/mindrecord/tools/imagenet_to_mr.py +206 -0
  281. mindspore/mindrecord/tools/mnist_to_mr.py +259 -0
  282. mindspore/mindrecord/tools/tfrecord_to_mr.py +360 -0
  283. mindspore/mint/__init__.py +1586 -0
  284. mindspore/mint/distributed/__init__.py +31 -0
  285. mindspore/mint/distributed/distributed.py +254 -0
  286. mindspore/mint/linalg/__init__.py +22 -0
  287. mindspore/mint/nn/__init__.py +757 -0
  288. mindspore/mint/nn/functional.py +679 -0
  289. mindspore/mint/nn/layer/__init__.py +39 -0
  290. mindspore/mint/nn/layer/activation.py +133 -0
  291. mindspore/mint/nn/layer/normalization.py +477 -0
  292. mindspore/mint/nn/layer/pooling.py +110 -0
  293. mindspore/mint/optim/__init__.py +24 -0
  294. mindspore/mint/optim/adamw.py +206 -0
  295. mindspore/mint/special/__init__.py +63 -0
  296. mindspore/multiprocessing/__init__.py +73 -0
  297. mindspore/nn/__init__.py +47 -0
  298. mindspore/nn/cell.py +2787 -0
  299. mindspore/nn/dynamic_lr.py +482 -0
  300. mindspore/nn/grad/__init__.py +21 -0
  301. mindspore/nn/grad/cell_grad.py +196 -0
  302. mindspore/nn/layer/__init__.py +63 -0
  303. mindspore/nn/layer/activation.py +1822 -0
  304. mindspore/nn/layer/basic.py +1629 -0
  305. mindspore/nn/layer/channel_shuffle.py +90 -0
  306. mindspore/nn/layer/combined.py +248 -0
  307. mindspore/nn/layer/container.py +734 -0
  308. mindspore/nn/layer/conv.py +1505 -0
  309. mindspore/nn/layer/dense.py +204 -0
  310. mindspore/nn/layer/embedding.py +869 -0
  311. mindspore/nn/layer/image.py +661 -0
  312. mindspore/nn/layer/math.py +1069 -0
  313. mindspore/nn/layer/normalization.py +1273 -0
  314. mindspore/nn/layer/padding.py +880 -0
  315. mindspore/nn/layer/pooling.py +2302 -0
  316. mindspore/nn/layer/rnn_cells.py +388 -0
  317. mindspore/nn/layer/rnns.py +849 -0
  318. mindspore/nn/layer/thor_layer.py +963 -0
  319. mindspore/nn/layer/timedistributed.py +155 -0
  320. mindspore/nn/layer/transformer.py +823 -0
  321. mindspore/nn/learning_rate_schedule.py +512 -0
  322. mindspore/nn/loss/__init__.py +36 -0
  323. mindspore/nn/loss/loss.py +2924 -0
  324. mindspore/nn/metrics.py +53 -0
  325. mindspore/nn/optim/__init__.py +45 -0
  326. mindspore/nn/optim/_dist_optimizer_registry.py +111 -0
  327. mindspore/nn/optim/ada_grad.py +217 -0
  328. mindspore/nn/optim/adadelta.py +206 -0
  329. mindspore/nn/optim/adafactor.py +448 -0
  330. mindspore/nn/optim/adam.py +1297 -0
  331. mindspore/nn/optim/adamax.py +220 -0
  332. mindspore/nn/optim/adasum.py +548 -0
  333. mindspore/nn/optim/asgd.py +216 -0
  334. mindspore/nn/optim/ftrl.py +401 -0
  335. mindspore/nn/optim/lamb.py +296 -0
  336. mindspore/nn/optim/lars.py +202 -0
  337. mindspore/nn/optim/lazyadam.py +533 -0
  338. mindspore/nn/optim/momentum.py +239 -0
  339. mindspore/nn/optim/optimizer.py +1034 -0
  340. mindspore/nn/optim/proximal_ada_grad.py +242 -0
  341. mindspore/nn/optim/rmsprop.py +264 -0
  342. mindspore/nn/optim/rprop.py +251 -0
  343. mindspore/nn/optim/sgd.py +237 -0
  344. mindspore/nn/optim/tft_wrapper.py +127 -0
  345. mindspore/nn/optim/thor.py +1310 -0
  346. mindspore/nn/probability/__init__.py +22 -0
  347. mindspore/nn/probability/bijector/__init__.py +35 -0
  348. mindspore/nn/probability/bijector/bijector.py +337 -0
  349. mindspore/nn/probability/bijector/exp.py +65 -0
  350. mindspore/nn/probability/bijector/gumbel_cdf.py +144 -0
  351. mindspore/nn/probability/bijector/invert.py +126 -0
  352. mindspore/nn/probability/bijector/power_transform.py +196 -0
  353. mindspore/nn/probability/bijector/scalar_affine.py +167 -0
  354. mindspore/nn/probability/bijector/softplus.py +189 -0
  355. mindspore/nn/probability/bnn_layers/__init__.py +29 -0
  356. mindspore/nn/probability/bnn_layers/_util.py +46 -0
  357. mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +112 -0
  358. mindspore/nn/probability/bnn_layers/conv_variational.py +267 -0
  359. mindspore/nn/probability/bnn_layers/dense_variational.py +302 -0
  360. mindspore/nn/probability/bnn_layers/layer_distribution.py +123 -0
  361. mindspore/nn/probability/distribution/__init__.py +56 -0
  362. mindspore/nn/probability/distribution/_utils/__init__.py +34 -0
  363. mindspore/nn/probability/distribution/_utils/custom_ops.py +96 -0
  364. mindspore/nn/probability/distribution/_utils/utils.py +362 -0
  365. mindspore/nn/probability/distribution/bernoulli.py +334 -0
  366. mindspore/nn/probability/distribution/beta.py +391 -0
  367. mindspore/nn/probability/distribution/categorical.py +435 -0
  368. mindspore/nn/probability/distribution/cauchy.py +383 -0
  369. mindspore/nn/probability/distribution/distribution.py +827 -0
  370. mindspore/nn/probability/distribution/exponential.py +350 -0
  371. mindspore/nn/probability/distribution/gamma.py +391 -0
  372. mindspore/nn/probability/distribution/geometric.py +335 -0
  373. mindspore/nn/probability/distribution/gumbel.py +257 -0
  374. mindspore/nn/probability/distribution/half_normal.py +133 -0
  375. mindspore/nn/probability/distribution/laplace.py +128 -0
  376. mindspore/nn/probability/distribution/log_normal.py +272 -0
  377. mindspore/nn/probability/distribution/logistic.py +379 -0
  378. mindspore/nn/probability/distribution/normal.py +336 -0
  379. mindspore/nn/probability/distribution/poisson.py +288 -0
  380. mindspore/nn/probability/distribution/student_t.py +149 -0
  381. mindspore/nn/probability/distribution/transformed_distribution.py +235 -0
  382. mindspore/nn/probability/distribution/uniform.py +375 -0
  383. mindspore/nn/reinforcement/__init__.py +24 -0
  384. mindspore/nn/reinforcement/_batch_read_write.py +142 -0
  385. mindspore/nn/reinforcement/_tensors_queue.py +152 -0
  386. mindspore/nn/reinforcement/tensor_array.py +145 -0
  387. mindspore/nn/sparse/__init__.py +23 -0
  388. mindspore/nn/sparse/sparse.py +147 -0
  389. mindspore/nn/wrap/__init__.py +49 -0
  390. mindspore/nn/wrap/cell_wrapper.py +968 -0
  391. mindspore/nn/wrap/grad_reducer.py +608 -0
  392. mindspore/nn/wrap/loss_scale.py +694 -0
  393. mindspore/numpy/__init__.py +121 -0
  394. mindspore/numpy/array_creations.py +2731 -0
  395. mindspore/numpy/array_ops.py +2629 -0
  396. mindspore/numpy/dtypes.py +185 -0
  397. mindspore/numpy/fft.py +966 -0
  398. mindspore/numpy/logic_ops.py +936 -0
  399. mindspore/numpy/math_ops.py +5911 -0
  400. mindspore/numpy/utils.py +214 -0
  401. mindspore/numpy/utils_const.py +565 -0
  402. mindspore/ops/__init__.py +56 -0
  403. mindspore/ops/_constants.py +30 -0
  404. mindspore/ops/_grad_experimental/__init__.py +31 -0
  405. mindspore/ops/_grad_experimental/grad_array_ops.py +830 -0
  406. mindspore/ops/_grad_experimental/grad_base.py +143 -0
  407. mindspore/ops/_grad_experimental/grad_comm_ops.py +714 -0
  408. mindspore/ops/_grad_experimental/grad_debug_ops.py +31 -0
  409. mindspore/ops/_grad_experimental/grad_implementations.py +203 -0
  410. mindspore/ops/_grad_experimental/grad_inner_ops.py +79 -0
  411. mindspore/ops/_grad_experimental/grad_math_ops.py +802 -0
  412. mindspore/ops/_grad_experimental/grad_nn_ops.py +231 -0
  413. mindspore/ops/_grad_experimental/grad_quant_ops.py +238 -0
  414. mindspore/ops/_grad_experimental/grad_sparse.py +342 -0
  415. mindspore/ops/_grad_experimental/grad_sparse_ops.py +399 -0
  416. mindspore/ops/_grad_experimental/taylor_rule.py +220 -0
  417. mindspore/ops/_op_impl/__init__.py +23 -0
  418. mindspore/ops/_op_impl/_custom_op/__init__.py +39 -0
  419. mindspore/ops/_op_impl/_custom_op/_basic.py +158 -0
  420. mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +279 -0
  421. mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +156 -0
  422. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +109 -0
  423. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +125 -0
  424. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +105 -0
  425. mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +124 -0
  426. mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +116 -0
  427. mindspore/ops/_op_impl/_custom_op/correction_mul.py +89 -0
  428. mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +196 -0
  429. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +366 -0
  430. mindspore/ops/_op_impl/_custom_op/dsd_impl.py +162 -0
  431. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +136 -0
  432. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +206 -0
  433. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +88 -0
  434. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +128 -0
  435. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +199 -0
  436. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +88 -0
  437. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +156 -0
  438. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +184 -0
  439. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +143 -0
  440. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +169 -0
  441. mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +548 -0
  442. mindspore/ops/_op_impl/_custom_op/img2col_impl.py +881 -0
  443. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +278 -0
  444. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +200 -0
  445. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +334 -0
  446. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +255 -0
  447. mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +222 -0
  448. mindspore/ops/_op_impl/_custom_op/matmul_dds_grad_impl.py +644 -0
  449. mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +488 -0
  450. mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +87 -0
  451. mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +129 -0
  452. mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +121 -0
  453. mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +352 -0
  454. mindspore/ops/_op_impl/aicpu/__init__.py +441 -0
  455. mindspore/ops/_op_impl/aicpu/abs.py +36 -0
  456. mindspore/ops/_op_impl/aicpu/acos.py +32 -0
  457. mindspore/ops/_op_impl/aicpu/acos_grad.py +33 -0
  458. mindspore/ops/_op_impl/aicpu/acosh.py +34 -0
  459. mindspore/ops/_op_impl/aicpu/acosh_grad.py +35 -0
  460. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d.py +34 -0
  461. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
  462. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d.py +39 -0
  463. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d_grad.py +39 -0
  464. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
  465. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d_grad.py +37 -0
  466. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d.py +42 -0
  467. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d_grad.py +152 -0
  468. mindspore/ops/_op_impl/aicpu/add.py +43 -0
  469. mindspore/ops/_op_impl/aicpu/add_n.py +41 -0
  470. mindspore/ops/_op_impl/aicpu/add_v2.py +40 -0
  471. mindspore/ops/_op_impl/aicpu/addcdiv.py +41 -0
  472. mindspore/ops/_op_impl/aicpu/addcmul.py +47 -0
  473. mindspore/ops/_op_impl/aicpu/adjust_contrastv2.py +32 -0
  474. mindspore/ops/_op_impl/aicpu/adjust_hue.py +31 -0
  475. mindspore/ops/_op_impl/aicpu/adjust_saturation.py +32 -0
  476. mindspore/ops/_op_impl/aicpu/affine_grid.py +33 -0
  477. mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
  478. mindspore/ops/_op_impl/aicpu/angle.py +31 -0
  479. mindspore/ops/_op_impl/aicpu/arg_max.py +75 -0
  480. mindspore/ops/_op_impl/aicpu/arg_min.py +75 -0
  481. mindspore/ops/_op_impl/aicpu/argmax_with_value.py +43 -0
  482. mindspore/ops/_op_impl/aicpu/argmin_with_value.py +43 -0
  483. mindspore/ops/_op_impl/aicpu/asin.py +32 -0
  484. mindspore/ops/_op_impl/aicpu/asin_grad.py +33 -0
  485. mindspore/ops/_op_impl/aicpu/asinh.py +34 -0
  486. mindspore/ops/_op_impl/aicpu/asinh_grad.py +35 -0
  487. mindspore/ops/_op_impl/aicpu/atanh.py +34 -0
  488. mindspore/ops/_op_impl/aicpu/avgpool_grad_v1.py +37 -0
  489. mindspore/ops/_op_impl/aicpu/avgpool_v1.py +36 -0
  490. mindspore/ops/_op_impl/aicpu/bartlett_window.py +36 -0
  491. mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -0
  492. mindspore/ops/_op_impl/aicpu/batch_norm_grad_grad.py +49 -0
  493. mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
  494. mindspore/ops/_op_impl/aicpu/bessel_i0.py +31 -0
  495. mindspore/ops/_op_impl/aicpu/betainc.py +31 -0
  496. mindspore/ops/_op_impl/aicpu/bias_add.py +44 -0
  497. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +42 -0
  498. mindspore/ops/_op_impl/aicpu/bincount.py +33 -0
  499. mindspore/ops/_op_impl/aicpu/blackman_window.py +36 -0
  500. mindspore/ops/_op_impl/aicpu/broadcast_to.py +58 -0
  501. mindspore/ops/_op_impl/aicpu/bucketize.py +34 -0
  502. mindspore/ops/_op_impl/aicpu/cache_swap_table.py +102 -0
  503. mindspore/ops/_op_impl/aicpu/cast.py +225 -0
  504. mindspore/ops/_op_impl/aicpu/cauchy.py +33 -0
  505. mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
  506. mindspore/ops/_op_impl/aicpu/check_numerics.py +33 -0
  507. mindspore/ops/_op_impl/aicpu/cholesky.py +32 -0
  508. mindspore/ops/_op_impl/aicpu/cholesky_inverse.py +31 -0
  509. mindspore/ops/_op_impl/aicpu/cholesky_solve.py +33 -0
  510. mindspore/ops/_op_impl/aicpu/choleskygrad.py +32 -0
  511. mindspore/ops/_op_impl/aicpu/coalesce.py +37 -0
  512. mindspore/ops/_op_impl/aicpu/col2im.py +38 -0
  513. mindspore/ops/_op_impl/aicpu/combined_non_max_suppression.py +42 -0
  514. mindspore/ops/_op_impl/aicpu/compare_and_bitpack.py +37 -0
  515. mindspore/ops/_op_impl/aicpu/complex.py +32 -0
  516. mindspore/ops/_op_impl/aicpu/complex_abs.py +31 -0
  517. mindspore/ops/_op_impl/aicpu/compute_accidental_hits.py +44 -0
  518. mindspore/ops/_op_impl/aicpu/concat.py +57 -0
  519. mindspore/ops/_op_impl/aicpu/concat_offset.py +42 -0
  520. mindspore/ops/_op_impl/aicpu/concat_offset_v1.py +31 -0
  521. mindspore/ops/_op_impl/aicpu/conj.py +42 -0
  522. mindspore/ops/_op_impl/aicpu/conjugate_transpose.py +58 -0
  523. mindspore/ops/_op_impl/aicpu/cos.py +34 -0
  524. mindspore/ops/_op_impl/aicpu/cosh.py +34 -0
  525. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  526. mindspore/ops/_op_impl/aicpu/crop_and_resize.py +69 -0
  527. mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_boxes.py +68 -0
  528. mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_image.py +38 -0
  529. mindspore/ops/_op_impl/aicpu/cross.py +42 -0
  530. mindspore/ops/_op_impl/aicpu/csr_sparse_matrix_to_dense.py +48 -0
  531. mindspore/ops/_op_impl/aicpu/csr_sparse_matrix_to_sparse_tensor.py +51 -0
  532. mindspore/ops/_op_impl/aicpu/ctc_greedy_decoder.py +35 -0
  533. mindspore/ops/_op_impl/aicpu/ctc_loss_v2.py +43 -0
  534. mindspore/ops/_op_impl/aicpu/ctc_loss_v2_grad.py +45 -0
  535. mindspore/ops/_op_impl/aicpu/ctcloss.py +38 -0
  536. mindspore/ops/_op_impl/aicpu/cummax.py +41 -0
  537. mindspore/ops/_op_impl/aicpu/cumprod.py +58 -0
  538. mindspore/ops/_op_impl/aicpu/cumsum.py +58 -0
  539. mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +36 -0
  540. mindspore/ops/_op_impl/aicpu/data_format_vec_permute.py +32 -0
  541. mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
  542. mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +43 -0
  543. mindspore/ops/_op_impl/aicpu/dense_to_csr_sparse_matrix.py +49 -0
  544. mindspore/ops/_op_impl/aicpu/dense_to_dense_set_operation.py +45 -0
  545. mindspore/ops/_op_impl/aicpu/dense_to_sparse_set_operation.py +48 -0
  546. mindspore/ops/_op_impl/aicpu/depth_to_space.py +44 -0
  547. mindspore/ops/_op_impl/aicpu/diag.py +36 -0
  548. mindspore/ops/_op_impl/aicpu/diag_part.py +36 -0
  549. mindspore/ops/_op_impl/aicpu/diagonal.py +35 -0
  550. mindspore/ops/_op_impl/aicpu/digamma.py +31 -0
  551. mindspore/ops/_op_impl/aicpu/div.py +41 -0
  552. mindspore/ops/_op_impl/aicpu/div_no_nan.py +35 -0
  553. mindspore/ops/_op_impl/aicpu/dropout2d.py +42 -0
  554. mindspore/ops/_op_impl/aicpu/dropout3d.py +42 -0
  555. mindspore/ops/_op_impl/aicpu/dropout_genmask.py +41 -0
  556. mindspore/ops/_op_impl/aicpu/dropout_genmask_v3.py +32 -0
  557. mindspore/ops/_op_impl/aicpu/dynamic_stitch.py +42 -0
  558. mindspore/ops/_op_impl/aicpu/edit_distance.py +56 -0
  559. mindspore/ops/_op_impl/aicpu/eig.py +35 -0
  560. mindspore/ops/_op_impl/aicpu/embedding_lookup.py +102 -0
  561. mindspore/ops/_op_impl/aicpu/end_of_sequence.py +30 -0
  562. mindspore/ops/_op_impl/aicpu/environ_create.py +28 -0
  563. mindspore/ops/_op_impl/aicpu/environ_destroy_all.py +28 -0
  564. mindspore/ops/_op_impl/aicpu/environ_get.py +41 -0
  565. mindspore/ops/_op_impl/aicpu/environ_set.py +40 -0
  566. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  567. mindspore/ops/_op_impl/aicpu/equal.py +41 -0
  568. mindspore/ops/_op_impl/aicpu/exp.py +37 -0
  569. mindspore/ops/_op_impl/aicpu/expand.py +45 -0
  570. mindspore/ops/_op_impl/aicpu/expand_dims.py +42 -0
  571. mindspore/ops/_op_impl/aicpu/expm1.py +34 -0
  572. mindspore/ops/_op_impl/aicpu/extract_glimpse.py +35 -0
  573. mindspore/ops/_op_impl/aicpu/eye.py +44 -0
  574. mindspore/ops/_op_impl/aicpu/fft_with_size.py +47 -0
  575. mindspore/ops/_op_impl/aicpu/fill_diagonal.py +39 -0
  576. mindspore/ops/_op_impl/aicpu/fill_v2.py +58 -0
  577. mindspore/ops/_op_impl/aicpu/flatten.py +43 -0
  578. mindspore/ops/_op_impl/aicpu/floor_div.py +38 -0
  579. mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
  580. mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
  581. mindspore/ops/_op_impl/aicpu/fractional_avg_pool.py +41 -0
  582. mindspore/ops/_op_impl/aicpu/fractional_avg_pool_grad.py +41 -0
  583. mindspore/ops/_op_impl/aicpu/fractional_max_pool.py +41 -0
  584. mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_grad_with_fixed_ksize.py +43 -0
  585. mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +65 -0
  586. mindspore/ops/_op_impl/aicpu/fractional_max_pool_grad.py +42 -0
  587. mindspore/ops/_op_impl/aicpu/fractional_max_pool_grad_with_fixed_ksize.py +42 -0
  588. mindspore/ops/_op_impl/aicpu/fractional_max_pool_with_fixed_ksize.py +49 -0
  589. mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
  590. mindspore/ops/_op_impl/aicpu/fused_sparse_adam.py +46 -0
  591. mindspore/ops/_op_impl/aicpu/fused_sparse_ftrl.py +41 -0
  592. mindspore/ops/_op_impl/aicpu/fused_sparse_lazy_adam.py +46 -0
  593. mindspore/ops/_op_impl/aicpu/fused_sparse_proximal_adagrad.py +39 -0
  594. mindspore/ops/_op_impl/aicpu/gamma.py +38 -0
  595. mindspore/ops/_op_impl/aicpu/gather.py +46 -0
  596. mindspore/ops/_op_impl/aicpu/gather_d.py +79 -0
  597. mindspore/ops/_op_impl/aicpu/gather_d_grad_v2.py +79 -0
  598. mindspore/ops/_op_impl/aicpu/gather_grad.py +54 -0
  599. mindspore/ops/_op_impl/aicpu/gather_nd.py +56 -0
  600. mindspore/ops/_op_impl/aicpu/gcd.py +32 -0
  601. mindspore/ops/_op_impl/aicpu/generate_eod_mask.py +38 -0
  602. mindspore/ops/_op_impl/aicpu/geqrf.py +32 -0
  603. mindspore/ops/_op_impl/aicpu/get_next.py +39 -0
  604. mindspore/ops/_op_impl/aicpu/glu.py +33 -0
  605. mindspore/ops/_op_impl/aicpu/glu_grad.py +34 -0
  606. mindspore/ops/_op_impl/aicpu/greater.py +41 -0
  607. mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
  608. mindspore/ops/_op_impl/aicpu/grid_sampler_2d.py +35 -0
  609. mindspore/ops/_op_impl/aicpu/grid_sampler_2d_grad.py +38 -0
  610. mindspore/ops/_op_impl/aicpu/grid_sampler_3d.py +34 -0
  611. mindspore/ops/_op_impl/aicpu/grid_sampler_3d_grad.py +38 -0
  612. mindspore/ops/_op_impl/aicpu/hamming_window.py +57 -0
  613. mindspore/ops/_op_impl/aicpu/hard_sigmoid.py +32 -0
  614. mindspore/ops/_op_impl/aicpu/hard_sigmoid_grad.py +33 -0
  615. mindspore/ops/_op_impl/aicpu/heaviside.py +40 -0
  616. mindspore/ops/_op_impl/aicpu/histogram.py +35 -0
  617. mindspore/ops/_op_impl/aicpu/hsv_to_rgb.py +32 -0
  618. mindspore/ops/_op_impl/aicpu/hypot.py +32 -0
  619. mindspore/ops/_op_impl/aicpu/identity.py +42 -0
  620. mindspore/ops/_op_impl/aicpu/identity_n.py +41 -0
  621. mindspore/ops/_op_impl/aicpu/igamma.py +30 -0
  622. mindspore/ops/_op_impl/aicpu/igammac.py +30 -0
  623. mindspore/ops/_op_impl/aicpu/igammagrada.py +30 -0
  624. mindspore/ops/_op_impl/aicpu/im2col.py +43 -0
  625. mindspore/ops/_op_impl/aicpu/imag.py +31 -0
  626. mindspore/ops/_op_impl/aicpu/index_fill.py +54 -0
  627. mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
  628. mindspore/ops/_op_impl/aicpu/init_data_set_queue.py +27 -0
  629. mindspore/ops/_op_impl/aicpu/inplace_index_add.py +39 -0
  630. mindspore/ops/_op_impl/aicpu/instance_norm_v2.py +41 -0
  631. mindspore/ops/_op_impl/aicpu/instance_norm_v2_grad.py +44 -0
  632. mindspore/ops/_op_impl/aicpu/is_finite.py +40 -0
  633. mindspore/ops/_op_impl/aicpu/is_inf.py +31 -0
  634. mindspore/ops/_op_impl/aicpu/is_nan.py +31 -0
  635. mindspore/ops/_op_impl/aicpu/kldivloss.py +34 -0
  636. mindspore/ops/_op_impl/aicpu/kldivlossgrad.py +35 -0
  637. mindspore/ops/_op_impl/aicpu/layer_norm_grad_grad.py +47 -0
  638. mindspore/ops/_op_impl/aicpu/lcm.py +32 -0
  639. mindspore/ops/_op_impl/aicpu/left_shift.py +38 -0
  640. mindspore/ops/_op_impl/aicpu/less.py +41 -0
  641. mindspore/ops/_op_impl/aicpu/less_equal.py +41 -0
  642. mindspore/ops/_op_impl/aicpu/lgamma.py +33 -0
  643. mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +57 -0
  644. mindspore/ops/_op_impl/aicpu/linspace.py +33 -0
  645. mindspore/ops/_op_impl/aicpu/list_diff.py +50 -0
  646. mindspore/ops/_op_impl/aicpu/log.py +37 -0
  647. mindspore/ops/_op_impl/aicpu/log1p.py +34 -0
  648. mindspore/ops/_op_impl/aicpu/log_matrix_determinant.py +31 -0
  649. mindspore/ops/_op_impl/aicpu/log_normal_reverse.py +33 -0
  650. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +37 -0
  651. mindspore/ops/_op_impl/aicpu/logical_xor.py +30 -0
  652. mindspore/ops/_op_impl/aicpu/logit.py +33 -0
  653. mindspore/ops/_op_impl/aicpu/logit_grad.py +34 -0
  654. mindspore/ops/_op_impl/aicpu/logspace.py +36 -0
  655. mindspore/ops/_op_impl/aicpu/lower_bound.py +47 -0
  656. mindspore/ops/_op_impl/aicpu/lstsq.py +34 -0
  657. mindspore/ops/_op_impl/aicpu/lu.py +39 -0
  658. mindspore/ops/_op_impl/aicpu/lu_solve.py +32 -0
  659. mindspore/ops/_op_impl/aicpu/lu_unpack.py +114 -0
  660. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +49 -0
  661. mindspore/ops/_op_impl/aicpu/masked_fill.py +42 -0
  662. mindspore/ops/_op_impl/aicpu/masked_scatter.py +40 -0
  663. mindspore/ops/_op_impl/aicpu/masked_select.py +31 -0
  664. mindspore/ops/_op_impl/aicpu/masked_select_grad.py +35 -0
  665. mindspore/ops/_op_impl/aicpu/matmul.py +39 -0
  666. mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
  667. mindspore/ops/_op_impl/aicpu/matrix_determinant.py +30 -0
  668. mindspore/ops/_op_impl/aicpu/matrix_diag_part_v3.py +54 -0
  669. mindspore/ops/_op_impl/aicpu/matrix_diag_v3.py +56 -0
  670. mindspore/ops/_op_impl/aicpu/matrix_exp.py +34 -0
  671. mindspore/ops/_op_impl/aicpu/matrix_inverse.py +31 -0
  672. mindspore/ops/_op_impl/aicpu/matrix_logarithm.py +31 -0
  673. mindspore/ops/_op_impl/aicpu/matrix_power.py +37 -0
  674. mindspore/ops/_op_impl/aicpu/matrix_set_diag_v3.py +54 -0
  675. mindspore/ops/_op_impl/aicpu/matrix_solve.py +35 -0
  676. mindspore/ops/_op_impl/aicpu/matrix_solve_ls.py +36 -0
  677. mindspore/ops/_op_impl/aicpu/matrix_triangular_solve.py +36 -0
  678. mindspore/ops/_op_impl/aicpu/max_pool3d_grad_with_argmax.py +60 -0
  679. mindspore/ops/_op_impl/aicpu/max_pool3d_with_argmax.py +59 -0
  680. mindspore/ops/_op_impl/aicpu/max_unpool2d.py +57 -0
  681. mindspore/ops/_op_impl/aicpu/max_unpool2d_grad.py +58 -0
  682. mindspore/ops/_op_impl/aicpu/max_unpool3d.py +57 -0
  683. mindspore/ops/_op_impl/aicpu/max_unpool3d_grad.py +58 -0
  684. mindspore/ops/_op_impl/aicpu/maximum_grad_grad.py +40 -0
  685. mindspore/ops/_op_impl/aicpu/maxpool_grad_v1.py +46 -0
  686. mindspore/ops/_op_impl/aicpu/maxpool_v1.py +42 -0
  687. mindspore/ops/_op_impl/aicpu/median.py +39 -0
  688. mindspore/ops/_op_impl/aicpu/median_grad.py +45 -0
  689. mindspore/ops/_op_impl/aicpu/meshgrid.py +41 -0
  690. mindspore/ops/_op_impl/aicpu/minimum_grad_grad.py +40 -0
  691. mindspore/ops/_op_impl/aicpu/mirror_pad.py +50 -0
  692. mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +48 -0
  693. mindspore/ops/_op_impl/aicpu/mul.py +43 -0
  694. mindspore/ops/_op_impl/aicpu/mul_no_nan.py +42 -0
  695. mindspore/ops/_op_impl/aicpu/multi_margin_loss.py +37 -0
  696. mindspore/ops/_op_impl/aicpu/multi_margin_loss_grad.py +41 -0
  697. mindspore/ops/_op_impl/aicpu/multilabel_margin_loss_grad.py +37 -0
  698. mindspore/ops/_op_impl/aicpu/multinomial.py +47 -0
  699. mindspore/ops/_op_impl/aicpu/multinomial_with_replacement.py +35 -0
  700. mindspore/ops/_op_impl/aicpu/mvlgamma.py +32 -0
  701. mindspore/ops/_op_impl/aicpu/mvlgamma_grad.py +33 -0
  702. mindspore/ops/_op_impl/aicpu/nan_to_num.py +34 -0
  703. mindspore/ops/_op_impl/aicpu/neg.py +36 -0
  704. mindspore/ops/_op_impl/aicpu/nextafter.py +32 -0
  705. mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
  706. mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
  707. mindspore/ops/_op_impl/aicpu/no_repeat_ngram.py +34 -0
  708. mindspore/ops/_op_impl/aicpu/non_deterministic_ints.py +33 -0
  709. mindspore/ops/_op_impl/aicpu/non_max_suppression.py +36 -0
  710. mindspore/ops/_op_impl/aicpu/non_max_suppression_with_overlaps.py +35 -0
  711. mindspore/ops/_op_impl/aicpu/non_zero.py +43 -0
  712. mindspore/ops/_op_impl/aicpu/not_equal.py +39 -0
  713. mindspore/ops/_op_impl/aicpu/nth_element.py +39 -0
  714. mindspore/ops/_op_impl/aicpu/nuclear_norm.py +33 -0
  715. mindspore/ops/_op_impl/aicpu/one_hot.py +116 -0
  716. mindspore/ops/_op_impl/aicpu/ones_like.py +39 -0
  717. mindspore/ops/_op_impl/aicpu/orgqr.py +34 -0
  718. mindspore/ops/_op_impl/aicpu/pad_and_shift.py +33 -0
  719. mindspore/ops/_op_impl/aicpu/pad_v3.py +61 -0
  720. mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +59 -0
  721. mindspore/ops/_op_impl/aicpu/padding.py +41 -0
  722. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +54 -0
  723. mindspore/ops/_op_impl/aicpu/pdist_grad.py +33 -0
  724. mindspore/ops/_op_impl/aicpu/poisson.py +37 -0
  725. mindspore/ops/_op_impl/aicpu/polar.py +32 -0
  726. mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
  727. mindspore/ops/_op_impl/aicpu/pow.py +39 -0
  728. mindspore/ops/_op_impl/aicpu/print_tensor.py +39 -0
  729. mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +113 -0
  730. mindspore/ops/_op_impl/aicpu/qr.py +36 -0
  731. mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
  732. mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
  733. mindspore/ops/_op_impl/aicpu/ragged_range.py +49 -0
  734. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
  735. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_tensor.py +74 -0
  736. mindspore/ops/_op_impl/aicpu/random_categorical.py +68 -0
  737. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +36 -0
  738. mindspore/ops/_op_impl/aicpu/random_gamma.py +38 -0
  739. mindspore/ops/_op_impl/aicpu/random_poisson.py +134 -0
  740. mindspore/ops/_op_impl/aicpu/random_shuffle.py +47 -0
  741. mindspore/ops/_op_impl/aicpu/randperm.py +38 -0
  742. mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
  743. mindspore/ops/_op_impl/aicpu/range.py +36 -0
  744. mindspore/ops/_op_impl/aicpu/range_v2.py +35 -0
  745. mindspore/ops/_op_impl/aicpu/real.py +31 -0
  746. mindspore/ops/_op_impl/aicpu/real_div.py +40 -0
  747. mindspore/ops/_op_impl/aicpu/reciprocal.py +34 -0
  748. mindspore/ops/_op_impl/aicpu/reciprocal_grad.py +35 -0
  749. mindspore/ops/_op_impl/aicpu/reduce_mean.py +57 -0
  750. mindspore/ops/_op_impl/aicpu/reduce_prod.py +57 -0
  751. mindspore/ops/_op_impl/aicpu/reduce_sum.py +57 -0
  752. mindspore/ops/_op_impl/aicpu/relu_grad_v3.py +41 -0
  753. mindspore/ops/_op_impl/aicpu/relu_v3.py +38 -0
  754. mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +96 -0
  755. mindspore/ops/_op_impl/aicpu/reshape.py +42 -0
  756. mindspore/ops/_op_impl/aicpu/resize_area.py +40 -0
  757. mindspore/ops/_op_impl/aicpu/resize_bicubic.py +20 -0
  758. mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +19 -0
  759. mindspore/ops/_op_impl/aicpu/resize_bilinear.py +32 -0
  760. mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +32 -0
  761. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +36 -0
  762. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +35 -0
  763. mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
  764. mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
  765. mindspore/ops/_op_impl/aicpu/reverse_sequence.py +55 -0
  766. mindspore/ops/_op_impl/aicpu/reversev2.py +54 -0
  767. mindspore/ops/_op_impl/aicpu/rgb_to_hsv.py +32 -0
  768. mindspore/ops/_op_impl/aicpu/right_shift.py +38 -0
  769. mindspore/ops/_op_impl/aicpu/rnnt_loss.py +35 -0
  770. mindspore/ops/_op_impl/aicpu/round.py +34 -0
  771. mindspore/ops/_op_impl/aicpu/rsqrt.py +33 -0
  772. mindspore/ops/_op_impl/aicpu/rsqrt_grad.py +36 -0
  773. mindspore/ops/_op_impl/aicpu/sample_distorted_bounding_box_v2.py +49 -0
  774. mindspore/ops/_op_impl/aicpu/scale_and_translate.py +52 -0
  775. mindspore/ops/_op_impl/aicpu/scale_and_translate_grad.py +36 -0
  776. mindspore/ops/_op_impl/aicpu/scatter.py +79 -0
  777. mindspore/ops/_op_impl/aicpu/scatter_add_with_axis.py +53 -0
  778. mindspore/ops/_op_impl/aicpu/scatter_elements.py +39 -0
  779. mindspore/ops/_op_impl/aicpu/scatter_nd.py +59 -0
  780. mindspore/ops/_op_impl/aicpu/scatter_nd_max.py +54 -0
  781. mindspore/ops/_op_impl/aicpu/scatter_nd_min.py +54 -0
  782. mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +59 -0
  783. mindspore/ops/_op_impl/aicpu/search_sorted.py +44 -0
  784. mindspore/ops/_op_impl/aicpu/segment_max.py +52 -0
  785. mindspore/ops/_op_impl/aicpu/segment_mean.py +56 -0
  786. mindspore/ops/_op_impl/aicpu/segment_min.py +52 -0
  787. mindspore/ops/_op_impl/aicpu/segment_prod.py +56 -0
  788. mindspore/ops/_op_impl/aicpu/segment_sum.py +56 -0
  789. mindspore/ops/_op_impl/aicpu/select.py +45 -0
  790. mindspore/ops/_op_impl/aicpu/self_adjoint_eig.py +34 -0
  791. mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
  792. mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
  793. mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
  794. mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
  795. mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
  796. mindspore/ops/_op_impl/aicpu/set_size.py +38 -0
  797. mindspore/ops/_op_impl/aicpu/sign.py +36 -0
  798. mindspore/ops/_op_impl/aicpu/sin.py +34 -0
  799. mindspore/ops/_op_impl/aicpu/sinc.py +43 -0
  800. mindspore/ops/_op_impl/aicpu/sinh.py +34 -0
  801. mindspore/ops/_op_impl/aicpu/slice.py +59 -0
  802. mindspore/ops/_op_impl/aicpu/slice_grad.py +76 -0
  803. mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
  804. mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
  805. mindspore/ops/_op_impl/aicpu/sort.py +39 -0
  806. mindspore/ops/_op_impl/aicpu/space_to_depth.py +44 -0
  807. mindspore/ops/_op_impl/aicpu/sparse_addmm.py +87 -0
  808. mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +80 -0
  809. mindspore/ops/_op_impl/aicpu/sparse_apply_centered_rms_prop.py +105 -0
  810. mindspore/ops/_op_impl/aicpu/sparse_apply_momentum.py +80 -0
  811. mindspore/ops/_op_impl/aicpu/sparse_apply_proximal_gradient_descent.py +79 -0
  812. mindspore/ops/_op_impl/aicpu/sparse_concat.py +59 -0
  813. mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
  814. mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_add.py +58 -0
  815. mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_div.py +58 -0
  816. mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_mul.py +58 -0
  817. mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows.py +63 -0
  818. mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows_grad.py +45 -0
  819. mindspore/ops/_op_impl/aicpu/sparse_matrix_mat_mul.py +56 -0
  820. mindspore/ops/_op_impl/aicpu/sparse_matrix_nnz.py +81 -0
  821. mindspore/ops/_op_impl/aicpu/sparse_matrix_transpose.py +116 -0
  822. mindspore/ops/_op_impl/aicpu/sparse_reorder.py +56 -0
  823. mindspore/ops/_op_impl/aicpu/sparse_reshape.py +34 -0
  824. mindspore/ops/_op_impl/aicpu/sparse_segment_mean_grad.py +36 -0
  825. mindspore/ops/_op_impl/aicpu/sparse_segment_mean_with_num_segments.py +44 -0
  826. mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n.py +43 -0
  827. mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n_grad.py +38 -0
  828. mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n_with_num_segments.py +44 -0
  829. mindspore/ops/_op_impl/aicpu/sparse_segment_sum.py +49 -0
  830. mindspore/ops/_op_impl/aicpu/sparse_segment_sum_with_num_segments.py +68 -0
  831. mindspore/ops/_op_impl/aicpu/sparse_slice.py +63 -0
  832. mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +61 -0
  833. mindspore/ops/_op_impl/aicpu/sparse_softmax.py +33 -0
  834. mindspore/ops/_op_impl/aicpu/sparse_softmax_cross_entropy_with_logits_v2.py +35 -0
  835. mindspore/ops/_op_impl/aicpu/sparse_sparse_maximum.py +53 -0
  836. mindspore/ops/_op_impl/aicpu/sparse_sparse_minimum.py +53 -0
  837. mindspore/ops/_op_impl/aicpu/sparse_tensor_dense_add.py +84 -0
  838. mindspore/ops/_op_impl/aicpu/sparse_tensor_dense_mat_mul.py +190 -0
  839. mindspore/ops/_op_impl/aicpu/sparse_tensor_to_csr_sparse_matrix.py +51 -0
  840. mindspore/ops/_op_impl/aicpu/sparse_to_dense_v2.py +73 -0
  841. mindspore/ops/_op_impl/aicpu/split.py +45 -0
  842. mindspore/ops/_op_impl/aicpu/sqrt.py +34 -0
  843. mindspore/ops/_op_impl/aicpu/sqrt_grad.py +35 -0
  844. mindspore/ops/_op_impl/aicpu/square.py +35 -0
  845. mindspore/ops/_op_impl/aicpu/squared_difference.py +37 -0
  846. mindspore/ops/_op_impl/aicpu/squeeze.py +42 -0
  847. mindspore/ops/_op_impl/aicpu/sspaddmm.py +97 -0
  848. mindspore/ops/_op_impl/aicpu/stack.py +45 -0
  849. mindspore/ops/_op_impl/aicpu/stack_push_pop.py +87 -0
  850. mindspore/ops/_op_impl/aicpu/standard_laplace.py +34 -0
  851. mindspore/ops/_op_impl/aicpu/standard_normal.py +34 -0
  852. mindspore/ops/_op_impl/aicpu/stateless_dropout_genmask.py +37 -0
  853. mindspore/ops/_op_impl/aicpu/stft.py +70 -0
  854. mindspore/ops/_op_impl/aicpu/strided_slice.py +43 -0
  855. mindspore/ops/_op_impl/aicpu/strided_slice_grad.py +50 -0
  856. mindspore/ops/_op_impl/aicpu/sub.py +41 -0
  857. mindspore/ops/_op_impl/aicpu/sub_and_filter.py +36 -0
  858. mindspore/ops/_op_impl/aicpu/tan.py +34 -0
  859. mindspore/ops/_op_impl/aicpu/tanh.py +34 -0
  860. mindspore/ops/_op_impl/aicpu/tanh_grad.py +35 -0
  861. mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
  862. mindspore/ops/_op_impl/aicpu/tile.py +56 -0
  863. mindspore/ops/_op_impl/aicpu/topk.py +34 -0
  864. mindspore/ops/_op_impl/aicpu/trace.py +40 -0
  865. mindspore/ops/_op_impl/aicpu/tracegrad.py +41 -0
  866. mindspore/ops/_op_impl/aicpu/trans_data.py +35 -0
  867. mindspore/ops/_op_impl/aicpu/transpose.py +58 -0
  868. mindspore/ops/_op_impl/aicpu/tridiagonal_matmul.py +42 -0
  869. mindspore/ops/_op_impl/aicpu/tridiagonal_solve.py +35 -0
  870. mindspore/ops/_op_impl/aicpu/tril.py +42 -0
  871. mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
  872. mindspore/ops/_op_impl/aicpu/triplet_margin_loss.py +62 -0
  873. mindspore/ops/_op_impl/aicpu/triu.py +43 -0
  874. mindspore/ops/_op_impl/aicpu/triu_indices.py +34 -0
  875. mindspore/ops/_op_impl/aicpu/truncated_normal.py +39 -0
  876. mindspore/ops/_op_impl/aicpu/uniform.py +36 -0
  877. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +41 -0
  878. mindspore/ops/_op_impl/aicpu/uniform_int.py +36 -0
  879. mindspore/ops/_op_impl/aicpu/uniform_real.py +33 -0
  880. mindspore/ops/_op_impl/aicpu/unique.py +31 -0
  881. mindspore/ops/_op_impl/aicpu/unique_consecutive.py +47 -0
  882. mindspore/ops/_op_impl/aicpu/unique_with_pad.py +32 -0
  883. mindspore/ops/_op_impl/aicpu/unravel_index.py +32 -0
  884. mindspore/ops/_op_impl/aicpu/unsorted_segment_prod.py +53 -0
  885. mindspore/ops/_op_impl/aicpu/unsorted_segment_sum.py +57 -0
  886. mindspore/ops/_op_impl/aicpu/unstack.py +45 -0
  887. mindspore/ops/_op_impl/aicpu/update_cache.py +44 -0
  888. mindspore/ops/_op_impl/aicpu/upper_bound.py +47 -0
  889. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +42 -0
  890. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +49 -0
  891. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +40 -0
  892. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +50 -0
  893. mindspore/ops/_op_impl/aicpu/xdivy.py +35 -0
  894. mindspore/ops/_op_impl/aicpu/xlogy.py +33 -0
  895. mindspore/ops/_op_impl/aicpu/zeros_like.py +42 -0
  896. mindspore/ops/_op_impl/aicpu/zeta.py +31 -0
  897. mindspore/ops/_op_impl/akg/__init__.py +19 -0
  898. mindspore/ops/_op_impl/akg/ascend/__init__.py +48 -0
  899. mindspore/ops/_op_impl/akg/ascend/abs.py +35 -0
  900. mindspore/ops/_op_impl/akg/ascend/add.py +42 -0
  901. mindspore/ops/_op_impl/akg/ascend/add_n.py +37 -0
  902. mindspore/ops/_op_impl/akg/ascend/batchmatmul.py +33 -0
  903. mindspore/ops/_op_impl/akg/ascend/cast.py +46 -0
  904. mindspore/ops/_op_impl/akg/ascend/equal.py +35 -0
  905. mindspore/ops/_op_impl/akg/ascend/exp.py +35 -0
  906. mindspore/ops/_op_impl/akg/ascend/expand_dims.py +33 -0
  907. mindspore/ops/_op_impl/akg/ascend/greater.py +34 -0
  908. mindspore/ops/_op_impl/akg/ascend/greater_equal.py +35 -0
  909. mindspore/ops/_op_impl/akg/ascend/less.py +31 -0
  910. mindspore/ops/_op_impl/akg/ascend/less_equal.py +35 -0
  911. mindspore/ops/_op_impl/akg/ascend/load_im2col.py +33 -0
  912. mindspore/ops/_op_impl/akg/ascend/log.py +34 -0
  913. mindspore/ops/_op_impl/akg/ascend/maximum.py +36 -0
  914. mindspore/ops/_op_impl/akg/ascend/minimum.py +39 -0
  915. mindspore/ops/_op_impl/akg/ascend/mul.py +41 -0
  916. mindspore/ops/_op_impl/akg/ascend/neg.py +37 -0
  917. mindspore/ops/_op_impl/akg/ascend/pow.py +35 -0
  918. mindspore/ops/_op_impl/akg/ascend/prod_force_se_a.py +33 -0
  919. mindspore/ops/_op_impl/akg/ascend/real_div.py +36 -0
  920. mindspore/ops/_op_impl/akg/ascend/reciprocal.py +32 -0
  921. mindspore/ops/_op_impl/akg/ascend/reduce_max.py +32 -0
  922. mindspore/ops/_op_impl/akg/ascend/reduce_min.py +32 -0
  923. mindspore/ops/_op_impl/akg/ascend/reduce_sum.py +37 -0
  924. mindspore/ops/_op_impl/akg/ascend/rsqrt.py +35 -0
  925. mindspore/ops/_op_impl/akg/ascend/select.py +37 -0
  926. mindspore/ops/_op_impl/akg/ascend/sqrt.py +35 -0
  927. mindspore/ops/_op_impl/akg/ascend/square.py +35 -0
  928. mindspore/ops/_op_impl/akg/ascend/sub.py +42 -0
  929. mindspore/ops/_op_impl/akg/cpu/__init__.py +23 -0
  930. mindspore/ops/_op_impl/akg/cpu/coo2csr.py +29 -0
  931. mindspore/ops/_op_impl/akg/cpu/csr2coo.py +29 -0
  932. mindspore/ops/_op_impl/akg/cpu/csr_gather.py +33 -0
  933. mindspore/ops/_op_impl/akg/cpu/csr_mm.py +34 -0
  934. mindspore/ops/_op_impl/akg/cpu/csr_mul.py +33 -0
  935. mindspore/ops/_op_impl/akg/cpu/csr_mv.py +33 -0
  936. mindspore/ops/_op_impl/akg/cpu/csr_reduce_sum.py +31 -0
  937. mindspore/ops/_op_impl/akg/gpu/__init__.py +24 -0
  938. mindspore/ops/_op_impl/akg/gpu/coo2csr.py +29 -0
  939. mindspore/ops/_op_impl/akg/gpu/csr2coo.py +29 -0
  940. mindspore/ops/_op_impl/akg/gpu/csr_div.py +36 -0
  941. mindspore/ops/_op_impl/akg/gpu/csr_gather.py +33 -0
  942. mindspore/ops/_op_impl/akg/gpu/csr_mm.py +37 -0
  943. mindspore/ops/_op_impl/akg/gpu/csr_mul.py +36 -0
  944. mindspore/ops/_op_impl/akg/gpu/csr_mv.py +36 -0
  945. mindspore/ops/_op_impl/akg/gpu/csr_reduce_sum.py +33 -0
  946. mindspore/ops/_op_impl/cpu/__init__.py +78 -0
  947. mindspore/ops/_op_impl/cpu/adam.py +49 -0
  948. mindspore/ops/_op_impl/cpu/adam_weight_decay.py +47 -0
  949. mindspore/ops/_op_impl/cpu/arg_max.py +30 -0
  950. mindspore/ops/_op_impl/cpu/arg_max_with_value.py +31 -0
  951. mindspore/ops/_op_impl/cpu/arg_min_with_value.py +31 -0
  952. mindspore/ops/_op_impl/cpu/buffer_append.py +28 -0
  953. mindspore/ops/_op_impl/cpu/buffer_get.py +28 -0
  954. mindspore/ops/_op_impl/cpu/buffer_sample.py +28 -0
  955. mindspore/ops/_op_impl/cpu/cast.py +171 -0
  956. mindspore/ops/_op_impl/cpu/concat_offset.py +38 -0
  957. mindspore/ops/_op_impl/cpu/conv2d.py +30 -0
  958. mindspore/ops/_op_impl/cpu/conv3d.py +30 -0
  959. mindspore/ops/_op_impl/cpu/div.py +32 -0
  960. mindspore/ops/_op_impl/cpu/dropout.py +31 -0
  961. mindspore/ops/_op_impl/cpu/dropout_grad.py +30 -0
  962. mindspore/ops/_op_impl/cpu/dynamic_shape.py +42 -0
  963. mindspore/ops/_op_impl/cpu/dynamic_stitch.py +41 -0
  964. mindspore/ops/_op_impl/cpu/equal_count.py +30 -0
  965. mindspore/ops/_op_impl/cpu/gather_d.py +49 -0
  966. mindspore/ops/_op_impl/cpu/gather_d_grad.py +38 -0
  967. mindspore/ops/_op_impl/cpu/gather_d_grad_v2.py +40 -0
  968. mindspore/ops/_op_impl/cpu/gather_v2.py +40 -0
  969. mindspore/ops/_op_impl/cpu/hsigmoid.py +33 -0
  970. mindspore/ops/_op_impl/cpu/hsigmoid_grad.py +34 -0
  971. mindspore/ops/_op_impl/cpu/hswish.py +32 -0
  972. mindspore/ops/_op_impl/cpu/hswish_grad.py +33 -0
  973. mindspore/ops/_op_impl/cpu/identity_n.py +40 -0
  974. mindspore/ops/_op_impl/cpu/is_finite.py +39 -0
  975. mindspore/ops/_op_impl/cpu/l2loss.py +30 -0
  976. mindspore/ops/_op_impl/cpu/layer_norm.py +36 -0
  977. mindspore/ops/_op_impl/cpu/layer_norm_grad.py +38 -0
  978. mindspore/ops/_op_impl/cpu/maximum.py +35 -0
  979. mindspore/ops/_op_impl/cpu/maximum_grad.py +47 -0
  980. mindspore/ops/_op_impl/cpu/minimum.py +40 -0
  981. mindspore/ops/_op_impl/cpu/minimum_grad.py +51 -0
  982. mindspore/ops/_op_impl/cpu/mirror_pad.py +36 -0
  983. mindspore/ops/_op_impl/cpu/mirror_pad_grad.py +36 -0
  984. mindspore/ops/_op_impl/cpu/mul.py +32 -0
  985. mindspore/ops/_op_impl/cpu/one_hot.py +31 -0
  986. mindspore/ops/_op_impl/cpu/pad.py +32 -0
  987. mindspore/ops/_op_impl/cpu/pow.py +32 -0
  988. mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +42 -0
  989. mindspore/ops/_op_impl/cpu/pyexecute.py +29 -0
  990. mindspore/ops/_op_impl/cpu/pyfunc.py +29 -0
  991. mindspore/ops/_op_impl/cpu/range.py +34 -0
  992. mindspore/ops/_op_impl/cpu/real_div.py +33 -0
  993. mindspore/ops/_op_impl/cpu/reduce_all.py +29 -0
  994. mindspore/ops/_op_impl/cpu/reduce_any.py +29 -0
  995. mindspore/ops/_op_impl/cpu/reduce_max.py +32 -0
  996. mindspore/ops/_op_impl/cpu/reduce_mean.py +40 -0
  997. mindspore/ops/_op_impl/cpu/reduce_min.py +32 -0
  998. mindspore/ops/_op_impl/cpu/reduce_prod.py +40 -0
  999. mindspore/ops/_op_impl/cpu/reduce_std.py +31 -0
  1000. mindspore/ops/_op_impl/cpu/reduce_sum.py +41 -0
  1001. mindspore/ops/_op_impl/cpu/space_to_batch_nd.py +38 -0
  1002. mindspore/ops/_op_impl/cpu/sparse_slice.py +62 -0
  1003. mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +60 -0
  1004. mindspore/ops/_op_impl/cpu/split.py +34 -0
  1005. mindspore/ops/_op_impl/cpu/sspaddmm.py +95 -0
  1006. mindspore/ops/_op_impl/cpu/stack.py +38 -0
  1007. mindspore/ops/_op_impl/cpu/sub.py +32 -0
  1008. mindspore/ops/_op_impl/cpu/tensor_copy_slices.py +41 -0
  1009. mindspore/ops/_op_impl/cpu/tile.py +37 -0
  1010. mindspore/ops/_op_impl/cpu/top_k.py +31 -0
  1011. mindspore/ops/_op_impl/cpu/transpose.py +39 -0
  1012. mindspore/ops/_primitive_cache.py +90 -0
  1013. mindspore/ops/_register_for_op.py +73 -0
  1014. mindspore/ops/_utils/__init__.py +20 -0
  1015. mindspore/ops/_utils/utils.py +147 -0
  1016. mindspore/ops/_vmap/__init__.py +25 -0
  1017. mindspore/ops/_vmap/vmap_array_ops.py +2149 -0
  1018. mindspore/ops/_vmap/vmap_base.py +533 -0
  1019. mindspore/ops/_vmap/vmap_convolution_ops.py +441 -0
  1020. mindspore/ops/_vmap/vmap_debug_ops.py +50 -0
  1021. mindspore/ops/_vmap/vmap_grad_math_ops.py +274 -0
  1022. mindspore/ops/_vmap/vmap_grad_nn_ops.py +806 -0
  1023. mindspore/ops/_vmap/vmap_image_ops.py +194 -0
  1024. mindspore/ops/_vmap/vmap_math_ops.py +993 -0
  1025. mindspore/ops/_vmap/vmap_nn_ops.py +2250 -0
  1026. mindspore/ops/_vmap/vmap_other_ops.py +105 -0
  1027. mindspore/ops/_vmap/vmap_random_ops.py +122 -0
  1028. mindspore/ops/_vmap/vmap_sparse_ops.py +89 -0
  1029. mindspore/ops/auto_generate/__init__.py +31 -0
  1030. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +309 -0
  1031. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +252 -0
  1032. mindspore/ops/auto_generate/gen_arg_handler.py +197 -0
  1033. mindspore/ops/auto_generate/gen_extend_func.py +1701 -0
  1034. mindspore/ops/auto_generate/gen_ops_def.py +8482 -0
  1035. mindspore/ops/auto_generate/gen_ops_prim.py +16704 -0
  1036. mindspore/ops/auto_generate/pyboost_inner_prim.py +549 -0
  1037. mindspore/ops/composite/__init__.py +71 -0
  1038. mindspore/ops/composite/base.py +1318 -0
  1039. mindspore/ops/composite/env_ops.py +41 -0
  1040. mindspore/ops/composite/math_ops.py +125 -0
  1041. mindspore/ops/composite/multitype_ops/__init__.py +77 -0
  1042. mindspore/ops/composite/multitype_ops/_compile_utils.py +1459 -0
  1043. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +897 -0
  1044. mindspore/ops/composite/multitype_ops/add_impl.py +606 -0
  1045. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +56 -0
  1046. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +56 -0
  1047. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +56 -0
  1048. mindspore/ops/composite/multitype_ops/div_impl.py +189 -0
  1049. mindspore/ops/composite/multitype_ops/equal_impl.py +335 -0
  1050. mindspore/ops/composite/multitype_ops/floordiv_impl.py +88 -0
  1051. mindspore/ops/composite/multitype_ops/getitem_impl.py +400 -0
  1052. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +109 -0
  1053. mindspore/ops/composite/multitype_ops/greater_impl.py +110 -0
  1054. mindspore/ops/composite/multitype_ops/in_impl.py +196 -0
  1055. mindspore/ops/composite/multitype_ops/left_shift_impl.py +37 -0
  1056. mindspore/ops/composite/multitype_ops/less_equal_impl.py +111 -0
  1057. mindspore/ops/composite/multitype_ops/less_impl.py +112 -0
  1058. mindspore/ops/composite/multitype_ops/logic_not_impl.py +113 -0
  1059. mindspore/ops/composite/multitype_ops/logical_and_impl.py +60 -0
  1060. mindspore/ops/composite/multitype_ops/logical_or_impl.py +61 -0
  1061. mindspore/ops/composite/multitype_ops/mod_impl.py +86 -0
  1062. mindspore/ops/composite/multitype_ops/mul_impl.py +294 -0
  1063. mindspore/ops/composite/multitype_ops/negative_impl.py +79 -0
  1064. mindspore/ops/composite/multitype_ops/not_equal_impl.py +290 -0
  1065. mindspore/ops/composite/multitype_ops/not_in_impl.py +196 -0
  1066. mindspore/ops/composite/multitype_ops/ones_like_impl.py +96 -0
  1067. mindspore/ops/composite/multitype_ops/pow_impl.py +87 -0
  1068. mindspore/ops/composite/multitype_ops/right_shift_impl.py +37 -0
  1069. mindspore/ops/composite/multitype_ops/setitem_impl.py +884 -0
  1070. mindspore/ops/composite/multitype_ops/sub_impl.py +116 -0
  1071. mindspore/ops/composite/multitype_ops/uadd_impl.py +29 -0
  1072. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +228 -0
  1073. mindspore/ops/deprecated.py +315 -0
  1074. mindspore/ops/function/__init__.py +782 -0
  1075. mindspore/ops/function/array_func.py +7226 -0
  1076. mindspore/ops/function/clip_func.py +384 -0
  1077. mindspore/ops/function/debug_func.py +181 -0
  1078. mindspore/ops/function/fft_func.py +44 -0
  1079. mindspore/ops/function/grad/__init__.py +34 -0
  1080. mindspore/ops/function/grad/grad_func.py +1425 -0
  1081. mindspore/ops/function/image_func.py +292 -0
  1082. mindspore/ops/function/linalg_func.py +416 -0
  1083. mindspore/ops/function/math_func.py +12228 -0
  1084. mindspore/ops/function/nn_func.py +8609 -0
  1085. mindspore/ops/function/other_func.py +115 -0
  1086. mindspore/ops/function/parameter_func.py +134 -0
  1087. mindspore/ops/function/random_func.py +1715 -0
  1088. mindspore/ops/function/reshard_func.py +104 -0
  1089. mindspore/ops/function/sparse_func.py +884 -0
  1090. mindspore/ops/function/sparse_unary_func.py +2422 -0
  1091. mindspore/ops/function/spectral_func.py +150 -0
  1092. mindspore/ops/function/vmap_func.py +117 -0
  1093. mindspore/ops/functional.py +464 -0
  1094. mindspore/ops/op_info_register.py +1572 -0
  1095. mindspore/ops/operations/__init__.py +722 -0
  1096. mindspore/ops/operations/_csr_ops.py +403 -0
  1097. mindspore/ops/operations/_custom_grad.py +181 -0
  1098. mindspore/ops/operations/_embedding_cache_ops.py +307 -0
  1099. mindspore/ops/operations/_grad_ops.py +2978 -0
  1100. mindspore/ops/operations/_infer_ops.py +19 -0
  1101. mindspore/ops/operations/_inner_ops.py +2544 -0
  1102. mindspore/ops/operations/_map_tensor_ops.py +112 -0
  1103. mindspore/ops/operations/_ms_kernel.py +601 -0
  1104. mindspore/ops/operations/_ocr_ops.py +379 -0
  1105. mindspore/ops/operations/_opaque_predicate_registry.py +41 -0
  1106. mindspore/ops/operations/_pyfunc_registry.py +58 -0
  1107. mindspore/ops/operations/_quant_ops.py +1844 -0
  1108. mindspore/ops/operations/_rl_inner_ops.py +1231 -0
  1109. mindspore/ops/operations/_scalar_ops.py +106 -0
  1110. mindspore/ops/operations/_sequence_ops.py +1155 -0
  1111. mindspore/ops/operations/_sparse_grad_ops.py +56 -0
  1112. mindspore/ops/operations/_tensor_array.py +359 -0
  1113. mindspore/ops/operations/_thor_ops.py +807 -0
  1114. mindspore/ops/operations/array_ops.py +6124 -0
  1115. mindspore/ops/operations/comm_ops.py +1985 -0
  1116. mindspore/ops/operations/control_ops.py +127 -0
  1117. mindspore/ops/operations/custom_ops.py +1129 -0
  1118. mindspore/ops/operations/debug_ops.py +678 -0
  1119. mindspore/ops/operations/image_ops.py +1041 -0
  1120. mindspore/ops/operations/inner_ops.py +697 -0
  1121. mindspore/ops/operations/linalg_ops.py +95 -0
  1122. mindspore/ops/operations/manually_defined/__init__.py +24 -0
  1123. mindspore/ops/operations/manually_defined/_inner.py +73 -0
  1124. mindspore/ops/operations/manually_defined/ops_def.py +2271 -0
  1125. mindspore/ops/operations/math_ops.py +5095 -0
  1126. mindspore/ops/operations/nn_ops.py +9575 -0
  1127. mindspore/ops/operations/other_ops.py +874 -0
  1128. mindspore/ops/operations/random_ops.py +1288 -0
  1129. mindspore/ops/operations/reshard_ops.py +53 -0
  1130. mindspore/ops/operations/rl_ops.py +288 -0
  1131. mindspore/ops/operations/sparse_ops.py +2753 -0
  1132. mindspore/ops/operations/spectral_ops.py +111 -0
  1133. mindspore/ops/primitive.py +1046 -0
  1134. mindspore/ops/signature.py +54 -0
  1135. mindspore/ops/vm_impl_registry.py +91 -0
  1136. mindspore/ops_generate/__init__.py +27 -0
  1137. mindspore/ops_generate/arg_dtype_cast.py +252 -0
  1138. mindspore/ops_generate/arg_handler.py +197 -0
  1139. mindspore/ops_generate/gen_aclnn_implement.py +263 -0
  1140. mindspore/ops_generate/gen_constants.py +36 -0
  1141. mindspore/ops_generate/gen_ops.py +1099 -0
  1142. mindspore/ops_generate/gen_ops_inner_prim.py +131 -0
  1143. mindspore/ops_generate/gen_pyboost_func.py +1052 -0
  1144. mindspore/ops_generate/gen_utils.py +209 -0
  1145. mindspore/ops_generate/op_proto.py +145 -0
  1146. mindspore/ops_generate/pyboost_utils.py +367 -0
  1147. mindspore/ops_generate/template.py +261 -0
  1148. mindspore/parallel/__init__.py +30 -0
  1149. mindspore/parallel/_auto_parallel_context.py +1486 -0
  1150. mindspore/parallel/_cell_wrapper.py +174 -0
  1151. mindspore/parallel/_cost_model_context.py +700 -0
  1152. mindspore/parallel/_dp_allreduce_fusion.py +159 -0
  1153. mindspore/parallel/_offload_context.py +275 -0
  1154. mindspore/parallel/_parallel_serialization.py +561 -0
  1155. mindspore/parallel/_ps_context.py +242 -0
  1156. mindspore/parallel/_recovery_context.py +110 -0
  1157. mindspore/parallel/_tensor.py +730 -0
  1158. mindspore/parallel/_transformer/__init__.py +35 -0
  1159. mindspore/parallel/_transformer/layers.py +765 -0
  1160. mindspore/parallel/_transformer/loss.py +251 -0
  1161. mindspore/parallel/_transformer/moe.py +693 -0
  1162. mindspore/parallel/_transformer/op_parallel_config.py +222 -0
  1163. mindspore/parallel/_transformer/transformer.py +3119 -0
  1164. mindspore/parallel/_utils.py +612 -0
  1165. mindspore/parallel/algo_parameter_config.py +400 -0
  1166. mindspore/parallel/checkpoint_transform.py +650 -0
  1167. mindspore/parallel/cluster/__init__.py +15 -0
  1168. mindspore/parallel/cluster/process_entity/__init__.py +18 -0
  1169. mindspore/parallel/cluster/process_entity/_api.py +352 -0
  1170. mindspore/parallel/cluster/process_entity/_utils.py +101 -0
  1171. mindspore/parallel/cluster/run.py +136 -0
  1172. mindspore/parallel/mpi/__init__.py +14 -0
  1173. mindspore/parallel/mpi/_mpi_config.py +116 -0
  1174. mindspore/parallel/parameter_broadcast.py +151 -0
  1175. mindspore/parallel/shard.py +481 -0
  1176. mindspore/parallel/transform_safetensors.py +993 -0
  1177. mindspore/profiler/__init__.py +28 -0
  1178. mindspore/profiler/common/__init__.py +14 -0
  1179. mindspore/profiler/common/constant.py +29 -0
  1180. mindspore/profiler/common/exceptions/__init__.py +14 -0
  1181. mindspore/profiler/common/exceptions/error_code.py +83 -0
  1182. mindspore/profiler/common/exceptions/exceptions.py +286 -0
  1183. mindspore/profiler/common/process_pool.py +41 -0
  1184. mindspore/profiler/common/registry.py +47 -0
  1185. mindspore/profiler/common/singleton.py +28 -0
  1186. mindspore/profiler/common/struct_type.py +118 -0
  1187. mindspore/profiler/common/util.py +472 -0
  1188. mindspore/profiler/common/validator/__init__.py +14 -0
  1189. mindspore/profiler/common/validator/validate_path.py +84 -0
  1190. mindspore/profiler/dynamic_profiler.py +694 -0
  1191. mindspore/profiler/envprofiling.py +254 -0
  1192. mindspore/profiler/parser/__init__.py +14 -0
  1193. mindspore/profiler/parser/aicpu_data_parser.py +272 -0
  1194. mindspore/profiler/parser/ascend_analysis/__init__.py +14 -0
  1195. mindspore/profiler/parser/ascend_analysis/constant.py +71 -0
  1196. mindspore/profiler/parser/ascend_analysis/file_manager.py +180 -0
  1197. mindspore/profiler/parser/ascend_analysis/function_event.py +185 -0
  1198. mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +136 -0
  1199. mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +131 -0
  1200. mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +104 -0
  1201. mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
  1202. mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +123 -0
  1203. mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +86 -0
  1204. mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +75 -0
  1205. mindspore/profiler/parser/ascend_cluster_generator.py +116 -0
  1206. mindspore/profiler/parser/ascend_communicate_generator.py +314 -0
  1207. mindspore/profiler/parser/ascend_flops_generator.py +116 -0
  1208. mindspore/profiler/parser/ascend_fpbp_generator.py +82 -0
  1209. mindspore/profiler/parser/ascend_hccl_generator.py +271 -0
  1210. mindspore/profiler/parser/ascend_integrate_generator.py +42 -0
  1211. mindspore/profiler/parser/ascend_memory_generator.py +185 -0
  1212. mindspore/profiler/parser/ascend_msprof_exporter.py +282 -0
  1213. mindspore/profiler/parser/ascend_msprof_generator.py +187 -0
  1214. mindspore/profiler/parser/ascend_op_generator.py +334 -0
  1215. mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
  1216. mindspore/profiler/parser/ascend_timeline_generator.py +545 -0
  1217. mindspore/profiler/parser/base_timeline_generator.py +483 -0
  1218. mindspore/profiler/parser/container.py +229 -0
  1219. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +697 -0
  1220. mindspore/profiler/parser/flops_parser.py +531 -0
  1221. mindspore/profiler/parser/framework_enum.py +111 -0
  1222. mindspore/profiler/parser/framework_parser.py +464 -0
  1223. mindspore/profiler/parser/framework_struct.py +61 -0
  1224. mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
  1225. mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
  1226. mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
  1227. mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
  1228. mindspore/profiler/parser/hccl_parser.py +573 -0
  1229. mindspore/profiler/parser/hwts_log_parser.py +122 -0
  1230. mindspore/profiler/parser/integrator.py +526 -0
  1231. mindspore/profiler/parser/memory_usage_parser.py +277 -0
  1232. mindspore/profiler/parser/minddata_analyzer.py +800 -0
  1233. mindspore/profiler/parser/minddata_parser.py +186 -0
  1234. mindspore/profiler/parser/minddata_pipeline_parser.py +299 -0
  1235. mindspore/profiler/parser/op_intermediate_parser.py +149 -0
  1236. mindspore/profiler/parser/optime_parser.py +250 -0
  1237. mindspore/profiler/parser/profiler_info.py +213 -0
  1238. mindspore/profiler/parser/step_trace_parser.py +666 -0
  1239. mindspore/profiler/profiler.py +153 -0
  1240. mindspore/profiler/profiling.py +1922 -0
  1241. mindspore/rewrite/__init__.py +28 -0
  1242. mindspore/rewrite/api/__init__.py +17 -0
  1243. mindspore/rewrite/api/node.py +519 -0
  1244. mindspore/rewrite/api/node_type.py +53 -0
  1245. mindspore/rewrite/api/pattern_engine.py +490 -0
  1246. mindspore/rewrite/api/scoped_value.py +181 -0
  1247. mindspore/rewrite/api/symbol_tree.py +497 -0
  1248. mindspore/rewrite/ast_helpers/__init__.py +25 -0
  1249. mindspore/rewrite/ast_helpers/ast_converter.py +143 -0
  1250. mindspore/rewrite/ast_helpers/ast_finder.py +404 -0
  1251. mindspore/rewrite/ast_helpers/ast_flattener.py +268 -0
  1252. mindspore/rewrite/ast_helpers/ast_modifier.py +605 -0
  1253. mindspore/rewrite/ast_helpers/ast_replacer.py +79 -0
  1254. mindspore/rewrite/common/__init__.py +19 -0
  1255. mindspore/rewrite/common/config.py +24 -0
  1256. mindspore/rewrite/common/error_log.py +39 -0
  1257. mindspore/rewrite/common/event.py +28 -0
  1258. mindspore/rewrite/common/namer.py +271 -0
  1259. mindspore/rewrite/common/namespace.py +118 -0
  1260. mindspore/rewrite/common/observable.py +44 -0
  1261. mindspore/rewrite/common/observer.py +54 -0
  1262. mindspore/rewrite/node/__init__.py +22 -0
  1263. mindspore/rewrite/node/call_function.py +95 -0
  1264. mindspore/rewrite/node/cell_container.py +139 -0
  1265. mindspore/rewrite/node/control_flow.py +113 -0
  1266. mindspore/rewrite/node/node.py +1428 -0
  1267. mindspore/rewrite/node/node_manager.py +283 -0
  1268. mindspore/rewrite/node/node_topological_manager.py +223 -0
  1269. mindspore/rewrite/parsers/__init__.py +29 -0
  1270. mindspore/rewrite/parsers/arguments_parser.py +63 -0
  1271. mindspore/rewrite/parsers/assign_parser.py +852 -0
  1272. mindspore/rewrite/parsers/attribute_parser.py +57 -0
  1273. mindspore/rewrite/parsers/class_def_parser.py +289 -0
  1274. mindspore/rewrite/parsers/constant_parser.py +104 -0
  1275. mindspore/rewrite/parsers/container_parser.py +88 -0
  1276. mindspore/rewrite/parsers/expr_parser.py +55 -0
  1277. mindspore/rewrite/parsers/for_parser.py +61 -0
  1278. mindspore/rewrite/parsers/function_def_parser.py +84 -0
  1279. mindspore/rewrite/parsers/if_parser.py +85 -0
  1280. mindspore/rewrite/parsers/module_parser.py +117 -0
  1281. mindspore/rewrite/parsers/parser.py +43 -0
  1282. mindspore/rewrite/parsers/parser_register.py +86 -0
  1283. mindspore/rewrite/parsers/return_parser.py +37 -0
  1284. mindspore/rewrite/parsers/while_parser.py +59 -0
  1285. mindspore/rewrite/sparsify/__init__.py +0 -0
  1286. mindspore/rewrite/sparsify/sparse_transformer.py +457 -0
  1287. mindspore/rewrite/sparsify/sparsify.py +112 -0
  1288. mindspore/rewrite/sparsify/utils.py +179 -0
  1289. mindspore/rewrite/symbol_tree/__init__.py +20 -0
  1290. mindspore/rewrite/symbol_tree/symbol_tree.py +1819 -0
  1291. mindspore/rewrite/symbol_tree/symbol_tree_builder.py +76 -0
  1292. mindspore/rewrite/symbol_tree/symbol_tree_dumper.py +142 -0
  1293. mindspore/run_check/__init__.py +20 -0
  1294. mindspore/run_check/_check_version.py +507 -0
  1295. mindspore/run_check/run_check.py +66 -0
  1296. mindspore/safeguard/__init__.py +18 -0
  1297. mindspore/safeguard/rewrite_obfuscation.py +875 -0
  1298. mindspore/scipy/__init__.py +18 -0
  1299. mindspore/scipy/fft.py +264 -0
  1300. mindspore/scipy/linalg.py +919 -0
  1301. mindspore/scipy/ops.py +165 -0
  1302. mindspore/scipy/ops_grad.py +115 -0
  1303. mindspore/scipy/ops_wrapper.py +74 -0
  1304. mindspore/scipy/optimize/__init__.py +20 -0
  1305. mindspore/scipy/optimize/_bfgs.py +230 -0
  1306. mindspore/scipy/optimize/_lagrange.py +201 -0
  1307. mindspore/scipy/optimize/_lbfgs.py +146 -0
  1308. mindspore/scipy/optimize/gradient_optimization_algorithm.py +168 -0
  1309. mindspore/scipy/optimize/line_search.py +370 -0
  1310. mindspore/scipy/optimize/linear_sum_assignment.py +78 -0
  1311. mindspore/scipy/optimize/minimize.py +200 -0
  1312. mindspore/scipy/utils.py +156 -0
  1313. mindspore/scipy/utils_const.py +246 -0
  1314. mindspore/train/__init__.py +48 -0
  1315. mindspore/train/_utils.py +465 -0
  1316. mindspore/train/amp.py +935 -0
  1317. mindspore/train/anf_ir_pb2.py +1517 -0
  1318. mindspore/train/callback/__init__.py +44 -0
  1319. mindspore/train/callback/_backup_and_restore.py +117 -0
  1320. mindspore/train/callback/_callback.py +613 -0
  1321. mindspore/train/callback/_checkpoint.py +814 -0
  1322. mindspore/train/callback/_cluster_monitor.py +201 -0
  1323. mindspore/train/callback/_dataset_graph.py +150 -0
  1324. mindspore/train/callback/_early_stop.py +239 -0
  1325. mindspore/train/callback/_flops_collector.py +239 -0
  1326. mindspore/train/callback/_history.py +92 -0
  1327. mindspore/train/callback/_lambda_callback.py +80 -0
  1328. mindspore/train/callback/_landscape.py +1049 -0
  1329. mindspore/train/callback/_loss_monitor.py +107 -0
  1330. mindspore/train/callback/_lr_scheduler_callback.py +76 -0
  1331. mindspore/train/callback/_on_request_exit.py +298 -0
  1332. mindspore/train/callback/_reduce_lr_on_plateau.py +226 -0
  1333. mindspore/train/callback/_summary_collector.py +1184 -0
  1334. mindspore/train/callback/_tft_register.py +352 -0
  1335. mindspore/train/callback/_time_monitor.py +141 -0
  1336. mindspore/train/checkpoint_pb2.py +233 -0
  1337. mindspore/train/data_sink.py +219 -0
  1338. mindspore/train/dataset_helper.py +692 -0
  1339. mindspore/train/lineage_pb2.py +1260 -0
  1340. mindspore/train/loss_scale_manager.py +213 -0
  1341. mindspore/train/memory_profiling_pb2.py +298 -0
  1342. mindspore/train/metrics/__init__.py +175 -0
  1343. mindspore/train/metrics/accuracy.py +133 -0
  1344. mindspore/train/metrics/auc.py +129 -0
  1345. mindspore/train/metrics/bleu_score.py +170 -0
  1346. mindspore/train/metrics/confusion_matrix.py +700 -0
  1347. mindspore/train/metrics/cosine_similarity.py +109 -0
  1348. mindspore/train/metrics/dice.py +116 -0
  1349. mindspore/train/metrics/error.py +175 -0
  1350. mindspore/train/metrics/fbeta.py +167 -0
  1351. mindspore/train/metrics/hausdorff_distance.py +333 -0
  1352. mindspore/train/metrics/loss.py +97 -0
  1353. mindspore/train/metrics/mean_surface_distance.py +189 -0
  1354. mindspore/train/metrics/metric.py +373 -0
  1355. mindspore/train/metrics/occlusion_sensitivity.py +225 -0
  1356. mindspore/train/metrics/perplexity.py +133 -0
  1357. mindspore/train/metrics/precision.py +160 -0
  1358. mindspore/train/metrics/recall.py +159 -0
  1359. mindspore/train/metrics/roc.py +223 -0
  1360. mindspore/train/metrics/root_mean_square_surface_distance.py +191 -0
  1361. mindspore/train/metrics/topk.py +167 -0
  1362. mindspore/train/mind_ir_pb2.py +1908 -0
  1363. mindspore/train/model.py +2252 -0
  1364. mindspore/train/node_strategy_pb2.py +653 -0
  1365. mindspore/train/print_pb2.py +184 -0
  1366. mindspore/train/profiling_parallel_pb2.py +151 -0
  1367. mindspore/train/serialization.py +3325 -0
  1368. mindspore/train/summary/__init__.py +23 -0
  1369. mindspore/train/summary/_lineage_adapter.py +41 -0
  1370. mindspore/train/summary/_summary_adapter.py +496 -0
  1371. mindspore/train/summary/_writer_pool.py +207 -0
  1372. mindspore/train/summary/enums.py +56 -0
  1373. mindspore/train/summary/summary_record.py +581 -0
  1374. mindspore/train/summary/writer.py +167 -0
  1375. mindspore/train/summary_pb2.py +1165 -0
  1376. mindspore/train/train_thor/__init__.py +20 -0
  1377. mindspore/train/train_thor/convert_utils.py +268 -0
  1378. mindspore/train/train_thor/dataset_helper.py +192 -0
  1379. mindspore/train/train_thor/model_thor.py +257 -0
  1380. mindspore/utils/__init__.py +21 -0
  1381. mindspore/utils/utils.py +60 -0
  1382. mindspore/version.py +1 -0
  1383. mindspore-2.4.0.dist-info/METADATA +352 -0
  1384. mindspore-2.4.0.dist-info/RECORD +1387 -0
  1385. mindspore-2.4.0.dist-info/WHEEL +5 -0
  1386. mindspore-2.4.0.dist-info/entry_points.txt +3 -0
  1387. mindspore-2.4.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,823 @@
1
+ # Copyright 2022 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+ """
16
+ Transformer Cells module, include TransformerEncoderLayer, TransformerDecoderLayer,
17
+ TransformerEncoder, TransformerDecoder, Transformer.
18
+ """
19
+ import copy
20
+ import math
21
+ from typing import Union, Optional
22
+ import mindspore
23
+ import mindspore.ops as ops
24
+ import mindspore.common.dtype as mstype
25
+ from mindspore.common.tensor import Tensor
26
+ from mindspore.common.parameter import Parameter
27
+ from mindspore.common.initializer import initializer, XavierNormal, XavierUniform, \
28
+ HeUniform, Uniform, _calculate_fan_in_and_fan_out
29
+ from mindspore.ops.function.nn_func import multi_head_attention_forward
30
+ from mindspore.nn.cell import Cell
31
+ from .basic import Dense, Dropout
32
+ from .activation import ReLU, GELU
33
+ from .normalization import LayerNorm
34
+ from .container import CellList
35
+ __all__ = ['MultiheadAttention', 'TransformerEncoderLayer', 'TransformerDecoderLayer',
36
+ 'TransformerEncoder', 'TransformerDecoder', 'Transformer']
37
+
38
+
39
+ class MultiheadAttention(Cell):
40
+ r"""
41
+ This is an implementation of multihead attention in the paper `Attention is all you need
42
+ <https://arxiv.org/pdf/1706.03762v5.pdf>`_. Given the query vector, the key vector and value vector,
43
+ the attention will be performed as the following:
44
+
45
+ .. math::
46
+ MultiHeadAttention(query, key, value) = Concat(head_1, \dots, head_h)W^O
47
+
48
+ where :math:`head_i = Attention(QW_i^Q, KW_i^K, VW_i^V)`, and :math:`W^O` , :math:`W_i^Q` , :math:`W_i^K` ,
49
+ :math:`W_i^V` are weight matrices. The default input / output projection layers is with a bias.
50
+
51
+ if query, key and value tensor is same, then it will be self attention.
52
+
53
+ Args:
54
+ embed_dim (int): Total dimension of MultiheadAttention.
55
+ num_heads (int): Number of attention heads. Note that `embed_dim` will be split
56
+ across `num_heads` (i.e. each head will have dimension `embed_dim // num_heads`).
57
+ dropout (float): Dropout probability of `attn_output_weights`. Default: ``0.0``.
58
+ has_bias (bool): Whether adds bias to input / output projection layers. Default: ``True``.
59
+ add_bias_kv (bool): Whether adds bias to the key and value sequences at axis=0. Default: ``False``.
60
+ add_zero_attn (bool): Whether adds a new batch of zeros to the key and value sequences at axis=1.
61
+ Default: ``False``.
62
+ kdim (int): Total number of features for keys. Default: ``None`` (`kdim=embed_dim`).
63
+ vdim (int): Total number of features for values. Default: ``None`` (`vdim=embed_dim`).
64
+ batch_first (bool): If ``True``, then the input and output shape are :math:`(batch, seq, feature)` ,
65
+ else :math:`(seq, batch, feature)` . Default: ``False``.
66
+ dtype (:class:`mindspore.dtype`): Data type of Parameter. Default: ``mstype.float32`` .
67
+
68
+ Inputs:
69
+ - **query** (Tensor) - The query embeddings. If `query` is unbatched, the shape is :math:`(L, E_q)`,
70
+ otherwise the shape is :math:`(L, N, E_q)` when `batch_first=False` or :math:`(N, L, E_q)` when
71
+ `batch_first=True` , where :math:`L`is the target sequence length, :math:`N` is the batch size,
72
+ and :math:`E_q` is the query embedding dimension `embed_dim`. Supported types: float16, float32,
73
+ float64. Queries are compared against key-value pairs to produce the output.
74
+ - **key** (Tensor) - The key embeddings. If `key` is unbatched, the shape is :math:`(S, E_k)`, otherwise
75
+ the shape is :math:`(S, N, E_k)` when `batch_first=False` or :math:`(N, S, E_k)` when
76
+ `batch_first=True` , where :math:`S` is the source sequence length, :math:`N` is the batch size,
77
+ and :math:`E_k` is the key embedding dimension `kdim`. Supported types: float16, float32, float64.
78
+ - **value** (Tensor) - The value embeddings. If `value` is unbatched, the shape is :math:`(S, E_v)`,
79
+ otherwise the shape is :math:`(S, N, E_v)` when `batch_first=False` or :math:`(N, S, E_v)` when
80
+ `batch_first=True` , where :math:`S` is the source sequence length, :math:`N` is the batch size,
81
+ and :math:`E_v` is the value embedding dimension `vdim`. Supported types: float16, float32, float64.
82
+ - **key_padding_mask** (Tensor, optional) - If specified, a mask of shape :math:`(N, S)` indicating which
83
+ elements within `key` to ignore for the purpose of attention (i.e. treat as "padding").
84
+ For unbatched `query`, shape should be :math:`(S)`. Binary and float masks are supported.
85
+ For a binary mask, a ``True`` value indicates that the corresponding `key` value will be ignored for
86
+ the purpose of attention. For a float mask, it will be directly added to the corresponding `key` value.
87
+ Supported float types: float16, float32, float64. Default: ``None``.
88
+ - **need_weights** (bool) - Whether returns `attn_output_weights` in addition to `attn_outputs`.
89
+ Default: ``True``.
90
+ - **attn_mask** (Tensor, optional) - If specified, a 2D or 3D mask preventing attention to certain positions.
91
+ Must be of shape :math:`(L, S)` or :math:`(N\cdot\text{num_heads}, L, S)`, where :math:`N` is the
92
+ batch size, :math:`L` is the target sequence length, and :math:`S` is the source sequence length.
93
+ A 2D mask will be broadcasted across the batch while a 3D mask allows for a different mask for each entry
94
+ in the batch. For a binary mask, a ``True`` value indicates that the corresponding position is not allowed
95
+ to attend. For a float mask, the mask values will be added to the attention weight.
96
+ Supported float types: float16, float32, float64. Default: ``None``.
97
+ - **average_attn_weights** (bool) - If true, indicates that the returned `attn_weights` should be averaged
98
+ across heads. Otherwise, `attn_weights` are provided separately per head. Note that this flag only
99
+ has an effect when `need_weights=True`. Default: ``True`` (i.e. average weights across heads)
100
+
101
+ Outputs:
102
+ Tuple, a tuple contains(`attn_output`, `attn_output_weights`)
103
+
104
+ - **attn_output** - Attention outputs. If input is unbatched, the output shape is :math:`(L, E)`, otherwise
105
+ the output shape is :math:`(L, N, E)` when `batch_first=False` or :math:`(N, L, E)` when
106
+ `batch_first=True` , where :math:`L` is the target sequence length, :math:`N` is the batch size,
107
+ and :math:`E` is the embedding dimension `embed_dim`.
108
+ - **attn_output_weights** - Only returned when `need_weights=True`. If `average_attn_weights=True`,
109
+ returns attention weights averaged across heads with shape :math:`(L, S)` when input is unbatched or
110
+ :math:`(N, L, S)` when input is batched, where :math:`N` is the batch size, :math:`L` is
111
+ the target sequence length, and :math:`S` is the source sequence length.
112
+ If `average_attn_weights=False`, returns attention weights per
113
+ head of shape :math:`(\text{num_heads}, L, S)` when input is unbatched or
114
+ :math:`(N, \text{num_heads}, L, S)` when input is batched.
115
+
116
+ Raises:
117
+ ValueError: If the init argument `embed_dim` is not divisible by `num_heads`.
118
+ TypeError: If the input argument `key_padding_mask` is not bool or floating types.
119
+
120
+ Supported Platforms:
121
+ ``Ascend`` ``GPU`` ``CPU``
122
+
123
+ Examples:
124
+ >>> import mindspore as ms
125
+ >>> import numpy as np
126
+ >>> embed_dim, num_heads = 128, 8
127
+ >>> seq_length, batch_size = 10, 8
128
+ >>> query = ms.Tensor(np.random.randn(seq_length, batch_size, embed_dim), ms.float32)
129
+ >>> key = ms.Tensor(np.random.randn(seq_length, batch_size, embed_dim), ms.float32)
130
+ >>> value = ms.Tensor(np.random.randn(seq_length, batch_size, embed_dim), ms.float32)
131
+ >>> multihead_attn = ms.nn.MultiheadAttention(embed_dim, num_heads)
132
+ >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
133
+ >>> print(attn_output.shape)
134
+ (10, 8, 128)
135
+ """
136
+
137
+ def __init__(self, embed_dim, num_heads, dropout=0.0, has_bias=True, add_bias_kv=False,
138
+ add_zero_attn=False, kdim=None, vdim=None, batch_first=False, dtype=mstype.float32):
139
+ super().__init__()
140
+ self.embed_dim = embed_dim
141
+ self.kdim = kdim if kdim is not None else embed_dim
142
+ self.vdim = vdim if vdim is not None else embed_dim
143
+ self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
144
+
145
+ self.num_heads = num_heads
146
+ self.dropout = dropout
147
+ self.batch_first = batch_first
148
+ self.head_dim = embed_dim // num_heads
149
+ if self.head_dim * num_heads != self.embed_dim:
150
+ raise ValueError("The init argument 'embed_dim' must be divisible by 'num_heads'.")
151
+
152
+ if dtype is None:
153
+ dtype = mindspore.float32
154
+ if not self._qkv_same_embed_dim:
155
+ self.q_proj_weight = Parameter(initializer(XavierUniform(), (embed_dim, embed_dim), dtype), 'q_proj_weight')
156
+ self.k_proj_weight = Parameter(initializer(XavierUniform(), (embed_dim, self.kdim), dtype), 'k_proj_weight')
157
+ self.v_proj_weight = Parameter(initializer(XavierUniform(), (embed_dim, self.vdim), dtype), 'v_proj_weight')
158
+ self.in_proj_weight = None
159
+ else:
160
+ self.in_proj_weight = Parameter(initializer(XavierUniform(), (3 * embed_dim, embed_dim), dtype),
161
+ 'in_proj_weight')
162
+ self.q_proj_weight = None
163
+ self.k_proj_weight = None
164
+ self.v_proj_weight = None
165
+
166
+ if has_bias:
167
+ self.in_proj_bias = Parameter(initializer('zeros', (3 * embed_dim), dtype), 'in_proj_bias')
168
+ else:
169
+ self.in_proj_bias = None
170
+ fan_in, _ = _calculate_fan_in_and_fan_out((embed_dim, embed_dim))
171
+ bound = 1 / math.sqrt(fan_in)
172
+ self.out_proj = Dense(embed_dim, embed_dim, has_bias=has_bias, weight_init=HeUniform(math.sqrt(5)),
173
+ bias_init=Uniform(bound), dtype=dtype)
174
+
175
+ if add_bias_kv:
176
+ self.bias_k = Parameter(initializer(XavierNormal(), (1, 1, embed_dim), dtype), 'bias_k')
177
+ self.bias_v = Parameter(initializer(XavierNormal(), (1, 1, embed_dim), dtype), 'bias_v')
178
+ else:
179
+ self.bias_k = self.bias_v = None
180
+
181
+ self.add_zero_attn = add_zero_attn
182
+ self.k_is_v = False
183
+ self.q_is_k = False
184
+ self.dtype = dtype
185
+
186
+ def __call__(self, *args, **kwargs):
187
+ query = kwargs.get('query', args[0])
188
+ key = kwargs.get('key', args[1])
189
+ value = kwargs.get('value', args[2])
190
+ self.k_is_v = key is value
191
+ self.q_is_k = query is key
192
+ return super().__call__(*args, **kwargs)
193
+
194
+ def construct(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: Optional[Tensor] = None,
195
+ need_weights: bool = True, attn_mask: Optional[Tensor] = None, average_attn_weights: bool = True):
196
+ is_batched = query.ndim == 3
197
+ if key_padding_mask is not None:
198
+ _kpm_dtype = key_padding_mask.dtype
199
+ if _kpm_dtype != mindspore.bool_ and not ops.is_floating_point(key_padding_mask):
200
+ raise ValueError(
201
+ "only bool and floating types of key_padding_mask are supported")
202
+
203
+ if self.batch_first and is_batched:
204
+ # k_is_v and q_is_k preprocess in __call__ since Graph mode do not support `is`
205
+ if self.k_is_v:
206
+ if self.q_is_k:
207
+ query = key = value = query.swapaxes(1, 0)
208
+ else:
209
+ query, key = [x.swapaxes(1, 0) for x in (query, key)]
210
+ value = key
211
+ else:
212
+ query, key, value = [x.swapaxes(1, 0) for x in (query, key, value)]
213
+
214
+ if not self._qkv_same_embed_dim:
215
+ attn_output, attn_output_weights = multi_head_attention_forward(
216
+ query, key, value, self.embed_dim, self.num_heads,
217
+ self.in_proj_weight, self.in_proj_bias,
218
+ self.bias_k, self.bias_v, self.add_zero_attn,
219
+ self.dropout, self.out_proj.weight, self.out_proj.bias,
220
+ training=self.training,
221
+ key_padding_mask=key_padding_mask,
222
+ attn_mask=attn_mask, use_separate_proj_weight=True,
223
+ q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
224
+ v_proj_weight=self.v_proj_weight, average_attn_weights=average_attn_weights,
225
+ k_is_v=self.k_is_v, q_is_k=self.q_is_k, dtype=self.dtype)
226
+ else:
227
+ attn_output, attn_output_weights = multi_head_attention_forward(
228
+ query, key, value, self.embed_dim, self.num_heads,
229
+ self.in_proj_weight, self.in_proj_bias,
230
+ self.bias_k, self.bias_v, self.add_zero_attn,
231
+ self.dropout, self.out_proj.weight, self.out_proj.bias,
232
+ training=self.training,
233
+ key_padding_mask=key_padding_mask,
234
+ attn_mask=attn_mask, average_attn_weights=average_attn_weights,
235
+ k_is_v=self.k_is_v, q_is_k=self.q_is_k, dtype=self.dtype)
236
+
237
+ if self.batch_first and is_batched:
238
+ attn_output = attn_output.swapaxes(1, 0)
239
+ if need_weights:
240
+ return attn_output, attn_output_weights
241
+ return (attn_output,)
242
+
243
+
244
+ class TransformerEncoderLayer(Cell):
245
+ r"""
246
+ Transformer Encoder Layer. This is an implementation of the single layer of the transformer
247
+ encoder layer, including multihead attention and feedward layer.
248
+
249
+ Args:
250
+ d_model (int): The number of features in the input tensor.
251
+ nhead (int): The number of heads in the MultiheadAttention modules.
252
+ dim_feedforward (int): The dimension of the feedforward layer. Default: ``2048``.
253
+ dropout (float): The dropout value. Default: ``0.1``.
254
+ activation (Union[str, callable, Cell]): The activation function of the intermediate layer,
255
+ can be a string (``"relu"`` or ``"gelu"``), Cell instance (:class:`mindspore.nn.ReLU` or
256
+ :class:`mindspore.nn.GELU` ) or a callable ( :func:`mindspore.ops.relu` or
257
+ :func:`mindspore.ops.gelu` ). Default: ``"relu"``.
258
+ layer_norm_eps (float): The epsilon value in LayerNorm modules. Default: ``1e-5``.
259
+ batch_first (bool): If `batch_first=True` , then the shape of input and output tensors is
260
+ :math:`(batch, seq, feature)` , otherwise the shape is :math:`(seq, batch, feature)` .
261
+ Default: ``False``.
262
+ norm_first (bool): If `norm_first = True`, layer norm is located prior to attention and feedforward
263
+ operations; if `norm_first = False`, layer norm is located after the attention and feedforward
264
+ operations. Default: ``False``.
265
+ dtype (:class:`mindspore.dtype`): Data type of Parameter. Default: ``mstype.float32`` .
266
+
267
+ Inputs:
268
+ - **src** (Tensor) - the sequence to the encoder layer. For unbatched input, the shape is
269
+ :math:`(S, E)` ; otherwise if `batch_first=False` , the shape is :math:`(S, N, E)` and if
270
+ `batch_first=True` , the shape is :math:`(N, S, E)`, where :math:`(S)` is the source sequence
271
+ length, :math:`(N)` is the batch number and :math:`(E)` is the feature number.
272
+ Supported types: float16, float32, float64.
273
+ - **src_mask** (Tensor, optional) - the mask for the src sequence. The shape is :math:`(S, S)`
274
+ or :math:`(N*nhead, S, S)`. Supported types: float16, float32, float64, bool. Default: ``None``.
275
+ - **src_key_padding_mask** (Tensor, optional) - the mask for the src keys per batch. The shape is
276
+ :math:`(S)` for unbatched input, otherwise :math:`(N, S)` . Supported types: float16, float32,
277
+ float64, bool. Default: ``None``.
278
+
279
+ Outputs:
280
+ Tensor. The shape and dtype of Tensor is the same with `src` .
281
+
282
+ Raises:
283
+ ValueError: If the init argument `activation` is not str, callable or Cell instance.
284
+ ValueError: If the init argument `activation` is not :class:`mindspore.nn.ReLU`,
285
+ :class:`mindspore.nn.GELU` instance, :func:`mindspore.ops.relu`,
286
+ :func:`mindspore.ops.gelu`, "relu" or "gelu" .
287
+
288
+ Supported Platforms:
289
+ ``Ascend`` ``GPU`` ``CPU``
290
+
291
+ Examples:
292
+ >>> import mindspore as ms
293
+ >>> import numpy as np
294
+ >>> encoder_layer = ms.nn.TransformerEncoderLayer(d_model=512, nhead=8)
295
+ >>> src = ms.Tensor(np.random.rand(10, 32, 512), ms.float32)
296
+ >>> out = encoder_layer(src)
297
+ >>> print(out.shape)
298
+ (10, 32, 512)
299
+ >>> # Alternatively, when batch_first=True:
300
+ >>> encoder_layer = ms.nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True)
301
+ >>> src = ms.Tensor(np.random.rand(32, 10, 512), ms.float32)
302
+ >>> out = encoder_layer(src)
303
+ >>> print(out.shape)
304
+ (32, 10, 512)
305
+ """
306
+
307
+ def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1,
308
+ activation: Union[str, Cell, callable] = 'relu', layer_norm_eps: float = 1e-5,
309
+ batch_first: bool = False, norm_first: bool = False, dtype=mstype.float32):
310
+ super().__init__()
311
+ self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first, dtype=dtype)
312
+ # feedforward layer
313
+ fan_in, _ = _calculate_fan_in_and_fan_out((dim_feedforward, d_model))
314
+ bound = 1 / math.sqrt(fan_in)
315
+ self.dense1 = Dense(d_model, dim_feedforward, weight_init=HeUniform(math.sqrt(5)),
316
+ bias_init=Uniform(bound), dtype=dtype)
317
+ self.dropout = Dropout(p=dropout)
318
+ fan_in1, _ = _calculate_fan_in_and_fan_out((d_model, dim_feedforward))
319
+ bound1 = 1 / math.sqrt(fan_in1)
320
+ self.dense2 = Dense(dim_feedforward, d_model, weight_init=HeUniform(math.sqrt(5)),
321
+ bias_init=Uniform(bound1), dtype=dtype)
322
+
323
+ self.norm_first = norm_first
324
+ self.norm1 = LayerNorm((d_model,), epsilon=layer_norm_eps, dtype=dtype)
325
+ self.norm2 = LayerNorm((d_model,), epsilon=layer_norm_eps, dtype=dtype)
326
+ self.dropout1 = Dropout(p=dropout)
327
+ self.dropout2 = Dropout(p=dropout)
328
+ self.activation1 = activation
329
+
330
+ if not isinstance(activation, str) and not isinstance(activation, Cell) \
331
+ and not callable(activation):
332
+ raise ValueError(f"The argument 'activation' must be str, callable or Cell instance,"
333
+ f" but get {activation}.")
334
+ if isinstance(activation, Cell) and (not isinstance(activation, ReLU) and \
335
+ not isinstance(activation, GELU)):
336
+ raise ValueError(f"The argument 'activation' must be nn.ReLU or nn.GELU instance,"
337
+ f" but get {activation}.")
338
+ if callable(activation) and (activation is not ops.relu and \
339
+ activation is not ops.gelu):
340
+ raise ValueError(f"The argument 'activation' must be ops.relu or ops.gelu instance,"
341
+ f" but get {activation}.")
342
+ # string inputs of activation
343
+ if isinstance(activation, str):
344
+ activation = _get_activation_fn(activation)
345
+ self.activation = activation
346
+ self.d_model = d_model
347
+ self.nhead = nhead
348
+ self.dim_feedforward = dim_feedforward
349
+ self.dropout_num = dropout
350
+ self.layernorm_eps = layer_norm_eps
351
+ self.batch_first = batch_first
352
+ self.norm_first = norm_first
353
+ self.dtype = dtype
354
+
355
+ def construct(self, src: Tensor, src_mask: Optional[Tensor] = None,
356
+ src_key_padding_mask: Optional[Tensor] = None):
357
+ if src_key_padding_mask is not None:
358
+ _skpm_dtype = src_key_padding_mask.dtype
359
+ if _skpm_dtype != mindspore.bool_ and not ops.is_floating_point(src_key_padding_mask):
360
+ raise AssertionError(
361
+ "only bool and floating types of key_padding_mask are supported")
362
+
363
+ x = src
364
+ if self.norm_first:
365
+ x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask)
366
+ x = x + self._ff_block(self.norm2(x))
367
+ else:
368
+ x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask))
369
+ x = self.norm2(x + self._ff_block(x))
370
+
371
+ return x
372
+
373
+ def _sa_block(self, x, attn_mask, key_padding_mask):
374
+ x = self.self_attn(x, x, x,
375
+ attn_mask=attn_mask,
376
+ key_padding_mask=key_padding_mask,
377
+ need_weights=False)[0]
378
+ return self.dropout1(x)
379
+
380
+ def _ff_block(self, x):
381
+ x = self.dense2(self.dropout(self.activation(self.dense1(x))))
382
+ return self.dropout2(x)
383
+
384
+
385
+ class TransformerDecoderLayer(Cell):
386
+ r"""
387
+ Transformer Decoder Layer. This is an implementation of the single layer of the transformer
388
+ decoder layer, including self-attention, cross attention and feedward layer.
389
+
390
+ Args:
391
+ d_model (int): The number of expected features in the input tensor.
392
+ nhead (int): The number of heads in the MultiheadAttention modules.
393
+ dim_feedforward (int): The dimension of the feedforward layer. Default: ``2048``.
394
+ dropout (float): The dropout value. Default: ``0.1``.
395
+ activation (Union[str, callable, Cell]): The activation function of the intermediate layer,
396
+ can be a string (``"relu"`` or ``"gelu"``), Cell instance (:class:`mindspore.nn.ReLU` or
397
+ :class:`mindspore.nn.GELU` ) or a callable ( :func:`mindspore.ops.relu` or
398
+ :func:`mindspore.ops.gelu` ). Default: ``"relu"``.
399
+ layer_norm_eps (float): The epsilon value in LayerNorm modules. Default: ``1e-5``.
400
+ batch_first (bool): If `batch_first=True` , then the shape of input and output tensors is
401
+ :math:`(batch, seq, feature)` , otherwise the shape is :math:`(seq, batch, feature)`.
402
+ Default: ``False``.
403
+ norm_first (bool): If `norm_first = True`, layer norm is located prior to attention and feedforward
404
+ operations; if `norm_first = False`, layer norm is located after the attention and feedforward
405
+ operations. Default: ``False``.
406
+ dtype (:class:`mindspore.dtype`): Data type of Parameter. Default: ``mstype.float32`` .
407
+
408
+ Inputs:
409
+ - **tgt** (Tensor) - The sequence to the decoder layer. For unbatched input, the shape is
410
+ :math:`(T, E)` ; otherwise if `batch_first=False` , the shape is :math:`(T, N, E)` and if
411
+ `batch_first=True` , the shape is :math:`(N, T, E)`, where :math:`(T)` is the target sequence
412
+ length. Supported types: float16, float32, float64.
413
+ - **memory** (Tensor) - The sequence from the last layer of the encoder. Supported types: float16,
414
+ float32, float64.
415
+ - **tgt_mask** (Tensor, optional) - The mask of the tgt sequence. The shape is :math:`(T, T)`
416
+ or :math:`(N*nhead, T, T)`. Supported types: float16, float32, float64, bool. Default: ``None``.
417
+ - **memory_mask** (Tensor, optional) - The mask of the memory sequence. The shape is
418
+ :math:`(T, S)` . Supported types: float16, float32, float64, bool. Default: ``None``.
419
+ - **tgt_key_padding_mask** (Tensor, optional): The mask of the tgt keys per batch. The shape is
420
+ :math:`(T)` for unbatched input, otherwise :math:`(N, T)` . Supported types: float16, float32,
421
+ float64, bool. Default: ``None``.
422
+ - **memory_key_padding_mask** (Tensor, optional) - The mask of the memory keys per batch. The shape
423
+ is :math:`(S)` for unbatched input, otherwise :math:`(N, S)` . Supported types: float16, float32,
424
+ float64, bool. Default: ``None``.
425
+
426
+ Outputs:
427
+ Tensor. The shape and dtype of Tensor is the same with `tgt` .
428
+
429
+ Raises:
430
+ ValueError: If the init argument `activation` is not str, callable or Cell instance.
431
+ ValueError: If the init argument `activation` is not :class:`mindspore.nn.ReLU`,
432
+ :class:`mindspore.nn.GELU` instance, :func:`mindspore.ops.relu`,
433
+ :func:`mindspore.ops.gelu` , "relu" or "gelu" .
434
+
435
+ Supported Platforms:
436
+ ``Ascend`` ``GPU`` ``CPU``
437
+
438
+ Examples:
439
+ >>> import mindspore as ms
440
+ >>> import numpy as np
441
+ >>> decoder_layer = ms.nn.TransformerDecoderLayer(d_model=512, nhead=8)
442
+ >>> memory = ms.Tensor(np.random.rand(10, 32, 512), ms.float32)
443
+ >>> tgt = ms.Tensor(np.random.rand(20, 32, 512), ms.float32)
444
+ >>> out = decoder_layer(tgt, memory)
445
+ >>> print(out.shape)
446
+ (20, 32, 512)
447
+ >>> # Alternatively, when `batch_first` is ``True``:
448
+ >>> decoder_layer = ms.nn.TransformerDecoderLayer(d_model=512, nhead=8, batch_first=True)
449
+ >>> memory = ms.Tensor(np.random.rand(32, 10, 512), ms.float32)
450
+ >>> tgt = ms.Tensor(np.random.rand(32, 20, 512), ms.float32)
451
+ >>> out = decoder_layer(tgt, memory)
452
+ >>> print(out.shape)
453
+ (32, 20, 512)
454
+ """
455
+
456
+ def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1,
457
+ activation: Union[str, Cell, callable] = 'relu', layer_norm_eps: float = 1e-5,
458
+ batch_first: bool = False, norm_first: bool = False, dtype=mstype.float32):
459
+ super().__init__()
460
+ self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first, dtype=dtype)
461
+ self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first, dtype=dtype)
462
+ # feedforward layer
463
+ fan_in, _ = _calculate_fan_in_and_fan_out((dim_feedforward, d_model))
464
+ bound = 1 / math.sqrt(fan_in)
465
+ self.dense1 = Dense(d_model, dim_feedforward, weight_init=HeUniform(math.sqrt(5)),
466
+ bias_init=Uniform(bound), dtype=dtype)
467
+ self.dropout = Dropout(p=dropout)
468
+ fan_in1, _ = _calculate_fan_in_and_fan_out((d_model, dim_feedforward))
469
+ bound1 = 1 / math.sqrt(fan_in1)
470
+ self.dense2 = Dense(dim_feedforward, d_model, weight_init=HeUniform(math.sqrt(5)),
471
+ bias_init=Uniform(bound1), dtype=dtype)
472
+
473
+ self.norm_first = norm_first
474
+ self.norm1 = LayerNorm((d_model,), epsilon=layer_norm_eps, dtype=dtype)
475
+ self.norm2 = LayerNorm((d_model,), epsilon=layer_norm_eps, dtype=dtype)
476
+ self.norm3 = LayerNorm((d_model,), epsilon=layer_norm_eps, dtype=dtype)
477
+ self.dropout1 = Dropout(p=dropout)
478
+ self.dropout2 = Dropout(p=dropout)
479
+ self.dropout3 = Dropout(p=dropout)
480
+ self.activation1 = activation
481
+
482
+ if not isinstance(activation, str) and not isinstance(activation, Cell) \
483
+ and not callable(activation):
484
+ raise ValueError(f"The argument 'activation' must be str, callable or Cell instance,"
485
+ f" but get {activation}.")
486
+ if isinstance(activation, Cell) and (not isinstance(activation, ReLU) and \
487
+ not isinstance(activation, GELU)):
488
+ raise ValueError(f"The argument 'activation' must be nn.ReLU or nn.GELU instance,"
489
+ f" but get {activation}.")
490
+ if callable(activation) and (activation is not ops.relu and \
491
+ activation is not ops.gelu):
492
+ raise ValueError(f"The argument 'activation' must be ops.relu or ops.gelu instance,"
493
+ f" but get {activation}.")
494
+ # string inputs of activation
495
+ if isinstance(activation, str):
496
+ activation = _get_activation_fn(activation)
497
+ self.activation = activation
498
+ self.d_model = d_model
499
+ self.nhead = nhead
500
+ self.dim_feedforward = dim_feedforward
501
+ self.dropout_num = dropout
502
+ self.layernorm_eps = layer_norm_eps
503
+ self.batch_first = batch_first
504
+ self.norm_first = norm_first
505
+ self.dtype = dtype
506
+
507
+ def construct(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None,
508
+ memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None,
509
+ memory_key_padding_mask: Optional[Tensor] = None):
510
+ x = tgt
511
+ if self.norm_first:
512
+ x = x + self._sa_block(self.norm1(x), tgt_mask, tgt_key_padding_mask)
513
+ x = x + self._mha_block(self.norm2(x), memory, memory_mask, memory_key_padding_mask)
514
+ x = x + self._ff_block(self.norm3(x))
515
+ else:
516
+ x = self.norm1(x + self._sa_block(x, tgt_mask, tgt_key_padding_mask))
517
+ x = self.norm2(x + self._mha_block(x, memory, memory_mask, memory_key_padding_mask))
518
+ x = self.norm3(x + self._ff_block(x))
519
+
520
+ return x
521
+
522
+ def _sa_block(self, x, attn_mask, key_padding_mask):
523
+ x = self.self_attn(x, x, x,
524
+ attn_mask=attn_mask,
525
+ key_padding_mask=key_padding_mask,
526
+ need_weights=False)[0]
527
+ return self.dropout1(x)
528
+
529
+ def _mha_block(self, x, mem, attn_mask, key_padding_mask):
530
+ x = self.multihead_attn(x, mem, mem,
531
+ attn_mask=attn_mask,
532
+ key_padding_mask=key_padding_mask,
533
+ need_weights=False)[0]
534
+ return self.dropout2(x)
535
+
536
+ def _ff_block(self, x):
537
+ x = self.dense2(self.dropout(self.activation(self.dense1(x))))
538
+ return self.dropout3(x)
539
+
540
+
541
+ class TransformerEncoder(Cell):
542
+ r"""
543
+ Transformer Encoder module with multi-layer stacked of :class:`mindspore.nn.TransformerEncoderLayer`,
544
+ including multihead attention and feedforward layer. Users can build the
545
+ BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters.
546
+
547
+ Args:
548
+ encoder_layer (Cell): An instance of the :class:`mindspore.nn.TransformerEncoderLayer` class.
549
+ num_layers (int): The number of encoder-layers in the encoder.
550
+ norm (Cell, optional): The layer normalization module. Default: ``None``.
551
+
552
+ Inputs:
553
+ - **src** (Tensor) - The sequence to the encoder. For unbatched input, the shape is
554
+ :math:`(S, E)` ; otherwise if `batch_first=False` in :class:`mindspore.nn.TransformerEncoderLayer`,
555
+ the shape is :math:`(S, N, E)` and if `batch_first=True` , the shape is :math:`(N, S, E)`,
556
+ where :math:`(S)` is the source sequence length, :math:`(N)` is the batch number and :math:`(E)` is
557
+ the feature number. Supported types: float16, float32, float64.
558
+ - **src_mask** (Tensor, optional) - The mask of the src sequence. The shape is :math:`(S, S)`
559
+ or :math:`(N*nhead, S, S)` , where `nhead` is the arguent in :class:`mindspore.nn.TransformerEncoderLayer`.
560
+ Supported types: float16, float32, float64, bool. Default: ``None``.
561
+ - **src_key_padding_mask** (Tensor, optional) - the mask of the src keys per batch. The shape is
562
+ :math:`(S)` for unbatched input, otherwise :math:`(N, S)` . Supported types: float16, float32,
563
+ float64, bool. Default: ``None``.
564
+
565
+ Outputs:
566
+ Tensor. The shape and dtype of Tensor is the same with `src` .
567
+
568
+ Raises:
569
+ AssertionError: If the input argument `src_key_padding_mask` is not bool or floating types.
570
+
571
+ Supported Platforms:
572
+ ``Ascend`` ``GPU`` ``CPU``
573
+
574
+ Examples:
575
+ >>> import mindspore as ms
576
+ >>> import numpy as np
577
+ >>> encoder_layer = ms.nn.TransformerEncoderLayer(d_model=512, nhead=8)
578
+ >>> transformer_encoder = ms.nn.TransformerEncoder(encoder_layer, num_layers=6)
579
+ >>> src = ms.Tensor(np.random.rand(10, 32, 512), ms.float32)
580
+ >>> out = transformer_encoder(src)
581
+ >>> print(out.shape)
582
+ (10, 32, 512)
583
+ """
584
+
585
+ def __init__(self, encoder_layer, num_layers, norm=None):
586
+ super(TransformerEncoder, self).__init__()
587
+ layers = TransformerEncoderLayer(encoder_layer.d_model, encoder_layer.nhead, encoder_layer.dim_feedforward,
588
+ encoder_layer.dropout_num, encoder_layer.activation1,
589
+ encoder_layer.layernorm_eps, encoder_layer.batch_first,
590
+ encoder_layer.norm_first, dtype=encoder_layer.dtype)
591
+ self.layers = CellList([copy.deepcopy(layers) for _ in range(num_layers)])
592
+ self.num_layers = num_layers
593
+ self.norm = norm
594
+
595
+ def construct(self, src: Tensor, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None):
596
+ if src_key_padding_mask is not None:
597
+ _skpm_dtype = src_key_padding_mask.dtype
598
+ if _skpm_dtype != mindspore.bool_ and not ops.is_floating_point(src_key_padding_mask):
599
+ raise AssertionError(
600
+ "only bool and floating types of key_padding_mask are supported")
601
+ output = src
602
+ src_key_padding_mask_for_layers = src_key_padding_mask
603
+ for mod in self.layers:
604
+ output = mod(output, src_mask=src_mask, src_key_padding_mask=src_key_padding_mask_for_layers)
605
+
606
+ if self.norm is not None:
607
+ output = self.norm(output)
608
+
609
+ return output
610
+
611
+
612
+ class TransformerDecoder(Cell):
613
+ r"""
614
+ Transformer Decoder module with multi-layer stacked of :class:`mindspore.nn.TransformerDecoderLayer`,
615
+ including multihead self attention, cross attention and feedforward layer.
616
+
617
+ Args:
618
+ decoder_layer (Cell): An instance of the :class:`mindspore.nn.TransformerDecoderLayer` class.
619
+ num_layers (int): The number of decoder-layers in the decoder.
620
+ norm (Cell, optional): The layer normalization module. Default: ``None``.
621
+
622
+ Inputs:
623
+ - **tgt** (Tensor) - The sequence to the decoder. For unbatched input, the shape is
624
+ :math:`(T, E)` ; otherwise if `batch_first=False` in :class:`mindspore.nn.TransformerDecoderLayer`,
625
+ the shape is :math:`(T, N, E)` and if `batch_first=True` , the shape is :math:`(N, T, E)`,
626
+ where :math:`(T)` is the target sequence length, :math:`(N)` is the number of batches,
627
+ and :math:`(E)` is the number of features. Supported types: float16, float32, float64.
628
+ - **memory** (Tensor) - The sequence from the last layer of the encoder. Supported types: float16,
629
+ float32, float64.
630
+ - **tgt_mask** (Tensor, optional) - the mask of the tgt sequence. The shape is :math:`(T, T)`
631
+ or :math:`(N*nhead, T, T)` , where `nhead` is the arguent in :class:`mindspore.nn.TransformerDecoderLayer`.
632
+ Supported types: float16, float32, float64, bool. Default: ``None``.
633
+ - **memory_mask** (Tensor, optional) - the mask of the memory sequence. The shape is
634
+ :math:`(T, S)` . Supported types: float16, float32, float64, bool. Default: ``None``.
635
+ - **tgt_key_padding_mask** (Tensor, optional) - the mask of the tgt keys per batch.
636
+ Shape is :math:`(T)`. Supported
637
+ types: float16, float32, float64, bool. Default: ``None``.
638
+ - **memory_key_padding_mask** (Tensor, optional) - the mask of the memory keys per batch. The shape
639
+ is :math:`(S)` for unbatched input, otherwise :math:`(N, S)` . Supported types: float16, float32,
640
+ float64, bool. Default: ``None``.
641
+
642
+ Outputs:
643
+ Tensor. The shape and dtype of Tensor is the same with `tgt` .
644
+
645
+ Supported Platforms:
646
+ ``Ascend`` ``GPU`` ``CPU``
647
+
648
+ Examples:
649
+ >>> import mindspore as ms
650
+ >>> import numpy as np
651
+ >>> decoder_layer = ms.nn.TransformerDecoderLayer(d_model=512, nhead=8)
652
+ >>> transformer_decoder = ms.nn.TransformerDecoder(decoder_layer, num_layers=6)
653
+ >>> memory = ms.Tensor(np.random.rand(10, 32, 512), ms.float32)
654
+ >>> tgt = ms.Tensor(np.random.rand(20, 32, 512), ms.float32)
655
+ >>> out = transformer_decoder(tgt, memory)
656
+ >>> print(out.shape)
657
+ (20, 32, 512)
658
+ """
659
+
660
+ def __init__(self, decoder_layer, num_layers, norm=None):
661
+ super(TransformerDecoder, self).__init__()
662
+ layers = TransformerDecoderLayer(decoder_layer.d_model, decoder_layer.nhead, decoder_layer.dim_feedforward,
663
+ decoder_layer.dropout_num, decoder_layer.activation1,
664
+ decoder_layer.layernorm_eps, decoder_layer.batch_first,
665
+ decoder_layer.norm_first, dtype=decoder_layer.dtype)
666
+ self.layers = CellList([copy.deepcopy(layers) for _ in range(num_layers)])
667
+ self.num_layers = num_layers
668
+ self.norm = norm
669
+
670
+ def construct(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None,
671
+ memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None,
672
+ memory_key_padding_mask: Optional[Tensor] = None):
673
+ output = tgt
674
+ for mod in self.layers:
675
+ output = mod(output, memory, tgt_mask=tgt_mask,
676
+ memory_mask=memory_mask,
677
+ tgt_key_padding_mask=tgt_key_padding_mask,
678
+ memory_key_padding_mask=memory_key_padding_mask)
679
+
680
+ if self.norm is not None:
681
+ output = self.norm(output)
682
+
683
+ return output
684
+
685
+
686
+ class Transformer(Cell):
687
+ r"""
688
+ Transformer module including encoder and decoder. The difference with the original implements is the module use
689
+ the residual addition before the layer normalization. And the default hidden activation is `gelu`.
690
+ The details can be found in `Attention is all you need <https://arxiv.org/pdf/1706.03762v5.pdf>`_.
691
+
692
+ Args:
693
+ d_model (int): The number of expected features in the inputs tensor for Encoder and Decoder. Default: ``512``.
694
+ nhead (int): The number of heads in the MultiheadAttention modules. Default: ``8``.
695
+ num_encoder_layers (int): The number of encoder-layers in the encoder. Default: ``6``.
696
+ num_decoder_layers (int): The number of decoder-layers in the decoder. Default: ``6``.
697
+ dim_feedforward (int): The dimension of the feedforward layer. Default: ``2048``.
698
+ dropout (float): The dropout value. Default: ``0.1``.
699
+ activation (Union[str, callable, Cell]): The activation function of the intermediate layer,
700
+ can be a string (``"relu"`` or ``"gelu"``), Cell instance (:class:`mindspore.nn.ReLU` or
701
+ :class:`mindspore.nn.GELU` ) or a callable ( :func:`mindspore.ops.relu` or
702
+ :func:`mindspore.ops.gelu` ). Default: ``"relu"``.
703
+ custom_encoder (Cell): Custom encoder. Default: ``None``.
704
+ custom_decoder (Cell): Custom decoder. Default: ``None``.
705
+ layer_norm_eps (float): the epsilion value in layer normalization module. Default: ``1e-5``.
706
+ batch_first (bool): If `batch_first=True`, then the shape of input and output tensors is
707
+ :math:`(batch, seq, feature)` , otherwise the shape is :math:`(seq, batch, feature)` .
708
+ Default: ``False``.
709
+ norm_first (bool): If `norm_first = True`, layer norm is located prior to attention and feedforward
710
+ operations; if `norm_first = False`, layer norm is located after the attention and feedforward
711
+ operations. Default: ``False``.
712
+ dtype (:class:`mindspore.dtype`): Data type of Parameter. Default: ``mstype.float32`` .
713
+
714
+ Inputs:
715
+ - **src** (Tensor) - The source sequence to the encoder. For unbatched input, the shape is
716
+ :math:`(S, E)` ; otherwise if `batch_first=False` , the shape is :math:`(S, N, E)` and if
717
+ `batch_first=True` , the shape is :math:`(N, S, E)`, where :math:`(S)` is the source sequence
718
+ length, :math:`(N)` is the batch number and :math:`(E)` is the feature number. Supported
719
+ types: float16, float32, float64.
720
+ - **tgt** (Tensor) - The target sequence to the decoder. For unbatched input, the shape is
721
+ :math:`(T, E)` ; otherwise if `batch_first=False` , the shape is :math:`(T, N, E)` and if
722
+ `batch_first=True` , the shape is :math:`(N, T, E)`, where :math:`(T)` is the target sequence
723
+ length. Supported types: float16, float32, float64.
724
+ - **src_mask** (Tensor, optional) - The mask of the src sequence. The shape is :math:`(S, S)`
725
+ or :math:`(N*nhead, S, S)`. Supported types: float16, float32, float64, bool. Default: ``None``.
726
+ - **tgt_mask** (Tensor, optional) - The mask of the tgt sequence. The shape is :math:`(T, T)`
727
+ or :math:`(N*nhead, T, T)`. Supported types: float16, float32, float64, bool. Default: ``None``.
728
+ - **memory_mask** (Tensor, optional) - The additive mask of the encoder output. The shape is
729
+ :math:`(T, S)` . Supported types: float16, float32, float64, bool. Default: ``None``.
730
+ - **src_key_padding_mask** (Tensor, optional) - The mask of src keys per batch. The shape is
731
+ :math:`(S)` for unbatched input, otherwise :math:`(N, S)` . Supported types: float16, float32,
732
+ float64, bool. Default: ``None``.
733
+ - **tgt_key_padding_mask** (Tensor, optional) - The mask of tgt keys per batch. The shape is
734
+ :math:`(T)` for unbatched input, otherwise :math:`(N, S)` . Supported types: float16, float32,
735
+ float64, bool. Default: ``None``.
736
+ - **memory_key_padding_mask** (Tensor, optional) - The mask of memory keys per batch. The shape
737
+ is :math:`(S)` for unbatched input, otherwise :math:`(N, S)` . Supported types: float16,
738
+ float32, float64, bool. Default: ``None``.
739
+
740
+ Outputs:
741
+ Tensor. The shape is :math:`(T, E)` for unbatched input, otherwise if `batch_first=False` , the shape is
742
+ :math:`(T, N, E)` and if `batch_first=True` , the shape is :math:`(N, T, E)`.
743
+
744
+ Raises:
745
+ ValueError: If the batch sizes of the init argument `src` and `tgt` are not equal.
746
+ ValueError: If the number of features of the init argument `src` and `tgt` is not equal to that of `d_model`.
747
+
748
+ Supported Platforms:
749
+ ``Ascend`` ``GPU`` ``CPU``
750
+
751
+ Examples:
752
+ >>> import mindspore as ms
753
+ >>> import numpy as np
754
+ >>> transformer_model = ms.nn.Transformer(nhead=16, num_encoder_layers=12)
755
+ >>> src = ms.Tensor(np.random.rand(10, 32, 512), ms.float32)
756
+ >>> tgt = ms.Tensor(np.random.rand(20, 32, 512), ms.float32)
757
+ >>> out = transformer_model(src, tgt)
758
+ >>> print(out.shape)
759
+ (20, 32, 512)
760
+ """
761
+
762
+ def __init__(self, d_model: int = 512, nhead: int = 8, num_encoder_layers: int = 6,
763
+ num_decoder_layers: int = 6, dim_feedforward: int = 2048, dropout: float = 0.1,
764
+ activation: Union[str, Cell, callable] = 'relu', custom_encoder: Optional[Cell] = None,
765
+ custom_decoder: Optional[Cell] = None, layer_norm_eps: float = 1e-5,
766
+ batch_first: bool = False, norm_first: bool = False, dtype=mstype.float32):
767
+ super(Transformer, self).__init__()
768
+
769
+ if custom_encoder is not None:
770
+ self.encoder = custom_encoder
771
+ else:
772
+ encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout,
773
+ activation, layer_norm_eps, batch_first, norm_first, dtype=dtype)
774
+ encoder_norm = LayerNorm((d_model,), epsilon=layer_norm_eps, dtype=dtype)
775
+ self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
776
+
777
+ if custom_decoder is not None:
778
+ self.decoder = custom_decoder
779
+ else:
780
+ decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout,
781
+ activation, layer_norm_eps, batch_first, norm_first, dtype=dtype)
782
+ decoder_norm = LayerNorm((d_model,), epsilon=layer_norm_eps, dtype=dtype)
783
+ self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm)
784
+
785
+ for _, p in self.parameters_and_names():
786
+ if p.ndim > 1:
787
+ p.set_data(initializer('xavier_uniform', p.shape, p.dtype))
788
+
789
+ self.d_model = d_model
790
+ self.nhead = nhead
791
+
792
+ self.batch_first = batch_first
793
+
794
+ def construct(self, src: Tensor, tgt: Tensor, src_mask: Optional[Tensor] = None, tgt_mask: Optional[Tensor] = None,
795
+ memory_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None,
796
+ tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None):
797
+ is_batched = src.ndim == 3
798
+ if self.batch_first:
799
+ src_batch_size = src.shape[0]
800
+ tgt_batch_size = src.shape[0]
801
+ else:
802
+ src_batch_size = src.shape[1]
803
+ tgt_batch_size = src.shape[1]
804
+ if src_batch_size != tgt_batch_size and is_batched:
805
+ raise ValueError("The number of batch size for 'src' and 'tgt' must be equal.")
806
+
807
+ if src.shape[-1] != self.d_model or tgt.shape[-1] != self.d_model:
808
+ raise ValueError("The number of features for 'src' and 'tgt' must be equal to `d_model`.")
809
+
810
+ memory = self.encoder(src, src_mask=src_mask, src_key_padding_mask=src_key_padding_mask)
811
+ output = self.decoder(tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask,
812
+ tgt_key_padding_mask=tgt_key_padding_mask,
813
+ memory_key_padding_mask=memory_key_padding_mask)
814
+ return output
815
+
816
+
817
+ def _get_activation_fn(activation: str):
818
+ if activation == "relu":
819
+ return ops.relu
820
+ if activation == "gelu":
821
+ return ops.gelu
822
+
823
+ raise ValueError(f"The activation must be relu/gelu, but get {activation}")