mindspore 2.0.0rc1__cp38-cp38-manylinux1_x86_64.whl → 2.2.0__cp38-cp38-manylinux1_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (884) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Third_Party_Open_Source_Software_Notice +2 -2
  3. mindspore/__init__.py +5 -2
  4. mindspore/_akg/akg/build_module.py +5 -6
  5. mindspore/_akg/akg/composite/build_module.py +49 -16
  6. mindspore/_akg/akg/composite/split_stitch.py +10 -11
  7. mindspore/_akg/akg/config/repository.json +195 -0
  8. mindspore/_akg/akg/global_configs.py +5 -1
  9. mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
  10. mindspore/_akg/akg/tvm/api.py +4 -3
  11. mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
  12. mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
  13. mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
  14. mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
  15. mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
  16. mindspore/_akg/akg/tvm/build_module.py +16 -1
  17. mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
  18. mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
  19. mindspore/_akg/akg/tvm/ir_builder.py +1 -1
  20. mindspore/_akg/akg/tvm/module.py +1 -2
  21. mindspore/_akg/akg/tvm/stmt.py +2 -2
  22. mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
  23. mindspore/_akg/akg/utils/kernel_exec.py +58 -260
  24. mindspore/_akg/akg/utils/op_dsl.py +17 -1
  25. mindspore/_akg/akg/utils/result_analysis.py +4 -24
  26. mindspore/_akg/akg/utils/tbe_codegen_utils.py +198 -0
  27. mindspore/_c_dataengine.cpython-38-x86_64-linux-gnu.so +0 -0
  28. mindspore/_c_expression.cpython-38-x86_64-linux-gnu.so +0 -0
  29. mindspore/_c_mindrecord.cpython-38-x86_64-linux-gnu.so +0 -0
  30. mindspore/_check_jit_forbidden_api.py +5 -1
  31. mindspore/_checkparam.py +79 -62
  32. mindspore/_extends/graph_kernel/__init__.py +0 -1
  33. mindspore/_extends/graph_kernel/model/graph_split.py +2 -0
  34. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  35. mindspore/_extends/graph_kernel/splitter.py +1 -9
  36. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +128 -21
  37. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +2 -2
  38. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  39. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +18 -13
  40. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +13 -9
  41. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  42. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  43. mindspore/_extends/parse/__init__.py +19 -17
  44. mindspore/_extends/parse/namespace.py +7 -36
  45. mindspore/_extends/parse/parser.py +375 -189
  46. mindspore/_extends/parse/resources.py +36 -41
  47. mindspore/_extends/parse/standard_method.py +350 -245
  48. mindspore/_extends/parse/trope.py +2 -12
  49. mindspore/_extends/remote/kernel_build_server.py +24 -7
  50. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  51. mindspore/_install_custom.py +43 -0
  52. mindspore/_mindspore_offline_debug.cpython-38-x86_64-linux-gnu.so +0 -0
  53. mindspore/amp.py +85 -19
  54. mindspore/bin/cache_admin +0 -0
  55. mindspore/bin/cache_server +0 -0
  56. mindspore/boost/base.py +2 -2
  57. mindspore/boost/boost.py +27 -32
  58. mindspore/boost/boost_cell_wrapper.py +37 -13
  59. mindspore/boost/grad_accumulation.py +1 -1
  60. mindspore/boost/grad_freeze.py +34 -6
  61. mindspore/boost/group_loss_scale_manager.py +15 -14
  62. mindspore/boost/less_batch_normalization.py +28 -3
  63. mindspore/common/__init__.py +15 -11
  64. mindspore/common/_auto_dynamic.py +68 -0
  65. mindspore/common/_jit_fallback_utils.py +111 -0
  66. mindspore/common/_register_for_adapter.py +17 -5
  67. mindspore/common/_register_for_tensor.py +2 -2
  68. mindspore/common/_stub_tensor.py +18 -15
  69. mindspore/common/_utils.py +31 -7
  70. mindspore/common/api.py +269 -101
  71. mindspore/common/auto_dynamic_shape.py +498 -0
  72. mindspore/common/dtype.py +61 -21
  73. mindspore/common/dump.py +9 -7
  74. mindspore/common/initializer.py +106 -76
  75. mindspore/common/jit_config.py +35 -14
  76. mindspore/common/lazy_inline.py +187 -0
  77. mindspore/common/mindir_util.py +101 -0
  78. mindspore/common/mutable.py +10 -13
  79. mindspore/common/parameter.py +246 -55
  80. mindspore/common/seed.py +13 -7
  81. mindspore/common/sparse_tensor.py +29 -33
  82. mindspore/common/tensor.py +907 -251
  83. mindspore/communication/__init__.py +7 -4
  84. mindspore/communication/_comm_helper.py +84 -4
  85. mindspore/communication/management.py +160 -88
  86. mindspore/config/op_info.config +99 -75
  87. mindspore/config/super_bar_config.json +36 -4
  88. mindspore/context.py +526 -219
  89. mindspore/dataset/__init__.py +9 -46
  90. mindspore/dataset/audio/__init__.py +4 -19
  91. mindspore/dataset/audio/transforms.py +545 -233
  92. mindspore/dataset/audio/utils.py +21 -18
  93. mindspore/dataset/callback/ds_callback.py +42 -13
  94. mindspore/dataset/core/config.py +158 -100
  95. mindspore/dataset/core/validator_helpers.py +1 -63
  96. mindspore/dataset/debug/debug_hook.py +45 -13
  97. mindspore/dataset/debug/pre_defined_hook.py +5 -5
  98. mindspore/dataset/engine/__init__.py +0 -5
  99. mindspore/dataset/engine/cache_client.py +38 -15
  100. mindspore/dataset/engine/datasets.py +615 -278
  101. mindspore/dataset/engine/datasets_audio.py +154 -283
  102. mindspore/dataset/engine/datasets_standard_format.py +104 -116
  103. mindspore/dataset/engine/datasets_text.py +443 -326
  104. mindspore/dataset/engine/datasets_user_defined.py +251 -164
  105. mindspore/dataset/engine/datasets_vision.py +839 -1443
  106. mindspore/dataset/engine/iterators.py +11 -4
  107. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +7 -3
  108. mindspore/dataset/engine/obs/util.py +3 -0
  109. mindspore/dataset/engine/offload.py +6 -6
  110. mindspore/dataset/engine/queue.py +15 -14
  111. mindspore/dataset/engine/samplers.py +39 -23
  112. mindspore/dataset/engine/serializer_deserializer.py +22 -6
  113. mindspore/dataset/engine/validators.py +21 -331
  114. mindspore/dataset/text/__init__.py +5 -33
  115. mindspore/dataset/text/transforms.py +334 -165
  116. mindspore/dataset/text/utils.py +215 -145
  117. mindspore/dataset/transforms/__init__.py +1 -1
  118. mindspore/dataset/transforms/c_transforms.py +3 -2
  119. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  120. mindspore/dataset/transforms/transforms.py +174 -71
  121. mindspore/dataset/utils/browse_dataset.py +25 -17
  122. mindspore/dataset/utils/line_reader.py +24 -21
  123. mindspore/dataset/vision/__init__.py +5 -26
  124. mindspore/dataset/vision/c_transforms.py +177 -165
  125. mindspore/dataset/vision/py_transforms.py +114 -119
  126. mindspore/dataset/vision/py_transforms_util.py +54 -51
  127. mindspore/dataset/vision/transforms.py +1127 -381
  128. mindspore/dataset/vision/utils.py +54 -38
  129. mindspore/dataset/vision/validators.py +12 -2
  130. mindspore/experimental/map_parameter.py +38 -4
  131. mindspore/{dataset/datapreprocess → experimental/optim}/__init__.py +14 -4
  132. mindspore/experimental/optim/adam.py +192 -0
  133. mindspore/experimental/optim/adamw.py +181 -0
  134. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  135. mindspore/experimental/optim/optimizer.py +252 -0
  136. mindspore/experimental/optim/sgd.py +147 -0
  137. mindspore/gen_ops.py +273 -0
  138. mindspore/include/OWNERS +1 -2
  139. mindspore/include/api/context.h +21 -1
  140. mindspore/include/api/data_type.h +2 -1
  141. mindspore/include/api/graph.h +0 -15
  142. mindspore/include/api/kernel.h +2 -0
  143. mindspore/include/api/kernel_api.h +37 -12
  144. mindspore/include/api/model.h +29 -42
  145. mindspore/include/api/model_group.h +14 -3
  146. mindspore/include/api/model_parallel_runner.h +18 -2
  147. mindspore/include/api/serialization.h +26 -0
  148. mindspore/include/api/status.h +1 -0
  149. mindspore/include/api/types.h +38 -4
  150. mindspore/include/c_api/ms/abstract.h +67 -0
  151. mindspore/include/c_api/ms/attribute.h +197 -0
  152. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  153. mindspore/include/c_api/ms/base/macros.h +32 -0
  154. mindspore/include/c_api/ms/base/status.h +33 -0
  155. mindspore/include/c_api/ms/base/types.h +282 -0
  156. mindspore/include/c_api/ms/context.h +102 -0
  157. mindspore/include/c_api/ms/graph.h +160 -0
  158. mindspore/include/c_api/ms/node.h +606 -0
  159. mindspore/include/c_api/ms/tensor.h +161 -0
  160. mindspore/include/c_api/ms/value.h +84 -0
  161. mindspore/include/c_api/status_c.h +3 -0
  162. mindspore/include/dataset/constants.h +6 -12
  163. mindspore/include/dataset/execute.h +23 -13
  164. mindspore/include/dataset/text.h +26 -26
  165. mindspore/include/dataset/transforms.h +25 -31
  166. mindspore/include/dataset/vision.h +60 -60
  167. mindspore/include/dataset/vision_ascend.h +5 -6
  168. mindspore/include/dataset/vision_lite.h +17 -17
  169. mindspore/include/mindapi/base/format.h +0 -1
  170. mindspore/include/mindapi/base/type_id.h +2 -1
  171. mindspore/include/mindapi/base/types.h +5 -1
  172. mindspore/lib/libdnnl.so.2 +0 -0
  173. mindspore/lib/libjemalloc.so.2 +0 -0
  174. mindspore/lib/libmindspore.so +0 -0
  175. mindspore/lib/libmindspore_backend.so +0 -0
  176. mindspore/lib/libmindspore_common.so +0 -0
  177. mindspore/lib/libmindspore_core.so +0 -0
  178. mindspore/lib/libmindspore_glog.so.0 +0 -0
  179. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  180. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  181. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  182. mindspore/lib/libmindspore_shared_lib.so +0 -0
  183. mindspore/lib/libmpi_adapter.so +0 -0
  184. mindspore/lib/libnnacl.so +0 -0
  185. mindspore/lib/libopencv_core.so.4.5 +0 -0
  186. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  187. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  188. mindspore/lib/libps_cache.so +0 -0
  189. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  190. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  191. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +9000 -0
  192. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  193. mindspore/lib/plugin/ascend/libakg.so +0 -0
  194. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  195. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  196. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  197. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  198. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  199. mindspore/lib/plugin/cpu/libakg.so +0 -0
  200. mindspore/lib/plugin/gpu/libcuda_ops.so.10 +0 -0
  201. mindspore/lib/plugin/gpu/libcuda_ops.so.11 +0 -0
  202. mindspore/lib/plugin/gpu10.1/libakg.so +0 -0
  203. mindspore/lib/plugin/gpu10.1/libnccl.so.2 +0 -0
  204. mindspore/lib/plugin/gpu10.1/libnvidia_collective.so +0 -0
  205. mindspore/lib/plugin/gpu11.1/libakg.so +0 -0
  206. mindspore/lib/plugin/gpu11.1/libnccl.so.2 +0 -0
  207. mindspore/lib/plugin/gpu11.1/libnvidia_collective.so +0 -0
  208. mindspore/lib/plugin/gpu11.6/libakg.so +0 -0
  209. mindspore/lib/plugin/gpu11.6/libnccl.so.2 +0 -0
  210. mindspore/lib/plugin/gpu11.6/libnvidia_collective.so +0 -0
  211. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  212. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  213. mindspore/lib/plugin/libmindspore_gpu.so.10.1 +0 -0
  214. mindspore/lib/plugin/libmindspore_gpu.so.11.1 +0 -0
  215. mindspore/lib/plugin/libmindspore_gpu.so.11.6 +0 -0
  216. mindspore/log.py +9 -6
  217. mindspore/mindrecord/filereader.py +33 -4
  218. mindspore/mindrecord/filewriter.py +70 -35
  219. mindspore/mindrecord/mindpage.py +40 -34
  220. mindspore/mindrecord/shardreader.py +1 -1
  221. mindspore/mindrecord/shardsegment.py +1 -1
  222. mindspore/mindrecord/tools/cifar100_to_mr.py +25 -18
  223. mindspore/mindrecord/tools/cifar10_to_mr.py +25 -18
  224. mindspore/mindrecord/tools/csv_to_mr.py +29 -13
  225. mindspore/mindrecord/tools/imagenet_to_mr.py +24 -10
  226. mindspore/mindrecord/tools/mnist_to_mr.py +24 -11
  227. mindspore/mindrecord/tools/tfrecord_to_mr.py +31 -26
  228. mindspore/nn/cell.py +463 -169
  229. mindspore/nn/dynamic_lr.py +47 -43
  230. mindspore/nn/layer/activation.py +225 -82
  231. mindspore/nn/layer/basic.py +121 -79
  232. mindspore/nn/layer/channel_shuffle.py +21 -21
  233. mindspore/nn/layer/combined.py +33 -26
  234. mindspore/nn/layer/container.py +277 -22
  235. mindspore/nn/layer/conv.py +441 -304
  236. mindspore/nn/layer/dense.py +19 -13
  237. mindspore/nn/layer/embedding.py +62 -49
  238. mindspore/nn/layer/flash_attention.py +264 -0
  239. mindspore/nn/layer/image.py +50 -39
  240. mindspore/nn/layer/math.py +62 -51
  241. mindspore/nn/layer/normalization.py +219 -167
  242. mindspore/nn/layer/padding.py +58 -70
  243. mindspore/nn/layer/pooling.py +334 -287
  244. mindspore/nn/layer/rnn_cells.py +53 -38
  245. mindspore/nn/layer/rnns.py +59 -56
  246. mindspore/nn/layer/thor_layer.py +52 -44
  247. mindspore/nn/layer/timedistributed.py +6 -4
  248. mindspore/nn/layer/transformer.py +284 -164
  249. mindspore/nn/learning_rate_schedule.py +34 -25
  250. mindspore/nn/loss/__init__.py +3 -2
  251. mindspore/nn/loss/loss.py +554 -311
  252. mindspore/nn/optim/ada_grad.py +12 -9
  253. mindspore/nn/optim/adadelta.py +14 -11
  254. mindspore/nn/optim/adafactor.py +19 -16
  255. mindspore/nn/optim/adam.py +62 -47
  256. mindspore/nn/optim/adamax.py +13 -10
  257. mindspore/nn/optim/adasum.py +12 -8
  258. mindspore/nn/optim/asgd.py +10 -9
  259. mindspore/nn/optim/ftrl.py +20 -17
  260. mindspore/nn/optim/lamb.py +16 -12
  261. mindspore/nn/optim/lars.py +8 -6
  262. mindspore/nn/optim/lazyadam.py +25 -20
  263. mindspore/nn/optim/momentum.py +10 -7
  264. mindspore/nn/optim/optimizer.py +61 -9
  265. mindspore/nn/optim/proximal_ada_grad.py +14 -13
  266. mindspore/nn/optim/rmsprop.py +17 -13
  267. mindspore/nn/optim/rprop.py +30 -17
  268. mindspore/nn/optim/sgd.py +40 -23
  269. mindspore/nn/optim/thor.py +24 -26
  270. mindspore/nn/probability/bijector/bijector.py +11 -11
  271. mindspore/nn/probability/bijector/exp.py +1 -1
  272. mindspore/nn/probability/bijector/gumbel_cdf.py +3 -3
  273. mindspore/nn/probability/bijector/invert.py +1 -1
  274. mindspore/nn/probability/bijector/power_transform.py +29 -29
  275. mindspore/nn/probability/bijector/scalar_affine.py +3 -3
  276. mindspore/nn/probability/bijector/softplus.py +5 -5
  277. mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +4 -2
  278. mindspore/nn/probability/bnn_layers/conv_variational.py +13 -13
  279. mindspore/nn/probability/bnn_layers/dense_variational.py +12 -12
  280. mindspore/nn/probability/bnn_layers/layer_distribution.py +9 -8
  281. mindspore/nn/probability/distribution/_utils/custom_ops.py +19 -3
  282. mindspore/nn/probability/distribution/_utils/utils.py +1 -1
  283. mindspore/nn/probability/distribution/bernoulli.py +9 -9
  284. mindspore/nn/probability/distribution/beta.py +8 -8
  285. mindspore/nn/probability/distribution/categorical.py +23 -15
  286. mindspore/nn/probability/distribution/cauchy.py +5 -6
  287. mindspore/nn/probability/distribution/distribution.py +3 -3
  288. mindspore/nn/probability/distribution/exponential.py +4 -4
  289. mindspore/nn/probability/distribution/gamma.py +10 -10
  290. mindspore/nn/probability/distribution/geometric.py +8 -8
  291. mindspore/nn/probability/distribution/gumbel.py +8 -9
  292. mindspore/nn/probability/distribution/half_normal.py +5 -5
  293. mindspore/nn/probability/distribution/laplace.py +5 -5
  294. mindspore/nn/probability/distribution/log_normal.py +12 -11
  295. mindspore/nn/probability/distribution/logistic.py +8 -8
  296. mindspore/nn/probability/distribution/normal.py +6 -5
  297. mindspore/nn/probability/distribution/poisson.py +10 -11
  298. mindspore/nn/probability/distribution/student_t.py +8 -9
  299. mindspore/nn/probability/distribution/transformed_distribution.py +5 -5
  300. mindspore/nn/probability/distribution/uniform.py +11 -11
  301. mindspore/nn/reinforcement/tensor_array.py +2 -2
  302. mindspore/nn/sparse/sparse.py +9 -9
  303. mindspore/nn/wrap/cell_wrapper.py +188 -63
  304. mindspore/nn/wrap/grad_reducer.py +21 -12
  305. mindspore/nn/wrap/loss_scale.py +136 -49
  306. mindspore/numpy/__init__.py +4 -4
  307. mindspore/numpy/array_creations.py +55 -56
  308. mindspore/numpy/array_ops.py +134 -35
  309. mindspore/numpy/logic_ops.py +66 -20
  310. mindspore/numpy/math_ops.py +142 -139
  311. mindspore/numpy/utils_const.py +2 -2
  312. mindspore/offline_debug/convert_async.py +2 -2
  313. mindspore/ops/_grad_experimental/__init__.py +7 -5
  314. mindspore/ops/_grad_experimental/grad_array_ops.py +231 -348
  315. mindspore/ops/{_grad → _grad_experimental}/grad_base.py +1 -33
  316. mindspore/ops/{_grad → _grad_experimental}/grad_comm_ops.py +25 -13
  317. mindspore/ops/{_grad/__init__.py → _grad_experimental/grad_debug_ops.py} +15 -7
  318. mindspore/ops/{_grad → _grad_experimental}/grad_implementations.py +17 -11
  319. mindspore/ops/_grad_experimental/grad_inner_ops.py +33 -52
  320. mindspore/ops/_grad_experimental/grad_math_ops.py +151 -1224
  321. mindspore/ops/_grad_experimental/grad_nn_ops.py +141 -414
  322. mindspore/ops/{_grad → _grad_experimental}/grad_quant_ops.py +10 -6
  323. mindspore/ops/_grad_experimental/grad_sparse.py +317 -2
  324. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -13
  325. mindspore/ops/{_grad → _grad_experimental}/taylor_rule.py +1 -1
  326. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  327. mindspore/ops/_op_impl/_custom_op/flash_attention/__init__.py +0 -0
  328. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +406 -0
  329. mindspore/{_extends/graph_kernel/expanders/complex/__init__.py → ops/_op_impl/_custom_op/flash_attention/constants.py} +27 -8
  330. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +467 -0
  331. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +563 -0
  332. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +193 -0
  333. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +435 -0
  334. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
  335. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +45 -0
  336. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +67 -0
  337. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +62 -0
  338. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
  339. mindspore/ops/_op_impl/aicpu/__init__.py +41 -1
  340. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
  341. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  342. mindspore/ops/_op_impl/aicpu/cast.py +52 -0
  343. mindspore/ops/_op_impl/aicpu/coalesce.py +2 -0
  344. mindspore/ops/_op_impl/aicpu/col2im.py +3 -1
  345. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  346. mindspore/ops/_op_impl/aicpu/dropout_genmask.py +6 -0
  347. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  348. mindspore/ops/_op_impl/aicpu/eye.py +4 -4
  349. mindspore/ops/_op_impl/aicpu/fft_with_size.py +6 -0
  350. mindspore/ops/_op_impl/aicpu/fill_diagonal.py +5 -0
  351. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  352. mindspore/ops/_op_impl/aicpu/im2col.py +3 -5
  353. mindspore/ops/_op_impl/aicpu/lgamma.py +1 -0
  354. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  355. mindspore/ops/_op_impl/aicpu/lu.py +39 -0
  356. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  357. mindspore/ops/_op_impl/aicpu/masked_scatter.py +1 -0
  358. mindspore/ops/_op_impl/aicpu/masked_select_grad.py +3 -0
  359. mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
  360. mindspore/ops/_op_impl/aicpu/matrix_power.py +6 -1
  361. mindspore/ops/_op_impl/aicpu/median.py +1 -0
  362. mindspore/ops/_op_impl/aicpu/multinomial.py +9 -9
  363. mindspore/ops/_op_impl/aicpu/not_equal.py +0 -5
  364. mindspore/ops/_op_impl/aicpu/pad_v3.py +3 -1
  365. mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +2 -0
  366. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  367. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  368. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  369. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  370. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  371. mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +0 -1
  372. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +0 -6
  373. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +0 -7
  374. mindspore/ops/_op_impl/aicpu/scatter_nd.py +2 -0
  375. mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
  376. mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
  377. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  378. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  379. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -4
  380. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -4
  381. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  382. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  383. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  384. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  385. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  386. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +14 -6
  387. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +22 -8
  388. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +11 -6
  389. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +21 -10
  390. mindspore/ops/_op_impl/tbe/__init__.py +6 -4
  391. mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
  392. mindspore/ops/_op_impl/tbe/avg_pool.py +2 -2
  393. mindspore/ops/_op_impl/tbe/avg_pool_3d.py +3 -3
  394. mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +4 -4
  395. mindspore/ops/_op_impl/tbe/avg_pool_ds.py +2 -2
  396. mindspore/ops/_op_impl/tbe/avg_pool_grad.py +3 -3
  397. mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py +3 -3
  398. mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
  399. mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +2 -2
  400. mindspore/ops/_op_impl/tbe/bn_infer.py +2 -2
  401. mindspore/ops/_op_impl/tbe/bn_infer_ds.py +3 -2
  402. mindspore/ops/_op_impl/tbe/broadcast_to.py +1 -1
  403. mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +3 -3
  404. mindspore/ops/_op_impl/tbe/expand_dims.py +1 -1
  405. mindspore/ops/_op_impl/tbe/gather_v2.py +56 -0
  406. mindspore/ops/_op_impl/tbe/im2col.py +4 -4
  407. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  408. mindspore/ops/_op_impl/tbe/mem_set.py +38 -0
  409. mindspore/ops/_op_impl/tbe/scatter_nd_add.py +3 -0
  410. mindspore/ops/_op_impl/tbe/scatter_nd_d.py +1 -1
  411. mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
  412. mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +2 -2
  413. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  414. mindspore/ops/_primitive_cache.py +1 -1
  415. mindspore/ops/_tracefunc.py +241 -0
  416. mindspore/ops/_utils/utils.py +10 -2
  417. mindspore/ops/_vmap/vmap_array_ops.py +5 -3
  418. mindspore/ops/_vmap/vmap_base.py +5 -4
  419. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  420. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  421. mindspore/ops/_vmap/vmap_grad_nn_ops.py +11 -6
  422. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  423. mindspore/ops/_vmap/vmap_nn_ops.py +135 -11
  424. mindspore/ops/arg_dtype_cast.py +54 -0
  425. mindspore/ops/composite/__init__.py +7 -5
  426. mindspore/ops/composite/base.py +78 -34
  427. mindspore/ops/composite/math_ops.py +5 -695
  428. mindspore/ops/composite/multitype_ops/_compile_utils.py +403 -97
  429. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +28 -22
  430. mindspore/ops/composite/multitype_ops/add_impl.py +69 -7
  431. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
  432. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
  433. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -0
  434. mindspore/ops/composite/multitype_ops/div_impl.py +1 -0
  435. mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -0
  436. mindspore/ops/composite/multitype_ops/getitem_impl.py +48 -10
  437. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +2 -0
  438. mindspore/ops/composite/multitype_ops/greater_impl.py +2 -0
  439. mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -0
  440. mindspore/ops/composite/multitype_ops/less_equal_impl.py +2 -0
  441. mindspore/ops/composite/multitype_ops/less_impl.py +2 -0
  442. mindspore/ops/composite/multitype_ops/logic_not_impl.py +2 -2
  443. mindspore/ops/composite/multitype_ops/mod_impl.py +1 -0
  444. mindspore/ops/composite/multitype_ops/mul_impl.py +1 -0
  445. mindspore/ops/composite/multitype_ops/negative_impl.py +1 -0
  446. mindspore/ops/composite/multitype_ops/not_in_impl.py +1 -0
  447. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  448. mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
  449. mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -0
  450. mindspore/ops/composite/multitype_ops/setitem_impl.py +10 -7
  451. mindspore/ops/composite/multitype_ops/sub_impl.py +1 -0
  452. mindspore/ops/composite/multitype_ops/uadd_impl.py +2 -0
  453. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  454. mindspore/ops/deprecated.py +304 -0
  455. mindspore/ops/function/__init__.py +41 -4
  456. mindspore/ops/function/array_func.py +1108 -467
  457. mindspore/ops/function/clip_func.py +94 -27
  458. mindspore/ops/function/debug_func.py +3 -1
  459. mindspore/ops/function/grad/grad_func.py +82 -73
  460. mindspore/ops/function/image_func.py +28 -12
  461. mindspore/ops/function/linalg_func.py +135 -39
  462. mindspore/ops/function/math_func.py +3779 -894
  463. mindspore/ops/function/nn_func.py +1584 -657
  464. mindspore/ops/function/parameter_func.py +13 -3
  465. mindspore/ops/function/random_func.py +247 -153
  466. mindspore/ops/function/sparse_func.py +14 -11
  467. mindspore/ops/function/sparse_unary_func.py +173 -47
  468. mindspore/ops/function/spectral_func.py +8 -4
  469. mindspore/ops/function/vmap_func.py +8 -7
  470. mindspore/ops/functional.py +47 -16
  471. mindspore/ops/op_info_register.py +346 -86
  472. mindspore/ops/operations/__init__.py +38 -22
  473. mindspore/ops/operations/_grad_ops.py +145 -149
  474. mindspore/ops/operations/_inner_ops.py +298 -56
  475. mindspore/ops/operations/_ms_kernel.py +3 -3
  476. mindspore/ops/operations/_quant_ops.py +24 -28
  477. mindspore/ops/operations/_rl_inner_ops.py +9 -7
  478. mindspore/ops/operations/_scalar_ops.py +115 -0
  479. mindspore/ops/operations/_sequence_ops.py +148 -10
  480. mindspore/ops/operations/_tensor_array.py +1 -1
  481. mindspore/ops/operations/_thor_ops.py +2 -2
  482. mindspore/ops/operations/array_ops.py +1239 -561
  483. mindspore/ops/operations/comm_ops.py +166 -90
  484. mindspore/ops/operations/control_ops.py +3 -3
  485. mindspore/ops/operations/custom_ops.py +124 -102
  486. mindspore/ops/operations/debug_ops.py +24 -11
  487. mindspore/ops/operations/image_ops.py +86 -71
  488. mindspore/ops/operations/inner_ops.py +18 -13
  489. mindspore/ops/operations/linalg_ops.py +30 -11
  490. mindspore/ops/operations/math_ops.py +1730 -435
  491. mindspore/ops/operations/nn_ops.py +1953 -943
  492. mindspore/ops/operations/other_ops.py +65 -43
  493. mindspore/ops/operations/random_ops.py +258 -98
  494. mindspore/ops/operations/rl_ops.py +4 -36
  495. mindspore/ops/operations/sparse_ops.py +38 -33
  496. mindspore/ops/operations/spectral_ops.py +8 -4
  497. mindspore/ops/primitive.py +66 -44
  498. mindspore/ops/signature.py +5 -5
  499. mindspore/parallel/_auto_parallel_context.py +80 -19
  500. mindspore/parallel/_cost_model_context.py +42 -0
  501. mindspore/parallel/_offload_context.py +162 -72
  502. mindspore/parallel/_parallel_serialization.py +2 -2
  503. mindspore/parallel/_ps_context.py +16 -4
  504. mindspore/parallel/_recovery_context.py +2 -1
  505. mindspore/parallel/_tensor.py +15 -13
  506. mindspore/parallel/_transformer/layers.py +8 -6
  507. mindspore/parallel/_transformer/loss.py +1 -0
  508. mindspore/parallel/_transformer/moe.py +7 -7
  509. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  510. mindspore/parallel/_transformer/transformer.py +34 -14
  511. mindspore/parallel/_utils.py +36 -14
  512. mindspore/parallel/algo_parameter_config.py +114 -20
  513. mindspore/parallel/checkpoint_transform.py +16 -18
  514. mindspore/parallel/shard.py +16 -13
  515. mindspore/profiler/__init__.py +1 -1
  516. mindspore/profiler/common/struct_type.py +3 -3
  517. mindspore/profiler/common/util.py +3 -2
  518. mindspore/profiler/envprofiling.py +11 -4
  519. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  520. mindspore/profiler/parser/ascend_flops_generator.py +94 -0
  521. mindspore/profiler/parser/ascend_fpbp_generator.py +76 -0
  522. mindspore/profiler/parser/ascend_hccl_generator.py +288 -0
  523. mindspore/profiler/parser/ascend_msprof_exporter.py +213 -0
  524. mindspore/profiler/parser/ascend_msprof_generator.py +199 -0
  525. mindspore/profiler/parser/ascend_op_generator.py +276 -0
  526. mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
  527. mindspore/profiler/parser/ascend_timeline_generator.py +110 -54
  528. mindspore/profiler/parser/base_timeline_generator.py +11 -7
  529. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +45 -46
  530. mindspore/profiler/parser/flops_parser.py +15 -11
  531. mindspore/profiler/parser/framework_parser.py +92 -73
  532. mindspore/profiler/parser/hccl_parser.py +16 -12
  533. mindspore/profiler/parser/integrator.py +22 -11
  534. mindspore/profiler/parser/memory_usage_parser.py +36 -11
  535. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  536. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  537. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  538. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  539. mindspore/profiler/parser/optime_parser.py +1 -1
  540. mindspore/profiler/parser/profiler_info.py +4 -5
  541. mindspore/profiler/parser/step_trace_parser.py +11 -14
  542. mindspore/profiler/profiling.py +678 -377
  543. mindspore/rewrite/api/node.py +211 -54
  544. mindspore/rewrite/api/node_type.py +5 -0
  545. mindspore/rewrite/api/pattern_engine.py +22 -23
  546. mindspore/rewrite/api/scoped_value.py +20 -17
  547. mindspore/rewrite/api/symbol_tree.py +252 -106
  548. mindspore/rewrite/api/tree_node_helper.py +3 -0
  549. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  550. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  551. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  552. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +97 -46
  553. mindspore/rewrite/common/rewrite_elog.py +5 -1
  554. mindspore/rewrite/namer.py +51 -51
  555. mindspore/rewrite/namespace.py +14 -5
  556. mindspore/{ops/bprop_mindir → rewrite/node}/__init__.py +9 -4
  557. mindspore/rewrite/node/call_function.py +79 -0
  558. mindspore/rewrite/node/cell_container.py +135 -0
  559. mindspore/rewrite/node/control_flow.py +88 -0
  560. mindspore/rewrite/{node.py → node/node.py} +313 -247
  561. mindspore/rewrite/node/node_manager.py +254 -0
  562. mindspore/rewrite/node/node_topological_manager.py +243 -0
  563. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  564. mindspore/rewrite/parsers/assign_parser.py +225 -239
  565. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  566. mindspore/rewrite/parsers/class_def_parser.py +179 -218
  567. mindspore/rewrite/parsers/constant_parser.py +9 -6
  568. mindspore/rewrite/parsers/container_parser.py +9 -7
  569. mindspore/rewrite/parsers/for_parser.py +36 -15
  570. mindspore/rewrite/parsers/function_def_parser.py +23 -20
  571. mindspore/rewrite/parsers/if_parser.py +28 -24
  572. mindspore/rewrite/parsers/module_parser.py +202 -25
  573. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  574. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  575. mindspore/rewrite/parsers/return_parser.py +6 -6
  576. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  577. mindspore/rewrite/sparsify/sparsify.py +4 -1
  578. mindspore/rewrite/sparsify/utils.py +11 -5
  579. mindspore/rewrite/symbol_tree.py +577 -732
  580. mindspore/rewrite/symbol_tree_builder.py +9 -175
  581. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  582. mindspore/run_check/_check_version.py +46 -39
  583. mindspore/run_check/run_check.py +3 -2
  584. mindspore/{scipy/sparse → safeguard}/__init__.py +4 -5
  585. mindspore/safeguard/rewrite_obfuscation.py +517 -0
  586. mindspore/scipy/__init__.py +1 -1
  587. mindspore/scipy/linalg.py +67 -61
  588. mindspore/scipy/ops.py +5 -41
  589. mindspore/scipy/ops_grad.py +3 -2
  590. mindspore/scipy/ops_wrapper.py +5 -5
  591. mindspore/scipy/optimize/line_search.py +8 -8
  592. mindspore/scipy/optimize/linear_sum_assignment.py +4 -4
  593. mindspore/scipy/optimize/minimize.py +16 -12
  594. mindspore/scipy/utils.py +1 -52
  595. mindspore/scipy/utils_const.py +4 -4
  596. mindspore/train/__init__.py +4 -4
  597. mindspore/train/_utils.py +13 -5
  598. mindspore/train/amp.py +410 -148
  599. mindspore/train/anf_ir_pb2.py +16 -4
  600. mindspore/train/callback/_backup_and_restore.py +8 -11
  601. mindspore/train/callback/_callback.py +80 -3
  602. mindspore/train/callback/_checkpoint.py +82 -51
  603. mindspore/train/callback/_early_stop.py +12 -15
  604. mindspore/train/callback/_history.py +1 -1
  605. mindspore/train/callback/_lambda_callback.py +13 -13
  606. mindspore/train/callback/_landscape.py +21 -17
  607. mindspore/train/callback/_loss_monitor.py +9 -10
  608. mindspore/train/callback/_on_request_exit.py +16 -33
  609. mindspore/train/callback/_reduce_lr_on_plateau.py +21 -24
  610. mindspore/train/callback/_summary_collector.py +44 -30
  611. mindspore/train/callback/_time_monitor.py +62 -12
  612. mindspore/train/data_sink.py +10 -16
  613. mindspore/train/dataset_helper.py +154 -86
  614. mindspore/train/loss_scale_manager.py +14 -9
  615. mindspore/train/metrics/__init__.py +10 -2
  616. mindspore/train/metrics/accuracy.py +1 -1
  617. mindspore/train/metrics/auc.py +1 -1
  618. mindspore/train/metrics/bleu_score.py +2 -2
  619. mindspore/train/metrics/confusion_matrix.py +14 -14
  620. mindspore/train/metrics/cosine_similarity.py +3 -3
  621. mindspore/train/metrics/dice.py +1 -1
  622. mindspore/train/metrics/fbeta.py +1 -1
  623. mindspore/train/metrics/hausdorff_distance.py +8 -6
  624. mindspore/train/metrics/mean_surface_distance.py +5 -4
  625. mindspore/train/metrics/metric.py +49 -17
  626. mindspore/train/metrics/occlusion_sensitivity.py +4 -4
  627. mindspore/train/metrics/perplexity.py +1 -1
  628. mindspore/train/metrics/precision.py +2 -2
  629. mindspore/train/metrics/recall.py +2 -3
  630. mindspore/train/metrics/roc.py +7 -7
  631. mindspore/train/metrics/root_mean_square_surface_distance.py +5 -4
  632. mindspore/train/metrics/topk.py +7 -4
  633. mindspore/train/mind_ir_pb2.py +193 -48
  634. mindspore/train/model.py +377 -133
  635. mindspore/train/serialization.py +697 -245
  636. mindspore/train/summary/_summary_adapter.py +5 -2
  637. mindspore/train/summary/_writer_pool.py +4 -3
  638. mindspore/train/summary/summary_record.py +25 -23
  639. mindspore/train/train_thor/convert_utils.py +39 -23
  640. mindspore/train/train_thor/dataset_helper.py +4 -3
  641. mindspore/train/train_thor/model_thor.py +8 -8
  642. mindspore/version.py +1 -1
  643. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/METADATA +7 -8
  644. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/RECORD +647 -818
  645. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/entry_points.txt +0 -1
  646. mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
  647. mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
  648. mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
  649. mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
  650. mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
  651. mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
  652. mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
  653. mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
  654. mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
  655. mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
  656. mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
  657. mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
  658. mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
  659. mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
  660. mindspore/_akg/akg/tvm/rpc/base.py +0 -182
  661. mindspore/_akg/akg/tvm/rpc/client.py +0 -436
  662. mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
  663. mindspore/_akg/akg/tvm/rpc/server.py +0 -413
  664. mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
  665. mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
  666. mindspore/_extends/graph_kernel/expander.py +0 -80
  667. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -57
  668. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  669. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  670. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  671. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  672. mindspore/_extends/graph_kernel/expanders/bias_add_grad.py +0 -49
  673. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  674. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  675. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  676. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  677. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  678. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  679. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  680. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  681. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  682. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  683. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  684. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  685. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  686. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  687. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  688. mindspore/_extends/graph_kernel/expanders/gather.py +0 -43
  689. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  690. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  691. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  692. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  693. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  694. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  695. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  696. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  697. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  698. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  699. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  700. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  701. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  702. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  703. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  704. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  705. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  706. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  707. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  708. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  709. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  710. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  711. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  712. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  713. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  714. mindspore/_extends/graph_kernel/expanders/tile.py +0 -54
  715. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  716. mindspore/_extends/parse/jit_fallback_modules.py +0 -51
  717. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  718. mindspore/dataset/engine/graphdata.py +0 -1586
  719. mindspore/include/api/net.h +0 -142
  720. mindspore/ops/_grad/grad_array_ops.py +0 -1347
  721. mindspore/ops/_grad/grad_clip_ops.py +0 -84
  722. mindspore/ops/_grad/grad_debug_ops.py +0 -68
  723. mindspore/ops/_grad/grad_inner_ops.py +0 -235
  724. mindspore/ops/_grad/grad_math_ops.py +0 -1684
  725. mindspore/ops/_grad/grad_nn_ops.py +0 -1529
  726. mindspore/ops/_grad/grad_other_ops.py +0 -89
  727. mindspore/ops/_grad/grad_sequence_ops.py +0 -296
  728. mindspore/ops/_grad/grad_sparse.py +0 -323
  729. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -249
  730. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -195
  731. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  732. mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
  733. mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
  734. mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +0 -19
  735. mindspore/ops/bprop_mindir/Argmax_bprop.mindir +0 -15
  736. mindspore/ops/bprop_mindir/Argmin_bprop.mindir +0 -15
  737. mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +0 -19
  738. mindspore/ops/bprop_mindir/Assign_bprop.mindir +0 -17
  739. mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +0 -150
  740. mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +0 -66
  741. mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
  742. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -15
  743. mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
  744. mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +0 -28
  745. mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
  746. mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +0 -33
  747. mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +0 -306
  748. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -13
  749. mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
  750. mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
  751. mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +0 -240
  752. mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +0 -247
  753. mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +0 -247
  754. mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +0 -315
  755. mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +0 -278
  756. mindspore/ops/bprop_mindir/DType_bprop.mindir +0 -14
  757. mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +0 -58
  758. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -13
  759. mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +0 -23
  760. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  761. mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +0 -15
  762. mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
  763. mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
  764. mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +0 -25
  765. mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +0 -18
  766. mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +0 -27
  767. mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
  768. mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
  769. mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
  770. mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +0 -14
  771. mindspore/ops/bprop_mindir/Elu_bprop.mindir +0 -16
  772. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  773. mindspore/ops/bprop_mindir/Equal_bprop.mindir +0 -19
  774. mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +0 -58
  775. mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +0 -16
  776. mindspore/ops/bprop_mindir/Flatten_bprop.mindir +0 -54
  777. mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +0 -19
  778. mindspore/ops/bprop_mindir/GatherD_bprop.mindir +0 -26
  779. mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +0 -57
  780. mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
  781. mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +0 -19
  782. mindspore/ops/bprop_mindir/Greater_bprop.mindir +0 -19
  783. mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +0 -16
  784. mindspore/ops/bprop_mindir/HSwish_bprop.mindir +0 -16
  785. mindspore/ops/bprop_mindir/IOU_bprop.mindir +0 -19
  786. mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
  787. mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +0 -15
  788. mindspore/ops/bprop_mindir/IsInf_bprop.mindir +0 -15
  789. mindspore/ops/bprop_mindir/IsNan_bprop.mindir +0 -15
  790. mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +0 -126
  791. mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +0 -15
  792. mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +0 -30
  793. mindspore/ops/bprop_mindir/LRN_bprop.mindir +0 -43
  794. mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
  795. mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +0 -19
  796. mindspore/ops/bprop_mindir/Less_bprop.mindir +0 -19
  797. mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +0 -23
  798. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -13
  799. mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +0 -23
  800. mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +0 -19
  801. mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +0 -15
  802. mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +0 -21
  803. mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +0 -74
  804. mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +0 -74
  805. mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +0 -75
  806. mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +0 -65
  807. mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
  808. mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
  809. mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
  810. mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +0 -27
  811. mindspore/ops/bprop_mindir/Mish_bprop.mindir +0 -35
  812. mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
  813. mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
  814. mindspore/ops/bprop_mindir/NonZero_bprop.mindir +0 -14
  815. mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +0 -19
  816. mindspore/ops/bprop_mindir/OneHot_bprop.mindir +0 -26
  817. mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +0 -14
  818. mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
  819. mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
  820. mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
  821. mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +0 -29
  822. mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +0 -82
  823. mindspore/ops/bprop_mindir/Range_bprop.mindir +0 -22
  824. mindspore/ops/bprop_mindir/Rank_bprop.mindir +0 -14
  825. mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +0 -16
  826. mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
  827. mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +0 -19
  828. mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +0 -19
  829. mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +0 -20
  830. mindspore/ops/bprop_mindir/Reshape_bprop.mindir +0 -60
  831. mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +0 -29
  832. mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +0 -89
  833. mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +0 -52
  834. mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +0 -22
  835. mindspore/ops/bprop_mindir/Round_bprop.mindir +0 -15
  836. mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
  837. mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
  838. mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +0 -22
  839. mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +0 -24
  840. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -22
  841. mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
  842. mindspore/ops/bprop_mindir/SeLU_bprop.mindir +0 -21
  843. mindspore/ops/bprop_mindir/Select_bprop.mindir +0 -31
  844. mindspore/ops/bprop_mindir/Shape_bprop.mindir +0 -14
  845. mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +0 -21
  846. mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
  847. mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +0 -16
  848. mindspore/ops/bprop_mindir/Sign_bprop.mindir +0 -15
  849. mindspore/ops/bprop_mindir/Slice_bprop.mindir +0 -26
  850. mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +0 -36
  851. mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  852. mindspore/ops/bprop_mindir/Softplus_bprop.mindir +0 -16
  853. mindspore/ops/bprop_mindir/Softsign_bprop.mindir +0 -33
  854. mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
  855. mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +0 -28
  856. mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +0 -23
  857. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  858. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  859. mindspore/ops/bprop_mindir/Split_bprop.mindir +0 -22
  860. mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +0 -54
  861. mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +0 -95
  862. mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +0 -98
  863. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -29
  864. mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
  865. mindspore/ops/bprop_mindir/Tanh_bprop.mindir +0 -66
  866. mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +0 -22
  867. mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +0 -29
  868. mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +0 -14
  869. mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
  870. mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
  871. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -23
  872. mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +0 -19
  873. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -20
  874. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -16
  875. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -22
  876. mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +0 -32
  877. mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +0 -38
  878. mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +0 -15
  879. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  880. mindspore/rewrite/node_visitor.py +0 -44
  881. mindspore/rewrite/topological_manager.py +0 -203
  882. mindspore/scipy/sparse/linalg.py +0 -192
  883. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/WHEEL +0 -0
  884. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/top_level.txt +0 -0
@@ -16,11 +16,11 @@
16
16
  Transformer Cells module, include TransformerEncoderLayer, TransformerDecoderLayer,
17
17
  TransformerEncoder, TransformerDecoder, Transformer.
18
18
  """
19
- import copy
20
19
  import math
21
20
  from typing import Union, Optional
22
21
  import mindspore
23
22
  import mindspore.ops as ops
23
+ import mindspore.common.dtype as mstype
24
24
  from mindspore.common.tensor import Tensor
25
25
  from mindspore.common.parameter import Parameter
26
26
  from mindspore.common.initializer import initializer, XavierNormal, XavierUniform, \
@@ -36,24 +36,17 @@ __all__ = ['MultiheadAttention', 'TransformerEncoderLayer', 'TransformerDecoderL
36
36
  'TransformerEncoder', 'TransformerDecoder', 'Transformer']
37
37
 
38
38
 
39
- class _Linear(Dense):
40
- def __init__(self, in_channels, out_channels, has_bias=True):
41
- fan_in, _ = _calculate_fan_in_and_fan_out((out_channels, in_channels))
42
- bound = 1 / math.sqrt(fan_in)
43
- super().__init__(in_channels, out_channels, weight_init=HeUniform(math.sqrt(5)),
44
- bias_init=Uniform(bound), has_bias=has_bias, activation=None)
45
-
46
-
47
39
  class MultiheadAttention(Cell):
48
40
  r"""
49
41
  This is an implementation of multihead attention in the paper `Attention is all you need
50
- <https://arxiv.org/pdf/1706.03762v5.pdf>`_. Given the query vector with source length, and the
51
- key and value vector with target length, the attention will be performed as the following
42
+ <https://arxiv.org/pdf/1706.03762v5.pdf>`_. Given the query vector, the key vector and value vector,
43
+ the attention will be performed as the following:
52
44
 
53
45
  .. math::
54
- MultiHeadAttention(query, key, vector) = Concat(head_1, \dots, head_h)W^O
46
+ MultiHeadAttention(query, key, value) = Concat(head_1, \dots, head_h)W^O
55
47
 
56
- where :math:`head_i = Attention(QW_i^Q, KW_i^K, VW_i^V)`. The default is with a bias.
48
+ where :math:`head_i = Attention(QW_i^Q, KW_i^K, VW_i^V)`, and :math:`W^O` , :math:`W_i^Q` , :math:`W_i^K` ,
49
+ :math:`W_i^V` are weight matrices. The default input / output projection layers is with a bias.
57
50
 
58
51
  if query, key and value tensor is same, then it will be self attention.
59
52
 
@@ -70,36 +63,37 @@ class MultiheadAttention(Cell):
70
63
  vdim (int): Total number of features for values. Default: ``None`` (`vdim=embed_dim`).
71
64
  batch_first (bool): If ``True``, then the input and output shape are :math:`(batch, seq, feature)` ,
72
65
  else :math:`(seq, batch, feature)` . Default: ``False``.
66
+ dtype (:class:`mindspore.dtype`): Data type of Parameter. Default: ``mstype.float32`` .
73
67
 
74
68
  Inputs:
75
69
  - **query** (Tensor): The query embeddings. If `query` is unbatched, the shape is :math:`(L, E_q)`,
76
70
  otherwise the shape is :math:`(L, N, E_q)` when `batch_first=False` or :math:`(N, L, E_q)` when
77
- `batch_first=True`, where :math:`L`is the target sequence length, :math:`N` is the batch size,
78
- and :math:`E_q` is the query embedding dimension `embed_dim`. Queries are compared against
79
- key-value pairs to produce the output. See "Attention Is All You Need" for more details.
71
+ `batch_first=True` , where :math:`L`is the target sequence length, :math:`N` is the batch size,
72
+ and :math:`E_q` is the query embedding dimension `embed_dim`. Supported types: float16, float32,
73
+ float64. Queries are compared against key-value pairs to produce the output.
80
74
  - **key** (Tensor): The key embeddings. If `key` is unbatched, the shape is :math:`(S, E_k)`, otherwise
81
75
  the shape is :math:`(S, N, E_k)` when `batch_first=False` or :math:`(N, S, E_k)` when
82
- `batch_first=True`, where :math:`S` is the source sequence length, :math:`N` is the batch size,
83
- and :math:`E_k` is the key embedding dimension `kdim`. See "Attention Is All You Need" for more details.
76
+ `batch_first=True` , where :math:`S` is the source sequence length, :math:`N` is the batch size,
77
+ and :math:`E_k` is the key embedding dimension `kdim`. Supported types: float16, float32, float64.
84
78
  - **value** (Tensor): The value embeddings. If `value` is unbatched, the shape is :math:`(S, E_v)`,
85
79
  otherwise the shape is :math:`(S, N, E_v)` when `batch_first=False` or :math:`(N, S, E_v)` when
86
- `batch_first=True`, where :math:`S` is the source sequence length, :math:`N` is the batch size,
87
- and :math:`E_v` is the value embedding dimension `vdim`. See "Attention Is All You Need" for more details.
80
+ `batch_first=True` , where :math:`S` is the source sequence length, :math:`N` is the batch size,
81
+ and :math:`E_v` is the value embedding dimension `vdim`. Supported types: float16, float32, float64.
88
82
  - **key_padding_mask** (Tensor, optional): If specified, a mask of shape :math:`(N, S)` indicating which
89
83
  elements within `key` to ignore for the purpose of attention (i.e. treat as "padding").
90
- For unbatched `query`, shape should be :math:`(S)`. Binary and byte masks are supported.
84
+ For unbatched `query`, shape should be :math:`(S)`. Binary and float masks are supported.
91
85
  For a binary mask, a ``True`` value indicates that the corresponding `key` value will be ignored for
92
86
  the purpose of attention. For a float mask, it will be directly added to the corresponding `key` value.
87
+ Supported float types: float16, float32, float64. Default: ``None``.
93
88
  - **need_weights** (bool): Whether returns `attn_output_weights` in addition to `attn_outputs`.
94
89
  Default: ``True``.
95
90
  - **attn_mask** (Tensor, optional): If specified, a 2D or 3D mask preventing attention to certain positions.
96
- Must be of shape :math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the
91
+ Must be of shape :math:`(L, S)` or :math:`(N\cdot\text{num_heads}, L, S)`, where :math:`N` is the
97
92
  batch size, :math:`L` is the target sequence length, and :math:`S` is the source sequence length.
98
93
  A 2D mask will be broadcasted across the batch while a 3D mask allows for a different mask for each entry
99
- in the batch. Binary, byte, and float masks are supported. For a binary mask, a ``True`` value indicates
100
- that the corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that
101
- the corresponding position is not allowed to attend. For a float mask, the mask values will be added to
102
- the attention weight.
94
+ in the batch. For a binary mask, a ``True`` value indicates that the corresponding position is not allowed
95
+ to attend. For a float mask, the mask values will be added to the attention weight.
96
+ Supported float types: float16, float32, float64. Default: ``None``.
103
97
  - **average_attn_weights** (bool): If true, indicates that the returned `attn_weights` should be averaged
104
98
  across heads. Otherwise, `attn_weights` are provided separately per head. Note that this flag only
105
99
  has an effect when `need_weights=True`. Default: ``True`` (i.e. average weights across heads)
@@ -109,33 +103,39 @@ class MultiheadAttention(Cell):
109
103
 
110
104
  - **attn_output** - Attention outputs. If input is unbatched, the output shape is :math:`(L, E)`, otherwise
111
105
  the output shape is :math:`(L, N, E)` when `batch_first=False` or :math:`(N, L, E)` when
112
- `batch_first=True`, where :math:`L` is the target sequence length, :math:`N` is the batch size,
106
+ `batch_first=True` , where :math:`L` is the target sequence length, :math:`N` is the batch size,
113
107
  and :math:`E` is the embedding dimension `embed_dim`.
114
108
  - **attn_output_weights** - Only returned when `need_weights=True`. If `average_attn_weights=True`,
115
109
  returns attention weights averaged across heads with shape :math:`(L, S)` when input is unbatched or
116
110
  :math:`(N, L, S)` when input is batched, where :math:`N` is the batch size, :math:`L` is
117
111
  the target sequence length, and :math:`S` is the source sequence length.
118
112
  If `average_attn_weights=False`, returns attention weights per
119
- head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or
120
- :math:`(N, \text{num\_heads}, L, S)` when input is batched.
113
+ head of shape :math:`(\text{num_heads}, L, S)` when input is unbatched or
114
+ :math:`(N, \text{num_heads}, L, S)` when input is batched.
115
+
116
+ Raises:
117
+ ValueError: If the init argument `embed_dim` is not divisible by `num_heads`.
118
+ TypeError: If the input argument `key_padding_mask` is not bool or floating types.
121
119
 
122
120
  Supported Platforms:
123
121
  ``Ascend`` ``GPU`` ``CPU``
124
122
 
125
123
  Examples:
124
+ >>> import mindspore as ms
125
+ >>> import numpy as np
126
126
  >>> embed_dim, num_heads = 128, 8
127
127
  >>> seq_length, batch_size = 10, 8
128
- >>> query = Tensor(np.random.randn(seq_length, batch_size, embed_dim), mindspore.float32)
129
- >>> key = Tensor(np.random.randn(seq_length, batch_size, embed_dim), mindspore.float32)
130
- >>> value = Tensor(np.random.randn(seq_length, batch_size, embed_dim), mindspore.float32)
131
- >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
128
+ >>> query = ms.Tensor(np.random.randn(seq_length, batch_size, embed_dim), ms.float32)
129
+ >>> key = ms.Tensor(np.random.randn(seq_length, batch_size, embed_dim), ms.float32)
130
+ >>> value = ms.Tensor(np.random.randn(seq_length, batch_size, embed_dim), ms.float32)
131
+ >>> multihead_attn = ms.nn.MultiheadAttention(embed_dim, num_heads)
132
132
  >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
133
133
  >>> print(attn_output.shape)
134
134
  (10, 8, 128)
135
135
  """
136
136
 
137
- def __init__(self, embed_dim, num_heads, dropout=0., has_bias=True, add_bias_kv=False,
138
- add_zero_attn=False, kdim=None, vdim=None, batch_first=False):
137
+ def __init__(self, embed_dim, num_heads, dropout=0.0, has_bias=True, add_bias_kv=False,
138
+ add_zero_attn=False, kdim=None, vdim=None, batch_first=False, dtype=mstype.float32):
139
139
  super().__init__()
140
140
  self.embed_dim = embed_dim
141
141
  self.kdim = kdim if kdim is not None else embed_dim
@@ -149,32 +149,39 @@ class MultiheadAttention(Cell):
149
149
  if self.head_dim * num_heads != self.embed_dim:
150
150
  raise ValueError("The init argument 'embed_dim' must be divisible by 'num_heads'.")
151
151
 
152
+ if dtype is None:
153
+ dtype = mindspore.float32
152
154
  if not self._qkv_same_embed_dim:
153
- self.q_proj_weight = Parameter(initializer(XavierUniform(), (embed_dim, embed_dim)), 'q_proj_weight')
154
- self.k_proj_weight = Parameter(initializer(XavierUniform(), (embed_dim, self.kdim)), 'k_proj_weight')
155
- self.v_proj_weight = Parameter(initializer(XavierUniform(), (embed_dim, self.vdim)), 'v_proj_weight')
155
+ self.q_proj_weight = Parameter(initializer(XavierUniform(), (embed_dim, embed_dim), dtype), 'q_proj_weight')
156
+ self.k_proj_weight = Parameter(initializer(XavierUniform(), (embed_dim, self.kdim), dtype), 'k_proj_weight')
157
+ self.v_proj_weight = Parameter(initializer(XavierUniform(), (embed_dim, self.vdim), dtype), 'v_proj_weight')
156
158
  self.in_proj_weight = None
157
159
  else:
158
- self.in_proj_weight = Parameter(initializer(XavierUniform(), (3 * embed_dim, embed_dim)), 'in_proj_weight')
160
+ self.in_proj_weight = Parameter(initializer(XavierUniform(), (3 * embed_dim, embed_dim), dtype),
161
+ 'in_proj_weight')
159
162
  self.q_proj_weight = None
160
163
  self.k_proj_weight = None
161
164
  self.v_proj_weight = None
162
165
 
163
166
  if has_bias:
164
- self.in_proj_bias = Parameter(initializer('zeros', (3 * embed_dim)), 'in_proj_bias')
167
+ self.in_proj_bias = Parameter(initializer('zeros', (3 * embed_dim), dtype), 'in_proj_bias')
165
168
  else:
166
169
  self.in_proj_bias = None
167
- self.out_proj = _Linear(embed_dim, embed_dim, has_bias=has_bias)
170
+ fan_in, _ = _calculate_fan_in_and_fan_out((embed_dim, embed_dim))
171
+ bound = 1 / math.sqrt(fan_in)
172
+ self.out_proj = Dense(embed_dim, embed_dim, has_bias=has_bias, weight_init=HeUniform(math.sqrt(5)),
173
+ bias_init=Uniform(bound), dtype=dtype)
168
174
 
169
175
  if add_bias_kv:
170
- self.bias_k = Parameter(initializer(XavierNormal(), (1, 1, embed_dim)), 'bias_k')
171
- self.bias_v = Parameter(initializer(XavierNormal(), (1, 1, embed_dim)), 'bias_v')
176
+ self.bias_k = Parameter(initializer(XavierNormal(), (1, 1, embed_dim), dtype), 'bias_k')
177
+ self.bias_v = Parameter(initializer(XavierNormal(), (1, 1, embed_dim), dtype), 'bias_v')
172
178
  else:
173
179
  self.bias_k = self.bias_v = None
174
180
 
175
181
  self.add_zero_attn = add_zero_attn
176
182
  self.k_is_v = False
177
183
  self.q_is_k = False
184
+ self.dtype = dtype
178
185
 
179
186
  def __call__(self, *args, **kwargs):
180
187
  query = kwargs.get('query', args[0])
@@ -215,7 +222,7 @@ class MultiheadAttention(Cell):
215
222
  attn_mask=attn_mask, use_separate_proj_weight=True,
216
223
  q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
217
224
  v_proj_weight=self.v_proj_weight, average_attn_weights=average_attn_weights,
218
- k_is_v=self.k_is_v, q_is_k=self.q_is_k)
225
+ k_is_v=self.k_is_v, q_is_k=self.q_is_k, dtype=self.dtype)
219
226
  else:
220
227
  attn_output, attn_output_weights = multi_head_attention_forward(
221
228
  query, key, value, self.embed_dim, self.num_heads,
@@ -225,7 +232,7 @@ class MultiheadAttention(Cell):
225
232
  training=self.training,
226
233
  key_padding_mask=key_padding_mask,
227
234
  attn_mask=attn_mask, average_attn_weights=average_attn_weights,
228
- k_is_v=self.k_is_v, q_is_k=self.q_is_k)
235
+ k_is_v=self.k_is_v, q_is_k=self.q_is_k, dtype=self.dtype)
229
236
 
230
237
  if self.batch_first and is_batched:
231
238
  attn_output = attn_output.swapaxes(1, 0)
@@ -245,65 +252,90 @@ class TransformerEncoderLayer(Cell):
245
252
  dim_feedforward (int): The dimension of the feedforward layer. Default: ``2048``.
246
253
  dropout (float): The dropout value. Default: ``0.1``.
247
254
  activation (Union[str, callable, Cell]): The activation function of the intermediate layer,
248
- can be a string (`"relu"` or `"gelu"`), Cell instance (`nn.ReLU()` or `nn.GELU()`) or
249
- a callable (`ops.relu` or `ops.gelu`). Default: ``"relu"``.
255
+ can be a string (``"relu"`` or ``"gelu"``), Cell instance (:class:`mindspore.nn.ReLU` or
256
+ :class:`mindspore.nn.GELU` ) or a callable ( :func:`mindspore.ops.relu` or
257
+ :func:`mindspore.ops.gelu` ). Default: ``"relu"``.
250
258
  layer_norm_eps (float): The epsilon value in LayerNorm modules. Default: ``1e-5``.
251
- batch_first (bool): If `batch_first = True`, then the shape of input and output tensors is
252
- :math:`(batch, seq, feature)` , otherwise the shape is :math:`(seq, batch, feature)` .
259
+ batch_first (bool): If `batch_first=True` , then the shape of input and output tensors is
260
+ :math:`(batch, seq, feature)` , otherwise the shape is :math:`(seq, batch, feature)` .
253
261
  Default: ``False``.
254
- norm_first (bool): If `norm_first = True`, layer norm is done prior to attention and feedforward
255
- operations, respectively. Default: ``False``.
262
+ norm_first (bool): If `norm_first = True`, layer norm is located prior to attention and feedforward
263
+ operations; if `norm_first = False`, layer norm is located after the attention and feedforward
264
+ operations. Default: ``False``.
265
+ dtype (:class:`mindspore.dtype`): Data type of Parameter. Default: ``mstype.float32`` .
256
266
 
257
267
  Inputs:
258
- - **src** (Tensor): the sequence to the encoder layer.
259
- - **src_mask** (Tensor, optional): the mask for the src sequence. Default: ``None``.
260
- - **src_key_padding_mask** (Tensor, optional): the mask for the src keys per batch.
261
- Default: ``None``.
268
+ - **src** (Tensor): the sequence to the encoder layer. For unbatched input, the shape is
269
+ :math:`(S, E)` ; otherwise if `batch_first=False` , the shape is :math:`(S, N, E)` and if
270
+ `batch_first=True` , the shape is :math:`(S, N, E)`, where :math:`(S)` is the source sequence
271
+ length, :math:`(N)` is the batch number and :math:`(E)` is the feature number.
272
+ Supported types: float16, float32, float64.
273
+ - **src_mask** (Tensor, optional): the mask for the src sequence. The shape is :math:`(S, S)`
274
+ or :math:`(N*nhead, S, S)`. Supported types: float16, float32, float64, bool. Default: ``None``.
275
+ - **src_key_padding_mask** (Tensor, optional): the mask for the src keys per batch. The shape is
276
+ :math:`(S)` for unbatched input, otherwise :math:`(N, S)` . Supported types: float16, float32,
277
+ float64, bool. Default: ``None``.
262
278
 
263
279
  Outputs:
264
- Tensor.
280
+ Tensor. The shape and dtype of Tensor is the same with `src` .
281
+
282
+ Raises:
283
+ ValueError: If the init argument `activation` is not str, callable or Cell instance.
284
+ ValueError: If the init argument `activation` is not :class:`mindspore.nn.ReLU`,
285
+ :class:`mindspore.nn.GELU` instance, :func:`mindspore.ops.relu`,
286
+ :func:`mindspore.ops.gelu`, "relu" or "gelu" .
265
287
 
266
288
  Supported Platforms:
267
289
  ``Ascend`` ``GPU`` ``CPU``
268
290
 
269
291
  Examples:
270
- >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
271
- >>> src = Tensor(np.random.rand(10, 32, 512), mindspore.float32)
292
+ >>> import mindspore as ms
293
+ >>> import numpy as np
294
+ >>> encoder_layer = ms.nn.TransformerEncoderLayer(d_model=512, nhead=8)
295
+ >>> src = ms.Tensor(np.random.rand(10, 32, 512), ms.float32)
272
296
  >>> out = encoder_layer(src)
297
+ >>> print(out.shape)
298
+ (10, 32, 512)
273
299
  >>> # Alternatively, when batch_first=True:
274
- >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True)
275
- >>> src = Tensor(np.random.rand(32, 10, 512), mindspore.float32)
300
+ >>> encoder_layer = ms.nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True)
301
+ >>> src = ms.Tensor(np.random.rand(32, 10, 512), ms.float32)
276
302
  >>> out = encoder_layer(src)
277
303
  >>> print(out.shape)
278
304
  (32, 10, 512)
279
305
  """
280
- __constants__ = ['batch_first', 'norm_first']
281
306
 
282
307
  def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1,
283
308
  activation: Union[str, Cell, callable] = 'relu', layer_norm_eps: float = 1e-5,
284
- batch_first: bool = False, norm_first: bool = False):
309
+ batch_first: bool = False, norm_first: bool = False, dtype=mstype.float32):
285
310
  super().__init__()
286
- self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first)
311
+ self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first, dtype=dtype)
287
312
  # feedforward layer
288
- self.linear1 = _Linear(d_model, dim_feedforward)
313
+ fan_in, _ = _calculate_fan_in_and_fan_out((dim_feedforward, d_model))
314
+ bound = 1 / math.sqrt(fan_in)
315
+ self.dense1 = Dense(d_model, dim_feedforward, weight_init=HeUniform(math.sqrt(5)),
316
+ bias_init=Uniform(bound), dtype=dtype)
289
317
  self.dropout = Dropout(p=dropout)
290
- self.linear2 = _Linear(dim_feedforward, d_model)
318
+ fan_in1, _ = _calculate_fan_in_and_fan_out((d_model, dim_feedforward))
319
+ bound1 = 1 / math.sqrt(fan_in1)
320
+ self.dense2 = Dense(dim_feedforward, d_model, weight_init=HeUniform(math.sqrt(5)),
321
+ bias_init=Uniform(bound1), dtype=dtype)
291
322
 
292
323
  self.norm_first = norm_first
293
- self.norm1 = LayerNorm((d_model,), epsilon=layer_norm_eps)
294
- self.norm2 = LayerNorm((d_model,), epsilon=layer_norm_eps)
324
+ self.norm1 = LayerNorm((d_model,), epsilon=layer_norm_eps, dtype=dtype)
325
+ self.norm2 = LayerNorm((d_model,), epsilon=layer_norm_eps, dtype=dtype)
295
326
  self.dropout1 = Dropout(p=dropout)
296
327
  self.dropout2 = Dropout(p=dropout)
328
+ self.activation1 = activation
297
329
 
298
330
  if not isinstance(activation, str) and not isinstance(activation, Cell) \
299
331
  and not callable(activation):
300
332
  raise ValueError(f"The argument 'activation' must be str, callable or Cell instance,"
301
333
  f" but get {activation}.")
302
- if isinstance(activation, Cell) and (not isinstance(activation, ReLU) or \
334
+ if isinstance(activation, Cell) and (not isinstance(activation, ReLU) and \
303
335
  not isinstance(activation, GELU)):
304
336
  raise ValueError(f"The argument 'activation' must be nn.ReLU or nn.GELU instance,"
305
337
  f" but get {activation}.")
306
- if callable(activation) and (activation is not ops.relu or \
338
+ if callable(activation) and (activation is not ops.relu and \
307
339
  activation is not ops.gelu):
308
340
  raise ValueError(f"The argument 'activation' must be ops.relu or ops.gelu instance,"
309
341
  f" but get {activation}.")
@@ -311,6 +343,14 @@ class TransformerEncoderLayer(Cell):
311
343
  if isinstance(activation, str):
312
344
  activation = _get_activation_fn(activation)
313
345
  self.activation = activation
346
+ self.d_model = d_model
347
+ self.nhead = nhead
348
+ self.dim_feedforward = dim_feedforward
349
+ self.dropout_num = dropout
350
+ self.layernorm_eps = layer_norm_eps
351
+ self.batch_first = batch_first
352
+ self.norm_first = norm_first
353
+ self.dtype = dtype
314
354
 
315
355
  def construct(self, src: Tensor, src_mask: Optional[Tensor] = None,
316
356
  src_key_padding_mask: Optional[Tensor] = None):
@@ -338,7 +378,7 @@ class TransformerEncoderLayer(Cell):
338
378
  return self.dropout1(x)
339
379
 
340
380
  def _ff_block(self, x):
341
- x = self.linear2(self.dropout(self.activation(self.linear1(x))))
381
+ x = self.dense2(self.dropout(self.activation(self.dense1(x))))
342
382
  return self.dropout2(x)
343
383
 
344
384
 
@@ -353,74 +393,101 @@ class TransformerDecoderLayer(Cell):
353
393
  dim_feedforward (int): The dimension of the feedforward layer. Default: ``2048``.
354
394
  dropout (float): The dropout value. Default: ``0.1``.
355
395
  activation (Union[str, callable, Cell]): The activation function of the intermediate layer,
356
- can be a string (`"relu"` or `"gelu"`), Cell instance (`nn.ReLU()` or `nn.GELU()`) or
357
- a callable (`ops.relu` or `ops.gelu`). Default: ``"relu"``
396
+ can be a string (``"relu"`` or ``"gelu"``), Cell instance (:class:`mindspore.nn.ReLU` or
397
+ :class:`mindspore.nn.GELU` ) or a callable ( :func:`mindspore.ops.relu` or
398
+ :func:`mindspore.ops.gelu` ). Default: ``"relu"``.
358
399
  layer_norm_eps (float): The epsilon value in LayerNorm modules. Default: ``1e-5``.
359
- batch_first (bool): If `batch_first = True`, then the shape of input and output tensors is
400
+ batch_first (bool): If `batch_first=True` , then the shape of input and output tensors is
360
401
  :math:`(batch, seq, feature)` , otherwise the shape is :math:`(seq, batch, feature)`.
361
402
  Default: ``False``.
362
- norm_first (bool): If `norm_first = True`, layer norm is done prior to attention and feedforward
363
- operations, respectively. Default: ``False``.
403
+ norm_first (bool): If `norm_first = True`, layer norm is located prior to attention and feedforward
404
+ operations; if `norm_first = False`, layer norm is located after the attention and feedforward
405
+ operations. Default: ``False``.
406
+ dtype (:class:`mindspore.dtype`): Data type of Parameter. Default: ``mstype.float32`` .
364
407
 
365
408
  Inputs:
366
- - **tgt** (Tensor): The sequence to the decoder layer.
367
- - **memory** (Tensor): The sequence from the last layer of the encoder.
368
- - **tgt_mask** (Tensor, optional): The mask of the tgt sequence. Default: ``None``.
369
- - **memory_mask** (Tensor, optional): The mask of the memory sequence. Default: ``None``.
370
- - **tgt_key_padding_mask** (Tensor, optional): The mask of the tgt keys per batch.
371
- Default: ``None``.
372
- - **memory_key_padding_mask** (Tensor, optional): The mask of the memory keys per batch.
373
- Default: ``None``.
409
+ - **tgt** (Tensor): The sequence to the decoder layer. For unbatched input, the shape is
410
+ :math:`(T, E)` ; otherwise if `batch_first=False` , the shape is :math:`(T, N, E)` and if
411
+ `batch_first=True` , the shape is :math:`(T, N, E)`, where :math:`(T)` is the target sequence
412
+ length. Supported types: float16, float32, float64.
413
+ - **memory** (Tensor): The sequence from the last layer of the encoder. Supported types: float16,
414
+ float32, float64.
415
+ - **tgt_mask** (Tensor, optional): The mask of the tgt sequence. The shape is :math:`(T, T)`
416
+ or :math:`(N*nhead, T, T)`. Supported types: float16, float32, float64, bool. Default: ``None``.
417
+ - **memory_mask** (Tensor, optional): The mask of the memory sequence. The shape is
418
+ :math:`(T, S)` . Supported types: float16, float32, float64, bool. Default: ``None``.
419
+ - **tgt_key_padding_mask** (Tensor, optional): The mask of the tgt keys per batch. The shape is
420
+ :math:`(T)` for unbatched input, otherwise :math:`(N, S)` . Supported types: float16, float32,
421
+ float64, bool. Default: ``None``.
422
+ - **memory_key_padding_mask** (Tensor, optional): The mask of the memory keys per batch. The shape
423
+ is :math:`(S)` for unbatched input, otherwise :math:`(N, S)` . Supported types: float16, float32,
424
+ float64, bool. Default: ``None``.
374
425
 
375
426
  Outputs:
376
- Tensor.
427
+ Tensor. The shape and dtype of Tensor is the same with `tgt` .
428
+
429
+ Raises:
430
+ ValueError: If the init argument `activation` is not str, callable or Cell instance.
431
+ ValueError: If the init argument `activation` is not :class:`mindspore.nn.ReLU`,
432
+ :class:`mindspore.nn.GELU` instance, :func:`mindspore.ops.relu`,
433
+ :func:`mindspore.ops.gelu` , "relu" or "gelu" .
377
434
 
378
435
  Supported Platforms:
379
436
  ``Ascend`` ``GPU`` ``CPU``
380
437
 
381
438
  Examples:
382
- >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
383
- >>> memory = Tensor(np.random.rand(10, 32, 512), mindspore.float32)
384
- >>> tgt = Tensor(np.random.rand(20, 32, 512), mindspore.float32)
439
+ >>> import mindspore as ms
440
+ >>> import numpy as np
441
+ >>> decoder_layer = ms.nn.TransformerDecoderLayer(d_model=512, nhead=8)
442
+ >>> memory = ms.Tensor(np.random.rand(10, 32, 512), ms.float32)
443
+ >>> tgt = ms.Tensor(np.random.rand(20, 32, 512), ms.float32)
385
444
  >>> out = decoder_layer(tgt, memory)
445
+ >>> print(out.shape)
446
+ (20, 32, 512)
386
447
  >>> # Alternatively, when `batch_first` is ``True``:
387
- >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8, batch_first=True)
388
- >>> memory = Tensor(np.random.rand(32, 10, 512), mindspore.float32)
389
- >>> tgt = Tensor(np.random.rand(32, 20, 512), mindspore.float32)
448
+ >>> decoder_layer = ms.nn.TransformerDecoderLayer(d_model=512, nhead=8, batch_first=True)
449
+ >>> memory = ms.Tensor(np.random.rand(32, 10, 512), ms.float32)
450
+ >>> tgt = ms.Tensor(np.random.rand(32, 20, 512), ms.float32)
390
451
  >>> out = decoder_layer(tgt, memory)
391
452
  >>> print(out.shape)
392
453
  (32, 20, 512)
393
454
  """
394
- __constants__ = ['batch_first', 'norm_first']
395
455
 
396
456
  def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1,
397
457
  activation: Union[str, Cell, callable] = 'relu', layer_norm_eps: float = 1e-5,
398
- batch_first: bool = False, norm_first: bool = False):
458
+ batch_first: bool = False, norm_first: bool = False, dtype=mstype.float32):
399
459
  super().__init__()
400
- self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first)
401
- self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first)
460
+ self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first, dtype=dtype)
461
+ self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first, dtype=dtype)
402
462
  # feedforward layer
403
- self.linear1 = _Linear(d_model, dim_feedforward)
463
+ fan_in, _ = _calculate_fan_in_and_fan_out((dim_feedforward, d_model))
464
+ bound = 1 / math.sqrt(fan_in)
465
+ self.dense1 = Dense(d_model, dim_feedforward, weight_init=HeUniform(math.sqrt(5)),
466
+ bias_init=Uniform(bound), dtype=dtype)
404
467
  self.dropout = Dropout(p=dropout)
405
- self.linear2 = _Linear(dim_feedforward, d_model)
468
+ fan_in1, _ = _calculate_fan_in_and_fan_out((d_model, dim_feedforward))
469
+ bound1 = 1 / math.sqrt(fan_in1)
470
+ self.dense2 = Dense(dim_feedforward, d_model, weight_init=HeUniform(math.sqrt(5)),
471
+ bias_init=Uniform(bound1), dtype=dtype)
406
472
 
407
473
  self.norm_first = norm_first
408
- self.norm1 = LayerNorm((d_model,), epsilon=layer_norm_eps)
409
- self.norm2 = LayerNorm((d_model,), epsilon=layer_norm_eps)
410
- self.norm3 = LayerNorm((d_model,), epsilon=layer_norm_eps)
474
+ self.norm1 = LayerNorm((d_model,), epsilon=layer_norm_eps, dtype=dtype)
475
+ self.norm2 = LayerNorm((d_model,), epsilon=layer_norm_eps, dtype=dtype)
476
+ self.norm3 = LayerNorm((d_model,), epsilon=layer_norm_eps, dtype=dtype)
411
477
  self.dropout1 = Dropout(p=dropout)
412
478
  self.dropout2 = Dropout(p=dropout)
413
479
  self.dropout3 = Dropout(p=dropout)
480
+ self.activation1 = activation
414
481
 
415
482
  if not isinstance(activation, str) and not isinstance(activation, Cell) \
416
483
  and not callable(activation):
417
484
  raise ValueError(f"The argument 'activation' must be str, callable or Cell instance,"
418
485
  f" but get {activation}.")
419
- if isinstance(activation, Cell) and (not isinstance(activation, ReLU) or \
486
+ if isinstance(activation, Cell) and (not isinstance(activation, ReLU) and \
420
487
  not isinstance(activation, GELU)):
421
488
  raise ValueError(f"The argument 'activation' must be nn.ReLU or nn.GELU instance,"
422
489
  f" but get {activation}.")
423
- if callable(activation) and (activation is not ops.relu or \
490
+ if callable(activation) and (activation is not ops.relu and \
424
491
  activation is not ops.gelu):
425
492
  raise ValueError(f"The argument 'activation' must be ops.relu or ops.gelu instance,"
426
493
  f" but get {activation}.")
@@ -428,6 +495,14 @@ class TransformerDecoderLayer(Cell):
428
495
  if isinstance(activation, str):
429
496
  activation = _get_activation_fn(activation)
430
497
  self.activation = activation
498
+ self.d_model = d_model
499
+ self.nhead = nhead
500
+ self.dim_feedforward = dim_feedforward
501
+ self.dropout_num = dropout
502
+ self.layernorm_eps = layer_norm_eps
503
+ self.batch_first = batch_first
504
+ self.norm_first = norm_first
505
+ self.dtype = dtype
431
506
 
432
507
  def construct(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None,
433
508
  memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None,
@@ -459,46 +534,61 @@ class TransformerDecoderLayer(Cell):
459
534
  return self.dropout2(x)
460
535
 
461
536
  def _ff_block(self, x):
462
- x = self.linear2(self.dropout(self.activation(self.linear1(x))))
537
+ x = self.dense2(self.dropout(self.activation(self.dense1(x))))
463
538
  return self.dropout3(x)
464
539
 
465
540
 
466
541
  class TransformerEncoder(Cell):
467
542
  r"""
468
- Transformer Encoder module with multi-layer stacked of `TransformerEncoderLayer`, including multihead self
543
+ Transformer Encoder module with multi-layer stacked of `TransformerEncoderLayer`, including multihead
469
544
  attention and feedforward layer. Users can build the
470
545
  BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters.
471
546
 
472
547
  Args:
473
- encoder_layer (Cell): An instance of the TransformerEncoderLayer() class.
548
+ encoder_layer (Cell): An instance of the :class:`mindspore.nn.TransformerEncoderLayer` class.
474
549
  num_layers (int): The number of encoder-layers in the encoder.
475
- norm (Cell, optional): The layer normalization module.
550
+ norm (Cell, optional): The layer normalization module. Default: ``None``.
476
551
 
477
552
  Inputs:
478
- - **src** (Tensor): The sequence to the encoder.
479
- - **src_mask** (Tensor, optional): The mask of the src sequence. Default: ``None``.
480
- - **src_key_padding_mask** (Tensor, optional): the mask of the src keys per batch .
481
- Default: ``None``.
553
+ - **src** (Tensor): The sequence to the encoder. For unbatched input, the shape is
554
+ :math:`(S, E)` ; otherwise if `batch_first=False` in TransformerEncoderLayer, the shape is
555
+ :math:`(S, N, E)` and if `batch_first=True` , the shape is :math:`(S, N, E)`, where :math:`(S)` is the
556
+ source sequence length, :math:`(N)` is the batch number and :math:`(E)` is the feature number.
557
+ Supported types: float16, float32, float64.
558
+ - **src_mask** (Tensor, optional): The mask of the src sequence. The shape is :math:`(S, S)`
559
+ or :math:`(N*nhead, S, S)` , where `nhead` is the arguent in TransformerDecoderLayer.
560
+ Supported types: float16, float32, float64, bool. Default: ``None``.
561
+ - **src_key_padding_mask** (Tensor, optional): the mask of the src keys per batch. The shape is
562
+ :math:`(S)` for unbatched input, otherwise :math:`(N, S)` . Supported types: float16, float32,
563
+ float64, bool. Default: ``None``.
482
564
 
483
565
  Outputs:
484
- Tensor.
566
+ Tensor. The shape and dtype of Tensor is the same with `src` .
567
+
568
+ Raises:
569
+ AssertionError: If the input argument `src_key_padding_mask` is not bool or floating types.
485
570
 
486
571
  Supported Platforms:
487
572
  ``Ascend`` ``GPU`` ``CPU``
488
573
 
489
574
  Examples:
490
- >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
491
- >>> transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
492
- >>> src = Tensor(np.random.rand(10, 32, 512), mindspore.float32)
575
+ >>> import mindspore as ms
576
+ >>> import numpy as np
577
+ >>> encoder_layer = ms.nn.TransformerEncoderLayer(d_model=512, nhead=8)
578
+ >>> transformer_encoder = ms.nn.TransformerEncoder(encoder_layer, num_layers=6)
579
+ >>> src = ms.Tensor(np.random.rand(10, 32, 512), ms.float32)
493
580
  >>> out = transformer_encoder(src)
494
581
  >>> print(out.shape)
495
582
  (10, 32, 512)
496
583
  """
497
- __constants__ = ['norm']
498
584
 
499
585
  def __init__(self, encoder_layer, num_layers, norm=None):
500
586
  super(TransformerEncoder, self).__init__()
501
- self.layers = _get_clones(encoder_layer, num_layers)
587
+ layers = TransformerEncoderLayer(encoder_layer.d_model, encoder_layer.nhead, encoder_layer.dim_feedforward,
588
+ encoder_layer.dropout_num, encoder_layer.activation1,
589
+ encoder_layer.layernorm_eps, encoder_layer.batch_first,
590
+ encoder_layer.norm_first, dtype=encoder_layer.dtype)
591
+ self.layers = CellList([layers for _ in range(num_layers)])
502
592
  self.num_layers = num_layers
503
593
  self.norm = norm
504
594
 
@@ -527,38 +617,51 @@ class TransformerDecoder(Cell):
527
617
  Args:
528
618
  decoder_layer (Cell): An instance of the :class:`mindspore.nn.TransformerDecoderLayer` class.
529
619
  num_layers (int): The number of decoder-layers in the decoder.
530
- norm (Cell, optional): The layer normalization module.
620
+ norm (Cell, optional): The layer normalization module. Default: ``None``.
531
621
 
532
622
  Inputs:
533
- - **tgt** (Tensor): The sequence to the decoder.
534
- - **memory** (Tensor): The sequence from the last layer of the encoder.
535
- - **tgt_mask** (Tensor, optional): the mask of the tgt sequence. Default: ``None``.
536
- - **memory_mask** (Tensor, optional): the mask of the memory sequence. Default: ``None``.
537
- - **tgt_key_padding_mask** (Tensor, optional): the mask of the tgt keys per batch.
538
- Default: ``None``.
539
- - **memory_key_padding_mask** (Tensor, optional): the mask of the memory keys per batch.
540
- Default: ``None``.
623
+ - **tgt** (Tensor): The sequence to the decoder. For unbatched input, the shape is
624
+ :math:`(T, E)` ; otherwise if `batch_first=False` in TransformerDecoderLayer, the shape is
625
+ :math:`(T, N, E)` and if `batch_first=True` , the shape is :math:`(T, N, E)`, where :math:`(T)` is the
626
+ target sequence length. Supported types: float16, float32, float64.
627
+ - **memory** (Tensor): The sequence from the last layer of the encoder. Supported types: float16,
628
+ float32, float64.
629
+ - **tgt_mask** (Tensor, optional): the mask of the tgt sequence. The shape is :math:`(T, T)`
630
+ or :math:`(N*nhead, T, T)` , where `nhead` is the arguent in TransformerDecoderLayer.
631
+ Supported types: float16, float32, float64, bool. Default: ``None``.
632
+ - **memory_mask** (Tensor, optional): the mask of the memory sequence. The shape is
633
+ :math:`(T, S)` . Supported types: float16, float32, float64, bool. Default: ``None``.
634
+ - **tgt_key_padding_mask** (Tensor, optional): the mask of the tgt keys per batch. Supported
635
+ types: float16, float32, float64, bool. Default: ``None``.
636
+ - **memory_key_padding_mask** (Tensor, optional): the mask of the memory keys per batch. The shape
637
+ is :math:`(S)` for unbatched input, otherwise :math:`(N, S)` . Supported types: float16, float32,
638
+ float64, bool. Default: ``None``.
541
639
 
542
640
  Outputs:
543
- Tensor.
641
+ Tensor. The shape and dtype of Tensor is the same with `tgt` .
544
642
 
545
643
  Supported Platforms:
546
644
  ``Ascend`` ``GPU`` ``CPU``
547
645
 
548
646
  Examples:
549
- >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
550
- >>> transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)
551
- >>> memory = Tensor(np.random.rand(10, 32, 512), mindspore.float32)
552
- >>> tgt = Tensor(np.random.rand(20, 32, 512), mindspore.float32)
647
+ >>> import mindspore as ms
648
+ >>> import numpy as np
649
+ >>> decoder_layer = ms.nn.TransformerDecoderLayer(d_model=512, nhead=8)
650
+ >>> transformer_decoder = ms.nn.TransformerDecoder(decoder_layer, num_layers=6)
651
+ >>> memory = ms.Tensor(np.random.rand(10, 32, 512), ms.float32)
652
+ >>> tgt = ms.Tensor(np.random.rand(20, 32, 512), ms.float32)
553
653
  >>> out = transformer_decoder(tgt, memory)
554
654
  >>> print(out.shape)
555
655
  (20, 32, 512)
556
656
  """
557
- __constants__ = ['norm']
558
657
 
559
658
  def __init__(self, decoder_layer, num_layers, norm=None):
560
659
  super(TransformerDecoder, self).__init__()
561
- self.layers = _get_clones(decoder_layer, num_layers)
660
+ layers = TransformerDecoderLayer(decoder_layer.d_model, decoder_layer.nhead, decoder_layer.dim_feedforward,
661
+ decoder_layer.dropout_num, decoder_layer.activation1,
662
+ decoder_layer.layernorm_eps, decoder_layer.batch_first,
663
+ decoder_layer.norm_first, dtype=decoder_layer.dtype)
664
+ self.layers = CellList([layers for _ in range(num_layers)])
562
665
  self.num_layers = num_layers
563
666
  self.norm = norm
564
667
 
@@ -566,7 +669,6 @@ class TransformerDecoder(Cell):
566
669
  memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None,
567
670
  memory_key_padding_mask: Optional[Tensor] = None):
568
671
  output = tgt
569
-
570
672
  for mod in self.layers:
571
673
  output = mod(output, memory, tgt_mask=tgt_mask,
572
674
  memory_mask=memory_mask,
@@ -582,52 +684,74 @@ class TransformerDecoder(Cell):
582
684
  class Transformer(Cell):
583
685
  r"""
584
686
  Transformer module including encoder and decoder. The difference with the original implements is the module use
585
- the residual addition before the layer normalization. And the default hidden act is `gelu`.
687
+ the residual addition before the layer normalization. And the default hidden activation is `gelu`.
586
688
  The details can be found in `Attention is all you need <https://arxiv.org/pdf/1706.03762v5.pdf>`_.
587
689
 
588
690
  Args:
589
- d_model (int): The number of expected features in the inputs tensor. Default: ``512``.
691
+ d_model (int): The number of expected features in the inputs tensor for Encoder and Decoder. Default: ``512``.
590
692
  nhead (int): The number of heads in the MultiheadAttention modules. Default: ``8``.
591
693
  num_encoder_layers (int): The number of encoder-layers in the encoder. Default: ``6``.
592
694
  num_decoder_layers (int): The number of decoder-layers in the decoder. Default: ``6``.
593
695
  dim_feedforward (int): The dimension of the feedforward layer. Default: ``2048``.
594
696
  dropout (float): The dropout value. Default: ``0.1``.
595
697
  activation (Union[str, callable, Cell]): The activation function of the intermediate layer,
596
- can be a string (`"relu"` or `"gelu"`), Cell instance (`nn.ReLU()` or `nn.GELU()`) or
597
- a callable (`ops.relu` or `ops.gelu`). Default: ``"relu"``
698
+ can be a string (``"relu"`` or ``"gelu"``), Cell instance (:class:`mindspore.nn.ReLU` or
699
+ :class:`mindspore.nn.GELU` ) or a callable ( :func:`mindspore.ops.relu` or
700
+ :func:`mindspore.ops.gelu` ). Default: ``"relu"``.
598
701
  custom_encoder (Cell): Custom encoder. Default: ``None``.
599
702
  custom_decoder (Cell): Custom decoder. Default: ``None``.
600
703
  layer_norm_eps (float): the epsilion value in layer normalization module. Default: ``1e-5``.
601
- batch_first (bool): If `batch_first = True`, then the shape of input and output tensors is
704
+ batch_first (bool): If `batch_first=True`, then the shape of input and output tensors is
602
705
  :math:`(batch, seq, feature)` , otherwise the shape is :math:`(seq, batch, feature)` .
603
706
  Default: ``False``.
604
- norm_first (bool): If `norm_first = True`, layer norm is done prior to attention and feedforward
605
- operations, respectively. Default: ``False``.
707
+ norm_first (bool): If `norm_first = True`, layer norm is located prior to attention and feedforward
708
+ operations; if `norm_first = False`, layer norm is located after the attention and feedforward
709
+ operations. Default: ``False``.
710
+ dtype (:class:`mindspore.dtype`): Data type of Parameter. Default: ``mstype.float32`` .
606
711
 
607
712
  Inputs:
608
- - **src** (Tensor): The source sequence to the encoder.
609
- - **tgt** (Tensor): The target sequence to the decoder.
610
- - **src_mask** (Tensor, optional): The mask of the src sequence. Default: ``None``.
611
- - **tgt_mask** (Tensor, optional): The mask of the tgt sequence. Default: ``None``.
612
- - **memory_mask** (Tensor, optional): The additive mask of the encoder output.
613
- Default: ``None``.
614
- - **src_key_padding_mask** (Tensor, optional): The mask of src keys per batch.
615
- Default: ``None``.
616
- - **tgt_key_padding_mask** (Tensor, optional): The mask of tgt keys per batch.
617
- Default: ``None``.
618
- - **memory_key_padding_mask** (Tensor, optional): The mask of memory keys per batch.
619
- Default: ``None``.
713
+ - **src** (Tensor): The source sequence to the encoder. For unbatched input, the shape is
714
+ :math:`(S, E)` ; otherwise if `batch_first=False` , the shape is :math:`(S, N, E)` and if
715
+ `batch_first=True` , the shape is :math:`(S, N, E)`, where :math:`(S)` is the source sequence
716
+ length, :math:`(N)` is the batch number and :math:`(E)` is the feature number. Supported
717
+ types: float16, float32, float64.
718
+ - **tgt** (Tensor): The target sequence to the decoder. For unbatched input, the shape is
719
+ :math:`(T, E)` ; otherwise if `batch_first=False` , the shape is :math:`(T, N, E)` and if
720
+ `batch_first=True` , the shape is :math:`(T, N, E)`, where :math:`(T)` is the target sequence
721
+ length. Supported types: float16, float32, float64.
722
+ - **src_mask** (Tensor, optional): The mask of the src sequence. The shape is :math:`(S, S)`
723
+ or :math:`(N*nhead, S, S)`. Supported types: float16, float32, float64, bool. Default: ``None``.
724
+ - **tgt_mask** (Tensor, optional): The mask of the tgt sequence. The shape is :math:`(T, T)`
725
+ or :math:`(N*nhead, T, T)`. Supported types: float16, float32, float64, bool. Default: ``None``.
726
+ - **memory_mask** (Tensor, optional): The additive mask of the encoder output. The shape is
727
+ :math:`(T, S)` . Supported types: float16, float32, float64, bool. Default: ``None``.
728
+ - **src_key_padding_mask** (Tensor, optional): The mask of src keys per batch. The shape is
729
+ :math:`(S)` for unbatched input, otherwise :math:`(N, S)` . Supported types: float16, float32,
730
+ float64, bool. Default: ``None``.
731
+ - **tgt_key_padding_mask** (Tensor, optional): The mask of tgt keys per batch. The shape is
732
+ :math:`(T)` for unbatched input, otherwise :math:`(N, S)` . Supported types: float16, float32,
733
+ float64, bool. Default: ``None``.
734
+ - **memory_key_padding_mask** (Tensor, optional): The mask of memory keys per batch. The shape
735
+ is :math:`(S)` for unbatched input, otherwise :math:`(N, S)` . Supported types: float16,
736
+ float32, float64, bool. Default: ``None``.
620
737
 
621
738
  Outputs:
622
- Tensor.
739
+ Tensor. The shape is :math:`(T, E)` for unbatched input, otherwise if `batch_first=False` , the shape is
740
+ :math:`(T, N, E)` and if `batch_first=True` , the shape is :math:`(N, T, E)`.
741
+
742
+ Raises:
743
+ ValueError: If the batch sizes of the init argument `src` and `tgt` are not equal.
744
+ ValueError: If the number of features of the init argument `src` and `tgt` is not equal to that of `d_model`.
623
745
 
624
746
  Supported Platforms:
625
747
  ``Ascend`` ``GPU`` ``CPU``
626
748
 
627
749
  Examples:
628
- >>> transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12)
629
- >>> src = Tensor(np.random.rand(10, 32, 512), mindspore.float32)
630
- >>> tgt = Tensor(np.random.rand(20, 32, 512), mindspore.float32)
750
+ >>> import mindspore as ms
751
+ >>> import numpy as np
752
+ >>> transformer_model = ms.nn.Transformer(nhead=16, num_encoder_layers=12)
753
+ >>> src = ms.Tensor(np.random.rand(10, 32, 512), ms.float32)
754
+ >>> tgt = ms.Tensor(np.random.rand(20, 32, 512), ms.float32)
631
755
  >>> out = transformer_model(src, tgt)
632
756
  >>> print(out.shape)
633
757
  (20, 32, 512)
@@ -637,23 +761,23 @@ class Transformer(Cell):
637
761
  num_decoder_layers: int = 6, dim_feedforward: int = 2048, dropout: float = 0.1,
638
762
  activation: Union[str, Cell, callable] = 'relu', custom_encoder: Optional[Cell] = None,
639
763
  custom_decoder: Optional[Cell] = None, layer_norm_eps: float = 1e-5,
640
- batch_first: bool = False, norm_first: bool = False):
764
+ batch_first: bool = False, norm_first: bool = False, dtype=mstype.float32):
641
765
  super(Transformer, self).__init__()
642
766
 
643
767
  if custom_encoder is not None:
644
768
  self.encoder = custom_encoder
645
769
  else:
646
770
  encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout,
647
- activation, layer_norm_eps, batch_first, norm_first)
648
- encoder_norm = LayerNorm((d_model,), epsilon=layer_norm_eps)
771
+ activation, layer_norm_eps, batch_first, norm_first, dtype=dtype)
772
+ encoder_norm = LayerNorm((d_model,), epsilon=layer_norm_eps, dtype=dtype)
649
773
  self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
650
774
 
651
775
  if custom_decoder is not None:
652
776
  self.decoder = custom_decoder
653
777
  else:
654
778
  decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout,
655
- activation, layer_norm_eps, batch_first, norm_first)
656
- decoder_norm = LayerNorm((d_model,), epsilon=layer_norm_eps)
779
+ activation, layer_norm_eps, batch_first, norm_first, dtype=dtype)
780
+ decoder_norm = LayerNorm((d_model,), epsilon=layer_norm_eps, dtype=dtype)
657
781
  self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm)
658
782
 
659
783
  for _, p in self.parameters_and_names():
@@ -695,7 +819,3 @@ def _get_activation_fn(activation: str):
695
819
  return ops.gelu
696
820
 
697
821
  raise ValueError(f"The activation must be relu/gelu, but get {activation}")
698
-
699
-
700
- def _get_clones(module, N):
701
- return CellList([copy.deepcopy(module) for i in range(N)])