mindspore 1.10.0__cp39-cp39-win_amd64.whl → 2.0.0rc1__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (966) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/ConcurrencyCheck.dll +0 -0
  3. mindspore/CppBuildInsights.dll +0 -0
  4. mindspore/CppCoreCheck.dll +0 -0
  5. mindspore/EnumIndex.dll +0 -0
  6. mindspore/EspXEngine.dll +0 -0
  7. mindspore/HResultCheck.dll +0 -0
  8. mindspore/KernelTraceControl.dll +0 -0
  9. mindspore/LocalESPC.dll +0 -0
  10. mindspore/Microsoft.Diagnostics.Tracing.EventSource.dll +0 -0
  11. mindspore/Microsoft.VisualStudio.RemoteControl.dll +0 -0
  12. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  13. mindspore/Microsoft.VisualStudio.Utilities.Internal.dll +0 -0
  14. mindspore/Newtonsoft.Json.dll +0 -0
  15. mindspore/System.Runtime.CompilerServices.Unsafe.dll +0 -0
  16. mindspore/VariantClear.dll +0 -0
  17. mindspore/__init__.py +9 -4
  18. mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
  19. mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
  20. mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
  21. mindspore/_check_jit_forbidden_api.py +102 -0
  22. mindspore/_checkparam.py +1066 -1001
  23. mindspore/_extends/builtin_operations.py +32 -4
  24. mindspore/_extends/graph_kernel/model/graph_split.py +66 -222
  25. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +12 -9
  26. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +119 -26
  27. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +50 -50
  28. mindspore/_extends/parallel_compile/akg_compiler/util.py +9 -6
  29. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +4 -25
  30. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +9 -4
  31. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -27
  32. mindspore/_extends/parse/__init__.py +5 -3
  33. mindspore/_extends/parse/namespace.py +17 -2
  34. mindspore/_extends/parse/parser.py +193 -34
  35. mindspore/_extends/parse/resources.py +7 -8
  36. mindspore/_extends/parse/standard_method.py +1780 -435
  37. mindspore/_extends/parse/trope.py +3 -1
  38. mindspore/amp.py +53 -58
  39. mindspore/atlprov.dll +0 -0
  40. mindspore/boost/adasum.py +3 -2
  41. mindspore/boost/boost.py +2 -2
  42. mindspore/boost/boost_cell_wrapper.py +46 -26
  43. mindspore/boost/dim_reduce.py +6 -5
  44. mindspore/boost/grad_accumulation.py +2 -1
  45. mindspore/boost/group_loss_scale_manager.py +1 -1
  46. mindspore/c1.dll +0 -0
  47. mindspore/c1xx.dll +0 -0
  48. mindspore/c2.dll +0 -0
  49. mindspore/cfgpersist.dll +0 -0
  50. mindspore/clang_rt.asan_dbg_dynamic-x86_64.dll +0 -0
  51. mindspore/clang_rt.asan_dynamic-x86_64.dll +0 -0
  52. mindspore/common/__init__.py +11 -10
  53. mindspore/common/_decorator.py +2 -0
  54. mindspore/common/_register_for_adapter.py +55 -0
  55. mindspore/common/_stub_tensor.py +201 -0
  56. mindspore/common/_utils.py +57 -0
  57. mindspore/common/api.py +582 -297
  58. mindspore/common/dtype.py +66 -18
  59. mindspore/common/dump.py +2 -2
  60. mindspore/common/initializer.py +38 -1
  61. mindspore/common/jit_config.py +25 -13
  62. mindspore/common/mutable.py +53 -24
  63. mindspore/common/parameter.py +60 -37
  64. mindspore/common/seed.py +8 -24
  65. mindspore/common/sparse_tensor.py +927 -0
  66. mindspore/common/tensor.py +1627 -3900
  67. mindspore/communication/__init__.py +10 -5
  68. mindspore/communication/_comm_helper.py +78 -214
  69. mindspore/communication/_hccl_management.py +2 -1
  70. mindspore/communication/management.py +136 -47
  71. mindspore/config/op_info.config +501 -1008
  72. mindspore/context.py +291 -56
  73. mindspore/d3dcompiler_47.dll +0 -0
  74. mindspore/dataset/__init__.py +12 -8
  75. mindspore/dataset/audio/__init__.py +9 -9
  76. mindspore/dataset/audio/transforms.py +1090 -228
  77. mindspore/dataset/audio/utils.py +87 -39
  78. mindspore/dataset/audio/validators.py +223 -1
  79. mindspore/dataset/callback/ds_callback.py +17 -15
  80. mindspore/dataset/core/config.py +246 -17
  81. mindspore/dataset/core/py_util_helpers.py +4 -3
  82. mindspore/dataset/core/validator_helpers.py +10 -10
  83. mindspore/{parallel/nn/layers.py → dataset/debug/__init__.py} +7 -8
  84. mindspore/dataset/debug/debug_hook.py +65 -0
  85. mindspore/dataset/debug/pre_defined_hook.py +67 -0
  86. mindspore/dataset/engine/__init__.py +7 -3
  87. mindspore/dataset/engine/cache_client.py +9 -9
  88. mindspore/dataset/engine/datasets.py +648 -477
  89. mindspore/dataset/engine/datasets_audio.py +165 -167
  90. mindspore/dataset/engine/datasets_standard_format.py +93 -67
  91. mindspore/dataset/engine/datasets_text.py +492 -342
  92. mindspore/dataset/engine/datasets_user_defined.py +85 -50
  93. mindspore/dataset/engine/datasets_vision.py +1224 -699
  94. mindspore/dataset/engine/graphdata.py +134 -69
  95. mindspore/dataset/engine/iterators.py +50 -9
  96. mindspore/dataset/engine/offload.py +52 -31
  97. mindspore/dataset/engine/samplers.py +27 -24
  98. mindspore/dataset/engine/serializer_deserializer.py +14 -15
  99. mindspore/dataset/engine/validators.py +213 -52
  100. mindspore/dataset/text/__init__.py +10 -8
  101. mindspore/dataset/text/transforms.py +152 -57
  102. mindspore/dataset/text/utils.py +98 -49
  103. mindspore/dataset/text/validators.py +25 -0
  104. mindspore/dataset/transforms/__init__.py +4 -2
  105. mindspore/dataset/transforms/c_transforms.py +11 -13
  106. mindspore/dataset/transforms/py_transforms.py +2 -2
  107. mindspore/dataset/transforms/py_transforms_util.py +10 -0
  108. mindspore/dataset/transforms/transforms.py +13 -15
  109. mindspore/dataset/transforms/validators.py +7 -7
  110. mindspore/dataset/utils/__init__.py +2 -1
  111. mindspore/dataset/utils/browse_dataset.py +13 -13
  112. mindspore/dataset/utils/line_reader.py +121 -0
  113. mindspore/dataset/vision/__init__.py +8 -7
  114. mindspore/dataset/vision/c_transforms.py +125 -126
  115. mindspore/dataset/vision/py_transforms.py +37 -37
  116. mindspore/dataset/vision/py_transforms_util.py +23 -20
  117. mindspore/dataset/vision/transforms.py +316 -315
  118. mindspore/dataset/vision/utils.py +313 -17
  119. mindspore/dataset/vision/validators.py +6 -6
  120. mindspore/default_config.py +0 -1
  121. mindspore/dpcmi.dll +0 -0
  122. mindspore/{compression → experimental}/__init__.py +6 -5
  123. mindspore/experimental/map_parameter.py +275 -0
  124. mindspore/include/OWNERS +0 -1
  125. mindspore/include/api/callback/callback.h +9 -13
  126. mindspore/include/api/callback/ckpt_saver.h +2 -2
  127. mindspore/include/api/callback/loss_monitor.h +2 -2
  128. mindspore/include/api/callback/lr_scheduler.h +5 -5
  129. mindspore/include/api/callback/time_monitor.h +2 -2
  130. mindspore/include/api/callback/train_accuracy.h +4 -6
  131. mindspore/include/api/cfg.h +19 -6
  132. mindspore/include/api/context.h +70 -9
  133. mindspore/include/api/delegate.h +8 -1
  134. mindspore/include/api/dual_abi_helper.h +8 -24
  135. mindspore/include/api/metrics/accuracy.h +2 -2
  136. mindspore/include/api/metrics/metrics.h +4 -3
  137. mindspore/include/api/model.h +9 -4
  138. mindspore/include/api/model_group.h +68 -0
  139. mindspore/include/api/model_parallel_runner.h +17 -17
  140. mindspore/include/api/net.h +12 -11
  141. mindspore/include/api/serialization.h +20 -4
  142. mindspore/include/api/status.h +7 -1
  143. mindspore/include/api/types.h +25 -21
  144. mindspore/include/api/visible.h +4 -0
  145. mindspore/include/c_api/model_c.h +5 -0
  146. mindspore/include/c_api/status_c.h +1 -1
  147. mindspore/include/dataset/config.h +1 -1
  148. mindspore/include/dataset/constants.h +14 -0
  149. mindspore/include/dataset/text.h +59 -0
  150. mindspore/include/dataset/vision.h +56 -117
  151. mindspore/include/dataset/vision_lite.h +102 -0
  152. mindspore/jpeg62.dll +0 -0
  153. mindspore/log.py +28 -28
  154. mindspore/mindrecord/common/exceptions.py +2 -4
  155. mindspore/mindrecord/filereader.py +19 -1
  156. mindspore/mindrecord/filewriter.py +250 -88
  157. mindspore/mindrecord/mindpage.py +13 -13
  158. mindspore/mindrecord/shardheader.py +15 -15
  159. mindspore/mindrecord/shardreader.py +9 -0
  160. mindspore/mindrecord/shardwriter.py +29 -29
  161. mindspore/mindrecord/tools/cifar100_to_mr.py +9 -9
  162. mindspore/mindrecord/tools/cifar10_to_mr.py +9 -9
  163. mindspore/mindrecord/tools/csv_to_mr.py +4 -4
  164. mindspore/mindrecord/tools/imagenet_to_mr.py +70 -65
  165. mindspore/mindrecord/tools/mnist_to_mr.py +41 -41
  166. mindspore/mindrecord/tools/tfrecord_to_mr.py +6 -6
  167. mindspore/{libmindspore_backend.dll → mindspore_backend.dll} +0 -0
  168. mindspore/mindspore_common.dll +0 -0
  169. mindspore/mindspore_core.dll +0 -0
  170. mindspore/mindspore_glog.dll +0 -0
  171. mindspore/mindspore_shared_lib.dll +0 -0
  172. mindspore/msobj140.dll +0 -0
  173. mindspore/mspdb140.dll +0 -0
  174. mindspore/mspdbcore.dll +0 -0
  175. mindspore/mspdbst.dll +0 -0
  176. mindspore/mspft140.dll +0 -0
  177. mindspore/msvcdis140.dll +0 -0
  178. mindspore/msvcp140_1.dll +0 -0
  179. mindspore/msvcp140_2.dll +0 -0
  180. mindspore/msvcp140_atomic_wait.dll +0 -0
  181. mindspore/msvcp140_codecvt_ids.dll +0 -0
  182. mindspore/nn/__init__.py +1 -5
  183. mindspore/nn/cell.py +297 -234
  184. mindspore/nn/dynamic_lr.py +1 -1
  185. mindspore/nn/grad/cell_grad.py +17 -42
  186. mindspore/nn/layer/__init__.py +7 -4
  187. mindspore/nn/layer/activation.py +131 -88
  188. mindspore/nn/layer/basic.py +313 -613
  189. mindspore/nn/layer/channel_shuffle.py +103 -0
  190. mindspore/nn/layer/combined.py +1 -1
  191. mindspore/nn/layer/container.py +52 -6
  192. mindspore/nn/layer/conv.py +112 -43
  193. mindspore/nn/layer/dense.py +10 -9
  194. mindspore/nn/layer/embedding.py +36 -34
  195. mindspore/nn/layer/image.py +123 -27
  196. mindspore/nn/layer/math.py +108 -107
  197. mindspore/nn/layer/normalization.py +212 -366
  198. mindspore/nn/layer/padding.py +370 -42
  199. mindspore/nn/layer/pooling.py +1443 -219
  200. mindspore/nn/layer/rnn_cells.py +11 -16
  201. mindspore/nn/layer/rnns.py +38 -39
  202. mindspore/nn/layer/thor_layer.py +24 -25
  203. mindspore/nn/layer/timedistributed.py +5 -5
  204. mindspore/nn/layer/transformer.py +701 -0
  205. mindspore/nn/learning_rate_schedule.py +8 -8
  206. mindspore/nn/loss/__init__.py +9 -6
  207. mindspore/nn/loss/loss.py +678 -142
  208. mindspore/nn/metrics.py +53 -0
  209. mindspore/nn/optim/_dist_optimizer_registry.py +2 -2
  210. mindspore/nn/optim/ada_grad.py +8 -8
  211. mindspore/nn/optim/adadelta.py +2 -3
  212. mindspore/nn/optim/adafactor.py +18 -14
  213. mindspore/nn/optim/adam.py +429 -87
  214. mindspore/nn/optim/adamax.py +5 -6
  215. mindspore/nn/optim/adasum.py +10 -8
  216. mindspore/nn/optim/asgd.py +7 -7
  217. mindspore/nn/optim/ftrl.py +81 -11
  218. mindspore/nn/optim/lamb.py +7 -8
  219. mindspore/nn/optim/lars.py +4 -4
  220. mindspore/nn/optim/lazyadam.py +82 -7
  221. mindspore/nn/optim/momentum.py +8 -7
  222. mindspore/nn/optim/optimizer.py +19 -10
  223. mindspore/nn/optim/proximal_ada_grad.py +6 -5
  224. mindspore/nn/optim/rmsprop.py +3 -3
  225. mindspore/nn/optim/rprop.py +20 -16
  226. mindspore/nn/optim/sgd.py +21 -15
  227. mindspore/nn/optim/thor.py +23 -21
  228. mindspore/nn/probability/__init__.py +0 -2
  229. mindspore/nn/probability/bijector/bijector.py +7 -6
  230. mindspore/nn/probability/bijector/invert.py +4 -2
  231. mindspore/nn/probability/bijector/softplus.py +2 -2
  232. mindspore/nn/probability/bnn_layers/dense_variational.py +1 -1
  233. mindspore/nn/probability/bnn_layers/layer_distribution.py +2 -2
  234. mindspore/nn/probability/distribution/__init__.py +6 -0
  235. mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -2
  236. mindspore/nn/probability/distribution/_utils/utils.py +11 -17
  237. mindspore/nn/probability/distribution/bernoulli.py +6 -6
  238. mindspore/nn/probability/distribution/beta.py +1 -1
  239. mindspore/nn/probability/distribution/categorical.py +9 -9
  240. mindspore/nn/probability/distribution/cauchy.py +8 -8
  241. mindspore/nn/probability/distribution/distribution.py +12 -6
  242. mindspore/nn/probability/distribution/exponential.py +5 -5
  243. mindspore/nn/probability/distribution/gamma.py +3 -3
  244. mindspore/nn/probability/distribution/geometric.py +6 -5
  245. mindspore/nn/probability/distribution/gumbel.py +5 -5
  246. mindspore/nn/probability/distribution/half_normal.py +133 -0
  247. mindspore/nn/probability/distribution/laplace.py +128 -0
  248. mindspore/nn/probability/distribution/log_normal.py +0 -1
  249. mindspore/nn/probability/distribution/logistic.py +4 -5
  250. mindspore/nn/probability/distribution/normal.py +11 -15
  251. mindspore/nn/probability/distribution/poisson.py +6 -2
  252. mindspore/nn/probability/distribution/student_t.py +150 -0
  253. mindspore/nn/probability/distribution/transformed_distribution.py +4 -4
  254. mindspore/nn/probability/distribution/uniform.py +5 -5
  255. mindspore/nn/reinforcement/_tensors_queue.py +3 -3
  256. mindspore/nn/reinforcement/tensor_array.py +2 -2
  257. mindspore/nn/sparse/sparse.py +8 -1
  258. mindspore/nn/wrap/cell_wrapper.py +55 -27
  259. mindspore/nn/wrap/grad_reducer.py +20 -11
  260. mindspore/nn/wrap/loss_scale.py +47 -30
  261. mindspore/numpy/array_creations.py +33 -22
  262. mindspore/numpy/array_ops.py +46 -42
  263. mindspore/numpy/logic_ops.py +6 -27
  264. mindspore/numpy/math_ops.py +26 -19
  265. mindspore/numpy/utils.py +1 -8
  266. mindspore/numpy/utils_const.py +112 -62
  267. mindspore/opencv_core452.dll +0 -0
  268. mindspore/opencv_imgcodecs452.dll +0 -0
  269. mindspore/opencv_imgproc452.dll +0 -0
  270. mindspore/ops/__init__.py +6 -3
  271. mindspore/ops/_constants.py +0 -6
  272. mindspore/ops/_grad/__init__.py +2 -1
  273. mindspore/ops/_grad/grad_array_ops.py +209 -152
  274. mindspore/ops/_grad/grad_base.py +55 -17
  275. mindspore/ops/_grad/grad_clip_ops.py +11 -3
  276. mindspore/ops/_grad/grad_comm_ops.py +58 -47
  277. mindspore/ops/_grad/grad_implementations.py +21 -61
  278. mindspore/ops/_grad/grad_inner_ops.py +48 -6
  279. mindspore/ops/_grad/grad_math_ops.py +306 -161
  280. mindspore/ops/_grad/grad_nn_ops.py +192 -181
  281. mindspore/ops/_grad/grad_other_ops.py +1 -1
  282. mindspore/ops/_grad/grad_quant_ops.py +5 -5
  283. mindspore/ops/_grad/grad_sequence_ops.py +296 -0
  284. mindspore/ops/_grad/grad_sparse.py +15 -9
  285. mindspore/ops/_grad_experimental/__init__.py +1 -0
  286. mindspore/ops/_grad_experimental/grad_array_ops.py +441 -55
  287. mindspore/ops/_grad_experimental/grad_image_ops.py +25 -7
  288. mindspore/ops/_grad_experimental/grad_inner_ops.py +3 -44
  289. mindspore/ops/_grad_experimental/grad_linalg_ops.py +16 -21
  290. mindspore/ops/_grad_experimental/grad_math_ops.py +979 -49
  291. mindspore/ops/_grad_experimental/grad_nn_ops.py +78 -8
  292. mindspore/ops/_grad_experimental/grad_scalar_ops.py +112 -0
  293. mindspore/ops/_grad_experimental/grad_sparse_ops.py +197 -13
  294. mindspore/ops/_op_impl/__init__.py +3 -3
  295. mindspore/ops/_op_impl/_custom_op/__init__.py +0 -1
  296. mindspore/ops/_op_impl/_custom_op/_basic.py +0 -1
  297. mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +1 -1
  298. mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +4 -2
  299. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +2 -2
  300. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +2 -2
  301. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +5 -5
  302. mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +3 -3
  303. mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +1 -1
  304. mindspore/ops/_op_impl/_custom_op/correction_mul.py +3 -3
  305. mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +2 -2
  306. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +4 -8
  307. mindspore/ops/_op_impl/_custom_op/dsd_impl.py +1 -1
  308. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +2 -2
  309. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +2 -2
  310. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +2 -2
  311. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +2 -2
  312. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +2 -2
  313. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +2 -2
  314. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +2 -2
  315. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +2 -2
  316. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +2 -2
  317. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +2 -2
  318. mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +1 -1
  319. mindspore/ops/_op_impl/_custom_op/img2col_impl.py +1 -1
  320. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
  321. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +1 -1
  322. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +1 -1
  323. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +1 -1
  324. mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +2 -2
  325. mindspore/ops/_op_impl/_custom_op/matmul_dds_grad_impl.py +0 -1
  326. mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +0 -1
  327. mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +1 -1
  328. mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +2 -2
  329. mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +2 -2
  330. mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +1 -1
  331. mindspore/ops/_op_impl/aicpu/__init__.py +238 -3
  332. mindspore/ops/_op_impl/aicpu/abs.py +36 -0
  333. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d.py +34 -0
  334. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
  335. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d.py +39 -0
  336. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d_grad.py +39 -0
  337. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d_grad.py +37 -0
  338. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d.py +42 -0
  339. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d_grad.py +152 -0
  340. mindspore/ops/_op_impl/aicpu/add.py +43 -0
  341. mindspore/ops/_op_impl/aicpu/addcdiv.py +0 -32
  342. mindspore/ops/_op_impl/aicpu/addcmul.py +0 -84
  343. mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
  344. mindspore/ops/_op_impl/aicpu/arg_max.py +75 -0
  345. mindspore/ops/_op_impl/aicpu/arg_min.py +75 -0
  346. mindspore/ops/_op_impl/aicpu/argmin_with_value.py +43 -0
  347. mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -0
  348. mindspore/ops/_op_impl/aicpu/batch_norm_grad_grad.py +49 -0
  349. mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
  350. mindspore/ops/_op_impl/aicpu/bessel_i0.py +31 -0
  351. mindspore/ops/_op_impl/aicpu/bias_add.py +44 -0
  352. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +43 -0
  353. mindspore/ops/_op_impl/aicpu/bincount.py +33 -0
  354. mindspore/{nn/probability/infer/variational/__init__.py → ops/_op_impl/aicpu/cauchy.py} +17 -10
  355. mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
  356. mindspore/ops/_op_impl/aicpu/cholesky.py +1 -1
  357. mindspore/ops/_op_impl/{cpu/bias_add.py → aicpu/choleskygrad.py} +9 -7
  358. mindspore/ops/_op_impl/aicpu/combined_non_max_suppression.py +42 -0
  359. mindspore/ops/_op_impl/aicpu/concat_offset.py +42 -0
  360. mindspore/ops/_op_impl/aicpu/concat_offset_v1.py +31 -0
  361. mindspore/ops/_op_impl/aicpu/conj.py +11 -0
  362. mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_image.py +38 -0
  363. mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +36 -0
  364. mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
  365. mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +2 -2
  366. mindspore/ops/_op_impl/aicpu/dense_to_sparse_set_operation.py +48 -0
  367. mindspore/ops/_op_impl/aicpu/diag.py +36 -0
  368. mindspore/ops/_op_impl/aicpu/diag_part.py +36 -0
  369. mindspore/ops/_op_impl/aicpu/diagonal.py +35 -0
  370. mindspore/ops/_op_impl/{cpu/bias_add_grad.py → aicpu/digamma.py} +9 -7
  371. mindspore/ops/_op_impl/aicpu/eig.py +35 -0
  372. mindspore/ops/_op_impl/aicpu/fft_with_size.py +41 -0
  373. mindspore/ops/_op_impl/aicpu/flatten.py +1 -0
  374. mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
  375. mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
  376. mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +1 -1
  377. mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
  378. mindspore/ops/_op_impl/aicpu/glu.py +33 -0
  379. mindspore/ops/_op_impl/aicpu/glu_grad.py +34 -0
  380. mindspore/ops/_op_impl/aicpu/greater.py +41 -0
  381. mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
  382. mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
  383. mindspore/ops/_op_impl/{tbe/scatter_add_ds.py → aicpu/inplace_index_add.py} +17 -21
  384. mindspore/ops/_op_impl/aicpu/instance_norm_v2.py +41 -0
  385. mindspore/ops/_op_impl/aicpu/instance_norm_v2_grad.py +44 -0
  386. mindspore/ops/_op_impl/aicpu/layer_norm_grad_grad.py +47 -0
  387. mindspore/ops/_op_impl/aicpu/less.py +41 -0
  388. mindspore/ops/_op_impl/aicpu/less_equal.py +41 -0
  389. mindspore/ops/_op_impl/aicpu/lgamma.py +32 -0
  390. mindspore/ops/_op_impl/aicpu/log_normal_reverse.py +33 -0
  391. mindspore/ops/_op_impl/aicpu/logit.py +33 -0
  392. mindspore/ops/_op_impl/aicpu/logit_grad.py +34 -0
  393. mindspore/ops/_op_impl/aicpu/masked_fill.py +42 -0
  394. mindspore/ops/_op_impl/aicpu/masked_scatter.py +39 -0
  395. mindspore/ops/_op_impl/aicpu/matmul.py +39 -0
  396. mindspore/ops/_op_impl/aicpu/matrix_logarithm.py +31 -0
  397. mindspore/ops/_op_impl/aicpu/matrix_power.py +32 -0
  398. mindspore/ops/_op_impl/aicpu/matrix_solve_ls.py +36 -0
  399. mindspore/ops/_op_impl/aicpu/matrix_triangular_solve.py +36 -0
  400. mindspore/ops/_op_impl/aicpu/mirror_pad.py +2 -0
  401. mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +0 -4
  402. mindspore/ops/_op_impl/aicpu/mul.py +3 -1
  403. mindspore/ops/_op_impl/aicpu/multinomial.py +14 -6
  404. mindspore/ops/_op_impl/aicpu/multinomial_with_replacement.py +35 -0
  405. mindspore/ops/_op_impl/aicpu/nan_to_num.py +34 -0
  406. mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
  407. mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
  408. mindspore/ops/_op_impl/aicpu/ones_like.py +0 -2
  409. mindspore/ops/_op_impl/aicpu/polar.py +32 -0
  410. mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
  411. mindspore/ops/_op_impl/aicpu/qr.py +36 -0
  412. mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
  413. mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
  414. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
  415. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_tensor.py +74 -0
  416. mindspore/ops/_op_impl/aicpu/random_shuffle.py +3 -0
  417. mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
  418. mindspore/ops/_op_impl/aicpu/range.py +36 -0
  419. mindspore/ops/_op_impl/aicpu/reciprocal.py +34 -0
  420. mindspore/ops/_op_impl/aicpu/reciprocal_grad.py +35 -0
  421. mindspore/ops/_op_impl/aicpu/reduce_sum.py +57 -0
  422. mindspore/ops/_op_impl/aicpu/resize_bicubic.py +2 -8
  423. mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +1 -1
  424. mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
  425. mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
  426. mindspore/ops/_op_impl/aicpu/scatter_elements.py +4 -0
  427. mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +2 -0
  428. mindspore/ops/_op_impl/aicpu/search_sorted.py +12 -6
  429. mindspore/ops/_op_impl/aicpu/self_adjoint_eig.py +34 -0
  430. mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
  431. mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
  432. mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
  433. mindspore/ops/_op_impl/aicpu/slice_grad.py +76 -0
  434. mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
  435. mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
  436. mindspore/ops/_op_impl/aicpu/sort.py +39 -0
  437. mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +0 -24
  438. mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
  439. mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows.py +63 -0
  440. mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows_grad.py +45 -0
  441. mindspore/ops/_op_impl/aicpu/sparse_matrix_mat_mul.py +56 -0
  442. mindspore/ops/_op_impl/{tbe/slice_ds.py → aicpu/sparse_segment_sum.py} +16 -24
  443. mindspore/ops/_op_impl/aicpu/sparse_segment_sum_with_num_segments.py +68 -0
  444. mindspore/ops/_op_impl/aicpu/sparse_slice.py +63 -0
  445. mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +61 -0
  446. mindspore/ops/_op_impl/aicpu/squared_difference.py +2 -0
  447. mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +93 -0
  448. mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +66 -0
  449. mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
  450. mindspore/ops/_op_impl/{tbe/gather_v2.py → aicpu/tile.py} +24 -24
  451. mindspore/ops/_op_impl/aicpu/tridiagonal_solve.py +35 -0
  452. mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
  453. mindspore/ops/_op_impl/aicpu/triu_indices.py +34 -0
  454. mindspore/ops/_op_impl/aicpu/uniform.py +34 -0
  455. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +1 -0
  456. mindspore/ops/_op_impl/aicpu/unique_consecutive.py +10 -2
  457. mindspore/ops/_op_impl/cpu/__init__.py +1 -2
  458. mindspore/ops/_op_impl/cpu/dynamic_shape.py +5 -1
  459. mindspore/ops/_op_impl/cpu/maximum_grad.py +2 -0
  460. mindspore/{compression/common/__init__.py → ops/_op_impl/cpu/pyexecute.py} +13 -8
  461. mindspore/ops/_op_impl/cpu/reduce_sum.py +8 -0
  462. mindspore/ops/_op_impl/cpu/sparse_slice.py +62 -0
  463. mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +60 -0
  464. mindspore/ops/_op_impl/cpu/tensor_shape.py +5 -1
  465. mindspore/ops/_op_impl/tbe/__init__.py +27 -608
  466. mindspore/ops/_op_impl/tbe/addcdiv_ds.py +42 -0
  467. mindspore/ops/_op_impl/tbe/addcmul_ds.py +44 -0
  468. mindspore/ops/_op_impl/tbe/assign_add_ds.py +1 -0
  469. mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
  470. mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +1 -1
  471. mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad_v2.py +0 -1
  472. mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
  473. mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +1 -1
  474. mindspore/ops/_op_impl/tbe/batch_to_space_nd_v2.py +41 -0
  475. mindspore/ops/_op_impl/tbe/bce_with_logits_loss.py +1 -0
  476. mindspore/ops/_op_impl/tbe/bias_add_grad.py +2 -0
  477. mindspore/ops/_op_impl/tbe/bn_infer_grad.py +4 -2
  478. mindspore/ops/_op_impl/tbe/bn_infer_grad_ds.py +40 -0
  479. mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -1
  480. mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -1
  481. mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +6 -4
  482. mindspore/ops/_op_impl/tbe/cast.py +0 -2
  483. mindspore/ops/_op_impl/tbe/cast_ds.py +3 -3
  484. mindspore/ops/_op_impl/tbe/ctc_loss_v2.py +0 -2
  485. mindspore/ops/_op_impl/tbe/ctc_loss_v2_grad.py +0 -2
  486. mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +1 -0
  487. mindspore/ops/_op_impl/tbe/deformable_offsets.py +1 -0
  488. mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +1 -1
  489. mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +1 -1
  490. mindspore/ops/_op_impl/tbe/gather_nd.py +1 -0
  491. mindspore/ops/_op_impl/tbe/greater.py +2 -0
  492. mindspore/ops/_op_impl/tbe/{index_add.py → inplace_index_add.py} +3 -6
  493. mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2.py +0 -1
  494. mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +35 -0
  495. mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +35 -0
  496. mindspore/ops/_op_impl/tbe/one_hot_ds.py +0 -6
  497. mindspore/ops/_op_impl/tbe/{greater_ds.py → reduce_all_ds.py} +13 -16
  498. mindspore/ops/_op_impl/tbe/reduce_any_ds.py +39 -0
  499. mindspore/ops/_op_impl/tbe/roi_align_ds.py +44 -0
  500. mindspore/ops/_op_impl/tbe/roi_align_grad_ds.py +44 -0
  501. mindspore/ops/_op_impl/tbe/scatter_add.py +2 -0
  502. mindspore/ops/_op_impl/tbe/scatter_nd_add.py +2 -2
  503. mindspore/ops/_op_impl/tbe/slice.py +26 -15
  504. mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
  505. mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +1 -1
  506. mindspore/ops/_op_impl/tbe/strided_slice_grad_d.py +1 -0
  507. mindspore/ops/_op_impl/tbe/trans_data_ds.py +15 -5
  508. mindspore/ops/_op_impl/tbe/unsorted_segment_sum.py +1 -1
  509. mindspore/ops/_op_impl/tbe/unsorted_segment_sum_ds.py +2 -0
  510. mindspore/ops/_primitive_cache.py +3 -2
  511. mindspore/ops/_register_for_op.py +11 -0
  512. mindspore/ops/_utils/__init__.py +1 -1
  513. mindspore/ops/_utils/utils.py +20 -41
  514. mindspore/ops/_vmap/__init__.py +2 -2
  515. mindspore/ops/_vmap/vmap_array_ops.py +170 -78
  516. mindspore/ops/_vmap/vmap_base.py +24 -10
  517. mindspore/ops/_vmap/vmap_convolution_ops.py +7 -10
  518. mindspore/ops/_vmap/vmap_grad_math_ops.py +4 -4
  519. mindspore/ops/_vmap/vmap_grad_nn_ops.py +41 -9
  520. mindspore/ops/_vmap/vmap_image_ops.py +52 -0
  521. mindspore/ops/_vmap/vmap_math_ops.py +77 -6
  522. mindspore/ops/_vmap/vmap_nn_ops.py +78 -29
  523. mindspore/ops/_vmap/vmap_other_ops.py +3 -1
  524. mindspore/ops/_vmap/vmap_random_ops.py +55 -3
  525. mindspore/ops/_vmap/vmap_sparse_ops.py +1 -0
  526. mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
  527. mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
  528. mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +18 -19
  529. mindspore/ops/bprop_mindir/Argmax_bprop.mindir +13 -12
  530. mindspore/ops/bprop_mindir/Argmin_bprop.mindir +14 -13
  531. mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +17 -18
  532. mindspore/ops/bprop_mindir/Assign_bprop.mindir +16 -16
  533. mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +150 -0
  534. mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +66 -0
  535. mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
  536. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +13 -12
  537. mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
  538. mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +28 -0
  539. mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
  540. mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +33 -0
  541. mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +306 -0
  542. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +12 -8
  543. mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
  544. mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
  545. mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +240 -0
  546. mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +247 -0
  547. mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +247 -0
  548. mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +315 -0
  549. mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +278 -0
  550. mindspore/ops/bprop_mindir/DType_bprop.mindir +12 -12
  551. mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +58 -0
  552. mindspore/ops/bprop_mindir/Depend_bprop.mindir +12 -13
  553. mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +23 -0
  554. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +138 -0
  555. mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +15 -0
  556. mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
  557. mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
  558. mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +22 -24
  559. mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +16 -14
  560. mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +27 -0
  561. mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
  562. mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
  563. mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
  564. mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +12 -12
  565. mindspore/ops/bprop_mindir/Elu_bprop.mindir +16 -0
  566. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  567. mindspore/ops/bprop_mindir/Equal_bprop.mindir +18 -19
  568. mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +58 -0
  569. mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +16 -0
  570. mindspore/ops/bprop_mindir/Flatten_bprop.mindir +54 -0
  571. mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +18 -15
  572. mindspore/ops/bprop_mindir/GatherD_bprop.mindir +26 -0
  573. mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +57 -0
  574. mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
  575. mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +17 -18
  576. mindspore/ops/bprop_mindir/Greater_bprop.mindir +18 -19
  577. mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +16 -0
  578. mindspore/ops/bprop_mindir/HSwish_bprop.mindir +16 -0
  579. mindspore/ops/bprop_mindir/IOU_bprop.mindir +18 -19
  580. mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
  581. mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +13 -12
  582. mindspore/ops/bprop_mindir/IsInf_bprop.mindir +13 -10
  583. mindspore/ops/bprop_mindir/IsNan_bprop.mindir +14 -11
  584. mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +126 -0
  585. mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +15 -0
  586. mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +30 -0
  587. mindspore/ops/bprop_mindir/LRN_bprop.mindir +43 -0
  588. mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
  589. mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +18 -19
  590. mindspore/ops/bprop_mindir/Less_bprop.mindir +17 -18
  591. mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +22 -19
  592. mindspore/ops/bprop_mindir/Load_bprop.mindir +12 -13
  593. mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +23 -0
  594. mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +17 -18
  595. mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +14 -13
  596. mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +21 -0
  597. mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +74 -0
  598. mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +74 -0
  599. mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +75 -0
  600. mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +65 -0
  601. mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
  602. mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
  603. mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
  604. mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +27 -0
  605. mindspore/ops/bprop_mindir/Mish_bprop.mindir +35 -0
  606. mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
  607. mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
  608. mindspore/ops/bprop_mindir/NonZero_bprop.mindir +14 -0
  609. mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +18 -19
  610. mindspore/ops/bprop_mindir/OneHot_bprop.mindir +25 -23
  611. mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +13 -13
  612. mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
  613. mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
  614. mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
  615. mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +29 -0
  616. mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +82 -0
  617. mindspore/ops/bprop_mindir/Range_bprop.mindir +21 -19
  618. mindspore/ops/bprop_mindir/Rank_bprop.mindir +11 -11
  619. mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +16 -0
  620. mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
  621. mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +18 -17
  622. mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +18 -17
  623. mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +19 -23
  624. mindspore/ops/bprop_mindir/Reshape_bprop.mindir +60 -0
  625. mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +29 -0
  626. mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +89 -0
  627. mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +52 -0
  628. mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +22 -0
  629. mindspore/ops/bprop_mindir/Round_bprop.mindir +14 -13
  630. mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
  631. mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
  632. mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +22 -0
  633. mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +24 -0
  634. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +22 -0
  635. mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
  636. mindspore/ops/bprop_mindir/SeLU_bprop.mindir +21 -0
  637. mindspore/ops/bprop_mindir/Select_bprop.mindir +30 -34
  638. mindspore/ops/bprop_mindir/Shape_bprop.mindir +12 -12
  639. mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +21 -0
  640. mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
  641. mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +16 -0
  642. mindspore/ops/bprop_mindir/Sign_bprop.mindir +13 -12
  643. mindspore/ops/bprop_mindir/Slice_bprop.mindir +26 -0
  644. mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +36 -0
  645. mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  646. mindspore/ops/bprop_mindir/Softplus_bprop.mindir +16 -0
  647. mindspore/ops/bprop_mindir/Softsign_bprop.mindir +33 -0
  648. mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
  649. mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +28 -0
  650. mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +23 -0
  651. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  652. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  653. mindspore/ops/bprop_mindir/Split_bprop.mindir +22 -0
  654. mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +54 -0
  655. mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +95 -0
  656. mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +98 -0
  657. mindspore/ops/bprop_mindir/Switch_bprop.mindir +28 -32
  658. mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
  659. mindspore/ops/bprop_mindir/Tanh_bprop.mindir +66 -0
  660. mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +22 -0
  661. mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +29 -0
  662. mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +14 -0
  663. mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
  664. mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
  665. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +23 -0
  666. mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +18 -15
  667. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +11 -13
  668. mindspore/ops/bprop_mindir/Unique_bprop.mindir +16 -0
  669. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +22 -0
  670. mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +32 -0
  671. mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +38 -0
  672. mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +13 -12
  673. mindspore/ops/bprop_mindir/__init__.py +1 -4
  674. mindspore/ops/bprop_mindir/generate_mindir.py +32 -20
  675. mindspore/ops/composite/__init__.py +12 -13
  676. mindspore/ops/composite/base.py +261 -254
  677. mindspore/ops/composite/env_ops.py +41 -0
  678. mindspore/ops/composite/math_ops.py +197 -156
  679. mindspore/ops/composite/multitype_ops/_compile_utils.py +428 -176
  680. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +188 -87
  681. mindspore/ops/composite/multitype_ops/add_impl.py +23 -1
  682. mindspore/ops/composite/multitype_ops/div_impl.py +3 -3
  683. mindspore/ops/composite/multitype_ops/equal_impl.py +1 -0
  684. mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -1
  685. mindspore/ops/composite/multitype_ops/getitem_impl.py +52 -5
  686. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +31 -0
  687. mindspore/ops/composite/multitype_ops/greater_impl.py +31 -0
  688. mindspore/ops/composite/multitype_ops/in_impl.py +15 -3
  689. mindspore/ops/composite/multitype_ops/less_equal_impl.py +33 -2
  690. mindspore/ops/composite/multitype_ops/less_impl.py +33 -0
  691. mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -2
  692. mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
  693. mindspore/ops/composite/multitype_ops/mod_impl.py +1 -1
  694. mindspore/ops/composite/multitype_ops/mul_impl.py +21 -7
  695. mindspore/ops/composite/multitype_ops/not_in_impl.py +15 -3
  696. mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -4
  697. mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
  698. mindspore/ops/composite/multitype_ops/setitem_impl.py +62 -70
  699. mindspore/ops/composite/multitype_ops/sub_impl.py +3 -3
  700. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +41 -4
  701. mindspore/ops/function/__init__.py +323 -8
  702. mindspore/ops/function/array_func.py +3511 -780
  703. mindspore/ops/function/clip_func.py +329 -0
  704. mindspore/ops/function/debug_func.py +6 -6
  705. mindspore/ops/function/grad/__init__.py +5 -1
  706. mindspore/ops/function/grad/grad_func.py +736 -65
  707. mindspore/ops/function/image_func.py +270 -0
  708. mindspore/ops/function/linalg_func.py +268 -8
  709. mindspore/ops/function/math_func.py +8032 -3164
  710. mindspore/ops/function/nn_func.py +5619 -1855
  711. mindspore/ops/function/other_func.py +115 -0
  712. mindspore/ops/function/parameter_func.py +11 -10
  713. mindspore/ops/function/random_func.py +939 -77
  714. mindspore/ops/function/sparse_func.py +249 -84
  715. mindspore/ops/function/sparse_unary_func.py +2303 -0
  716. mindspore/ops/function/spectral_func.py +146 -0
  717. mindspore/ops/function/vmap_func.py +114 -0
  718. mindspore/ops/functional.py +182 -254
  719. mindspore/ops/op_info_register.py +79 -34
  720. mindspore/ops/operations/__init__.py +210 -118
  721. mindspore/ops/operations/_csr_ops.py +7 -7
  722. mindspore/ops/operations/_embedding_cache_ops.py +25 -15
  723. mindspore/ops/operations/_grad_ops.py +447 -322
  724. mindspore/ops/operations/_inner_ops.py +547 -176
  725. mindspore/ops/operations/_map_tensor_ops.py +112 -0
  726. mindspore/ops/operations/_ms_kernel.py +29 -27
  727. mindspore/ops/operations/_ocr_ops.py +11 -11
  728. mindspore/ops/operations/_opaque_predicate_registry.py +41 -0
  729. mindspore/ops/operations/_quant_ops.py +186 -101
  730. mindspore/ops/operations/_rl_inner_ops.py +122 -61
  731. mindspore/ops/operations/_scalar_ops.py +466 -0
  732. mindspore/ops/operations/_sequence_ops.py +1047 -0
  733. mindspore/ops/operations/_tensor_array.py +10 -11
  734. mindspore/ops/operations/_thor_ops.py +4 -4
  735. mindspore/ops/operations/array_ops.py +1428 -1226
  736. mindspore/ops/operations/comm_ops.py +180 -117
  737. mindspore/ops/operations/control_ops.py +4 -2
  738. mindspore/ops/operations/custom_ops.py +185 -98
  739. mindspore/ops/operations/debug_ops.py +92 -54
  740. mindspore/ops/operations/image_ops.py +406 -211
  741. mindspore/ops/operations/inner_ops.py +42 -53
  742. mindspore/ops/operations/linalg_ops.py +32 -29
  743. mindspore/ops/operations/math_ops.py +2076 -897
  744. mindspore/ops/operations/nn_ops.py +1282 -1252
  745. mindspore/ops/operations/other_ops.py +124 -278
  746. mindspore/ops/operations/random_ops.py +345 -178
  747. mindspore/ops/operations/rl_ops.py +8 -9
  748. mindspore/ops/operations/sparse_ops.py +502 -157
  749. mindspore/ops/operations/spectral_ops.py +107 -0
  750. mindspore/ops/primitive.py +192 -15
  751. mindspore/ops/vm_impl_registry.py +23 -2
  752. mindspore/parallel/__init__.py +6 -1
  753. mindspore/parallel/_auto_parallel_context.py +199 -92
  754. mindspore/parallel/_cell_wrapper.py +4 -2
  755. mindspore/parallel/_cost_model_context.py +3 -0
  756. mindspore/parallel/_dp_allreduce_fusion.py +2 -1
  757. mindspore/parallel/_offload_context.py +185 -0
  758. mindspore/parallel/_parallel_serialization.py +167 -28
  759. mindspore/parallel/_ps_context.py +9 -5
  760. mindspore/parallel/_recovery_context.py +1 -1
  761. mindspore/parallel/_tensor.py +9 -1
  762. mindspore/{nn/transformer → parallel/_transformer}/__init__.py +6 -6
  763. mindspore/{nn/transformer → parallel/_transformer}/layers.py +59 -37
  764. mindspore/{nn/transformer → parallel/_transformer}/loss.py +4 -7
  765. mindspore/{nn/transformer → parallel/_transformer}/moe.py +160 -35
  766. mindspore/{nn/transformer → parallel/_transformer}/op_parallel_config.py +3 -3
  767. mindspore/{nn/transformer → parallel/_transformer}/transformer.py +235 -196
  768. mindspore/parallel/_utils.py +47 -7
  769. mindspore/parallel/algo_parameter_config.py +5 -1
  770. mindspore/parallel/checkpoint_transform.py +329 -0
  771. mindspore/parallel/shard.py +229 -0
  772. mindspore/perf_msvcbuildinsights.dll +0 -0
  773. mindspore/pgodb140.dll +0 -0
  774. mindspore/pgort140.dll +0 -0
  775. mindspore/profiler/__init__.py +2 -1
  776. mindspore/profiler/common/util.py +4 -3
  777. mindspore/profiler/common/validator/validate_path.py +2 -2
  778. mindspore/profiler/envprofiling.py +249 -0
  779. mindspore/profiler/parser/aicpu_data_parser.py +38 -39
  780. mindspore/profiler/parser/ascend_timeline_generator.py +497 -0
  781. mindspore/profiler/parser/base_timeline_generator.py +471 -0
  782. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +684 -0
  783. mindspore/profiler/parser/framework_parser.py +42 -16
  784. mindspore/profiler/parser/hccl_parser.py +158 -158
  785. mindspore/profiler/parser/hwts_log_parser.py +7 -6
  786. mindspore/profiler/parser/integrator.py +18 -1579
  787. mindspore/profiler/parser/minddata_analyzer.py +8 -8
  788. mindspore/profiler/parser/msadvisor_analyzer.py +14 -27
  789. mindspore/profiler/parser/msadvisor_parser.py +2 -4
  790. mindspore/profiler/parser/optime_parser.py +17 -18
  791. mindspore/profiler/parser/profiler_info.py +108 -0
  792. mindspore/profiler/parser/step_trace_parser.py +1 -1
  793. mindspore/profiler/profiling.py +396 -194
  794. mindspore/rewrite/__init__.py +6 -2
  795. mindspore/rewrite/api/node.py +51 -110
  796. mindspore/rewrite/api/node_type.py +10 -6
  797. mindspore/rewrite/api/pattern_engine.py +51 -7
  798. mindspore/rewrite/api/scoped_value.py +64 -53
  799. mindspore/rewrite/api/symbol_tree.py +108 -61
  800. mindspore/rewrite/api/tree_node_helper.py +2 -3
  801. mindspore/{compression/quant/__init__.py → rewrite/ast_creator_register.py} +20 -11
  802. mindspore/rewrite/ast_helpers/__init__.py +6 -3
  803. mindspore/rewrite/ast_helpers/ast_creator.py +115 -0
  804. mindspore/rewrite/ast_helpers/ast_finder.py +99 -1
  805. mindspore/rewrite/ast_helpers/ast_modifier.py +17 -4
  806. mindspore/rewrite/ast_helpers/ast_replacer.py +1 -1
  807. mindspore/rewrite/ast_transformers/__init__.py +0 -1
  808. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +46 -5
  809. mindspore/rewrite/ast_transformers/remove_return_out_of_if.py +6 -3
  810. mindspore/rewrite/common/__init__.py +2 -0
  811. mindspore/rewrite/common/event.py +1 -1
  812. mindspore/rewrite/common/observable.py +1 -1
  813. mindspore/rewrite/common/observer.py +1 -1
  814. mindspore/rewrite/common/rewrite_elog.py +35 -0
  815. mindspore/rewrite/namer.py +2 -2
  816. mindspore/rewrite/namespace.py +14 -4
  817. mindspore/rewrite/node.py +161 -13
  818. mindspore/rewrite/parser.py +0 -1
  819. mindspore/rewrite/parser_register.py +0 -1
  820. mindspore/rewrite/parsers/arguments_parser.py +3 -2
  821. mindspore/rewrite/parsers/assign_parser.py +267 -67
  822. mindspore/rewrite/parsers/attribute_parser.py +56 -0
  823. mindspore/rewrite/parsers/class_def_parser.py +191 -108
  824. mindspore/rewrite/parsers/constant_parser.py +101 -0
  825. mindspore/rewrite/parsers/container_parser.py +88 -0
  826. mindspore/rewrite/parsers/for_parser.py +28 -15
  827. mindspore/rewrite/parsers/function_def_parser.py +21 -5
  828. mindspore/rewrite/parsers/if_parser.py +11 -28
  829. mindspore/rewrite/parsers/module_parser.py +9 -6
  830. mindspore/rewrite/parsers/return_parser.py +3 -2
  831. mindspore/rewrite/sparsify/__init__.py +0 -0
  832. mindspore/rewrite/sparsify/sparse_transformer.py +448 -0
  833. mindspore/rewrite/sparsify/sparsify.py +109 -0
  834. mindspore/rewrite/sparsify/utils.py +173 -0
  835. mindspore/rewrite/symbol_tree.py +322 -109
  836. mindspore/rewrite/symbol_tree_builder.py +45 -8
  837. mindspore/rewrite/symbol_tree_dumper.py +0 -1
  838. mindspore/rewrite/topological_manager.py +1 -2
  839. mindspore/run_check/_check_version.py +209 -112
  840. mindspore/run_check/run_check.py +2 -1
  841. mindspore/tbbmalloc.dll +0 -0
  842. mindspore/tinyxml2.dll +0 -0
  843. mindspore/train/__init__.py +6 -4
  844. mindspore/train/_utils.py +28 -5
  845. mindspore/train/amp.py +321 -50
  846. mindspore/train/callback/__init__.py +3 -1
  847. mindspore/train/callback/_backup_and_restore.py +120 -0
  848. mindspore/train/callback/_callback.py +8 -8
  849. mindspore/train/callback/_checkpoint.py +12 -9
  850. mindspore/train/callback/_early_stop.py +13 -7
  851. mindspore/train/callback/_history.py +8 -8
  852. mindspore/train/callback/_lambda_callback.py +6 -6
  853. mindspore/train/callback/_landscape.py +36 -38
  854. mindspore/train/callback/_loss_monitor.py +12 -6
  855. mindspore/train/callback/_lr_scheduler_callback.py +2 -4
  856. mindspore/train/callback/_on_request_exit.py +212 -0
  857. mindspore/train/callback/_reduce_lr_on_plateau.py +13 -7
  858. mindspore/train/callback/_summary_collector.py +27 -19
  859. mindspore/train/callback/_time_monitor.py +13 -7
  860. mindspore/train/checkpoint_pb2.py +68 -8
  861. mindspore/train/data_sink.py +122 -33
  862. mindspore/train/dataset_helper.py +28 -87
  863. mindspore/train/loss_scale_manager.py +4 -7
  864. mindspore/{nn → train}/metrics/__init__.py +20 -20
  865. mindspore/{nn → train}/metrics/accuracy.py +12 -10
  866. mindspore/{nn → train}/metrics/auc.py +4 -4
  867. mindspore/{nn → train}/metrics/bleu_score.py +4 -4
  868. mindspore/{nn → train}/metrics/confusion_matrix.py +10 -8
  869. mindspore/{nn → train}/metrics/cosine_similarity.py +4 -4
  870. mindspore/{nn → train}/metrics/dice.py +6 -5
  871. mindspore/{nn → train}/metrics/error.py +7 -5
  872. mindspore/{nn → train}/metrics/fbeta.py +9 -7
  873. mindspore/{nn → train}/metrics/hausdorff_distance.py +8 -6
  874. mindspore/{nn → train}/metrics/loss.py +4 -3
  875. mindspore/{nn → train}/metrics/mean_surface_distance.py +6 -5
  876. mindspore/{nn → train}/metrics/metric.py +6 -5
  877. mindspore/{nn → train}/metrics/occlusion_sensitivity.py +4 -3
  878. mindspore/{nn → train}/metrics/perplexity.py +5 -4
  879. mindspore/{nn → train}/metrics/precision.py +5 -4
  880. mindspore/{nn → train}/metrics/recall.py +5 -4
  881. mindspore/{nn → train}/metrics/roc.py +7 -6
  882. mindspore/{nn → train}/metrics/root_mean_square_surface_distance.py +6 -5
  883. mindspore/{nn → train}/metrics/topk.py +7 -5
  884. mindspore/train/mind_ir_pb2.py +339 -32
  885. mindspore/train/model.py +113 -84
  886. mindspore/train/serialization.py +547 -167
  887. mindspore/train/summary/_summary_adapter.py +1 -1
  888. mindspore/train/summary/summary_record.py +43 -12
  889. mindspore/train/train_thor/convert_utils.py +7 -1
  890. mindspore/train/train_thor/dataset_helper.py +3 -3
  891. mindspore/train/train_thor/model_thor.py +0 -4
  892. mindspore/turbojpeg.dll +0 -0
  893. mindspore/vcmeta.dll +0 -0
  894. mindspore/vcruntime140.dll +0 -0
  895. mindspore/vcruntime140_1.dll +0 -0
  896. mindspore/version.py +1 -1
  897. {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/METADATA +4 -3
  898. {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/RECORD +901 -660
  899. mindspore/compression/common/constant.py +0 -124
  900. mindspore/compression/export/__init__.py +0 -19
  901. mindspore/compression/export/quant_export.py +0 -514
  902. mindspore/compression/quant/qat.py +0 -636
  903. mindspore/compression/quant/quant_utils.py +0 -462
  904. mindspore/compression/quant/quantizer.py +0 -68
  905. mindspore/libatomic-1.dll +0 -0
  906. mindspore/libgcc_s_seh-1.dll +0 -0
  907. mindspore/libgfortran-4.dll +0 -0
  908. mindspore/libgomp-1.dll +0 -0
  909. mindspore/libjpeg-62.dll +0 -0
  910. mindspore/libmindspore.dll +0 -0
  911. mindspore/libmindspore_common.dll +0 -0
  912. mindspore/libmindspore_core.dll +0 -0
  913. mindspore/libmindspore_glog.dll +0 -0
  914. mindspore/libnnacl.dll +0 -0
  915. mindspore/libopencv_core452.dll +0 -0
  916. mindspore/libopencv_imgcodecs452.dll +0 -0
  917. mindspore/libopencv_imgproc452.dll +0 -0
  918. mindspore/libquadmath-0.dll +0 -0
  919. mindspore/libsqlite3.dll +0 -0
  920. mindspore/libssp-0.dll +0 -0
  921. mindspore/libstdc++-6.dll +0 -0
  922. mindspore/libtinyxml2.dll +0 -0
  923. mindspore/libturbojpeg.dll +0 -0
  924. mindspore/libwinpthread-1.dll +0 -0
  925. mindspore/nn/layer/quant.py +0 -1868
  926. mindspore/nn/layer/rnn_utils.py +0 -90
  927. mindspore/nn/probability/dpn/__init__.py +0 -22
  928. mindspore/nn/probability/dpn/vae/__init__.py +0 -25
  929. mindspore/nn/probability/dpn/vae/cvae.py +0 -138
  930. mindspore/nn/probability/dpn/vae/vae.py +0 -122
  931. mindspore/nn/probability/infer/__init__.py +0 -22
  932. mindspore/nn/probability/infer/variational/elbo.py +0 -70
  933. mindspore/nn/probability/infer/variational/svi.py +0 -84
  934. mindspore/nn/probability/toolbox/__init__.py +0 -22
  935. mindspore/nn/probability/toolbox/anomaly_detection.py +0 -99
  936. mindspore/nn/probability/toolbox/uncertainty_evaluation.py +0 -363
  937. mindspore/nn/probability/transforms/__init__.py +0 -22
  938. mindspore/nn/probability/transforms/transform_bnn.py +0 -262
  939. mindspore/nn/probability/zhusuan/__init__.py +0 -18
  940. mindspore/nn/probability/zhusuan/framework/__init__.py +0 -18
  941. mindspore/nn/probability/zhusuan/framework/bn.py +0 -95
  942. mindspore/nn/probability/zhusuan/variational/__init__.py +0 -18
  943. mindspore/nn/probability/zhusuan/variational/elbo.py +0 -46
  944. mindspore/ops/_op_impl/tbe/bias_add_grad_ds.py +0 -52
  945. mindspore/ops/_op_impl/tbe/scatter_nd_add_ds.py +0 -43
  946. mindspore/ops/bprop_mindir/AssignAdd_bprop.mindir +0 -20
  947. mindspore/ops/bprop_mindir/Identity_bprop.mindir +0 -9
  948. mindspore/ops/bprop_mindir/LogicalOr_bprop.mindir +0 -20
  949. mindspore/ops/bprop_mindir/ReLU_bprop.mindir +0 -16
  950. mindspore/ops/bprop_mindir/UpdateState_bprop.mindir +0 -17
  951. mindspore/ops/bprop_mindir/stop_gradient_bprop.mindir +0 -12
  952. mindspore/ops/composite/array_ops.py +0 -210
  953. mindspore/ops/composite/clip_ops.py +0 -238
  954. mindspore/ops/composite/random_ops.py +0 -426
  955. mindspore/ops/composite/vmap_ops.py +0 -38
  956. mindspore/ops/operations/sponge_ops.py +0 -3531
  957. mindspore/ops/operations/sponge_update_ops.py +0 -2546
  958. mindspore/parallel/nn/__init__.py +0 -42
  959. mindspore/parallel/nn/loss.py +0 -22
  960. mindspore/parallel/nn/moe.py +0 -21
  961. mindspore/parallel/nn/op_parallel_config.py +0 -22
  962. mindspore/parallel/nn/transformer.py +0 -31
  963. mindspore/run_check/_check_deps_version.py +0 -84
  964. {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/WHEEL +0 -0
  965. {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/entry_points.txt +0 -0
  966. {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/top_level.txt +0 -0
@@ -1,2546 +0,0 @@
1
- # Copyright 2021 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ============================================================================
15
-
16
- """
17
- Note:
18
- SPONGE operators for new modules. This is an experimental interface that is subject to change and/or deletion.
19
- """
20
- import math
21
- from ..primitive import PrimitiveWithInfer, prim_attr_register
22
- from ..._checkparam import Rel
23
- from ..._checkparam import Validator as validator
24
- from ...common import dtype as mstype
25
-
26
-
27
- class RefreshUintCrd(PrimitiveWithInfer):
28
- """
29
- Refresh the unsigned coordinate of each constrained atom in each constrain iteration.
30
-
31
- .. warning::
32
- This is an experimental prototype that is subject to change and/or deletion.
33
-
34
- Args:
35
- atom_numbers (int32): the number of atoms n.
36
- half_exp_gamma_plus_half (float32): constant value (1.0 + exp(gamma * dt)) if Langvin-Liu thermostat is used,
37
- where gamma is friction coefficient and dt is the simulation time step, 1.0 otherwise.
38
-
39
- Inputs:
40
- - **crd** (Tensor) - The coordinate of each atom.
41
- The data type is float32 and the shape is :math:`(n, 3)`.
42
- - **quarter_cof** (Tensor) - The 3-D scale factor.
43
- The data type is float32 and the shape is :math:`(3,)`.
44
- - **test_frc** (Tensor) - The constraint force.
45
- The data type is float32 and the shape is :math:`(n, 3)`.
46
- - **mass_inverse** (Tensor) - The inverse value of mass of each atom.
47
- The data type is float32 and the shape is :math:`(n,)`.
48
-
49
- Outputs:
50
- - **uint_crd** (Tensor) - The unsigned int coordinate value of each atom.
51
- The data type is uint32 and the shape is :math:`(n, 3)`.
52
-
53
- Supported Platforms:
54
- ``GPU``
55
- """
56
-
57
- @prim_attr_register
58
- def __init__(self, atom_numbers, half_exp_gamma_plus_half):
59
- validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
60
- validator.check_value_type('half_exp_gamma_plus_half', half_exp_gamma_plus_half, float, self.name)
61
- self.atom_numbers = atom_numbers
62
- self.half_exp_gamma_plus_half = half_exp_gamma_plus_half
63
- self.add_prim_attr('atom_numbers', self.atom_numbers)
64
- self.add_prim_attr('half_exp_gamma_plus_half', self.half_exp_gamma_plus_half)
65
- self.init_prim_io_names(
66
- inputs=['crd', 'quarter_cof', 'test_frc', 'mass_inverse'],
67
- outputs=['uint_crd'])
68
-
69
- def infer_shape(self, crd_shape, quarter_cof_shape, test_frc_shape, mass_inverse_shape):
70
- cls_name = self.name
71
- n = self.atom_numbers
72
- validator.check_int(len(crd_shape), 2, Rel.EQ, "crd_dim", cls_name)
73
- validator.check_int(len(quarter_cof_shape), 1, Rel.EQ, "quarter_cof_dim", cls_name)
74
- validator.check_int(len(test_frc_shape), 2, Rel.EQ, "test_frc_dim", cls_name)
75
- validator.check_int(len(mass_inverse_shape), 1, Rel.EQ, "mass_inverse_dim", cls_name)
76
-
77
- validator.check_int(crd_shape[0], n, Rel.EQ, "crd_shape[0]", cls_name)
78
- validator.check_int(crd_shape[1], 3, Rel.EQ, "crd_shape[1]", cls_name)
79
- validator.check_int(quarter_cof_shape[0], 3, Rel.EQ, "quarter_cof_shape", cls_name)
80
- validator.check_int(test_frc_shape[0], n, Rel.EQ, "test_frc_shape[0]", cls_name)
81
- validator.check_int(test_frc_shape[1], 3, Rel.EQ, "test_frc_shape[1]", cls_name)
82
- validator.check_int(mass_inverse_shape[0], n, Rel.EQ, "mass_inverse_shape", cls_name)
83
-
84
- return [n, 3]
85
-
86
- def infer_dtype(self, crd_dtype, quarter_cof_dtype, test_frc_dtype, mass_inverse_dtype):
87
- validator.check_tensor_dtype_valid('crd', crd_dtype, [mstype.float32], self.name)
88
- validator.check_tensor_dtype_valid('quarter_cof', quarter_cof_dtype, [mstype.float32], self.name)
89
- validator.check_tensor_dtype_valid('test_frc', test_frc_dtype, [mstype.float32], self.name)
90
- validator.check_tensor_dtype_valid('mass_inverse', mass_inverse_dtype, [mstype.float32], self.name)
91
- return mstype.uint32
92
-
93
-
94
- class ConstrainForceCycleWithVirial(PrimitiveWithInfer):
95
- """
96
- Calculate the constraint force and virial in each iteration.
97
-
98
- .. warning::
99
- This is an experimental prototype that is subject to change and/or deletion.
100
-
101
- Args:
102
- atom_numbers (int32): the number of atoms n.
103
- constrain_pair_numbers (int32): the number of constrain pairs m.
104
-
105
- Inputs:
106
- - **uint_crd** (Tensor) - The unsigned int coordinate value of each atom.
107
- The data type is uint32 and the shape is :math:`(n, 3)`.
108
- - **scaler** (Tensor) - The 3-D scale factor (x, y, z),
109
- The data type is float32 and the shape is :math:`(3,)`.
110
- - **pair_dr** (Tensor) - The displacement vector of each constrained atom pair.
111
- The data type is float32 and the shape is :math:`(m, 3)`.
112
- - **atom_i_serials** (Tensor) - The first atom index of each constrained atom pair.
113
- The data type is int32 and the shape is :math:`(m,)`.
114
- - **atom_j_serials** (Tensor) - The second atom index of each constrained atom pair.
115
- The data type is int32 and the shape is :math:`(m,)`.
116
- - **constant_rs** (Tensor) - The constrained distance of each constrained atom pair.
117
- The data type is float32 and the shape is :math:`(m,)`.
118
- - **constrain_ks** (Tensor) - The coefficient of each constrained atom pair.
119
- The data type is float32 and the shape is :math:`(m,)`.
120
-
121
- Outputs:
122
- - **test_frc** (Tensor) - The constraint force.
123
- The data type is float32 and the shape is :math:`(n, 3)`.
124
- - **atom_virial** (Tensor) - The virial caused by constraint force of each atom.
125
- The data type is float32 and the shape is :math:`(m,)`.
126
-
127
- Supported Platforms:
128
- ``GPU``
129
- """
130
-
131
- @prim_attr_register
132
- def __init__(self, atom_numbers, constrain_pair_numbers):
133
- validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
134
- validator.check_value_type('constrain_pair_numbers', constrain_pair_numbers, int, self.name)
135
- self.atom_numbers = atom_numbers
136
- self.constrain_pair_numbers = constrain_pair_numbers
137
- self.add_prim_attr('atom_numbers', self.atom_numbers)
138
- self.add_prim_attr('constrain_pair_numbers', self.constrain_pair_numbers)
139
- self.init_prim_io_names(
140
- inputs=['uint_crd', 'scaler', 'pair_dr', 'atom_i_serials', 'atom_j_serials',
141
- 'constant_rs', 'constrain_ks'],
142
- outputs=['test_frc', 'atom_virial'])
143
-
144
- def infer_shape(self, uint_crd_shape, scaler_shape, pair_dr_shape, atom_i_serials_shape,
145
- atom_j_serials_shape, constant_rs_shape, constrain_ks_shape):
146
- cls_name = self.name
147
- n = self.atom_numbers
148
- m = self.constrain_pair_numbers
149
- validator.check_int(len(uint_crd_shape), 2, Rel.EQ, "uint_crd_dim", cls_name)
150
- validator.check_int(len(scaler_shape), 1, Rel.EQ, "scaler_dim", cls_name)
151
- validator.check_int(len(pair_dr_shape), 2, Rel.EQ, "pair_dr_dim", cls_name)
152
- validator.check_int(len(atom_i_serials_shape), 1, Rel.EQ, "atom_i_serials_dim", cls_name)
153
- validator.check_int(len(atom_j_serials_shape), 1, Rel.EQ, "atom_j_serials_dim", cls_name)
154
- validator.check_int(len(constant_rs_shape), 1, Rel.EQ, "constant_rs_dim", cls_name)
155
- validator.check_int(len(constrain_ks_shape), 1, Rel.EQ, "constrain_ks_dim", cls_name)
156
-
157
- validator.check_int(uint_crd_shape[0], n, Rel.EQ, "uint_crd_shape[0]", cls_name)
158
- validator.check_int(uint_crd_shape[1], 3, Rel.EQ, "uint_crd_shape[1]", cls_name)
159
- validator.check_int(scaler_shape[0], 3, Rel.EQ, "scaler_shape", cls_name)
160
- validator.check_int(pair_dr_shape[0], m, Rel.EQ, "pair_dr_shape[0]", cls_name)
161
- validator.check_int(pair_dr_shape[1], 3, Rel.EQ, "pair_dr_shape[1]", cls_name)
162
- validator.check_int(atom_i_serials_shape[0], m, Rel.EQ, "atom_i_serials_shape", cls_name)
163
- validator.check_int(atom_j_serials_shape[0], m, Rel.EQ, "atom_j_serials_shape", cls_name)
164
- validator.check_int(constant_rs_shape[0], m, Rel.EQ, "constant_rs_shape", cls_name)
165
- validator.check_int(constrain_ks_shape[0], m, Rel.EQ, "constrain_ks_shape", cls_name)
166
- return [n, 3], [m,]
167
-
168
- def infer_dtype(self, uint_crd_dtype, scaler_dtype, pair_dr_dtype, atom_i_serials_dtype,
169
- atom_j_serials_dtype, constant_rs_dtype, constrain_ks_dtype):
170
- validator.check_tensor_dtype_valid('uint_crd', uint_crd_dtype, [mstype.uint32], self.name)
171
- validator.check_tensor_dtype_valid('scaler', scaler_dtype, [mstype.float32], self.name)
172
- validator.check_tensor_dtype_valid('pair_dr', pair_dr_dtype, [mstype.float32], self.name)
173
- validator.check_tensor_dtype_valid('atom_i_serials', atom_i_serials_dtype, [mstype.int32], self.name)
174
- validator.check_tensor_dtype_valid('atom_j_serials', atom_j_serials_dtype, [mstype.int32], self.name)
175
- validator.check_tensor_dtype_valid('constant_rs', constant_rs_dtype, [mstype.float32], self.name)
176
- validator.check_tensor_dtype_valid('constrain_ks', constrain_ks_dtype, [mstype.float32], self.name)
177
- return mstype.float32, mstype.float32
178
-
179
-
180
- class LastCrdToDr(PrimitiveWithInfer):
181
- """
182
- Calculate the displacement vector of each constrained atom pair.
183
-
184
- .. warning::
185
- This is an experimental prototype that is subject to change and/or deletion.
186
-
187
- Args:
188
- atom_numbers (int32): the number of atoms n.
189
- constrain_pair_numbers (int32): the number of constrain pairs m.
190
-
191
- Inputs:
192
- - **crd** (Tensor) - The coordinate of each atom.
193
- The data type is float32 and the shape is :math:`(n, 3)`.
194
- - **quarter_cof** (Tensor) - The 3-D scale factor.
195
- The data type is float32 and the shape is :math:`(3,)`.
196
- - **uint_dr_to_dr** (Tensor) - The 3-D scale factor (x, y, z)
197
- The data type is int32 and the shape is :math:`(3,)`..
198
- - **atom_i_serials** (Tensor) - The first atom index of each constrained atom pair.
199
- The data type is int32 and the shape is :math:`(m,)`.
200
- - **atom_j_serials** (Tensor) - The second atom index of each constrained atom pair.
201
- The data type is int32 and the shape is :math:`(m,)`.
202
- - **constant_rs** (Tensor) - The constrained distance of each constrained atom pair.
203
- The data type is float32 and the shape is :math:`(m,)`.
204
- - **constrain_ks** (Tensor) - The coefficient of each constrained atom pair.
205
- The data type is float32 and the shape is :math:`(m,)`.
206
-
207
- Outputs:
208
- - **pair_dr** (Tensor) - The displacement vector of each constrained atom pair.
209
- The data type is float32 and the shape is :math:`(m, 3)`.
210
-
211
- Supported Platforms:
212
- ``GPU``
213
- """
214
-
215
- @prim_attr_register
216
- def __init__(self, atom_numbers, constrain_pair_numbers):
217
- validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
218
- validator.check_value_type('constrain_pair_numbers', constrain_pair_numbers, int, self.name)
219
- self.constrain_pair_numbers = constrain_pair_numbers
220
- self.atom_numbers = atom_numbers
221
- self.add_prim_attr('constrain_pair_numbers', self.constrain_pair_numbers)
222
- self.init_prim_io_names(
223
- inputs=['crd', 'quarter_cof', 'uint_dr_to_dr', 'atom_i_serials', 'atom_j_serials',
224
- 'constant_rs', 'constrain_ks'],
225
- outputs=['pair_dr'])
226
-
227
- def infer_shape(self, crd_shape, quarter_cof_shape, uint_dr_to_dr_shape, atom_i_serials_shape,
228
- atom_j_serials_shape, constant_rs_shape, constrain_ks_shape):
229
- cls_name = self.name
230
- n = self.atom_numbers
231
- m = self.constrain_pair_numbers
232
- validator.check_int(len(crd_shape), 2, Rel.EQ, "crd_dim", cls_name)
233
- validator.check_int(len(quarter_cof_shape), 1, Rel.EQ, "quarter_cof_dim", cls_name)
234
- validator.check_int(len(uint_dr_to_dr_shape), 1, Rel.EQ, "quarter_cof_dim", cls_name)
235
- validator.check_int(len(atom_i_serials_shape), 1, Rel.EQ, "atom_i_serials_dim", cls_name)
236
- validator.check_int(len(atom_j_serials_shape), 1, Rel.EQ, "atom_j_serials_dim", cls_name)
237
- validator.check_int(len(constant_rs_shape), 1, Rel.EQ, "constant_rs_dim", cls_name)
238
- validator.check_int(len(constrain_ks_shape), 1, Rel.EQ, "constrain_ks_dim", cls_name)
239
-
240
- validator.check_int(crd_shape[0], n, Rel.EQ, "crd_shape[0]", cls_name)
241
- validator.check_int(crd_shape[1], 3, Rel.EQ, "crd_shape[1]", cls_name)
242
- validator.check_int(quarter_cof_shape[0], 3, Rel.EQ, "quarter_cof_shape", cls_name)
243
- validator.check_int(uint_dr_to_dr_shape[0], 3, Rel.EQ, "uint_dr_to_dr_shape", cls_name)
244
- validator.check_int(atom_i_serials_shape[0], m, Rel.EQ, "atom_i_serials_shape", cls_name)
245
- validator.check_int(atom_j_serials_shape[0], m, Rel.EQ, "atom_j_serials_shape", cls_name)
246
- validator.check_int(constant_rs_shape[0], m, Rel.EQ, "constant_rs_shape", cls_name)
247
- validator.check_int(constrain_ks_shape[0], m, Rel.EQ, "constrain_ks_shape", cls_name)
248
-
249
- return [m, 3]
250
-
251
- def infer_dtype(self, crd_dtype, quarter_cof_dtype, uint_dr_to_dr_dtype, atom_i_serials_dtype,
252
- atom_j_serials_dtype, constant_rs_dtype, constrain_ks_dtype):
253
- validator.check_tensor_dtype_valid('crd', crd_dtype, [mstype.float32], self.name)
254
- validator.check_tensor_dtype_valid('quarter_cof', quarter_cof_dtype, [mstype.float32], self.name)
255
- validator.check_tensor_dtype_valid('uint_dr_to_dr', uint_dr_to_dr_dtype, [mstype.float32], self.name)
256
- validator.check_tensor_dtype_valid('atom_i_serials', atom_i_serials_dtype, [mstype.int32], self.name)
257
- validator.check_tensor_dtype_valid('atom_j_serials', atom_j_serials_dtype, [mstype.int32], self.name)
258
- validator.check_tensor_dtype_valid('constant_rs', constant_rs_dtype, [mstype.float32], self.name)
259
- validator.check_tensor_dtype_valid('constrain_ks', constrain_ks_dtype, [mstype.float32], self.name)
260
- return mstype.float32
261
-
262
-
263
- class RefreshCrdVel(PrimitiveWithInfer):
264
- """
265
- Refresh the coordinate and velocity of each constrained atom after all iterations have ended.
266
-
267
- .. warning::
268
- This is an experimental prototype that is subject to change and/or deletion.
269
-
270
- Args:
271
- atom_numbers (int32): the number of atoms n.
272
- dt_inverse (float32): the inverse value of simulation time step.
273
- dt (float32): the simulation time step.
274
- exp_gamma (float32): constant value exp(gamma * dt).
275
- half_exp_gamma_plus_half (float32): constant value (1 + exp_gamma)/2.
276
-
277
- Inputs:
278
- - **crd** (Tensor) - The coordinate of each atom.
279
- The data type is float32 and the shape is :math:`(n, 3)`.
280
- - **vel** (Tensor) - The velocity of each atom.
281
- The data type is float32 and the shape is :math:`(n, 3)`.
282
- - **test_frc** (Tensor) - The constraint force calculated in the last iteration.
283
- The data type is float32 and the shape is :math:`(n, 3)`.
284
- - **mass_inverse** (Tensor) - The inverse value of mass of each atom.
285
- The data type is float32 and the shape is :math:`(n,)`.
286
-
287
- Outputs:
288
- - **res** (Tensor) - The return value after updating successfully.
289
- The data type is float32 and the shape is :math:`(1,)`.
290
-
291
- Supported Platforms:
292
- ``GPU``
293
- """
294
-
295
- @prim_attr_register
296
- def __init__(self, atom_numbers, dt_inverse, dt, exp_gamma, half_exp_gamma_plus_half):
297
- validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
298
- validator.check_value_type('dt', dt, float, self.name)
299
- validator.check_value_type('dt_inverse', dt_inverse, float, self.name)
300
- validator.check_value_type('exp_gamma', exp_gamma, float, self.name)
301
- validator.check_value_type('half_exp_gamma_plus_half', half_exp_gamma_plus_half, float, self.name)
302
- self.atom_numbers = atom_numbers
303
- self.dt_inverse = dt_inverse
304
- self.dt = dt
305
- self.exp_gamma = exp_gamma
306
- self.half_exp_gamma_plus_half = half_exp_gamma_plus_half
307
- self.add_prim_attr('atom_numbers', self.atom_numbers)
308
- self.add_prim_attr('dt_inverse', self.dt_inverse)
309
- self.add_prim_attr('dt', self.dt)
310
- self.add_prim_attr('exp_gamma', self.exp_gamma)
311
- self.add_prim_attr('half_exp_gamma_plus_half', self.half_exp_gamma_plus_half)
312
- self.init_prim_io_names(
313
- inputs=['crd', 'vel', 'test_frc', 'mass_inverse'],
314
- outputs=['res'])
315
- self.add_prim_attr('side_effect_mem', True)
316
-
317
- def infer_shape(self, crd_shape, vel_shape, test_frc_shape, mass_inverse_shape):
318
- cls_name = self.name
319
- n = self.atom_numbers
320
- validator.check_int(len(crd_shape), 2, Rel.EQ, "crd_dim", cls_name)
321
- validator.check_int(len(vel_shape), 2, Rel.EQ, "vel_dim", cls_name)
322
- validator.check_int(len(test_frc_shape), 2, Rel.EQ, "test_frc_dim", cls_name)
323
- validator.check_int(len(mass_inverse_shape), 1, Rel.EQ, "mass_inverse_dim", cls_name)
324
-
325
- validator.check_int(crd_shape[0], n, Rel.EQ, "crd_shape[0]", cls_name)
326
- validator.check_int(crd_shape[1], 3, Rel.EQ, "crd_shape[1]", cls_name)
327
- validator.check_int(vel_shape[0], n, Rel.EQ, "vel_shape[0]", cls_name)
328
- validator.check_int(vel_shape[1], 3, Rel.EQ, "vel_shape[1]", cls_name)
329
- validator.check_int(test_frc_shape[0], n, Rel.EQ, "test_frc_shape[0]", cls_name)
330
- validator.check_int(test_frc_shape[1], 3, Rel.EQ, "test_frc_shape[1]", cls_name)
331
- validator.check_int(mass_inverse_shape[0], n, Rel.EQ, "mass_inverse_shape[0]", cls_name)
332
- return [1,]
333
-
334
- def infer_dtype(self, crd_dtype, vel_dtype, test_frc_dtype, mass_inverse_dtype):
335
- validator.check_tensor_dtype_valid('crd', crd_dtype, [mstype.float32], self.name)
336
- validator.check_tensor_dtype_valid('vel', vel_dtype, [mstype.float32], self.name)
337
- validator.check_tensor_dtype_valid('test_frc', test_frc_dtype, [mstype.float32], self.name)
338
- validator.check_tensor_dtype_valid('mass_inverse', mass_inverse_dtype, [mstype.float32], self.name)
339
- return mstype.float32
340
-
341
-
342
- class CalculateNowrapCrd(PrimitiveWithInfer):
343
- """
344
- Calculate the inside-box periodic image of each atom.
345
-
346
- .. warning::
347
- This is an experimental prototype that is subject to change and/or deletion.
348
-
349
- Args:
350
- atom_numbers (int32): the number of atoms n.
351
-
352
- Inputs:
353
- - **crd** (Tensor) - The coordinate of each atom.
354
- The data type is float32 and the shape is :math:`(n, 3)`.
355
- - **box** (Tensor) - The 3-D size of system.
356
- The data type is float32 and the shape is :math:`(3, )`.
357
- - **box_map_times** (Tensor) - The number of times each atom has crossed the box.
358
- The data type is int32 and the shape is :math:`(n, 3)`.
359
-
360
- Outputs:
361
- - **nowrap_crd** (Tensor) - The inside-box periodic image of each atom.
362
- The data type is float32 and the shape is :math:`(n, 3)`.
363
-
364
- Supported Platforms:
365
- ``GPU``
366
- """
367
-
368
- @prim_attr_register
369
- def __init__(self, atom_numbers):
370
- validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
371
- self.atom_numbers = atom_numbers
372
- self.add_prim_attr('atom_numbers', self.atom_numbers)
373
- self.init_prim_io_names(
374
- inputs=['crd', 'box', 'box_map_times'],
375
- outputs=['nowrap_crd'])
376
-
377
- def infer_shape(self, crd_shape, box_shape, box_map_times_shape):
378
- cls_name = self.name
379
- n = self.atom_numbers
380
- validator.check_int(len(crd_shape), 2, Rel.EQ, "crd_dim", cls_name)
381
- validator.check_int(len(box_shape), 1, Rel.EQ, "box_dim", cls_name)
382
- validator.check_int(len(box_map_times_shape), 2, Rel.EQ, "box_map_times_dim", cls_name)
383
-
384
- validator.check_int(crd_shape[0], n, Rel.EQ, "crd_shape[0]", cls_name)
385
- validator.check_int(crd_shape[1], 3, Rel.EQ, "crd_shape[1]", cls_name)
386
- validator.check_int(box_shape[0], 3, Rel.EQ, "box_shape[0]", cls_name)
387
- validator.check_int(box_map_times_shape[0], n, Rel.EQ, "box_map_times_shape[0]", cls_name)
388
- validator.check_int(box_map_times_shape[1], 3, Rel.EQ, "box_map_times_shape[1]", cls_name)
389
- return [n, 3]
390
-
391
- def infer_dtype(self, crd_dtype, box_dtype, box_map_times_dtype):
392
- validator.check_tensor_dtype_valid('crd', crd_dtype, [mstype.float32], self.name)
393
- validator.check_tensor_dtype_valid('box', box_dtype, [mstype.float32], self.name)
394
- validator.check_tensor_dtype_valid('box_map_times', box_map_times_dtype,
395
- [mstype.float32], self.name)
396
- return mstype.float32
397
-
398
-
399
- class RefreshBoxmapTimes(PrimitiveWithInfer):
400
- """
401
- Refresh the box-crossing times of each atom.
402
-
403
- .. warning::
404
- This is an experimental prototype that is subject to change and/or deletion.
405
-
406
- Args:
407
- atom_numbers (int32): the number of atoms n.
408
-
409
- Inputs:
410
- - **crd** (Tensor) - The coordinate of each atom.
411
- The data type is float32 and the shape is :math:`(n, 3)`.
412
- - **old_crd** (Tensor) - The coordinate of each atom at last update.
413
- The data type is float32 and the shape is :math:`(n, 3)`.
414
- - **box_length_inverse** (Tensor) - The inverse value of box length in 3 dimensions.
415
- The data type is float32 and the shape is :math:`(3,)`.
416
- - **box_map_times** (Tensor) - The number of times each atom has crossed the box.
417
- The data type is int32 and the shape is :math:`(n, 3)`.
418
-
419
- Outputs:
420
- - **res** (Tensor) - The return value after updating successfully.
421
- The data type is float32 and the shape is :math:`(1,)`.
422
-
423
- Supported Platforms:
424
- ``GPU``
425
- """
426
-
427
- @prim_attr_register
428
- def __init__(self, atom_numbers):
429
- validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
430
- self.atom_numbers = atom_numbers
431
- self.add_prim_attr('atom_numbers', self.atom_numbers)
432
- self.init_prim_io_names(
433
- inputs=['crd', 'old_crd', 'box_length_inverse', 'box_map_times'],
434
- outputs=['res'])
435
-
436
- def infer_shape(self, crd_shape, old_crd_shape, box_length_inverse_shape, box_map_times_shape):
437
- cls_name = self.name
438
- n = self.atom_numbers
439
- validator.check_int(len(crd_shape), 2, Rel.EQ, "crd_dim", cls_name)
440
- validator.check_int(len(old_crd_shape), 2, Rel.EQ, "old_crd_dim", cls_name)
441
- validator.check_int(len(box_length_inverse_shape), 1, Rel.EQ, "box_length_inverse_dim", cls_name)
442
- validator.check_int(len(box_map_times_shape), 2, Rel.EQ, "box_map_times_dim", cls_name)
443
-
444
- validator.check_int(crd_shape[0], n, Rel.EQ, "crd_shape[0]", cls_name)
445
- validator.check_int(crd_shape[1], 3, Rel.EQ, "crd_shape[1]", cls_name)
446
- validator.check_int(old_crd_shape[0], n, Rel.EQ, "old_crd_shape[0]", cls_name)
447
- validator.check_int(old_crd_shape[1], 3, Rel.EQ, "old_crd_shape[1]", cls_name)
448
- validator.check_int(box_length_inverse_shape[0], 3, Rel.EQ, "box_length_inverse_shape[0]", cls_name)
449
- validator.check_int(box_map_times_shape[0], n, Rel.EQ, "box_map_times_shape[0]", cls_name)
450
- validator.check_int(box_map_times_shape[1], 3, Rel.EQ, "box_map_times_shape[1]", cls_name)
451
- return [1,]
452
-
453
- def infer_dtype(self, crd_dtype, old_crd_dtype, box_length_inverse_dtype, box_map_times_dtype):
454
- validator.check_tensor_dtype_valid('crd', crd_dtype, [mstype.float32], self.name)
455
- validator.check_tensor_dtype_valid('old_crd', old_crd_dtype, [mstype.float32], self.name)
456
- validator.check_tensor_dtype_valid('box_length_inverse', box_length_inverse_dtype, [mstype.float32], self.name)
457
- validator.check_tensor_dtype_valid('box_map_times', box_map_times_dtype,
458
- [mstype.int32], self.name)
459
- return mstype.float32
460
-
461
-
462
- class Totalc6get(PrimitiveWithInfer):
463
- """
464
- Get the average dispersion constant of short range Lennard-Jones interaction,
465
- for the subsequent long range correction energy and virial. Assume system has m Lennard-Jones types of atoms.
466
-
467
- .. warning::
468
- This is an experimental prototype that is subject to change and/or deletion.
469
-
470
- Args:
471
- atom_numbers (int32): the number of atoms n.
472
-
473
- Inputs:
474
- - **atom_lj_type** (Tensor) - The Lennard-Jones type of each atom.
475
- The data type is float32 and the shape is :math:`(n,)`.
476
- - **lj_b** (Tensor) - The attraction coefficient of each type. the number of pair atoms is m.
477
- The data type is float32 and the shape is :math:`(m,)`.
478
-
479
- Outputs:
480
- - **factor** (Tensor) - The average dispersion constant of Lennard-Jones interaction.
481
- The data type is float32 and the shape is :math:`(1,)`.
482
-
483
- Supported Platforms:
484
- ``GPU``
485
- """
486
-
487
- @prim_attr_register
488
- def __init__(self, atom_numbers):
489
- validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
490
- self.atom_numbers = atom_numbers
491
- self.add_prim_attr('atom_numbers', self.atom_numbers)
492
- self.init_prim_io_names(
493
- inputs=['atom_lj_type', 'lj_b'],
494
- outputs=['factor'])
495
-
496
- def infer_shape(self, atom_lj_type, lj_b):
497
- cls_name = self.name
498
- n = self.atom_numbers
499
- validator.check_int(len(atom_lj_type), 1, Rel.EQ, "atom_lj_type_dim", cls_name)
500
- validator.check_int(len(lj_b), 1, Rel.EQ, "LJ_b_dim", cls_name)
501
- validator.check_int(atom_lj_type[0], n, Rel.EQ, "atom_lj_type_shape[0]", cls_name)
502
- return [1,]
503
-
504
- def infer_dtype(self, atom_lj_type, lj_b):
505
- validator.check_tensor_dtype_valid('atom_lj_type', atom_lj_type, mstype.int32, self.name)
506
- validator.check_tensor_dtype_valid('lj_b', lj_b, mstype.float32, self.name)
507
- return mstype.float32
508
-
509
-
510
- class CrdToUintCrdQuarter(PrimitiveWithInfer):
511
- """
512
- Convert FP32 coordinate to Uint32 coordinate.
513
-
514
- .. warning::
515
- This is an experimental prototype that is subject to change and/or deletion.
516
-
517
- Args:
518
- atom_numbers (int32): the number of atoms n.
519
-
520
- Inputs:
521
- - **crd_to_uint_crd_cof** (Tensor) - The crd_to_uint_crd coefficient.
522
- The data type is float32 and the shape is :math:`(3,)`.
523
- - **crd** (Tensor) - The coordinate of each atom.
524
- The data type is float32 and the shape is :math:`(n, 3)`.
525
-
526
- Outputs:
527
- - **output** (Tensor) - The unsigned int coordinates.
528
- The data type is unsigned int32 and the shape is :math:`(n, 3)`.
529
-
530
- Supported Platforms:
531
- ``GPU``
532
- """
533
-
534
- @prim_attr_register
535
- def __init__(self, atom_numbers):
536
- """Initialize CrdToUintCrdQuarter"""
537
- validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
538
- self.atom_numbers = atom_numbers
539
- self.add_prim_attr('atom_numbers', self.atom_numbers)
540
- self.init_prim_io_names(
541
- inputs=['crd_to_uint_crd_cof', 'crd'],
542
- outputs=['output'])
543
-
544
- def infer_shape(self, crd_to_uint_crd_cof, crd):
545
- cls_name = self.name
546
- n = self.atom_numbers
547
- validator.check_int(len(crd), 2, Rel.EQ, "crd_dim", cls_name)
548
- validator.check_int(len(crd_to_uint_crd_cof), 1, Rel.EQ, "crd_to_uint_crd_cof_dim", cls_name)
549
- validator.check_int(crd_to_uint_crd_cof[0], 3, Rel.EQ, "crd_to_uint_crd_cof_shape", self.name)
550
- validator.check_int(crd[0], n, Rel.EQ, "crd[0]", self.name)
551
- validator.check_int(crd[1], 3, Rel.EQ, "crd[1]", self.name)
552
- return crd
553
-
554
- def infer_dtype(self, crd_to_uint_crd_cof, crd):
555
- validator.check_tensor_dtype_valid('crd_to_uint_crd_cof', crd_to_uint_crd_cof, [mstype.float32], self.name)
556
- validator.check_tensor_dtype_valid('crd', crd, [mstype.float32], self.name)
557
- return mstype.uint32
558
-
559
-
560
- class MDIterationLeapFrogLiujianWithMaxVel(PrimitiveWithInfer):
561
- """
562
- One step of classical leap frog algorithm to solve the finite difference
563
- Hamiltonian equations of motion for certain system, using Langevin dynamics
564
- with Liu's thermostat scheme, but with an maximum velocity limit. Assume the
565
- number of atoms is n and the target control temperature is T.
566
-
567
- Detailed iteration formula can be found in this paper: A unified thermostat
568
- scheme for efficient configurational sampling for classical/quantum canonical
569
- ensembles via molecular dynamics. DOI: 10.1063/1.4991621.
570
-
571
- .. warning::
572
- This is an experimental prototype that is subject to change and/or deletion.
573
-
574
- Args:
575
- atom_numbers (int32): the number of atoms n.
576
- dt (float32): time step for finite difference.
577
- half_dt (float32): half of time step for finite difference.
578
- exp_gamma (float32): parameter in Liu's dynamic, exp(-gamma_ln * dt).
579
- max_vel (float32): the maximum velocity limit.
580
-
581
- Inputs:
582
- - **inverse_mass** (Tensor) - The inverse value of mass of each atom.
583
- The data type is float32 and the shape is :math:`(n,)`.
584
- - **sqrt_mass_inverse** (Tensor) - The inverse sqrt of the mass in Liu's dynamics of each atom.
585
- The data type is float32 and the shape is :math:`(n,)`.
586
- - **vel** (Tensor) - The velocity of each atom.
587
- The data type is float32 and the shape is :math:`(n, 3)`.
588
- - **crd** (Tensor) - The coordinate of each atom.
589
- The data type is float32 and the shape is :math:`(n, 3)`.
590
- - **frc** (Tensor) - The force felt by each atom.
591
- The data type is float32 and the shape is :math:`(n, 3)`.
592
- - **acc** (Tensor) - The acceleration of each atom.
593
- The data type is float32 and the shape is :math:`(n, 3)`.
594
- - **rand_state** (Tensor) - Random state to generate random force.
595
- The data type is float32 and the shape is :math:`(math.ceil(n * 3.0 / 4.0) * 16, )`.
596
- - **rand_frc** (Tensor) - The random forces.
597
- The data type is float32 and the shape is :math:`(n, 3)`.
598
-
599
- Outputs:
600
- - **output** (float32) - The output coordinate of each atom.
601
- The data type is float32 and the shape is :math:`(n, 3)`.
602
-
603
- Supported Platforms:
604
- ``GPU``
605
- """
606
-
607
- @prim_attr_register
608
- def __init__(self, atom_numbers, half_dt, dt, exp_gamma, max_vel):
609
- """Initialize MDIterationLeapFrogLiujian"""
610
- validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
611
- validator.check_value_type('half_dt', half_dt, float, self.name)
612
- validator.check_value_type('dt', dt, float, self.name)
613
- validator.check_value_type('exp_gamma', exp_gamma, float, self.name)
614
- validator.check_value_type('max_vel', max_vel, float, self.name)
615
- self.atom_numbers = atom_numbers
616
- self.half_dt = half_dt
617
- self.dt = dt
618
- self.exp_gamma = exp_gamma
619
- self.max_vel = max_vel
620
-
621
- self.add_prim_attr('atom_numbers', self.atom_numbers)
622
- self.add_prim_attr('half_dt', self.half_dt)
623
- self.add_prim_attr('dt', self.dt)
624
- self.add_prim_attr('exp_gamma', self.exp_gamma)
625
- self.add_prim_attr('max_vel', self.max_vel)
626
- self.init_prim_io_names(
627
- inputs=['inverse_mass', 'sqrt_mass_inverse', 'vel', 'crd', 'frc', 'acc', 'rand_state', 'rand_frc'],
628
- outputs=['output'])
629
- self.add_prim_attr('side_effect_mem', True)
630
-
631
- def infer_shape(self, inverse_mass, sqrt_mass_inverse, vel, crd, frc, acc, rand_state, rand_frc):
632
- n = self.atom_numbers
633
- validator.check_int(len(inverse_mass), 1, Rel.EQ, "inverse_mass_dim", self.name)
634
- validator.check_int(len(sqrt_mass_inverse), 1, Rel.EQ, "sqrt_mass_inverse_dim", self.name)
635
- validator.check_int(len(rand_state), 1, Rel.EQ, "rand_state_dim", self.name)
636
- validator.check_int(len(rand_frc), 2, Rel.EQ, "rand_frc_dim", self.name)
637
- validator.check_int(len(vel), 2, Rel.EQ, "vel_dim", self.name)
638
- validator.check_int(len(crd), 2, Rel.EQ, "crd_dim", self.name)
639
- validator.check_int(len(frc), 2, Rel.EQ, "frc_dim", self.name)
640
- validator.check_int(len(acc), 2, Rel.EQ, "acc_dim", self.name)
641
- validator.check_int(inverse_mass[0], n, Rel.EQ, "inverse_mass", self.name)
642
- validator.check_int(sqrt_mass_inverse[0], n, Rel.EQ, "sqrt_mass_inverse", self.name)
643
- validator.check_int(vel[0], n, Rel.EQ, "vel_shape[0]", self.name)
644
- validator.check_int(vel[1], 3, Rel.EQ, "vel_shape[1]", self.name)
645
- validator.check_int(crd[0], n, Rel.EQ, "crd_shape[0]", self.name)
646
- validator.check_int(crd[1], 3, Rel.EQ, "crd_shape[1]", self.name)
647
- validator.check_int(frc[0], n, Rel.EQ, "frc_shape[0]", self.name)
648
- validator.check_int(frc[1], 3, Rel.EQ, "frc_shape[1]", self.name)
649
- validator.check_int(acc[0], n, Rel.EQ, "acc_shape[0]", self.name)
650
- validator.check_int(acc[1], 3, Rel.EQ, "acc_shape[1]", self.name)
651
- validator.check_int(rand_frc[0], n, Rel.EQ, "rand_frc_shape[0]", self.name)
652
- validator.check_int(rand_frc[1], 3, Rel.EQ, "rand_frc_shape[1]", self.name)
653
- validator.check_int(rand_state[0], math.ceil(self.atom_numbers * 3 / 4.0) * 16, Rel.EQ, "rand_state", self.name)
654
- return [self.atom_numbers, 3]
655
-
656
- def infer_dtype(self, inverse_mass, sqrt_mass_inverse, vel, crd, frc, acc, rand_state, rand_frc):
657
- validator.check_tensor_dtype_valid('inverse_mass', inverse_mass, [mstype.float32], self.name)
658
- validator.check_tensor_dtype_valid('sqrt_mass_inverse', sqrt_mass_inverse, [mstype.float32], self.name)
659
- validator.check_tensor_dtype_valid('vel', vel, [mstype.float32], self.name)
660
- validator.check_tensor_dtype_valid('crd', crd, [mstype.float32], self.name)
661
- validator.check_tensor_dtype_valid('frc', frc, [mstype.float32], self.name)
662
- validator.check_tensor_dtype_valid('acc', acc, [mstype.float32], self.name)
663
- validator.check_tensor_dtype_valid('rand_frc', rand_frc, [mstype.float32], self.name)
664
- validator.check_tensor_dtype_valid('rand_state', rand_state, [mstype.float32], self.name)
665
- return mstype.float32
666
-
667
-
668
- class GetCenterOfMass(PrimitiveWithInfer):
669
- """
670
- Get coordinate of centroid of each residue. Assume system has n atoms.
671
-
672
- .. warning::
673
- This is an experimental prototype that is subject to change and/or deletion.
674
-
675
- Args:
676
- residue_numbers (int32): the number of residues m.
677
-
678
- Inputs:
679
- - **start** (Tensor) - The start atom index of each residue.
680
- The data type is int32 and the shape is :math:`(m,)`.
681
- - **end** (Tensor) - The end atom index of each residue.
682
- The data type is int32 and the shape is :math:`(m,)`.
683
- - **crd** (Tensor) - The coordinate of each atom.
684
- The data type is float32 and the shape is :math:`(n, 3)`.
685
- - **atom_mass** (Tensor) - The mass of each atom and the atom number is n.
686
- The data type is float32 and the shape is :math:`(n,)`.
687
- - **residue_mass_inverse** (Tensor) - The inverse of mass of each residue.
688
- The data type is float32 and the shape is :math:`(m,)`.
689
-
690
- Outputs:
691
- - **center_of_mass** (Tensor) - The coordinate of centroid of each residue.
692
- The data type is float32 and the shape is :math:`(m, 3)`.
693
-
694
- Supported Platforms:
695
- ``GPU``
696
- """
697
-
698
- @prim_attr_register
699
- def __init__(self, residue_numbers):
700
- """Initialize GetCenterOfMass"""
701
- validator.check_value_type('residue_numbers', residue_numbers, int, self.name)
702
- self.residue_numbers = residue_numbers
703
- self.add_prim_attr('residue_numbers', self.residue_numbers)
704
- self.init_prim_io_names(
705
- inputs=['start', 'end', 'crd', 'atom_mass', 'residue_mass_inverse'],
706
- outputs=['center_of_mass'])
707
-
708
- def infer_shape(self, start, end, crd, atom_mass, residue_mass_inverse):
709
- n = crd[0]
710
- m = self.residue_numbers
711
- validator.check_int(len(start), 1, Rel.EQ, "start_dim", self.name)
712
- validator.check_int(len(end), 1, Rel.EQ, "end_dim", self.name)
713
- validator.check_int(len(crd), 2, Rel.EQ, "crd_dim", self.name)
714
- validator.check_int(len(atom_mass), 1, Rel.EQ, "atom_mass_dim", self.name)
715
- validator.check_int(len(residue_mass_inverse), 1, Rel.EQ, "residue_mass_inverse_dim", self.name)
716
- validator.check_int(start[0], m, Rel.EQ, "start_shape", self.name)
717
- validator.check_int(end[0], m, Rel.EQ, "end_shape", self.name)
718
- validator.check_int(crd[0], n, Rel.EQ, "crd_shape[0]", self.name)
719
- validator.check_int(crd[1], 3, Rel.EQ, "crd_shape[1]", self.name)
720
- validator.check_int(atom_mass[0], n, Rel.EQ, "atom_mass_shape[0]", self.name)
721
- validator.check_int(residue_mass_inverse[0], m, Rel.EQ, "residue_mass_inverse_shape", self.name)
722
- return [m, 3]
723
-
724
- def infer_dtype(self, start, end, crd, atom_mass, residue_mass_inverse):
725
- validator.check_tensor_dtype_valid('start', start, [mstype.int32], self.name)
726
- validator.check_tensor_dtype_valid('end', end, [mstype.int32], self.name)
727
- validator.check_tensor_dtype_valid('crd', crd, [mstype.float32], self.name)
728
- validator.check_tensor_dtype_valid('atom_mass', atom_mass, [mstype.float32], self.name)
729
- validator.check_tensor_dtype_valid('residue_mass_inverse', residue_mass_inverse, [mstype.float32], self.name)
730
- return mstype.float32
731
-
732
-
733
- class MapCenterOfMass(PrimitiveWithInfer):
734
- """
735
- Map all atoms in the same residue to the same periodic box, scale if necessary (usually in pressurestat).
736
- Assume system has n atoms.
737
-
738
- .. warning::
739
- This is an experimental prototype that is subject to change and/or deletion.
740
-
741
- Args:
742
- residue_numbers (int32): the number of residues m.
743
-
744
- Inputs:
745
- - **start** (Tensor) - The start atom index of each residue.
746
- The data type is int32 and the shape is :math:`(m,)`.
747
- - **end** (Tensor) - The end atom index of each residue.
748
- The data type is int32 and the shape is :math:`(m,)`.
749
- - **center_of_mass** (Tensor) - The coordinate of centroid of each residue.
750
- The data type is float32 and the shape is :math:`(m, 3)`.
751
- - **box_length** (Tensor) - The box length of the simulation box.
752
- The data type is float32 and the shape is :math:`(3,)`.
753
- - **no_wrap_crd** (Tensor) - The coordinate of each atom before wrap.
754
- The data type is float32 and the shape is :math:`(n, 3)`.
755
- - **crd** (Tensor) - The coordinate of each atom after wrap.
756
- The data type is float32 and the shape is :math:`(n, 3)`.
757
- - **scaler** (Tensor) - The scaler of system.
758
- The data type is float32 and the shape is :math:`(1,)`.
759
-
760
- Outputs:
761
- - **res** (Tensor) - The return value after updating successfully.
762
- The data type is float32 and the shape is :math:`(1,)`.
763
-
764
- Supported Platforms:
765
- ``GPU``
766
- """
767
-
768
- @prim_attr_register
769
- def __init__(self, residue_numbers):
770
- """Initialize MapCenterOfMass"""
771
- validator.check_value_type('residue_numbers', residue_numbers, int, self.name)
772
- self.residue_numbers = residue_numbers
773
- self.add_prim_attr('residue_numbers', self.residue_numbers)
774
- self.init_prim_io_names(
775
- inputs=['start', 'end', 'center_of_mass', 'box_length',
776
- 'no_wrap_crd', 'crd', 'scaler'],
777
- outputs=['res'])
778
-
779
- def infer_shape(self, start, end, center_of_mass, box_length, no_wrap_crd, crd, scaler):
780
- m = self.residue_numbers
781
- n = crd[0]
782
- validator.check_int(len(start), 1, Rel.EQ, "start_dim", self.name)
783
- validator.check_int(len(end), 1, Rel.EQ, "end_dim", self.name)
784
- validator.check_int(len(center_of_mass), 2, Rel.EQ, "center_of_mass_dim", self.name)
785
- validator.check_int(len(box_length), 1, Rel.EQ, "box_length_dim", self.name)
786
- validator.check_int(len(no_wrap_crd), 2, Rel.EQ, "no_wrap_crd_dim", self.name)
787
- validator.check_int(len(crd), 2, Rel.EQ, "crd_dim", self.name)
788
- validator.check_int(len(scaler), 1, Rel.EQ, "scaler_dim", self.name)
789
-
790
- validator.check_int(start[0], m, Rel.EQ, "start_shape", self.name)
791
- validator.check_int(end[0], m, Rel.EQ, "end_shape", self.name)
792
- validator.check_int(center_of_mass[0], m, Rel.EQ, "center_of_mass_shape[0]", self.name)
793
- validator.check_int(center_of_mass[1], 3, Rel.EQ, "center_of_mass_shape[1]", self.name)
794
- validator.check_int(box_length[0], 3, Rel.EQ, "box_length_shape", self.name)
795
- validator.check_int(scaler[0], 1, Rel.EQ, "scaler_shape", self.name)
796
- validator.check_int(no_wrap_crd[0], n, Rel.EQ, "no_wrap_crd_shape[0]", self.name)
797
- validator.check_int(no_wrap_crd[1], 3, Rel.EQ, "no_wrap_crd_shape[1]", self.name)
798
- validator.check_int(crd[0], n, Rel.EQ, "crd_shape[0]", self.name)
799
- validator.check_int(crd[1], 3, Rel.EQ, "crd_shape[1]", self.name)
800
- return [1,]
801
-
802
- def infer_dtype(self, start, end, center_of_mass, box_length, no_wrap_crd, crd, scaler):
803
- validator.check_tensor_dtype_valid('start', start, [mstype.int32], self.name)
804
- validator.check_tensor_dtype_valid('end', end, [mstype.int32], self.name)
805
- validator.check_tensor_dtype_valid('crd', crd, [mstype.float32], self.name)
806
- validator.check_tensor_dtype_valid('center_of_mass', center_of_mass, [mstype.float32], self.name)
807
- validator.check_tensor_dtype_valid('box_length', box_length, [mstype.float32], self.name)
808
- validator.check_tensor_dtype_valid('no_wrap_crd', no_wrap_crd, [mstype.float32], self.name)
809
- validator.check_tensor_dtype_valid('scaler', scaler, [mstype.float32], self.name)
810
- return mstype.float32
811
-
812
-
813
- class NeighborListRefresh(PrimitiveWithInfer):
814
- """
815
- Update (or construct if first time) the Verlet neighbor list for the
816
- calculation of short-ranged force.
817
-
818
- .. warning::
819
- This is an experimental prototype that is subject to change and/or deletion.
820
-
821
- Args:
822
- grid_numbers (int32): the total number of grids divided G.
823
- atom_numbers (int32): the number of atoms n.
824
- not_first_time (int32): whether to construct the neighbor list first time or not.
825
- nxy (int32): the total number of grids divided in xy plane.
826
- excluded_atom_numbers (int32): the total atom numbers in the excluded list E.
827
- cutoff_square (float32): the cutoff square distance for short-range force calculation.
828
- half_skin_square (float32): the maximum square value of the distance atom allowed to move between two updates.
829
- cutoff_with_skin (float32): cutoff + skin, indicates the radius of the neighbor list for each atom.
830
- half_cutoff_with_skin (float32): cutoff_with_skin/2.
831
- cutoff_with_skin_square (float32): the square value of cutoff_with_skin.
832
- refresh_interval (int32): the number of iteration steps between two updates of neighbor list. Default: 20.
833
- cutoff (float32): the cutoff distance for short-range force calculation. Default: 10.0.
834
- skin (float32): the maximum value of the distance atom allowed to move. Default: 2.0.
835
- max_atom_in_grid_numbers (int32): the maximum number of atoms in one grid k. Default: 64.
836
- max_neighbor_numbers (int32): The maximum number of neighbors m. Default: 800.
837
- forced_update (int32): the flag that decides whether to force an update. Default: 0.
838
- forced_check (int32): the flag that decides whether to force an check. Default: 0.
839
-
840
- Inputs:
841
- - **atom_numbers_in_grid_bucket** (Tensor) - The number of atoms in each grid bucket.
842
- The data type is int32 and the shape is :math:`(G,)`.
843
- - **bucket** (Tensor) - (Tensor) - The atom indices in each grid bucket.
844
- The data type is int32 and the shape is :math:`(G, k)`.
845
- - **crd** (Tensor) - The coordinates of each atom.
846
- The data type is float32 and the shape is :math:`(n, 3)`.
847
- - **box_length** (Tensor) - The box length of the simulation box.
848
- The data type is float32 and the shape is :math:`(3,)`.
849
- - **grid_n** (Tensor) - The number of grids divided of 3 dimensions of the simulation box.
850
- The data type is int32 and the shape is :math:`(3,)`.
851
- - **grid_length_inverse** (Tensor) - The inverse value of grid length.
852
- The data type is float32 and the shape is :math:`(3,)`.
853
- - **atom_in_grid_serial** (Tensor) - The grid index for each atom.
854
- The data type is int32 and the shape is :math:`(n,)`.
855
- - **old_crd** (Tensor) - The coordinates before update of each atom.
856
- The data type is float32 and the shape is :math:`(n, 3)`.
857
- - **crd_to_uint_crd_cof** (Tensor) - The scale factor between the unsigned int coordinate and the real one.
858
- The data type is float32 and the shape is :math:`(3,)`.
859
- - **uint_crd** (Tensor) - The unsigned int coordinates value fo each atom.
860
- The data type is unsigned int32 and the shape is :math:`(n, 3)`.
861
- - **gpointer** (Tensor) - The nearest neighbor grids (including self) of each grid.
862
- The data type is int32 and the shape is :math:`(G, 125)`.
863
- - **nl_atom_numbers** (Tensor) - The number of atoms in neighbor list of each atom.
864
- The data type is int32 and the shape is :math:`(n,)`.
865
- - **nl_atom_serial** (Tensor) - The indices of atoms in neighbor list of each atom.
866
- The data type is int32 and the shape is :math:`(n, m)`.
867
- - **uint_dr_to_dr_cof** (Tensor) - The scale factor.
868
- The data type is float32 and the shape is :math:`(3,)`.
869
- - **excluded_list_start** (Tensor) - The start excluded index in excluded list for each atom.
870
- The data type is int32 and the shape is :math:`(n,)`.
871
- - **excluded_list** (Tensor) - The contiguous join of excluded list of each atom.
872
- The data type is int32 and the shape is :math:`(E,)`.
873
- - **excluded_numbers** (Tensor) - The number of atom excluded in excluded list for each atom.
874
- The data type is int32 and the shape is :math:`(n,)`.
875
- - **need_refresh_flag** (Tensor) - Whether the neighbor list of each atom need update or not.
876
- The data type is int32 and the shape is :math:`(1,)`.
877
- - **refresh_count** (Union[Tensor, Scalar]) - Count how many iteration steps have passed since last update.
878
- The data type is int32 and the shape is :math:`(1,)` or :math:`()`.
879
-
880
- Outputs:
881
- - **res** (Tensor) - The return value after updating successfully.
882
- The data type is float32 and the shape is :math:`(1,)`.
883
-
884
- Supported Platforms:
885
- ``GPU``
886
- """
887
-
888
- @prim_attr_register
889
- def __init__(self, grid_numbers, atom_numbers, not_first_time, nxy, excluded_atom_numbers,
890
- cutoff_square, half_skin_square, cutoff_with_skin, half_cutoff_with_skin, cutoff_with_skin_square,
891
- refresh_interval=20, cutoff=10.0, skin=2.0, max_atom_in_grid_numbers=64, max_neighbor_numbers=800,
892
- forced_update=0, forced_check=0):
893
- """Initialize NeighborListRefresh"""
894
- self.grid_numbers = grid_numbers
895
- self.atom_numbers = atom_numbers
896
- self.refresh_interval = refresh_interval
897
- self.not_first_time = not_first_time
898
- self.cutoff = cutoff
899
- self.skin = skin
900
- self.max_atom_in_grid_numbers = max_atom_in_grid_numbers
901
- self.nxy = nxy
902
- self.excluded_atom_numbers = excluded_atom_numbers
903
- self.cutoff_square = cutoff_square
904
- self.half_skin_square = half_skin_square
905
- self.cutoff_with_skin = cutoff_with_skin
906
- self.half_cutoff_with_skin = half_cutoff_with_skin
907
- self.cutoff_with_skin_square = cutoff_with_skin_square
908
- self.max_neighbor_numbers = max_neighbor_numbers
909
- self.forced_update = forced_update
910
- self.forced_check = forced_check
911
- validator.check_value_type('grid_numbers', grid_numbers, int, self.name)
912
- validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
913
- validator.check_value_type('refresh_interval', refresh_interval, int, self.name)
914
- validator.check_value_type('not_first_time', not_first_time, int, self.name)
915
- validator.check_value_type('cutoff', cutoff, float, self.name)
916
- validator.check_value_type('skin', skin, float, self.name)
917
- validator.check_value_type('max_atom_in_grid_numbers', max_atom_in_grid_numbers, int, self.name)
918
- validator.check_value_type('nxy', nxy, int, self.name)
919
- validator.check_value_type('excluded_atom_numbers', excluded_atom_numbers, int, self.name)
920
- validator.check_value_type('cutoff_square', cutoff_square, float, self.name)
921
- validator.check_value_type('half_skin_square', half_skin_square, float, self.name)
922
- validator.check_value_type('cutoff_with_skin', cutoff_with_skin, float, self.name)
923
- validator.check_value_type('half_cutoff_with_skin', half_cutoff_with_skin, float, self.name)
924
- validator.check_value_type('cutoff_with_skin_square', cutoff_with_skin_square, float, self.name)
925
- validator.check_value_type('max_neighbor_numbers', max_neighbor_numbers, int, self.name)
926
- validator.check_value_type('forced_update', forced_update, int, self.name)
927
- validator.check_value_type('forced_check', forced_check, int, self.name)
928
- self.init_prim_io_names(
929
- inputs=['atom_numbers_in_grid_bucket', 'bucket', 'crd', 'box_length', 'grid_n', 'grid_length_inverse',
930
- 'atom_in_grid_serial', 'old_crd', 'crd_to_uint_crd_cof', 'uint_crd', 'gpointer', 'nl_atom_numbers',
931
- 'nl_atom_serial', 'uint_dr_to_dr_cof', 'excluded_list_start', 'excluded_list', 'excluded_numbers',
932
- 'need_refresh_flag', 'refresh_count'], outputs=['res'])
933
-
934
- self.add_prim_attr('grid_numbers', self.grid_numbers)
935
- self.add_prim_attr('atom_numbers', self.atom_numbers)
936
- self.add_prim_attr('refresh_interval', self.refresh_interval)
937
- self.add_prim_attr('not_first_time', self.not_first_time)
938
- self.add_prim_attr('cutoff', self.cutoff)
939
- self.add_prim_attr('skin', self.skin)
940
- self.add_prim_attr('max_atom_in_grid_numbers', self.max_atom_in_grid_numbers)
941
- self.add_prim_attr('nxy', self.nxy)
942
- self.add_prim_attr('excluded_atom_numbers', self.excluded_atom_numbers)
943
- self.add_prim_attr('cutoff_square', self.cutoff_square)
944
- self.add_prim_attr('half_skin_square', self.half_skin_square)
945
- self.add_prim_attr('cutoff_with_skin', self.cutoff_with_skin)
946
- self.add_prim_attr('half_cutoff_with_skin', self.half_cutoff_with_skin)
947
- self.add_prim_attr('cutoff_with_skin_square', self.cutoff_with_skin_square)
948
- self.add_prim_attr('forced_update', self.forced_update)
949
- self.add_prim_attr('forced_check', self.forced_check)
950
- self.add_prim_attr('side_effect_mem', True)
951
-
952
- def infer_shape(self, atom_numbers_in_grid_bucket_shape, bucket_shape, crd_shape, box_length_shape, grid_n_shape,
953
- grid_length_inverse_shape, atom_in_grid_serial_shape, old_crd_shape, crd_to_uint_crd_cof_shape,
954
- uint_crd_shape, gpointer_shape, nl_atom_numbers_shape, nl_atom_serial_shape,
955
- uint_dr_to_dr_cof_shape, excluded_list_start_shape, excluded_list_shape, excluded_numbers_shape,
956
- need_refresh_flag_shape, refresh_count_shape):
957
- validator.check_int(len(atom_numbers_in_grid_bucket_shape), 1, Rel.EQ,
958
- "atom_numbers_in_grid_bucket_dim", self.name)
959
- validator.check_int(len(bucket_shape), 2, Rel.EQ, "bucket_dim", self.name)
960
- validator.check_int(len(crd_shape), 2, Rel.EQ, "crd_dim", self.name)
961
- validator.check_int(len(box_length_shape), 1, Rel.EQ, "box_length_dim", self.name)
962
- validator.check_int(len(grid_n_shape), 1, Rel.EQ, "grid_n_dim", self.name)
963
- validator.check_int(len(grid_length_inverse_shape), 1, Rel.EQ, "grid_length_inverse_dim", self.name)
964
- validator.check_int(len(atom_in_grid_serial_shape), 1, Rel.EQ, "atom_in_grid_serial_dim", self.name)
965
- validator.check_int(len(old_crd_shape), 2, Rel.EQ, "old_crd_dim", self.name)
966
- validator.check_int(len(crd_to_uint_crd_cof_shape), 1, Rel.EQ, "crd_to_uint_crd_cof_dim", self.name)
967
- validator.check_int(len(uint_crd_shape), 2, Rel.EQ, "uint_crd_dim", self.name)
968
- validator.check_int(len(gpointer_shape), 2, Rel.EQ, "gpointer_dim", self.name)
969
- validator.check_int(len(nl_atom_numbers_shape), 1, Rel.EQ, "nl_atom_numbers_dim", self.name)
970
- validator.check_int(len(nl_atom_serial_shape), 2, Rel.EQ, "nl_atom_serial_dim", self.name)
971
- validator.check_int(len(uint_dr_to_dr_cof_shape), 1, Rel.EQ, "uint_dr_to_dr_cof_dim", self.name)
972
- validator.check_int(len(excluded_list_start_shape), 1, Rel.EQ, "excluded_list_start_dim", self.name)
973
- validator.check_int(len(excluded_list_shape), 1, Rel.EQ, "excluded_list_dim", self.name)
974
- validator.check_int(len(excluded_numbers_shape), 1, Rel.EQ, "excluded_numbers_dim", self.name)
975
- validator.check_int(len(need_refresh_flag_shape), 1, Rel.EQ, "need_refresh_flag_dim", self.name)
976
- validator.check_int(len(refresh_count_shape), 1, Rel.LE, "refresh_count_dim", self.name)
977
- validator.check_int(atom_numbers_in_grid_bucket_shape[0], self.grid_numbers, Rel.EQ,
978
- "atom_numbers_in_grid_bucket", self.name)
979
- validator.check_int(bucket_shape[0], self.grid_numbers, Rel.EQ, "bucket", self.name)
980
- validator.check_int(bucket_shape[1], self.max_atom_in_grid_numbers, Rel.EQ, "bucket", self.name)
981
- validator.check_int(crd_shape[0], self.atom_numbers, Rel.EQ, "crd", self.name)
982
- validator.check_int(crd_shape[1], 3, Rel.EQ, "crd", self.name)
983
- validator.check_int(box_length_shape[0], 3, Rel.EQ, "box_length", self.name)
984
- validator.check_int(grid_n_shape[0], 3, Rel.EQ, "grid_n", self.name)
985
- validator.check_int(grid_length_inverse_shape[0], 3, Rel.EQ, "grid_length_inverse", self.name)
986
- validator.check_int(atom_in_grid_serial_shape[0], self.atom_numbers, Rel.EQ, "atom_in_grid_serial",
987
- self.name)
988
- validator.check_int(old_crd_shape[0], self.atom_numbers, Rel.EQ, "old_crd", self.name)
989
- validator.check_int(old_crd_shape[1], 3, Rel.EQ, "old_crd", self.name)
990
- validator.check_int(crd_to_uint_crd_cof_shape[0], 3, Rel.EQ, "crd_to_uint_crd_cof", self.name)
991
- validator.check_int(uint_crd_shape[0], self.atom_numbers, Rel.EQ, "uint_crd", self.name)
992
- validator.check_int(uint_crd_shape[1], 3, Rel.EQ, "uint_crd", self.name)
993
- validator.check_int(gpointer_shape[0], self.grid_numbers, Rel.EQ, "gpointer", self.name)
994
- validator.check_int(gpointer_shape[1], 125, Rel.EQ, "gpointer", self.name)
995
- validator.check_int(nl_atom_numbers_shape[0], self.atom_numbers, Rel.EQ, "nl_atom_numbers", self.name)
996
- validator.check_int(nl_atom_serial_shape[0], self.atom_numbers, Rel.EQ, "nl_atom_serial", self.name)
997
- validator.check_int(nl_atom_serial_shape[1], self.max_neighbor_numbers, Rel.EQ, "nl_atom_serial",
998
- self.name)
999
- validator.check_int(uint_dr_to_dr_cof_shape[0], 3, Rel.EQ, "uint_dr_to_dr_cof", self.name)
1000
- validator.check_int(excluded_list_start_shape[0], self.atom_numbers, Rel.EQ, "excluded_list_start",
1001
- self.name)
1002
- validator.check_int(excluded_list_shape[0], self.excluded_atom_numbers, Rel.EQ, "excluded_list",
1003
- self.name)
1004
- validator.check_int(excluded_numbers_shape[0], self.atom_numbers, Rel.EQ, "excluded_numbers", self.name)
1005
- validator.check_int(need_refresh_flag_shape[0], 1, Rel.EQ, "need_refresh_flag", self.name)
1006
- if refresh_count_shape:
1007
- validator.check_int(refresh_count_shape[0], 1, Rel.EQ, "refresh_count_shape", self.name)
1008
- return [1,]
1009
-
1010
- def infer_dtype(self, atom_numbers_in_grid_bucket_dtype, bucket_dtype, crd_dtype, box_length_dtype, grid_n_dtype,
1011
- grid_length_inverse_dtype, atom_in_grid_serial_dtype, old_crd_dtype, crd_to_uint_crd_cof_dtype,
1012
- uint_crd_dtype, gpointer_dtype, nl_atom_numbers_dtype, nl_atom_serial_dtype,
1013
- uint_dr_to_dr_cof_dtype, excluded_list_start_dtype, excluded_list_dtype, excluded_numbers_dtype,
1014
- need_refresh_flag_dtype, refresh_count_dtype):
1015
- validator.check_tensor_dtype_valid('atom_numbers_in_grid_bucket', atom_numbers_in_grid_bucket_dtype,
1016
- [mstype.int32], self.name)
1017
- validator.check_tensor_dtype_valid('bucket', bucket_dtype, [mstype.int32], self.name)
1018
- validator.check_tensor_dtype_valid('crd', crd_dtype, [mstype.float32], self.name)
1019
- validator.check_tensor_dtype_valid('box_length', box_length_dtype, [mstype.float32], self.name)
1020
- validator.check_tensor_dtype_valid('grid_n', grid_n_dtype, [mstype.int32], self.name)
1021
- validator.check_tensor_dtype_valid('grid_length_inverse', grid_length_inverse_dtype, [mstype.float32],
1022
- self.name)
1023
- validator.check_tensor_dtype_valid('atom_in_grid_serial', atom_in_grid_serial_dtype, [mstype.int32],
1024
- self.name)
1025
- validator.check_tensor_dtype_valid('old_crd', old_crd_dtype, [mstype.float32], self.name)
1026
- validator.check_tensor_dtype_valid('crd_to_uint_crd_cof', crd_to_uint_crd_cof_dtype, [mstype.float32],
1027
- self.name)
1028
- validator.check_tensor_dtype_valid('uint_crd', uint_crd_dtype, [mstype.uint32], self.name)
1029
- validator.check_tensor_dtype_valid('gpointer', gpointer_dtype, [mstype.int32], self.name)
1030
- validator.check_tensor_dtype_valid('nl_atom_numbers', nl_atom_numbers_dtype, [mstype.int32], self.name)
1031
- validator.check_tensor_dtype_valid('nl_atom_serial', nl_atom_serial_dtype, [mstype.int32], self.name)
1032
- validator.check_tensor_dtype_valid('uint_dr_to_dr_cof', uint_dr_to_dr_cof_dtype, [mstype.float32],
1033
- self.name)
1034
- validator.check_tensor_dtype_valid('excluded_list_start', excluded_list_start_dtype, [mstype.int32],
1035
- self.name)
1036
- validator.check_tensor_dtype_valid('excluded_list', excluded_list_dtype, [mstype.int32], self.name)
1037
- validator.check_tensor_dtype_valid('excluded_numbers', excluded_numbers_dtype, [mstype.int32], self.name)
1038
- validator.check_tensor_dtype_valid('need_refresh_flag', need_refresh_flag_dtype, [mstype.int32],
1039
- self.name)
1040
- validator.check_tensor_dtype_valid('refresh_count', refresh_count_dtype, [mstype.int32],
1041
- self.name)
1042
-
1043
- return mstype.float32
1044
-
1045
-
1046
- class MDIterationLeapFrog(PrimitiveWithInfer):
1047
- """
1048
- One step of classical leap frog algorithm to solve the finite difference
1049
- Hamiltonian equations of motion for certain system.
1050
-
1051
- .. warning::
1052
- This is an experimental prototype that is subject to change and/or deletion.
1053
-
1054
- Args:
1055
- atom_numbers (int32): the number of atoms n.
1056
- dt (float32): the simulation time step.
1057
-
1058
- Inputs:
1059
- - **sqrt_mass_inverse** (Tensor) - The square root of the inverse value of the mass of each atom.
1060
- The data type is float32 and the shape is :math:`(n,)`.
1061
- - **vel** (Tensor) - The velocity of each atom.
1062
- The data type is float32 and the shape is :math:`(n, 3)`.
1063
- - **crd** (Tensor) - The coordinate of each atom.
1064
- The data type is float32 and the shape is :math:`(n, 3)`.
1065
- - **frc** (Tensor) - The force felt by each atom.
1066
- The data type is float32 and the shape is :math:`(n, 3)`.
1067
- - **acc** (Tensor) - The acceleration of each atom.
1068
- The data type is float32 and the shape is :math:`(n, 3)`.
1069
- - **inverse_mass** (Tensor) - The inverse value of mass of each atom.
1070
- The data type is float32 and the shape is :math:`(n,)`.
1071
-
1072
- Outputs:
1073
- - **res** (Tensor) - The return value after updating successfully.
1074
- The data type is float32 and the shape is :math:`(1,)`.
1075
-
1076
- Supported Platforms:
1077
- ``GPU``
1078
- """
1079
-
1080
- @prim_attr_register
1081
- def __init__(self, atom_numbers, dt):
1082
- """Initialize MDIterationLeapFrog"""
1083
- validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
1084
- validator.check_value_type('dt', dt, float, self.name)
1085
- self.atom_numbers = atom_numbers
1086
- self.dt = dt
1087
- self.add_prim_attr('atom_numbers', self.atom_numbers)
1088
- self.add_prim_attr('dt', self.dt)
1089
- self.init_prim_io_names(
1090
- inputs=['sqrt_mass_inverse', 'vel', 'crd', 'frc', 'acc', 'inverse_mass'],
1091
- outputs=['res'])
1092
- self.add_prim_attr('side_effect_mem', True)
1093
-
1094
- def infer_shape(self, vel, crd, frc, acc, inverse_mass):
1095
- n = self.atom_numbers
1096
- validator.check_int(len(vel), 2, Rel.EQ, "vel_dim", self.name)
1097
- validator.check_int(len(crd), 2, Rel.EQ, "crd_dim", self.name)
1098
- validator.check_int(len(frc), 2, Rel.EQ, "frc_dim", self.name)
1099
- validator.check_int(len(acc), 2, Rel.EQ, "acc_dim", self.name)
1100
- validator.check_int(len(inverse_mass), 1, Rel.EQ, "inverse_mass_dim", self.name)
1101
- validator.check_int(vel[0], n, Rel.EQ, "vel_shape[0]", self.name)
1102
- validator.check_int(vel[1], 3, Rel.EQ, "vel_shape[1]", self.name)
1103
- validator.check_int(crd[0], n, Rel.EQ, "crd_shape[0]", self.name)
1104
- validator.check_int(crd[1], 3, Rel.EQ, "crd_shape[1]", self.name)
1105
- validator.check_int(frc[0], n, Rel.EQ, "frc_shape[0]", self.name)
1106
- validator.check_int(frc[1], 3, Rel.EQ, "frc_shape[1]", self.name)
1107
- validator.check_int(acc[0], n, Rel.EQ, "acc_shape[0]", self.name)
1108
- validator.check_int(acc[1], 3, Rel.EQ, "acc_shape[1]", self.name)
1109
- validator.check_int(inverse_mass[0], n, Rel.EQ, "inverse_mass_shape", self.name)
1110
- return [1,]
1111
-
1112
- def infer_dtype(self, vel, crd, frc, acc, inverse_mass):
1113
- validator.check_tensor_dtype_valid('inverse_mass', inverse_mass, [mstype.float32], self.name)
1114
- validator.check_tensor_dtype_valid('vel', vel, [mstype.float32], self.name)
1115
- validator.check_tensor_dtype_valid('crd', crd, [mstype.float32], self.name)
1116
- validator.check_tensor_dtype_valid('frc', frc, [mstype.float32], self.name)
1117
- validator.check_tensor_dtype_valid('acc', acc, [mstype.float32], self.name)
1118
- return mstype.float32
1119
-
1120
-
1121
- class MDIterationLeapFrogWithMaxVel(PrimitiveWithInfer):
1122
- """
1123
- Leap frog algorithm to solve the Hamiltonian equations of motion with a maximum velocity limit.
1124
-
1125
- .. warning::
1126
- This is an experimental prototype that is subject to change and/or deletion.
1127
-
1128
- Args:
1129
- atom_numbers (int32): the number of atoms n.
1130
- dt (float32): the simulation time step.
1131
- max_velocity (float32): the maximum velocity limit.
1132
-
1133
- Inputs:
1134
- - **vel** (Tensor) - The velocity of each atom.
1135
- The data type is float32 and the shape is :math:`(n, 3)`.
1136
- - **crd** (Tensor) - The coordinate of each atom.
1137
- The data type is float32 and the shape is :math:`(n, 3)`.
1138
- - **frc** (Tensor) - The force felt by each atom.
1139
- The data type is float32 and the shape is :math:`(n, 3)`.
1140
- - **acc** (Tensor) - The acceleration of each atom.
1141
- The data type is float32 and the shape is :math:`(n, 3)`.
1142
- - **inverse_mass** (Tensor) - The inverse value of mass of each atom.
1143
- The data type is float32 and the shape is :math:`(n,)`.
1144
-
1145
- Outputs:
1146
- - **res** (Tensor) - The return value after updating successfully.
1147
- The data type is float32 and the shape is :math:`(1,)`.
1148
-
1149
-
1150
- Supported Platforms:
1151
- ``GPU``
1152
- """
1153
-
1154
- @prim_attr_register
1155
- def __init__(self, atom_numbers, dt, max_velocity):
1156
- """Initialize MDIterationLeapFrogWithMaxVel"""
1157
- validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
1158
- validator.check_value_type('dt', dt, float, self.name)
1159
- validator.check_value_type('max_velocity', max_velocity, float, self.name)
1160
- self.atom_numbers = atom_numbers
1161
- self.dt = dt
1162
- self.max_velocity = max_velocity
1163
-
1164
- self.add_prim_attr('atom_numbers', self.atom_numbers)
1165
- self.add_prim_attr('dt', self.dt)
1166
- self.add_prim_attr('max_velocity', self.max_velocity)
1167
- self.init_prim_io_names(
1168
- inputs=['vel', 'crd', 'frc', 'acc', 'inverse_mass'],
1169
- outputs=['res'])
1170
- self.add_prim_attr('side_effect_mem', True)
1171
-
1172
- def infer_shape(self, vel, crd, frc, acc, inverse_mass):
1173
- n = self.atom_numbers
1174
- validator.check_int(len(vel), 2, Rel.EQ, "vel_dim", self.name)
1175
- validator.check_int(len(crd), 2, Rel.EQ, "crd_dim", self.name)
1176
- validator.check_int(len(frc), 2, Rel.EQ, "frc_dim", self.name)
1177
- validator.check_int(len(acc), 2, Rel.EQ, "acc_dim", self.name)
1178
- validator.check_int(len(inverse_mass), 1, Rel.EQ, "inverse_mass_dim", self.name)
1179
- validator.check_int(inverse_mass[0], n, Rel.EQ, "inverse_mass_shape", self.name)
1180
- validator.check_int(vel[0], n, Rel.EQ, "vel_shape[0]", self.name)
1181
- validator.check_int(vel[1], 3, Rel.EQ, "vel_shape[1]", self.name)
1182
- validator.check_int(crd[0], n, Rel.EQ, "crd_shape[0]", self.name)
1183
- validator.check_int(crd[1], 3, Rel.EQ, "crd_shape[1]", self.name)
1184
- validator.check_int(frc[0], n, Rel.EQ, "frc_shape[0]", self.name)
1185
- validator.check_int(frc[1], 3, Rel.EQ, "frc_shape[1]", self.name)
1186
- validator.check_int(acc[0], n, Rel.EQ, "acc_shape[0]", self.name)
1187
- validator.check_int(acc[1], 3, Rel.EQ, "acc_shape[1]", self.name)
1188
- return [1,]
1189
-
1190
- def infer_dtype(self, vel, crd, frc, acc, inverse_mass):
1191
- validator.check_tensor_dtype_valid('inverse_mass', inverse_mass, [mstype.float32], self.name)
1192
- validator.check_tensor_dtype_valid('vel', vel, [mstype.float32], self.name)
1193
- validator.check_tensor_dtype_valid('crd', crd, [mstype.float32], self.name)
1194
- validator.check_tensor_dtype_valid('frc', frc, [mstype.float32], self.name)
1195
- validator.check_tensor_dtype_valid('acc', acc, [mstype.float32], self.name)
1196
- return mstype.float32
1197
-
1198
-
1199
- class MDIterationGradientDescent(PrimitiveWithInfer):
1200
- """
1201
- Update the coordinate of each atom in the direction of potential for energy minimization.
1202
-
1203
- .. warning::
1204
- This is an experimental prototype that is subject to change and/or deletion.
1205
-
1206
- Args:
1207
- atom_numbers (int32): the number of atoms n.
1208
- learning_rate (float32): the update step length.
1209
-
1210
- Inputs:
1211
- - **crd** (Tensor) - The coordinate of each atom.
1212
- The data type is float32 and the shape is :math:`(n, 3)`.
1213
- - **frc** (Tensor), The force felt by each atom.
1214
- The data type is float32 and the shape is :math:`(n, 3)`.
1215
-
1216
- Output:
1217
- - **res** (Tensor) - The return value after updating successfully.
1218
- The data type is float32 and the shape is :math:`(1,)`.
1219
-
1220
- Supported Platforms:
1221
- ``GPU``
1222
- """
1223
-
1224
- @prim_attr_register
1225
- def __init__(self, atom_numbers, learning_rate):
1226
- """Initialize MDIterationGradientDescent"""
1227
- validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
1228
- validator.check_value_type('learning_rate', learning_rate, float, self.name)
1229
- self.atom_numbers = atom_numbers
1230
- self.learning_rate = learning_rate
1231
- self.add_prim_attr('atom_numbers', self.atom_numbers)
1232
- self.add_prim_attr('learning_rate', self.learning_rate)
1233
- self.init_prim_io_names(
1234
- inputs=['crd', 'frc'],
1235
- outputs=['res'])
1236
- self.add_prim_attr('side_effect_mem', True)
1237
-
1238
- def infer_shape(self, crd, frc):
1239
- n = self.atom_numbers
1240
- validator.check_int(len(crd), 2, Rel.EQ, "crd_dim", self.name)
1241
- validator.check_int(len(frc), 2, Rel.EQ, "frc_dim", self.name)
1242
- validator.check_int(crd[0], n, Rel.EQ, "crd_shape[0]", self.name)
1243
- validator.check_int(crd[1], 3, Rel.EQ, "crd_shape[1]", self.name)
1244
- validator.check_int(frc[0], n, Rel.EQ, "frc_shape[0]", self.name)
1245
- validator.check_int(frc[1], 3, Rel.EQ, "frc_shape[1]", self.name)
1246
- return [1,]
1247
-
1248
- def infer_dtype(self, crd, frc):
1249
- validator.check_tensor_dtype_valid('crd', crd, [mstype.float32], self.name)
1250
- validator.check_tensor_dtype_valid('frc', frc, [mstype.float32], self.name)
1251
- return mstype.float32
1252
-
1253
-
1254
- class BondForceWithAtomEnergyAndVirial(PrimitiveWithInfer):
1255
- """
1256
- Calculate bond force, harmonic potential energy and atom virial together.
1257
-
1258
- The calculation formula is the same as operator BondForce() and BondEnergy().
1259
-
1260
- Because there is a large amount of inputs and each of them are related,
1261
- there is no way to construct `Examples` using random methods. For details, refer the webpage `SPONGE in MindSpore
1262
- <https://gitee.com/mindspore/mindscience/blob/master/MindSPONGE/docs/simple_formula.md>`_.
1263
-
1264
- .. warning::
1265
- This is an experimental prototype that is subject to change and/or deletion.
1266
-
1267
- Args:
1268
- atom_numbers (int32): the number of atoms n.
1269
- bond_numbers (int32): the number of harmonic bonds m.
1270
-
1271
- Inputs:
1272
- - **uint_crd_f** (Tensor) - The unsigned int coordinate value of each atom.
1273
- The data type is uint32 and the shape is :math:`(n, 3)`.
1274
- - **scaler_f** (Tensor) - The 3-D scale factor (x, y, z),
1275
- The data type is float32 and the shape is :math:`(3,)`.
1276
- - **atom_a** (Tensor) - The first atom index of each bond.
1277
- The data type is int32 and the shape is :math:`(m,)`.
1278
- - **atom_b** (Tensor) - The second atom index of each bond.
1279
- The data type is int32 and the shape is :math:`(m,)`.
1280
- - **bond_k** (Tensor) - The force constant of each bond.
1281
- The data type is float32 and the shape is :math:`(m,)`.
1282
- - **bond_r0** (Tensor) - The equlibrium length of each bond.
1283
- The data type is float32 and the shape is :math:`(m,)`.
1284
-
1285
- Outputs:
1286
- - **frc_f** (Tensor) - The force of each atom.
1287
- The data type is float32 and the shape is :math:`(n, 3)`.
1288
- - **atom_e** (Tensor) - The energy of each atom.
1289
- The data type is float32 and the shape is :math:`(n,)`.
1290
- - **atom_virial** (Tensor) - The virial of each atom.
1291
- The data type is float32 and the shape is :math:`(n,)`.
1292
-
1293
- Supported Platforms:
1294
- ``GPU``
1295
- """
1296
-
1297
- @prim_attr_register
1298
- def __init__(self, bond_numbers, atom_numbers):
1299
- """Initialize BondForceWithAtomEnergyAndVirial"""
1300
- validator.check_value_type('bond_numbers', bond_numbers, int, self.name)
1301
- validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
1302
- self.bond_numbers = bond_numbers
1303
- self.atom_numbers = atom_numbers
1304
- self.add_prim_attr('bond_numbers', self.bond_numbers)
1305
- self.add_prim_attr('atom_numbers', self.atom_numbers)
1306
- self.init_prim_io_names(inputs=['uint_crd_f', 'scaler_f', 'atom_a', 'atom_b', 'bond_k', 'bond_r0'],
1307
- outputs=['frc_f', 'atom_e', 'atom_virial'])
1308
-
1309
- def infer_shape(self, uint_crd_f_shape, scaler_f_shape, atom_a_shape, atom_b_shape, bond_k_shape, bond_r0_shape):
1310
- cls_name = self.name
1311
- n = self.atom_numbers
1312
- m = self.bond_numbers
1313
- validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
1314
- validator.check_int(len(scaler_f_shape), 1, Rel.EQ, "scaler_f_dim", cls_name)
1315
- validator.check_int(len(atom_a_shape), 1, Rel.EQ, "atom_a_dim", cls_name)
1316
- validator.check_int(len(atom_b_shape), 1, Rel.EQ, "atom_b_dim", cls_name)
1317
- validator.check_int(len(bond_k_shape), 1, Rel.EQ, "bond_k_dim", cls_name)
1318
- validator.check_int(len(bond_r0_shape), 1, Rel.EQ, "bond_r0_dim", cls_name)
1319
- validator.check_int(uint_crd_f_shape[0], n, Rel.EQ, "uint_crd_f_shape[0]", cls_name)
1320
- validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
1321
- validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name)
1322
- validator.check_int(atom_a_shape[0], m, Rel.EQ, "atom_a_shape", cls_name)
1323
- validator.check_int(atom_b_shape[0], m, Rel.EQ, "atom_b_shape", cls_name)
1324
- validator.check_int(bond_k_shape[0], m, Rel.EQ, "bond_k_shape", cls_name)
1325
- validator.check_int(bond_r0_shape[0], m, Rel.EQ, "bond_r0_shape", cls_name)
1326
-
1327
- return uint_crd_f_shape, [n,], [n,]
1328
-
1329
- def infer_dtype(self, uint_crd_f_dtype, scaler_f_type, atom_a_type, atom_b_type, bond_k_type, bond_r0_type):
1330
- validator.check_tensor_dtype_valid('uint_crd_f', uint_crd_f_dtype, [mstype.uint32], self.name)
1331
- validator.check_tensor_dtype_valid('scaler_f', scaler_f_type, [mstype.float32], self.name)
1332
-
1333
- validator.check_tensor_dtype_valid('atom_a', atom_a_type, [mstype.int32], self.name)
1334
- validator.check_tensor_dtype_valid('atom_b', atom_b_type, [mstype.int32], self.name)
1335
-
1336
- validator.check_tensor_dtype_valid('bond_k', bond_k_type, [mstype.float32], self.name)
1337
- validator.check_tensor_dtype_valid('bond_r0', bond_r0_type, [mstype.float32], self.name)
1338
- return mstype.float32, mstype.float32, mstype.float32
1339
-
1340
-
1341
- class LJForceWithVirialEnergy(PrimitiveWithInfer):
1342
- """
1343
- Calculate the Lennard-Jones force, virial and atom energy together.
1344
-
1345
- The calculation formula of Lennard-Jones part is the same as operator
1346
- LJForce(), and the PME direct part is within PME method.
1347
-
1348
- Because there is a large amount of inputs and each of them are related,
1349
- there is no way to construct `Examples` using random methods. For details, refer the webpage `SPONGE in MindSpore
1350
- <https://gitee.com/mindspore/mindscience/blob/master/MindSPONGE/docs/simple_formula.md>`_.
1351
-
1352
- .. warning::
1353
- This is an experimental prototype that is subject to change and/or deletion.
1354
-
1355
- Args:
1356
- atom_numbers (int32): the number of atoms, n.
1357
- cutoff (float32): the square value of cutoff.
1358
- pme_beta (float32): PME beta parameter, same as operator PMEReciprocalForce().
1359
- max_neighbor_numbers (int32): the max neighbor numbers, default 800.
1360
-
1361
- Inputs:
1362
- - **uint_crd** (Tensor) - The unsigned int coordinate value of each atom.
1363
- The data type is uint32 and the shape is :math:`(n, 3)`.
1364
- - **LJtype** (Tensor) - The Lennard-Jones type of each atom.
1365
- The data type is int32 and the shape is :math:`(n,)`.
1366
- - **charge** (Tensor) - The charge carried by each atom.
1367
- The data type is float32 and the shape is :math:`(n,)`.
1368
- - **scaler** (Tensor) - The scale factor between real
1369
- space coordinate and its unsigned int value.
1370
- The data type is float32 and the shape is :math:`(3,)`.
1371
- - **nl_numbers** (Tensor) - The each atom.
1372
- The data type is int32 and the shape is :math:`(n,)`.
1373
- - **nl_serial** (Tensor) - The neighbor list of each atom, the max number is 800.
1374
- The data type is int32 and the shape is :math:`(n, 800)`.
1375
- - **d_LJ_A** (Tensor) - The Lennard-Jones A coefficient of each kind of atom pair.
1376
- The number of atom pair is q. The data type is float32 and the shape is :math:`(q,)`.
1377
- - **d_LJ_B** (Tensor) - The Lennard-Jones B coefficient of each kind of atom pair.
1378
- The number of atom pair is q. The data type is float32 and the shape is :math:`(q,)`.
1379
-
1380
- Outputs:
1381
- - **frc** (Tensor), The force felt by each atom.
1382
- The data type is float32 and the shape is :math:`(n, 3)`.
1383
- - **virial** (Tensor), The virial felt by each atom.
1384
- The data type is float32 and the shape is :math:`(n,)`.
1385
- - **atom_energy** (Tensor), The atom energy felt by each atom.
1386
- The data type is float32 and the shape is :math:`(n, 3)`.
1387
-
1388
- Supported Platforms:
1389
- ``GPU``
1390
- """
1391
-
1392
- @prim_attr_register
1393
- def __init__(self, atom_numbers, cutoff, pme_beta, max_neighbor_numbers=800):
1394
- """Initialize LJForceWithVirialEnergy"""
1395
- validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
1396
- validator.check_value_type('cutoff', cutoff, float, self.name)
1397
- validator.check_value_type('pme_beta', pme_beta, float, self.name)
1398
- validator.check_value_type('max_neighbor_numbers', max_neighbor_numbers, int, self.name)
1399
- self.atom_numbers = atom_numbers
1400
- self.cutoff = cutoff
1401
- self.pme_beta = pme_beta
1402
- self.max_neighbor_numbers = max_neighbor_numbers
1403
- self.init_prim_io_names(
1404
- inputs=['uint_crd', 'LJtype', 'charge', 'scaler', 'nl_numbers', 'nl_serial', 'd_LJ_A', 'd_LJ_B'],
1405
- outputs=['frc', 'virial', 'atom_energy'])
1406
- self.add_prim_attr('atom_numbers', self.atom_numbers)
1407
- self.add_prim_attr('cutoff', self.cutoff)
1408
- self.add_prim_attr('pme_beta', self.pme_beta)
1409
- self.add_prim_attr('max_neighbor_numbers', self.max_neighbor_numbers)
1410
-
1411
- def infer_shape(self, uint_crd, ljtype, charge, scaler, nl_numbers, nl_serial, d_lj_a, d_lj_b):
1412
- cls_name = self.name
1413
- n = self.atom_numbers
1414
- q = d_lj_a[0]
1415
- m = self.max_neighbor_numbers
1416
- validator.check_int(len(uint_crd), 2, Rel.EQ, "uint_crd_dim", cls_name)
1417
- validator.check_int(len(ljtype), 1, Rel.EQ, "LJtype_dim", cls_name)
1418
- validator.check_int(len(charge), 1, Rel.EQ, "charge_dim", cls_name)
1419
- validator.check_int(len(scaler), 1, Rel.EQ, "scaler_dim", cls_name)
1420
- validator.check_int(len(nl_numbers), 1, Rel.EQ, "nl_numbers_dim", cls_name)
1421
- validator.check_int(len(nl_serial), 2, Rel.EQ, "nl_serial_dim", cls_name)
1422
- validator.check_int(len(d_lj_a), 1, Rel.EQ, "d_LJ_A_dim", cls_name)
1423
- validator.check_int(len(d_lj_b), 1, Rel.EQ, "d_LJ_B_dim", cls_name)
1424
-
1425
- validator.check_int(uint_crd[0], n, Rel.EQ, "uint_crd_shape[0]", cls_name)
1426
- validator.check_int(uint_crd[1], 3, Rel.EQ, "uint_crd_shape[1]", cls_name)
1427
- validator.check_int(ljtype[0], n, Rel.EQ, "LJtype_shape", cls_name)
1428
- validator.check_int(charge[0], n, Rel.EQ, "charge_shape", cls_name)
1429
- validator.check_int(scaler[0], 3, Rel.EQ, "scaler_shape", cls_name)
1430
- validator.check_int(nl_numbers[0], n, Rel.EQ, "nl_numbers_shape", cls_name)
1431
- validator.check_int(nl_serial[0], n, Rel.EQ, "nl_serial_shape[0]", cls_name)
1432
- validator.check_int(nl_serial[1], m, Rel.EQ, "nl_serial_shape[1]", cls_name)
1433
- validator.check_int(d_lj_a[0], q, Rel.EQ, "d_LJ_A_shape[0]", cls_name)
1434
- validator.check_int(d_lj_b[0], q, Rel.EQ, "d_LJ_B_shape[0]", cls_name)
1435
- return [n, 3], [n,], [n,]
1436
-
1437
- def infer_dtype(self, uint_crd, ljtype, charge, scaler, nl_numbers, nl_serial, d_lj_a, d_lj_b):
1438
- validator.check_tensor_dtype_valid('uint_crd', uint_crd, [mstype.uint32], self.name)
1439
- validator.check_tensor_dtype_valid('LJtype', ljtype, [mstype.int32], self.name)
1440
- validator.check_tensor_dtype_valid('charge', charge, [mstype.float32], self.name)
1441
- validator.check_tensor_dtype_valid('scaler', scaler, [mstype.float32], self.name)
1442
- validator.check_tensor_dtype_valid('nl_numbers', nl_numbers, [mstype.int32], self.name)
1443
- validator.check_tensor_dtype_valid('nl_serial', nl_serial, [mstype.int32], self.name)
1444
- validator.check_tensor_dtype_valid('d_LJ_A', d_lj_a, [mstype.float32], self.name)
1445
- validator.check_tensor_dtype_valid('d_LJ_B', d_lj_b, [mstype.float32], self.name)
1446
- return mstype.float32, mstype.float32, mstype.float32
1447
-
1448
-
1449
- class LJForceWithPMEDirectForceUpdate(PrimitiveWithInfer):
1450
- """
1451
- Calculate the Lennard-Jones force and PME direct force together for pressure.
1452
-
1453
- The calculation formula of Lennard-Jones part is the same as operator
1454
- LJForce(), and the PME direct part is within PME method.
1455
-
1456
- Because there is a large amount of inputs and each of them are related,
1457
- there is no way to construct `Examples` using random methods. For details, refer the webpage `SPONGE in MindSpore
1458
- <https://gitee.com/mindspore/mindscience/blob/master/MindSPONGE/docs/simple_formula.md>`_.
1459
-
1460
- .. warning::
1461
- This is an experimental prototype that is subject to change and/or deletion.
1462
-
1463
- Args:
1464
- atom_numbers (int32): the number of atoms, n.
1465
- cutoff (float32): the square value of cutoff.
1466
- pme_beta (float32): PME beta parameter, same as operator PMEReciprocalForce().
1467
- need_update (int32): if need_update = 1, calculate the pressure, default 0.
1468
-
1469
- Inputs:
1470
- - **uint_crd** (Tensor) - The unsigned int coordinate value of each atom.
1471
- The data type is uint32 and the shape is :math:`(n, 3)`.
1472
- - **LJtype** (Tensor) - The Lennard-Jones type of each atom.
1473
- The data type is int32 and the shape is :math:`(n,)`.
1474
- - **charge** (Tensor) - The charge carried by each atom.
1475
- The data type is float32 and the shape is :math:`(n,)`.
1476
- - **scaler** (Tensor) - The scale factor between real
1477
- space coordinate and its unsigned int value.
1478
- The data type is float32 and the shape is :math:`(3,)`.
1479
- - **nl_numbers** (Tensor) - The each atom.
1480
- The data type is int32 and the shape is :math:`(n,)`.
1481
- - **nl_serial** (Tensor) - The neighbor list of each atom, the max number is 800.
1482
- The data type is int32 and the shape is :math:`(n, 800)`.
1483
- - **d_LJ_A** (Tensor) - The Lennard-Jones A coefficient of each kind of atom pair.
1484
- The number of atom pair is q. The data type is float32 and the shape is :math:`(q,)`.
1485
- - **d_LJ_B** (Tensor) - The Lennard-Jones B coefficient of each kind of atom pair.
1486
- The number of atom pair is q. The data type is float32 and the shape is :math:`(q,)`.
1487
- - **beta** (Tensor) - PME beta parameter. The data type is float32 and the shape is :math:`(1,)`.
1488
-
1489
- Outputs:
1490
- - **frc** (Tensor) - The force felt by each atom.
1491
- The data type is float32 and the shape is :math:`(n, 3)`.
1492
-
1493
- Supported Platforms:
1494
- ``GPU``
1495
- """
1496
-
1497
- @prim_attr_register
1498
- def __init__(self, atom_numbers, cutoff, pme_beta, need_update=0):
1499
- """Initialize LJForceWithPMEDirectForce"""
1500
- validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
1501
- validator.check_value_type('cutoff', cutoff, float, self.name)
1502
- validator.check_value_type('pme_beta', pme_beta, float, self.name)
1503
- validator.check_value_type('need_update', need_update, int, self.name)
1504
- self.atom_numbers = atom_numbers
1505
- self.cutoff = cutoff
1506
- self.pme_beta = pme_beta
1507
- self.need_update = need_update
1508
- self.init_prim_io_names(
1509
- inputs=['uint_crd', 'LJtype', 'charge', 'scaler', 'nl_numbers', 'nl_serial', 'd_LJ_A', 'd_LJ_B', 'beta'],
1510
- outputs=['frc'])
1511
- self.add_prim_attr('atom_numbers', self.atom_numbers)
1512
- self.add_prim_attr('cutoff', self.cutoff)
1513
- self.add_prim_attr('pme_beta', self.pme_beta)
1514
- self.add_prim_attr('need_update', self.need_update)
1515
-
1516
- def infer_shape(self, uint_crd, ljtype, charge, scaler, nl_numbers, nl_serial, d_lj_a, d_lj_b, beta):
1517
- cls_name = self.name
1518
- n = self.atom_numbers
1519
- q = d_lj_a[0]
1520
- m = nl_serial[1]
1521
- validator.check_int(len(uint_crd), 2, Rel.EQ, "uint_crd_dim", cls_name)
1522
- validator.check_int(len(ljtype), 1, Rel.EQ, "LJtype_dim", cls_name)
1523
- validator.check_int(len(charge), 1, Rel.EQ, "charge_dim", cls_name)
1524
- validator.check_int(len(scaler), 1, Rel.EQ, "scaler_dim", cls_name)
1525
- validator.check_int(len(nl_numbers), 1, Rel.EQ, "nl_numbers_dim", cls_name)
1526
- validator.check_int(len(nl_serial), 2, Rel.EQ, "nl_serial_dim", cls_name)
1527
- validator.check_int(len(d_lj_a), 1, Rel.EQ, "d_LJ_A_dim", cls_name)
1528
- validator.check_int(len(d_lj_b), 1, Rel.EQ, "d_LJ_B_dim", cls_name)
1529
- validator.check_int(len(beta), 1, Rel.EQ, "beta_dim", cls_name)
1530
-
1531
- validator.check_int(uint_crd[0], n, Rel.EQ, "uint_crd_shape[0]", cls_name)
1532
- validator.check_int(uint_crd[1], 3, Rel.EQ, "uint_crd_shape[1]", cls_name)
1533
- validator.check_int(ljtype[0], n, Rel.EQ, "LJtype_shape", cls_name)
1534
- validator.check_int(charge[0], n, Rel.EQ, "charge_shape", cls_name)
1535
- validator.check_int(scaler[0], 3, Rel.EQ, "scaler_shape", cls_name)
1536
- validator.check_int(nl_numbers[0], n, Rel.EQ, "nl_numbers_shape", cls_name)
1537
- validator.check_int(nl_serial[0], n, Rel.EQ, "nl_serial_shape[0]", cls_name)
1538
- validator.check_int(nl_serial[1], m, Rel.EQ, "nl_serial_shape[1]", cls_name)
1539
- validator.check_int(d_lj_a[0], q, Rel.EQ, "d_LJ_A_shape[0]", cls_name)
1540
- validator.check_int(d_lj_b[0], q, Rel.EQ, "d_LJ_B_shape[0]", cls_name)
1541
- validator.check_int(beta[0], 1, Rel.EQ, "beta_shape", cls_name)
1542
- return [n, 3]
1543
-
1544
- def infer_dtype(self, uint_crd, ljtype, charge, scaler, nl_numbers, nl_serial, d_lj_a, d_lj_b, beta):
1545
- validator.check_tensor_dtype_valid('uint_crd', uint_crd, [mstype.uint32], self.name)
1546
- validator.check_tensor_dtype_valid('LJtype', ljtype, [mstype.int32], self.name)
1547
- validator.check_tensor_dtype_valid('charge', charge, [mstype.float32], self.name)
1548
- validator.check_tensor_dtype_valid('scaler', scaler, [mstype.float32], self.name)
1549
- validator.check_tensor_dtype_valid('nl_numbers', nl_numbers, [mstype.int32], self.name)
1550
- validator.check_tensor_dtype_valid('nl_serial', nl_serial, [mstype.int32], self.name)
1551
- validator.check_tensor_dtype_valid('d_LJ_A', d_lj_a, [mstype.float32], self.name)
1552
- validator.check_tensor_dtype_valid('d_LJ_B', d_lj_b, [mstype.float32], self.name)
1553
- validator.check_tensor_dtype_valid('beta', beta, [mstype.float32], self.name)
1554
- return mstype.float32
1555
-
1556
-
1557
- class PMEReciprocalForceUpdate(PrimitiveWithInfer):
1558
- """
1559
- Calculate the reciprocal part of long-range Coulumb force using
1560
- PME(Particle Meshed Ewald) method for pressure. Assume the number of atoms is n.
1561
-
1562
- The detailed calculation formula of PME(Particle Meshed Ewald) method
1563
- can be found in this paper: A Smooth Particle Mesh Ewald Method. DOI:
1564
- 10.1063/1.470117.
1565
-
1566
- Because there is a large amount of inputs and each of them are related,
1567
- there is no way to construct `Examples` using random methods. For details, refer the webpage `SPONGE in MindSpore
1568
- <https://gitee.com/mindspore/mindscience/blob/master/MindSPONGE/docs/simple_formula.md>`_.
1569
-
1570
- .. warning::
1571
- This is an experimental prototype that is subject to change and/or deletion.
1572
-
1573
- Args:
1574
- atom_numbers (int32): the number of atoms, n.
1575
- beta (float32): the PME beta parameter, determined by the
1576
- non-bond cutoff value and simulation precision tolerance.
1577
- fftx (int32): the number of points for Fourier transform in dimension X.
1578
- ffty (int32): the number of points for Fourier transform in dimension Y.
1579
- fftz (int32): the number of points for Fourier transform in dimension Z.
1580
- box_length_0 (float32): the value of boxlength idx 0
1581
- box_length_1 (float32): the value of boxlength idx 1
1582
- box_length_2 (float32): the value of boxlength idx 2
1583
- need_update (int32): if need_update = 1, calculate the pressure, default 0.
1584
-
1585
- Inputs:
1586
- - **uint_crd** (Tensor) - The unsigned int coordinate value of each atom.
1587
- The data type is uint32 and the shape is :math:`(n, 3)`.
1588
- - **charge** (Tensor) - The charge carried by each atom.
1589
- The data type is float32 and the shape is :math:`(n,)`
1590
- - **beta** (Tensor) - The PME beta parameter to be updated in pressure calculation.
1591
- The data type is float32 and the shape is :math:`(1,)`
1592
-
1593
- Outputs:
1594
- - **force** (Tensor) - The force felt by each atom.
1595
- The data type is float32 and the shape is :math:`(n, 3)`
1596
-
1597
- Supported Platforms:
1598
- ``GPU``
1599
- """
1600
-
1601
- @prim_attr_register
1602
- def __init__(self, atom_numbers, beta, fftx, ffty, fftz,
1603
- box_length_0, box_length_1, box_length_2, need_update=0):
1604
- """Initialize PMEReciprocalForce"""
1605
- validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
1606
- validator.check_value_type('beta', beta, float, self.name)
1607
- validator.check_value_type('fftx', fftx, int, self.name)
1608
- validator.check_value_type('ffty', ffty, int, self.name)
1609
- validator.check_value_type('fftz', fftz, int, self.name)
1610
- validator.check_value_type('box_length_0', box_length_0, float, self.name)
1611
- validator.check_value_type('box_length_1', box_length_1, float, self.name)
1612
- validator.check_value_type('box_length_2', box_length_2, float, self.name)
1613
- validator.check_value_type('need_update', need_update, int, self.name)
1614
- self.atom_numbers = atom_numbers
1615
- self.beta = beta
1616
- self.fftx = fftx
1617
- self.ffty = ffty
1618
- self.fftz = fftz
1619
- self.box_length_0 = box_length_0
1620
- self.box_length_1 = box_length_1
1621
- self.box_length_2 = box_length_2
1622
- self.need_update = need_update
1623
-
1624
- self.init_prim_io_names(inputs=['uint_crd', 'charge', 'beta'],
1625
- outputs=['force'])
1626
- self.add_prim_attr('atom_numbers', self.atom_numbers)
1627
- self.add_prim_attr('beta', self.beta)
1628
- self.add_prim_attr('fftx', self.fftx)
1629
- self.add_prim_attr('ffty', self.ffty)
1630
- self.add_prim_attr('fftz', self.fftz)
1631
- self.add_prim_attr('box_length_0', self.box_length_0)
1632
- self.add_prim_attr('box_length_1', self.box_length_1)
1633
- self.add_prim_attr('box_length_2', self.box_length_2)
1634
- self.add_prim_attr('need_update', self.need_update)
1635
-
1636
- def infer_shape(self, uint_crd_shape, charge_shape, beta):
1637
- cls_name = self.name
1638
- n = self.atom_numbers
1639
- validator.check_int(len(uint_crd_shape), 2, Rel.EQ, "uint_crd_dim", cls_name)
1640
- validator.check_int(len(charge_shape), 1, Rel.EQ, "charge_dim", cls_name)
1641
- validator.check_int(len(beta), 1, Rel.EQ, "beta_dim", cls_name)
1642
- validator.check_int(uint_crd_shape[0], n, Rel.EQ, "uint_crd_shape[0]", cls_name)
1643
- validator.check_int(uint_crd_shape[1], 3, Rel.EQ, "uint_crd_shape[1]", cls_name)
1644
- validator.check_int(charge_shape[0], n, Rel.EQ, "charge_shape", cls_name)
1645
- validator.check_int(beta[0], 1, Rel.EQ, "beta_shape", cls_name)
1646
- return uint_crd_shape
1647
-
1648
- def infer_dtype(self, uint_crd_type, charge_type, beta):
1649
- validator.check_tensor_dtype_valid('uint_crd', uint_crd_type, [mstype.uint32], self.name)
1650
- validator.check_tensor_dtype_valid('charge', charge_type, [mstype.float32], self.name)
1651
- validator.check_tensor_dtype_valid('beta', beta, [mstype.float32], self.name)
1652
- return charge_type
1653
-
1654
-
1655
- class PMEExcludedForceUpdate(PrimitiveWithInfer):
1656
- """
1657
- Calculate the excluded part of long-range Coulumb force using
1658
- PME(Particle Meshed Ewald) method for pressure. Assume the number of atoms is
1659
- n, and the length of excluded list is E.
1660
-
1661
- Because there is a large amount of inputs and each of them are related,
1662
- there is no way to construct `Examples` using random methods. For details, refer the webpage `SPONGE in MindSpore
1663
- <https://gitee.com/mindspore/mindscience/blob/master/MindSPONGE/docs/simple_formula.md>`_.
1664
-
1665
- .. warning::
1666
- This is an experimental prototype that is subject to change and/or deletion.
1667
-
1668
- Args:
1669
- atom_numbers (int32): the number of atoms, n.
1670
- excluded_numbers (int32): the length of excluded list, E.
1671
- beta (float32): the PME beta parameter, determined by the
1672
- non-bond cutoff value and simulation precision tolerance.
1673
- need_update (int32): if need_update = 1, calculate the pressure, default 0.
1674
-
1675
- Inputs:
1676
- - **uint_crd** (Tensor) - The unsigned int coordinates value of each atom.
1677
- The data type is uint32 and the shape is :math:`(n, 3)`
1678
- - **scaler** (Tensor) - The scale factor between real space
1679
- coordinates and its unsigned int value. The data type is float32 and the shape is :math:`(3,)`
1680
- - **charge** (Tensor) - The charge carried by each atom.
1681
- The data type is float32 and the shape is :math:`(n,)`
1682
- - **excluded_list_start** (Tensor) - The start excluded index
1683
- in excluded list for each atom. The data type is int32 and the shape is :math:`(n,)`
1684
- - **excluded_list** (Tensor) - The contiguous join of excluded
1685
- list of each atom. E is the number of excluded atoms. The data type is int32 and the shape is :math:`(E,)`
1686
- - **excluded_atom_numbers** (Tensor) - The number of atom excluded
1687
- in excluded list for each atom. The data type is int32 and the shape is :math:`(n,)`
1688
- - **beta** (Tensor) - The PME beta parameter to be updated in pressure calculation.
1689
- The data type is float32 and the shape is :math:`(1,)`
1690
-
1691
- Outputs:
1692
- - **force** (Tensor) - The force felt by each atom.
1693
- The data type is float32 and the shape is :math:`(n, 3)`
1694
-
1695
- Supported Platforms:
1696
- ``GPU``
1697
- """
1698
-
1699
- @prim_attr_register
1700
- def __init__(self, atom_numbers, excluded_numbers, beta, need_update=0):
1701
- """Initialize PMEExcludedForce"""
1702
- validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
1703
- validator.check_value_type('excluded_numbers', excluded_numbers, int, self.name)
1704
- validator.check_value_type('beta', beta, float, self.name)
1705
- validator.check_value_type('need_update', need_update, int, self.name)
1706
- self.atom_numbers = atom_numbers
1707
- self.excluded_numbers = excluded_numbers
1708
- self.beta = beta
1709
- self.need_update = need_update
1710
- self.init_prim_io_names(
1711
- inputs=['uint_crd', 'scaler', 'charge', 'excluded_list_start', 'excluded_list',
1712
- 'excluded_atom_numbers', 'beta'],
1713
- outputs=['force'])
1714
- self.add_prim_attr('atom_numbers', self.atom_numbers)
1715
- self.add_prim_attr('excluded_numbers', self.excluded_numbers)
1716
- self.add_prim_attr('beta', self.beta)
1717
- self.add_prim_attr('need_update', self.need_update)
1718
-
1719
- def infer_shape(self, uint_crd_shape, scaler_shape, charge_shape, excluded_list_start_shape, excluded_list_shape,
1720
- excluded_atom_numbers_shape, beta):
1721
- cls_name = self.name
1722
- n = self.atom_numbers
1723
- e = self.excluded_numbers
1724
- validator.check_int(len(uint_crd_shape), 2, Rel.EQ, "uint_crd_dim", cls_name)
1725
- validator.check_int(len(scaler_shape), 1, Rel.EQ, "scaler_dim", cls_name)
1726
- validator.check_int(len(charge_shape), 1, Rel.EQ, "charge_dim", cls_name)
1727
- validator.check_int(len(excluded_list_start_shape), 1, Rel.EQ, "excluded_list_start_dim", cls_name)
1728
- validator.check_int(len(excluded_atom_numbers_shape), 1, Rel.EQ, "excluded_atom_numbers_dim", cls_name)
1729
- validator.check_int(len(excluded_list_shape), 1, Rel.EQ, "excluded_list_dim", cls_name)
1730
- validator.check_int(len(beta), 1, Rel.EQ, "beta_dim", cls_name)
1731
- validator.check_int(uint_crd_shape[0], n, Rel.EQ, "uint_crd_shape[0]", cls_name)
1732
- validator.check_int(uint_crd_shape[1], 3, Rel.EQ, "uint_crd_shape[1]", cls_name)
1733
- validator.check_int(scaler_shape[0], 3, Rel.EQ, "scaler_shape", cls_name)
1734
- validator.check_int(charge_shape[0], n, Rel.EQ, "charge_shape", cls_name)
1735
- validator.check_int(excluded_list_start_shape[0], n, Rel.EQ, "excluded_list_start_shape", cls_name)
1736
- validator.check_int(excluded_atom_numbers_shape[0], n, Rel.EQ, "excluded_atom_numbers_shape", cls_name)
1737
- validator.check_int(excluded_list_shape[0], e, Rel.EQ, "excluded_list_shape", cls_name)
1738
- validator.check_int(beta[0], 1, Rel.EQ, "beta_shape", cls_name)
1739
- return [n, 3]
1740
-
1741
- def infer_dtype(self, uint_crd_type, scaler_type, charge_type, excluded_list_start_type, excluded_list_type,
1742
- excluded_atom_numbers_type, beta):
1743
- validator.check_tensor_dtype_valid('scaler', scaler_type, [mstype.float32], self.name)
1744
- validator.check_tensor_dtype_valid('uint_crd', uint_crd_type, [mstype.uint32], self.name)
1745
- validator.check_tensor_dtype_valid('charge', charge_type, [mstype.float32], self.name)
1746
- validator.check_tensor_dtype_valid('excluded_list_start', excluded_list_start_type, [mstype.int32],
1747
- self.name)
1748
- validator.check_tensor_dtype_valid('excluded_list', excluded_list_type, [mstype.int32],
1749
- self.name)
1750
- validator.check_tensor_dtype_valid('excluded_atom_numbers', excluded_atom_numbers_type, [mstype.int32],
1751
- self.name)
1752
- validator.check_tensor_dtype_valid('beta', beta, [mstype.float32], self.name)
1753
- return mstype.float32
1754
-
1755
-
1756
- class LJForceWithVirialEnergyUpdate(PrimitiveWithInfer):
1757
- """
1758
- Calculate the Lennard-Jones force and PME direct force together for pressure.
1759
-
1760
- The calculation formula of Lennard-Jones part is the same as operator
1761
- LJForce(), and the PME direct part is within PME method.
1762
-
1763
- Because there is a large amount of inputs and each of them are related,
1764
- there is no way to construct `Examples` using random methods. For details, refer the webpage `SPONGE in MindSpore
1765
- <https://gitee.com/mindspore/mindscience/blob/master/MindSPONGE/docs/simple_formula.md>`_.
1766
-
1767
- .. warning::
1768
- This is an experimental prototype that is subject to change and/or deletion.
1769
-
1770
- Args:
1771
- atom_numbers (int32): the number of atoms, n.
1772
- cutoff (float32): the square value of cutoff.
1773
- pme_beta (float32): PME beta parameter, same as operator PMEReciprocalForce().
1774
- max_neighbor_numbers (int32): the max neighbor numbers, default 800.
1775
- need_update (int32): if need_update = 1, calculate the pressure, default 0.
1776
-
1777
- Inputs:
1778
- - **uint_crd** (Tensor) - The unsigned int coordinate value of each atom.
1779
- The data type is uint32 and the shape is :math:`(n, 3)`.
1780
- - **LJtype** (Tensor) - The Lennard-Jones type of each atom.
1781
- The data type is int32 and the shape is :math:`(n,)`.
1782
- - **charge** (Tensor) - The charge carried by each atom.
1783
- The data type is float32 and the shape is :math:`(n,)`.
1784
- - **scaler** (Tensor) - The scale factor.
1785
- The data type is float32 and the shape is :math:`(3,)`.
1786
- - **nl_numbers** (Tensor) - The each atom.
1787
- The data type is int32 and the shape is :math:`(n,)`.
1788
- - **nl_serial** (Tensor) - The neighbor list of each atom, the max number is 800.
1789
- The data type is int32 and the shape is :math:`(n, 800)`.
1790
- - **d_LJ_A** (Tensor) - The Lennard-Jones A coefficient of each kind of atom pair.
1791
- The number of atom pair is q. The data type is float32 and the shape is :math:`(q,)`.
1792
- - **d_LJ_B** (Tensor) - The Lennard-Jones B coefficient of each kind of atom pair.
1793
- The number of atom pair is q. The data type is float32 and the shape is :math:`(q,)`.
1794
- - **beta** (Tensor) - The PME beta parameter to be updated in pressure calculation.
1795
- The data type is float32 and the shape is :math:`(1,)`
1796
-
1797
- Outputs:
1798
- - **frc** (Tensor) - The force felt by each atom.
1799
- The data type is float32 and the shape is :math:`(n, 3)`.
1800
- - **virial** (Tensor) - The accumulated potential virial for each atom.
1801
- The data type is float32 and the shape is :math:`(n, )`.
1802
- - **atom_energy** (Tensor) - The accumulated potential energy for each atom.
1803
- The data type is float32 and the shape is :math:`(n, )`.
1804
-
1805
- Supported Platforms:
1806
- ``GPU``
1807
- """
1808
-
1809
- @prim_attr_register
1810
- def __init__(self, atom_numbers, cutoff, pme_beta, max_neighbor_numbers=800, need_update=0):
1811
- """Initialize LJForceWithPMEDirectForce"""
1812
- validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
1813
- validator.check_value_type('cutoff', cutoff, float, self.name)
1814
- validator.check_value_type('pme_beta', pme_beta, float, self.name)
1815
- validator.check_value_type('max_neighbor_numbers', max_neighbor_numbers, int, self.name)
1816
- validator.check_value_type('need_update', need_update, int, self.name)
1817
- self.atom_numbers = atom_numbers
1818
- self.cutoff = cutoff
1819
- self.pme_beta = pme_beta
1820
- self.max_neighbor_numbers = max_neighbor_numbers
1821
- self.need_update = need_update
1822
- self.init_prim_io_names(
1823
- inputs=['uint_crd', 'LJtype', 'charge', 'scaler', 'nl_numbers', 'nl_serial', 'd_LJ_A', 'd_LJ_B', 'beta'],
1824
- outputs=['frc', 'virial', 'atom_energy'])
1825
- self.add_prim_attr('atom_numbers', self.atom_numbers)
1826
- self.add_prim_attr('cutoff', self.cutoff)
1827
- self.add_prim_attr('pme_beta', self.pme_beta)
1828
- self.add_prim_attr('max_neighbor_numbers', self.max_neighbor_numbers)
1829
- self.add_prim_attr('need_update', self.need_update)
1830
-
1831
- def infer_shape(self, uint_crd, ljtype, charge, scaler, nl_numbers, nl_serial, d_lj_a, d_lj_b, beta):
1832
- cls_name = self.name
1833
- n = self.atom_numbers
1834
- q = d_lj_a[0]
1835
- m = self.max_neighbor_numbers
1836
- validator.check_int(len(uint_crd), 2, Rel.EQ, "uint_crd_dim", cls_name)
1837
- validator.check_int(len(ljtype), 1, Rel.EQ, "LJtype_dim", cls_name)
1838
- validator.check_int(len(charge), 1, Rel.EQ, "charge_dim", cls_name)
1839
- validator.check_int(len(scaler), 1, Rel.EQ, "scaler_dim", cls_name)
1840
- validator.check_int(len(nl_numbers), 1, Rel.EQ, "nl_numbers_dim", cls_name)
1841
- validator.check_int(len(nl_serial), 2, Rel.EQ, "nl_serial_dim", cls_name)
1842
- validator.check_int(len(d_lj_a), 1, Rel.EQ, "d_LJ_A_dim", cls_name)
1843
- validator.check_int(len(d_lj_b), 1, Rel.EQ, "d_LJ_B_dim", cls_name)
1844
- validator.check_int(len(beta), 1, Rel.EQ, "beta_dim", cls_name)
1845
- validator.check_int(uint_crd[0], n, Rel.EQ, "uint_crd_shape[0]", cls_name)
1846
- validator.check_int(uint_crd[1], 3, Rel.EQ, "uint_crd_shape[1]", cls_name)
1847
- validator.check_int(ljtype[0], n, Rel.EQ, "LJtype_shape", cls_name)
1848
- validator.check_int(charge[0], n, Rel.EQ, "charge_shape", cls_name)
1849
- validator.check_int(scaler[0], 3, Rel.EQ, "scaler_shape", cls_name)
1850
- validator.check_int(nl_numbers[0], n, Rel.EQ, "nl_numbers_shape", cls_name)
1851
- validator.check_int(nl_serial[0], n, Rel.EQ, "nl_serial_shape[0]", cls_name)
1852
- validator.check_int(nl_serial[1], m, Rel.EQ, "nl_serial_shape[1]", cls_name)
1853
- validator.check_int(d_lj_a[0], q, Rel.EQ, "d_LJ_A_shape[0]", cls_name)
1854
- validator.check_int(d_lj_b[0], q, Rel.EQ, "d_LJ_B_shape[0]", cls_name)
1855
- validator.check_int(beta[0], 1, Rel.EQ, "beta_shape[0]", cls_name)
1856
- return [n, 3], [n,], [n,]
1857
-
1858
- def infer_dtype(self, uint_crd, ljtype, charge, scaler, nl_numbers, nl_serial, d_lj_a, d_lj_b, beta):
1859
- validator.check_tensor_dtype_valid('uint_crd', uint_crd, [mstype.uint32], self.name)
1860
- validator.check_tensor_dtype_valid('LJtype', ljtype, [mstype.int32], self.name)
1861
- validator.check_tensor_dtype_valid('charge', charge, [mstype.float32], self.name)
1862
- validator.check_tensor_dtype_valid('scaler', scaler, [mstype.float32], self.name)
1863
- validator.check_tensor_dtype_valid('nl_numbers', nl_numbers, [mstype.int32], self.name)
1864
- validator.check_tensor_dtype_valid('nl_serial', nl_serial, [mstype.int32], self.name)
1865
- validator.check_tensor_dtype_valid('d_LJ_A', d_lj_a, [mstype.float32], self.name)
1866
- validator.check_tensor_dtype_valid('d_LJ_B', d_lj_b, [mstype.float32], self.name)
1867
- validator.check_tensor_dtype_valid('beta', beta, [mstype.float32], self.name)
1868
- return mstype.float32, mstype.float32, mstype.float32
1869
-
1870
-
1871
- class Dihedral14ForceWithAtomEnergyVirial(PrimitiveWithInfer):
1872
- """
1873
- Calculate the Lennard-Jones and Coulumb energy correction and force correction
1874
- for each necessary dihedral 1,4 terms together and add them to the total force
1875
- and potential energy for each atom.
1876
-
1877
- The calculation formula of force correction is the same as operator
1878
- :class:`Dihedral14LJForceWithDirectCF`, and the energy correction part is the same
1879
- as operator :class:`Dihedral14LJEnergy` and :class:`Dihedral14CFEnergy`.
1880
-
1881
- .. warning::
1882
- This is an experimental prototype that is subject to change and/or deletion.
1883
-
1884
- Args:
1885
- nb14_numbers (int32): the number of necessary dihedral 1,4 terms m.
1886
- atom_numbers (int32): the number of atoms n.
1887
-
1888
- Inputs:
1889
- - **uint_crd_f** (Tensor) - The unsigned int coordinate value of each atom.
1890
- The data type is uint32 and the shape is :math:`(n, 3)`.
1891
- - **LJtype** (Tensor) - The Lennard-Jones type of each atom.
1892
- The data type is int32 and the shape is :math:`(n,)`.
1893
- - **charge** (Tensor) - The charge carried by each atom.
1894
- The data type is float32 and the shape is :math:`(n,)`.
1895
- - **boxlength** (Tensor) - The length of molecular simulation box in 3 dimensions.
1896
- The data type is float32 and the shape is :math:`(3,)`.
1897
- - **a_14** (Tensor) - The first atom index of each dihedral 1,4 term.
1898
- The data type is int32 and the shape is :math:`(m,)`.
1899
- - **b_14** (Tensor) - The second atom index of each dihedral 1,4 term.
1900
- The data type is int32 and the shape is :math:`(m,)`.
1901
- - **lj_scale_factor** (Tensor) - The scale factor for the
1902
- Lennard-Jones part of force correction of each dihedral 1,4 term.
1903
- - **cf_scale_factor** (Tensor) - The scale factor for the Coulomb force.
1904
- The data type is float32 and the shape is :math:`(m,)`.
1905
- - **LJ_type_A** (Tensor) - The A parameter in Lennard-Jones scheme of each atom pair type.
1906
- The number of atom pair is q. The data type is float32 and the shape is :math:`(q,)`.
1907
- - **LJ_type_B** (Tensor) - The B parameter in Lennard-Jones scheme of each atom pair type.
1908
- The number of atom pair is q. The data type is float32 and the shape is :math:`(q,)`.
1909
-
1910
- Outputs:
1911
- - **frc** (Tensor) - The force felt by each atom.
1912
- The data type is float32 and the shape is :math:`(n, 3)`.
1913
- - **atom_energy** (Tensor) - The accumulated potential energy for each atom.
1914
- The data type is float32 and the shape is :math:`(n, )`.
1915
- - **atom_virial** (Tensor) - The accumulated potential virial for each atom.
1916
- The data type is float32 and the shape is :math:`(n, )`.
1917
-
1918
- Supported Platforms:
1919
- ``GPU``
1920
- """
1921
-
1922
- @prim_attr_register
1923
- def __init__(self, nb14_numbers, atom_numbers):
1924
- """Initialize Dihedral14LJCFForceWithAtomEnergy"""
1925
- validator.check_value_type('nb14_numbers', nb14_numbers, int, self.name)
1926
- validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
1927
- self.dihedral_14_numbers = nb14_numbers
1928
- self.atom_numbers = atom_numbers
1929
-
1930
- self.init_prim_io_names(
1931
- inputs=['uint_crd_f', 'LJtype', 'charge', 'boxlength', 'a_14', 'b_14', 'lj_scale_factor',
1932
- 'cf_scale_factor', 'LJ_type_A', 'LJ_type_B'],
1933
- outputs=['frc', 'atom_energy', 'atom_virial'])
1934
- self.add_prim_attr('dihedral_14_numbers', self.dihedral_14_numbers)
1935
- self.add_prim_attr('atom_numbers', self.atom_numbers)
1936
-
1937
- def infer_shape(self, uint_crd_f_shape, ljtype_shape, charge_shape, boxlength_f_shape, a_14_shape, b_14_shape,
1938
- lj_scale_factor_shape, cf_scale_factor_shape, lj_type_a_shape, lj_type_b_shape):
1939
- cls_name = self.name
1940
- n = self.atom_numbers
1941
- m = self.dihedral_14_numbers
1942
- q = lj_type_a_shape[0]
1943
- validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
1944
- validator.check_int(len(ljtype_shape), 1, Rel.EQ, "LJtype_dim", cls_name)
1945
- validator.check_int(len(charge_shape), 1, Rel.EQ, "charge_dim", cls_name)
1946
- validator.check_int(len(boxlength_f_shape), 1, Rel.EQ, "boxlength_f_dim", cls_name)
1947
- validator.check_int(len(a_14_shape), 1, Rel.EQ, "a_14_dim", cls_name)
1948
- validator.check_int(len(b_14_shape), 1, Rel.EQ, "b_14_dim", cls_name)
1949
- validator.check_int(len(lj_scale_factor_shape), 1, Rel.EQ, "lj_scale_factor_dim", cls_name)
1950
- validator.check_int(len(cf_scale_factor_shape), 1, Rel.EQ, "cf_scale_factor_dim", cls_name)
1951
- validator.check_int(len(lj_type_a_shape), 1, Rel.EQ, "LJ_type_A_dim", cls_name)
1952
- validator.check_int(len(lj_type_b_shape), 1, Rel.EQ, "LJ_type_B_dim", cls_name)
1953
-
1954
- validator.check_int(uint_crd_f_shape[0], n, Rel.EQ, "uint_crd_f_shape[0]", cls_name)
1955
- validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
1956
- validator.check_int(ljtype_shape[0], n, Rel.EQ, "LJtype_shape", cls_name)
1957
- validator.check_int(charge_shape[0], n, Rel.EQ, "charge_shape", cls_name)
1958
- validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f_shape", cls_name)
1959
- validator.check_int(lj_type_a_shape[0], q, Rel.EQ, "LJ_type_A_shape", cls_name)
1960
- validator.check_int(lj_type_b_shape[0], q, Rel.EQ, "LJ_type_B_shape", cls_name)
1961
- validator.check_int(a_14_shape[0], m, Rel.EQ, "a_14_shape", cls_name)
1962
- validator.check_int(b_14_shape[0], m, Rel.EQ, "b_14_shape", cls_name)
1963
- validator.check_int(lj_scale_factor_shape[0], m, Rel.EQ, "lj_scale_factor_shape", cls_name)
1964
- validator.check_int(cf_scale_factor_shape[0], m, Rel.EQ, "cf_scale_factor_shape", cls_name)
1965
- return [n, 3], [n,], [n,]
1966
-
1967
- def infer_dtype(self, uint_crd_f_dtype, ljtype_dtype, charge_dtype, boxlength_f_type, a_14_type, b_14_type,
1968
- lj_scale_factor_type, cf_scale_factor_type, lj_type_a_type, lj_type_b_type):
1969
- validator.check_tensor_dtype_valid('uint_crd_f', uint_crd_f_dtype, [mstype.uint32], self.name)
1970
- validator.check_tensor_dtype_valid('LJtype', ljtype_dtype, [mstype.int32], self.name)
1971
- validator.check_tensor_dtype_valid('charge', charge_dtype, [mstype.float32], self.name)
1972
- validator.check_tensor_dtype_valid('boxlength_f', boxlength_f_type, [mstype.float32], self.name)
1973
- validator.check_tensor_dtype_valid('a_14', a_14_type, [mstype.int32], self.name)
1974
- validator.check_tensor_dtype_valid('b_14', b_14_type, [mstype.int32], self.name)
1975
- validator.check_tensor_dtype_valid('lj_scale_factor', lj_scale_factor_type, [mstype.float32], self.name)
1976
- validator.check_tensor_dtype_valid('cf_scale_factor', cf_scale_factor_type, [mstype.float32], self.name)
1977
- validator.check_tensor_dtype_valid('LJ_type_A', lj_type_a_type, [mstype.float32], self.name)
1978
- validator.check_tensor_dtype_valid('LJ_type_B', lj_type_b_type, [mstype.float32], self.name)
1979
-
1980
- return mstype.float32, mstype.float32, mstype.float32
1981
-
1982
-
1983
- class PMEEnergyUpdate(PrimitiveWithInfer):
1984
- """
1985
- Calculate the Coulumb energy of the system using PME method for pressure.
1986
-
1987
- Because there is a large amount of inputs and each of them are related,
1988
- there is no way to construct `Examples` using random methods. For details, refer the webpage `SPONGE in MindSpore
1989
- <https://gitee.com/mindspore/mindscience/blob/master/MindSPONGE/docs/simple_formula.md>`_.
1990
-
1991
- .. warning::
1992
- This is an experimental prototype that is subject to change and/or deletion.
1993
-
1994
- Args:
1995
- atom_numbers (int32): the number of atoms, n.
1996
- excluded_numbers (int32): the length of excluded list, E.
1997
- beta (float32): the PME beta parameter, determined by the
1998
- non-bond cutoff value and simulation precision tolerance.
1999
- fftx (int32): the number of points for Fourier transform in dimension X.
2000
- ffty (int32): the number of points for Fourier transform in dimension Y.
2001
- fftz (int32): the number of points for Fourier transform in dimension Z.
2002
- box_length_0 (float32): the value of boxlength idx 0.
2003
- box_length_1 (float32): the value of boxlength idx 1.
2004
- box_length_2 (float32): the value of boxlength idx 2.
2005
- max_neighbor_numbers (int32): the max neighbor numbers, m, default 800.
2006
- need_update (int32): if need_update = 1, calculate the pressure, default 0.
2007
-
2008
- Inputs:
2009
- - **uint_crd** (Tensor) - The unsigned int coordinates value of each atom.
2010
- The data type is uint32 and the shape is :math:`(n, 3)`
2011
- - **charge** (Tensor) - The charge carried by each atom.
2012
- The data type is float32 and the shape is :math:`(n,)`
2013
- - **nl_numbers** - (Tensor) - The each atom.
2014
- The data type is int32 and the shape is :math:`(n, 3)`
2015
- - **nl_serial** - (Tensor) - The neighbor list of each atom, the max number is 800.
2016
- The data type is int32 and the shape is :math:`(n, m)`
2017
- - **scaler** (Tensor) - The scale factor between real space
2018
- coordinates and its unsigned int value. The data type is float32 and the shape is :math:`(3,)`
2019
- - **excluded_list_start** (Tensor) - The start excluded index
2020
- in excluded list for each atom. The data type is int32 and the shape is :math:`(n,)`
2021
- - **excluded_list** (Tensor) - The contiguous join of excluded
2022
- list of each atom. E is the number of excluded atoms. The data type is int32 and the shape is :math:`(E,)`
2023
- - **excluded_atom_numbers** (Tensor) - The number of atom excluded
2024
- in excluded list for each atom. The data type is int32 and the shape is :math:`(n,)`
2025
- - **factor** (Tensor) - The factor parameter to be updated in pressure calculation.
2026
- The data type is float32 and the shape is :math:`(1,)`
2027
- - **beta** (Tensor) - The PME beta parameter to be updated in pressure calculation.
2028
- The data type is float32 and the shape is :math:`(1,)`
2029
-
2030
- Outputs:
2031
- - **reciprocal_ene** (Tensor) - The reciprocal term of PME energy.
2032
- The data type is float32 and the the shape is :math:`(1,)`.
2033
- - **self_ene** (Tensor) - The self term of PME energy.
2034
- The data type is float32 and the the shape is :math:`(1,)`.
2035
- - **direct_ene** (Tensor) - The direct term of PME energy.
2036
- The data type is float32 and the the shape is :math:`(1,)`.
2037
- - **correction_ene** (Tensor) - The correction term of PME energy.
2038
- The data type is float32 and the the shape is :math:`(1,)`.
2039
-
2040
- Supported Platforms:
2041
- ``GPU``
2042
- """
2043
-
2044
- @prim_attr_register
2045
- def __init__(self, atom_numbers, excluded_numbers, beta, fftx, ffty, fftz, box_length_0, box_length_1,
2046
- box_length_2, max_neighbor_numbers=800, need_update=0):
2047
- """Initialize PMEEnergyUpdate"""
2048
- validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
2049
- validator.check_value_type('excluded_numbers', excluded_numbers, int, self.name)
2050
- validator.check_value_type('beta', beta, float, self.name)
2051
- validator.check_value_type('fftx', fftx, int, self.name)
2052
- validator.check_value_type('ffty', ffty, int, self.name)
2053
- validator.check_value_type('fftz', fftz, int, self.name)
2054
- validator.check_value_type('box_length_0', box_length_0, float, self.name)
2055
- validator.check_value_type('box_length_1', box_length_1, float, self.name)
2056
- validator.check_value_type('box_length_2', box_length_2, float, self.name)
2057
- validator.check_value_type('max_neighbor_numbers', max_neighbor_numbers, int, self.name)
2058
- validator.check_value_type('need_update', need_update, int, self.name)
2059
- self.atom_numbers = atom_numbers
2060
- self.excluded_numbers = excluded_numbers
2061
- self.beta = beta
2062
- self.fftx = fftx
2063
- self.ffty = ffty
2064
- self.fftz = fftz
2065
- self.box_length_0 = box_length_0
2066
- self.box_length_1 = box_length_1
2067
- self.box_length_2 = box_length_2
2068
- self.max_neighbor_numbers = max_neighbor_numbers
2069
- self.need_update = need_update
2070
- self.init_prim_io_names(
2071
- inputs=['uint_crd', 'charge', 'nl_numbers', 'nl_serial', 'scaler', 'excluded_list_start',
2072
- 'excluded_list', 'excluded_atom_numbers', 'factor', 'beta'],
2073
- outputs=['reciprocal_ene', 'self_ene', 'direct_ene', 'correction_ene'])
2074
- self.add_prim_attr('atom_numbers', self.atom_numbers)
2075
- self.add_prim_attr('excluded_numbers', self.excluded_numbers)
2076
- self.add_prim_attr('beta', self.beta)
2077
- self.add_prim_attr('fftx', self.fftx)
2078
- self.add_prim_attr('ffty', self.ffty)
2079
- self.add_prim_attr('fftz', self.fftz)
2080
- self.add_prim_attr('box_length_0', self.box_length_0)
2081
- self.add_prim_attr('box_length_1', self.box_length_1)
2082
- self.add_prim_attr('box_length_2', self.box_length_2)
2083
- self.add_prim_attr('max_neighbor_numbers', self.max_neighbor_numbers)
2084
- self.add_prim_attr('need_update', self.need_update)
2085
-
2086
- def infer_shape(self, uint_crd, charge, nl_numbers, nl_serial, scaler, excluded_list_start,
2087
- excluded_list, excluded_atom_numbers, factor, beta):
2088
- cls_name = self.name
2089
- n = self.atom_numbers
2090
- m = self.max_neighbor_numbers
2091
- e = self.excluded_numbers
2092
- validator.check_int(len(uint_crd), 2, Rel.EQ, "uint_crd_dim", cls_name)
2093
- validator.check_int(len(charge), 1, Rel.EQ, "charge_dim", cls_name)
2094
- validator.check_int(len(nl_numbers), 1, Rel.EQ, "nl_numbers_dim", cls_name)
2095
- validator.check_int(len(nl_serial), 2, Rel.EQ, "nl_serial_dim", cls_name)
2096
- validator.check_int(len(excluded_list_start), 1, Rel.EQ, "excluded_list_start_dim", cls_name)
2097
- validator.check_int(len(excluded_atom_numbers), 1, Rel.EQ, "excluded_atom_numbers_dim", cls_name)
2098
- validator.check_int(len(excluded_list), 1, Rel.EQ, "excluded_list_dim", cls_name)
2099
- validator.check_int(len(scaler), 1, Rel.EQ, "scaler_dim", cls_name)
2100
- validator.check_int(len(factor), 1, Rel.EQ, "factor_dim", cls_name)
2101
- validator.check_int(len(beta), 1, Rel.EQ, "beta_dim", cls_name)
2102
- validator.check_int(uint_crd[0], n, Rel.EQ, "uint_crd_shape[0]", cls_name)
2103
- validator.check_int(uint_crd[1], 3, Rel.EQ, "uint_crd_shape[1]", cls_name)
2104
- validator.check_int(charge[0], n, Rel.EQ, "charge_shape", cls_name)
2105
- validator.check_int(nl_numbers[0], n, Rel.EQ, "nl_numbers_shape[0]", cls_name)
2106
- validator.check_int(nl_serial[0], n, Rel.EQ, "nl_serial_shape[0]", cls_name)
2107
- validator.check_int(nl_serial[1], m, Rel.EQ, "nl_serial_shape[1]", cls_name)
2108
- validator.check_int(excluded_list_start[0], n, Rel.EQ, "excluded_list_start_shape", cls_name)
2109
- validator.check_int(excluded_atom_numbers[0], n, Rel.EQ, "excluded_atom_numbers_shape", cls_name)
2110
- validator.check_int(excluded_list[0], e, Rel.EQ, "excluded_list_shape", cls_name)
2111
- validator.check_int(factor[0], 1, Rel.EQ, "factor_shape", cls_name)
2112
- validator.check_int(beta[0], 1, Rel.EQ, "beta_shape", cls_name)
2113
- validator.check_int(scaler[0], 3, Rel.EQ, "scaler_shape", cls_name)
2114
- return [1,], [1,], [1,], [1,]
2115
-
2116
- def infer_dtype(self, uint_crd, charge, nl_numbers, nl_serial, scaler, excluded_list_start,
2117
- excluded_list, excluded_atom_numbers, factor, beta):
2118
- validator.check_tensor_dtype_valid('uint_crd', uint_crd, [mstype.uint32], self.name)
2119
- validator.check_tensor_dtype_valid('charge', charge, [mstype.float32], self.name)
2120
- validator.check_tensor_dtype_valid('nl_numbers', nl_numbers, [mstype.int32], self.name)
2121
- validator.check_tensor_dtype_valid('nl_serial', nl_serial, [mstype.int32], self.name)
2122
- validator.check_tensor_dtype_valid('scaler', scaler, [mstype.float32], self.name)
2123
- validator.check_tensor_dtype_valid('excluded_list_start', excluded_list_start, [mstype.int32],
2124
- self.name)
2125
- validator.check_tensor_dtype_valid('excluded_list', excluded_list, [mstype.int32],
2126
- self.name)
2127
- validator.check_tensor_dtype_valid('excluded_atom_numbers', excluded_atom_numbers, [mstype.int32],
2128
- self.name)
2129
- validator.check_tensor_dtype_valid('factor', factor, [mstype.float32], self.name)
2130
- validator.check_tensor_dtype_valid('beta', beta, [mstype.float32], self.name)
2131
- return charge, charge, charge, charge
2132
-
2133
-
2134
- class ConstrainForceCycle(PrimitiveWithInfer):
2135
- """
2136
- Calculate the constraint force in each iteration.
2137
-
2138
- .. warning::
2139
- This is an experimental prototype that is subject to change and/or deletion.
2140
-
2141
- Args:
2142
- atom_numbers (int32): the number of atoms n.
2143
- constrain_pair_numbers (int32): the number of constrain pairs m.
2144
-
2145
- Inputs:
2146
- - **uint_crd** (Tensor) - The unsigned int coordinate value of each atom.
2147
- The data type is uint32 and the shape is :math:`(n, 3)`.
2148
- - **scaler** (Tensor) - The 3-D scale factor (x, y, z),
2149
- The data type is float32 and the shape is :math:`(3,)`.
2150
- - **pair_dr** (Tensor) - The displacement vector of each constrained atom pair.
2151
- The data type is float32 and the shape is :math:`(m, 3)`.
2152
- - **atom_i_serials** (Tensor) - The first atom index of each constrained atom pair.
2153
- The data type is int32 and the shape is :math:`(m,)`.
2154
- - **atom_j_serials** (Tensor) - The second atom index of each constrained atom pair.
2155
- The data type is int32 and the shape is :math:`(m,)`.
2156
- - **constant_rs** (Tensor) - The constrained distance of each constrained atom pair.
2157
- The data type is float32 and the shape is :math:`(m,)`.
2158
- - **constrain_ks** (Tensor) - The coefficient of each constrained atom pair.
2159
- The data type is float32 and the shape is :math:`(m,)`.
2160
-
2161
- Outputs:
2162
- - **test_frc** (Tensor) - The constraint force.
2163
- The data type is float32 and the shape is :math:`(n, 3)`.
2164
-
2165
- Supported Platforms:
2166
- ``GPU``
2167
- """
2168
-
2169
- @prim_attr_register
2170
- def __init__(self, atom_numbers, constrain_pair_numbers):
2171
- validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
2172
- validator.check_value_type('constrain_pair_numbers', constrain_pair_numbers, int, self.name)
2173
- self.atom_numbers = atom_numbers
2174
- self.constrain_pair_numbers = constrain_pair_numbers
2175
- self.add_prim_attr('atom_numbers', self.atom_numbers)
2176
- self.add_prim_attr('constrain_pair_numbers', self.constrain_pair_numbers)
2177
- self.init_prim_io_names(
2178
- inputs=['uint_crd', 'scaler', 'pair_dr', 'atom_i_serials', 'atom_j_serials',
2179
- 'constant_rs', 'constrain_ks'],
2180
- outputs=['test_frc'])
2181
-
2182
- def infer_shape(self, uint_crd_shape, scaler_shape, pair_dr_shape, atom_i_serials_shape,
2183
- atom_j_serials_shape, constant_rs_shape, constrain_ks_shape):
2184
- cls_name = self.name
2185
- n = self.atom_numbers
2186
- m = self.constrain_pair_numbers
2187
- validator.check_int(len(uint_crd_shape), 2, Rel.EQ, "uint_crd_dim", cls_name)
2188
- validator.check_int(len(scaler_shape), 1, Rel.EQ, "scaler_dim", cls_name)
2189
- validator.check_int(len(pair_dr_shape), 2, Rel.EQ, "pair_dr_dim", cls_name)
2190
- validator.check_int(len(atom_i_serials_shape), 1, Rel.EQ, "atom_i_serials_dim", cls_name)
2191
- validator.check_int(len(atom_j_serials_shape), 1, Rel.EQ, "atom_j_serials_dim", cls_name)
2192
- validator.check_int(len(constant_rs_shape), 1, Rel.EQ, "constant_rs_dim", cls_name)
2193
- validator.check_int(len(constrain_ks_shape), 1, Rel.EQ, "constrain_ks_dim", cls_name)
2194
-
2195
- validator.check_int(uint_crd_shape[0], n, Rel.EQ, "uint_crd_shape[0]", cls_name)
2196
- validator.check_int(uint_crd_shape[1], 3, Rel.EQ, "uint_crd_shape[1]", cls_name)
2197
- validator.check_int(scaler_shape[0], 3, Rel.EQ, "scaler_shape", cls_name)
2198
- validator.check_int(pair_dr_shape[0], m, Rel.EQ, "pair_dr_shape[0]", cls_name)
2199
- validator.check_int(pair_dr_shape[1], 3, Rel.EQ, "pair_dr_shape[1]", cls_name)
2200
- validator.check_int(atom_i_serials_shape[0], m, Rel.EQ, "atom_i_serials_shape[0]", cls_name)
2201
- validator.check_int(atom_j_serials_shape[0], m, Rel.EQ, "atom_j_serials_shape[0]", cls_name)
2202
- validator.check_int(constant_rs_shape[0], m, Rel.EQ, "constant_rs_shape[0]", cls_name)
2203
- validator.check_int(constrain_ks_shape[0], m, Rel.EQ, "constrain_ks_shape[0]", cls_name)
2204
- return [n, 3]
2205
-
2206
- def infer_dtype(self, uint_crd_dtype, scaler_dtype, pair_dr_dtype, atom_i_serials_dtype,
2207
- atom_j_serials_dtype, constant_rs_dtype, constrain_ks_dtype):
2208
- validator.check_tensor_dtype_valid('uint_crd', uint_crd_dtype, [mstype.uint32], self.name)
2209
- validator.check_tensor_dtype_valid('scaler', scaler_dtype, [mstype.float32], self.name)
2210
- validator.check_tensor_dtype_valid('pair_dr', pair_dr_dtype, [mstype.float32], self.name)
2211
- validator.check_tensor_dtype_valid('atom_i_serials', atom_i_serials_dtype, [mstype.int32], self.name)
2212
- validator.check_tensor_dtype_valid('atom_j_serials', atom_j_serials_dtype, [mstype.int32], self.name)
2213
- validator.check_tensor_dtype_valid('constant_rs', constant_rs_dtype, [mstype.float32], self.name)
2214
- validator.check_tensor_dtype_valid('constrain_ks', constrain_ks_dtype, [mstype.float32], self.name)
2215
- return mstype.float32
2216
-
2217
-
2218
- class ConstrainForceVirial(PrimitiveWithInfer):
2219
- """
2220
- Calculate the constraint force and virial in a step with iteration numbers.
2221
-
2222
- .. warning::
2223
- This is an experimental prototype that is subject to change and/or deletion.
2224
-
2225
- Args:
2226
- atom_numbers (int32): the number of atoms n.
2227
- constrain_pair_numbers (int32): the number of constrain pairs m.
2228
- iteration_numbers (int32): the number of iteration numbers p.
2229
- half_exp_gamma_plus_half (float32): half exp_gamma plus half q.
2230
-
2231
- Inputs:
2232
- - **crd** (Tensor) - The coordinate of each atom.
2233
- The data type is float32 and the shape is :math:`(n, 3)`.
2234
- - **quarter_cof** (Tensor) - The 3-D scale factor.
2235
- The data type is float32 and the shape is :math:`(3,)`.
2236
- - **mass_inverse** (Tensor) - The inverse value of mass of each atom.
2237
- The data type is float32 and the shape is :math:`(n,)`.
2238
- - **scaler** (Tensor) - The 3-D scale factor (x, y, z),
2239
- The data type is float32 and the shape is :math:`(3,)`.
2240
- - **pair_dr** (Tensor) - The displacement vector of each constrained atom pair.
2241
- The data type is float32 and the shape is :math:`(m, 3)`.
2242
- - **atom_i_serials** (Tensor) - The first atom index of each constrained atom pair.
2243
- The data type is int32 and the shape is :math:`(m,)`.
2244
- - **atom_j_serials** (Tensor) - The second atom index of each constrained atom pair.
2245
- The data type is int32 and the shape is :math:`(m,)`.
2246
- - **constant_rs** (Tensor) - The constrained distance of each constrained atom pair.
2247
- The data type is float32 and the shape is :math:`(m,)`.
2248
- - **constrain_ks** (Tensor) - The coefficient of each constrained atom pair.
2249
- The data type is float32 and the shape is :math:`(m,)`.
2250
-
2251
- Outputs:
2252
- - **uint_crd** (Tensor) - The unsigned int coordinate value of each atom.
2253
- The data type is uint32 and the shape is :math:`(n, 3)`.
2254
- - **frc** (Tensor) - The force felt by each atom.
2255
- The data type is float32 and the shape is :math:`(n, 3)`.
2256
- - **virial** (Tensor) - The constraint virial on each atom.
2257
- The data type is float32 and the shape is :math:`(m,)`.
2258
-
2259
- Supported Platforms:
2260
- ``GPU``
2261
- """
2262
-
2263
- @prim_attr_register
2264
- def __init__(self, atom_numbers, constrain_pair_numbers, iteration_numbers, half_exp_gamma_plus_half):
2265
- validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
2266
- validator.check_value_type('constrain_pair_numbers', constrain_pair_numbers, int, self.name)
2267
- validator.check_value_type('iteration_numbers', iteration_numbers, int, self.name)
2268
- validator.check_value_type('half_exp_gamma_plus_half', half_exp_gamma_plus_half, float, self.name)
2269
- self.atom_numbers = atom_numbers
2270
- self.constrain_pair_numbers = constrain_pair_numbers
2271
- self.iteration_numbers = iteration_numbers
2272
- self.half_exp_gamma_plus_half = half_exp_gamma_plus_half
2273
- self.add_prim_attr('atom_numbers', self.atom_numbers)
2274
- self.add_prim_attr('constrain_pair_numbers', self.constrain_pair_numbers)
2275
- self.add_prim_attr('iteration_numbers', self.iteration_numbers)
2276
- self.add_prim_attr('half_exp_gamma_plus_half', self.half_exp_gamma_plus_half)
2277
-
2278
- self.init_prim_io_names(
2279
- inputs=['crd', 'quarter_cof', 'mass_inverse',
2280
- 'scaler', 'pair_dr', 'atom_i_serials', 'atom_j_serials',
2281
- 'constant_rs', 'constrain_ks'],
2282
- outputs=['uint_crd', 'frc', 'virial'])
2283
-
2284
- def infer_shape(self, crd, quarter_cof, mass_inverse, scaler_shape, pair_dr_shape, atom_i_serials_shape,
2285
- atom_j_serials_shape, constant_rs_shape, constrain_ks_shape):
2286
- cls_name = self.name
2287
- n = self.atom_numbers
2288
- m = self.constrain_pair_numbers
2289
- validator.check_int(len(crd), 2, Rel.EQ, "crd_dim", cls_name)
2290
- validator.check_int(len(quarter_cof), 1, Rel.EQ, "quarter_cof_dim", cls_name)
2291
- validator.check_int(len(mass_inverse), 1, Rel.EQ, "mass_inverse_dim", cls_name)
2292
- validator.check_int(len(scaler_shape), 1, Rel.EQ, "scaler_dim", cls_name)
2293
- validator.check_int(len(pair_dr_shape), 2, Rel.EQ, "pair_dr_dim", cls_name)
2294
- validator.check_int(len(atom_i_serials_shape), 1, Rel.EQ, "atom_i_serials_dim", cls_name)
2295
- validator.check_int(len(atom_j_serials_shape), 1, Rel.EQ, "atom_j_serials_dim", cls_name)
2296
- validator.check_int(len(constant_rs_shape), 1, Rel.EQ, "constant_rs_dim", cls_name)
2297
- validator.check_int(len(constrain_ks_shape), 1, Rel.EQ, "constrain_ks_dim", cls_name)
2298
- validator.check_int(crd[0], n, Rel.EQ, "crd_shape[0]", cls_name)
2299
- validator.check_int(crd[1], 3, Rel.EQ, "crd_shape[1]", cls_name)
2300
- validator.check_int(quarter_cof[0], 3, Rel.EQ, "quarter_cof_shape", cls_name)
2301
- validator.check_int(mass_inverse[0], n, Rel.EQ, "mass_inverse_shape", cls_name)
2302
- validator.check_int(scaler_shape[0], 3, Rel.EQ, "scaler_shape", cls_name)
2303
- validator.check_int(pair_dr_shape[0], m, Rel.EQ, "pair_dr_shape[0]", cls_name)
2304
- validator.check_int(pair_dr_shape[1], 3, Rel.EQ, "pair_dr_shape[1]", cls_name)
2305
- validator.check_int(atom_i_serials_shape[0], m, Rel.EQ, "atom_i_serials_shape[0]", cls_name)
2306
- validator.check_int(atom_j_serials_shape[0], m, Rel.EQ, "atom_j_serials_shape[0]", cls_name)
2307
- validator.check_int(constant_rs_shape[0], m, Rel.EQ, "constant_rs_shape[0]", cls_name)
2308
- validator.check_int(constrain_ks_shape[0], m, Rel.EQ, "constrain_ks_shape[0]", cls_name)
2309
- return [n, 3], [n, 3], [m,]
2310
-
2311
- def infer_dtype(self, crd, quarter_cof, mass_inverse, scaler_dtype, pair_dr_dtype, atom_i_serials_dtype,
2312
- atom_j_serials_dtype, constant_rs_dtype, constrain_ks_dtype):
2313
- validator.check_tensor_dtype_valid('crd', crd, [mstype.float32], self.name)
2314
- validator.check_tensor_dtype_valid('quarter_cof', quarter_cof, [mstype.float32], self.name)
2315
- validator.check_tensor_dtype_valid('mass_inverse', mass_inverse, [mstype.float32], self.name)
2316
- validator.check_tensor_dtype_valid('scaler', scaler_dtype, [mstype.float32], self.name)
2317
- validator.check_tensor_dtype_valid('pair_dr', pair_dr_dtype, [mstype.float32], self.name)
2318
- validator.check_tensor_dtype_valid('atom_i_serials', atom_i_serials_dtype, [mstype.int32], self.name)
2319
- validator.check_tensor_dtype_valid('atom_j_serials', atom_j_serials_dtype, [mstype.int32], self.name)
2320
- validator.check_tensor_dtype_valid('constant_rs', constant_rs_dtype, [mstype.float32], self.name)
2321
- validator.check_tensor_dtype_valid('constrain_ks', constrain_ks_dtype, [mstype.float32], self.name)
2322
- return mstype.uint32, mstype.float32, mstype.float32
2323
-
2324
-
2325
- class ConstrainForce(PrimitiveWithInfer):
2326
- """
2327
- Calculate the constraint force in a step with iteration numbers.
2328
-
2329
- .. warning::
2330
- This is an experimental prototype that is subject to change and/or deletion.
2331
-
2332
- Args:
2333
- atom_numbers (int32): the number of atoms n.
2334
- constrain_pair_numbers (int32): the number of constrain pairs m.
2335
- iteration_numbers (int32): the number of iteration numbers p.
2336
- half_exp_gamma_plus_half (float32): half exp_gamma plus half q.
2337
-
2338
- Inputs:
2339
- - **crd** (Tensor) - The coordinate of each atom.
2340
- The data type is float32 and the shape is :math:`(n, 3)`.
2341
- - **quarter_cof** (Tensor) - The 3-D scale factor.
2342
- The data type is float32 and the shape is :math:`(3,)`.
2343
- - **mass_inverse** (Tensor) - The inverse value of mass of each atom.
2344
- The data type is float32 and the shape is :math:`(n,)`.
2345
- - **scaler** (Tensor) - The 3-D scale factor (x, y, z),
2346
- The data type is float32 and the shape is :math:`(3,)`.
2347
- - **pair_dr** (Tensor) - The displacement vector of each constrained atom pair.
2348
- The data type is float32 and the shape is :math:`(m, 3)`.
2349
- - **atom_i_serials** (Tensor) - The first atom index of each constrained atom pair.
2350
- The data type is int32 and the shape is :math:`(m,)`.
2351
- - **atom_j_serials** (Tensor) - The second atom index of each constrained atom pair.
2352
- The data type is int32 and the shape is :math:`(m,)`.
2353
- - **constant_rs** (Tensor) - The constrained distance of each constrained atom pair.
2354
- The data type is float32 and the shape is :math:`(m,)`.
2355
- - **constrain_ks** (Tensor) - The coefficient of each constrained atom pair.
2356
- The data type is float32 and the shape is :math:`(m,)`.
2357
-
2358
- Outputs:
2359
- - **uint_crd** (Tensor) - The unsigned int coordinate value of each atom.
2360
- The data type is uint32 and the shape is :math:`(n, 3)`.
2361
- - **frc** (Tensor) - The constraint force on each atom.
2362
- The data type is float32 and the shape is :math:`(n, 3)`.
2363
- - **virial** (Tensor) - The constraint virial on each atom and it is zero.
2364
- The data type is float32 and the shape is :math:`(m,)`.
2365
-
2366
- Supported Platforms:
2367
- ``GPU``
2368
- """
2369
-
2370
- @prim_attr_register
2371
- def __init__(self, atom_numbers, constrain_pair_numbers, iteration_numbers, half_exp_gamma_plus_half):
2372
- validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
2373
- validator.check_value_type('constrain_pair_numbers', constrain_pair_numbers, int, self.name)
2374
- validator.check_value_type('iteration_numbers', iteration_numbers, int, self.name)
2375
- validator.check_value_type('half_exp_gamma_plus_half', half_exp_gamma_plus_half, float, self.name)
2376
- self.atom_numbers = atom_numbers
2377
- self.constrain_pair_numbers = constrain_pair_numbers
2378
- self.iteration_numbers = iteration_numbers
2379
- self.half_exp_gamma_plus_half = half_exp_gamma_plus_half
2380
- self.add_prim_attr('atom_numbers', self.atom_numbers)
2381
- self.add_prim_attr('constrain_pair_numbers', self.constrain_pair_numbers)
2382
- self.add_prim_attr('iteration_numbers', self.iteration_numbers)
2383
- self.add_prim_attr('half_exp_gamma_plus_half', self.half_exp_gamma_plus_half)
2384
-
2385
- self.init_prim_io_names(
2386
- inputs=['crd', 'quarter_cof', 'mass_inverse',
2387
- 'scaler', 'pair_dr', 'atom_i_serials', 'atom_j_serials', 'constant_rs', 'constrain_ks'],
2388
- outputs=['uint_crd', 'frc', 'virial'])
2389
-
2390
- def infer_shape(self, crd, quarter_cof, mass_inverse, scaler_shape, pair_dr_shape, atom_i_serials_shape,
2391
- atom_j_serials_shape, constant_rs_shape, constrain_ks_shape):
2392
- cls_name = self.name
2393
- n = self.atom_numbers
2394
- m = self.constrain_pair_numbers
2395
- validator.check_int(len(crd), 2, Rel.EQ, "crd_dim", cls_name)
2396
- validator.check_int(len(quarter_cof), 1, Rel.EQ, "quarter_cof_dim", cls_name)
2397
- validator.check_int(len(mass_inverse), 1, Rel.EQ, "mass_inverse_dim", cls_name)
2398
- validator.check_int(len(scaler_shape), 1, Rel.EQ, "scaler_dim", cls_name)
2399
- validator.check_int(len(pair_dr_shape), 2, Rel.EQ, "pair_dr_dim", cls_name)
2400
- validator.check_int(len(atom_i_serials_shape), 1, Rel.EQ, "atom_i_serials_dim", cls_name)
2401
- validator.check_int(len(atom_j_serials_shape), 1, Rel.EQ, "atom_j_serials_dim", cls_name)
2402
- validator.check_int(len(constant_rs_shape), 1, Rel.EQ, "constant_rs_dim", cls_name)
2403
- validator.check_int(len(constrain_ks_shape), 1, Rel.EQ, "constrain_ks_dim", cls_name)
2404
- validator.check_int(crd[0], n, Rel.EQ, "crd_shape[0]", cls_name)
2405
- validator.check_int(crd[1], 3, Rel.EQ, "crd_shape[1]", cls_name)
2406
- validator.check_int(quarter_cof[0], 3, Rel.EQ, "quarter_cof_shape", cls_name)
2407
- validator.check_int(mass_inverse[0], n, Rel.EQ, "mass_inverse_shape", cls_name)
2408
- validator.check_int(scaler_shape[0], 3, Rel.EQ, "scaler_shape", cls_name)
2409
- validator.check_int(pair_dr_shape[0], m, Rel.EQ, "pair_dr_shape[0]", cls_name)
2410
- validator.check_int(pair_dr_shape[1], 3, Rel.EQ, "pair_dr_shape[1]", cls_name)
2411
- validator.check_int(atom_i_serials_shape[0], m, Rel.EQ, "atom_i_serials_shape[0]", cls_name)
2412
- validator.check_int(atom_j_serials_shape[0], m, Rel.EQ, "atom_j_serials_shape[0]", cls_name)
2413
- validator.check_int(constant_rs_shape[0], m, Rel.EQ, "constant_rs_shape[0]", cls_name)
2414
- validator.check_int(constrain_ks_shape[0], m, Rel.EQ, "constrain_ks_shape[0]", cls_name)
2415
- return [n, 3], [n, 3], [m,]
2416
-
2417
- def infer_dtype(self, crd, quarter_cof, mass_inverse, scaler_dtype, pair_dr_dtype, atom_i_serials_dtype,
2418
- atom_j_serials_dtype, constant_rs_dtype, constrain_ks_dtype):
2419
- validator.check_tensor_dtype_valid('crd', crd, [mstype.float32], self.name)
2420
- validator.check_tensor_dtype_valid('quarter_cof', quarter_cof, [mstype.float32], self.name)
2421
- validator.check_tensor_dtype_valid('mass_inverse', mass_inverse, [mstype.float32], self.name)
2422
- validator.check_tensor_dtype_valid('scaler', scaler_dtype, [mstype.float32], self.name)
2423
- validator.check_tensor_dtype_valid('pair_dr', pair_dr_dtype, [mstype.float32], self.name)
2424
- validator.check_tensor_dtype_valid('atom_i_serials', atom_i_serials_dtype, [mstype.int32], self.name)
2425
- validator.check_tensor_dtype_valid('atom_j_serials', atom_j_serials_dtype, [mstype.int32], self.name)
2426
- validator.check_tensor_dtype_valid('constant_rs', constant_rs_dtype, [mstype.float32], self.name)
2427
- validator.check_tensor_dtype_valid('constrain_ks', constrain_ks_dtype, [mstype.float32], self.name)
2428
- return mstype.uint32, mstype.float32, mstype.float32
2429
-
2430
-
2431
- class Constrain(PrimitiveWithInfer):
2432
- """
2433
- Calculate the constraint force and virial depends on pressure calculation.
2434
-
2435
- .. warning::
2436
- This is an experimental prototype that is subject to change and/or deletion.
2437
-
2438
- Args:
2439
- atom_numbers (int32): the number of atoms n.
2440
- constrain_pair_numbers (int32): the number of constrain pairs m.
2441
- iteration_numbers (int32): the number of iteration numbers p.
2442
- half_exp_gamma_plus_half (float32): half exp_gamma plus half q.
2443
- update_interval (int32): the number of update interval, default 10.
2444
-
2445
- Inputs:
2446
- - **crd** (Tensor) - The coordinate of each atom.
2447
- The data type is float32 and the shape is :math:`(n, 3)`.
2448
- - **quarter_cof** (Tensor) - The 3-D scale factor.
2449
- The data type is float32 and the shape is :math:`(3,)`.
2450
- - **mass_inverse** (Tensor) - The inverse value of mass of each atom.
2451
- The data type is float32 and the shape is :math:`(n,)`.
2452
- - **scaler** (Tensor) - The 3-D scale factor (x, y, z),
2453
- The data type is float32 and the shape is :math:`(3,)`.
2454
- - **pair_dr** (Tensor) - The displacement vector of each constrained atom pair.
2455
- The data type is float32 and the shape is :math:`(m, 3)`.
2456
- - **atom_i_serials** (Tensor) - The first atom index of each constrained atom pair.
2457
- The data type is int32 and the shape is :math:`(m,)`.
2458
- - **atom_j_serials** (Tensor) - The second atom index of each constrained atom pair.
2459
- The data type is int32 and the shape is :math:`(m,)`.
2460
- - **constant_rs** (Tensor) - The constrained distance of each constrained atom pair.
2461
- The data type is float32 and the shape is :math:`(m,)`.
2462
- - **constrain_ks** (Tensor) - The coefficient of each constrained atom pair.
2463
- The data type is float32 and the shape is :math:`(m,)`.
2464
- - **need_pressure** (Tensor) - If need pressure, 1 else 0.
2465
- The data type is int32 and the shape is :math:`(1,)` or :math:`()`.
2466
-
2467
- Outputs:
2468
- - **uint_crd** (Tensor) - The unsigned int coordinate value of each atom.
2469
- The data type is uint32 and the shape is :math:`(n, 3)`.
2470
- - **frc** (Tensor) - The constraint force on each atom.
2471
- The data type is float32 and the shape is :math:`(n, 3)`.
2472
- - **virial** (Tensor) - The constraint virial on each atom.
2473
- The data type is float32 and the shape is :math:`(m,)`.
2474
-
2475
- Supported Platforms:
2476
- ``GPU``
2477
- """
2478
-
2479
- @prim_attr_register
2480
- def __init__(self, atom_numbers, constrain_pair_numbers, iteration_numbers, half_exp_gamma_plus_half,
2481
- update_interval=10):
2482
- validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
2483
- validator.check_value_type('constrain_pair_numbers', constrain_pair_numbers, int, self.name)
2484
- validator.check_value_type('iteration_numbers', iteration_numbers, int, self.name)
2485
- validator.check_value_type('half_exp_gamma_plus_half', half_exp_gamma_plus_half, float, self.name)
2486
- validator.check_value_type('update_interval', update_interval, int, self.name)
2487
- self.atom_numbers = atom_numbers
2488
- self.constrain_pair_numbers = constrain_pair_numbers
2489
- self.iteration_numbers = iteration_numbers
2490
- self.half_exp_gamma_plus_half = half_exp_gamma_plus_half
2491
- self.update_interval = update_interval
2492
- self.add_prim_attr('atom_numbers', self.atom_numbers)
2493
- self.add_prim_attr('constrain_pair_numbers', self.constrain_pair_numbers)
2494
- self.add_prim_attr('iteration_numbers', self.iteration_numbers)
2495
- self.add_prim_attr('half_exp_gamma_plus_half', self.half_exp_gamma_plus_half)
2496
- self.add_prim_attr('update_interval', self.update_interval)
2497
-
2498
- self.init_prim_io_names(
2499
- inputs=['crd', 'quarter_cof', 'mass_inverse',
2500
- 'scaler', 'pair_dr', 'atom_i_serials', 'atom_j_serials',
2501
- 'constant_rs', 'constrain_ks', 'need_pressure'],
2502
- outputs=['uint_crd', 'frc', 'virial'])
2503
-
2504
- def infer_shape(self, crd, quarter_cof, mass_inverse, scaler_shape, pair_dr_shape, atom_i_serials_shape,
2505
- atom_j_serials_shape, constant_rs_shape, constrain_ks_shape, need_pressure):
2506
- cls_name = self.name
2507
- n = self.atom_numbers
2508
- m = self.constrain_pair_numbers
2509
- validator.check_int(len(crd), 2, Rel.EQ, "crd_dim", cls_name)
2510
- validator.check_int(len(quarter_cof), 1, Rel.EQ, "quarter_cof_dim", cls_name)
2511
- validator.check_int(len(mass_inverse), 1, Rel.EQ, "mass_inverse_dim", cls_name)
2512
- validator.check_int(len(scaler_shape), 1, Rel.EQ, "scaler_dim", cls_name)
2513
- validator.check_int(len(pair_dr_shape), 2, Rel.EQ, "pair_dr_dim", cls_name)
2514
- validator.check_int(len(atom_i_serials_shape), 1, Rel.EQ, "atom_i_serials_dim", cls_name)
2515
- validator.check_int(len(atom_j_serials_shape), 1, Rel.EQ, "atom_j_serials_dim", cls_name)
2516
- validator.check_int(len(constant_rs_shape), 1, Rel.EQ, "constant_rs_dim", cls_name)
2517
- validator.check_int(len(constrain_ks_shape), 1, Rel.EQ, "constrain_ks_dim", cls_name)
2518
- validator.check_int(len(need_pressure), 1, Rel.LE, "need_pressure_dim", cls_name)
2519
- validator.check_int(crd[0], n, Rel.EQ, "crd_shape[0]", cls_name)
2520
- validator.check_int(crd[1], 3, Rel.EQ, "crd_shape[1]", cls_name)
2521
- validator.check_int(quarter_cof[0], 3, Rel.EQ, "quarter_cof_shape", cls_name)
2522
- validator.check_int(mass_inverse[0], n, Rel.EQ, "mass_inverse_shape", cls_name)
2523
- validator.check_int(scaler_shape[0], 3, Rel.EQ, "scaler_shape", cls_name)
2524
- validator.check_int(pair_dr_shape[0], m, Rel.EQ, "pair_dr_shape[0]", cls_name)
2525
- validator.check_int(pair_dr_shape[1], 3, Rel.EQ, "pair_dr_shape[1]", cls_name)
2526
- validator.check_int(atom_i_serials_shape[0], m, Rel.EQ, "atom_i_serials_shape[0]", cls_name)
2527
- validator.check_int(atom_j_serials_shape[0], m, Rel.EQ, "atom_j_serials_shape[0]", cls_name)
2528
- validator.check_int(constant_rs_shape[0], m, Rel.EQ, "constant_rs_shape[0]", cls_name)
2529
- validator.check_int(constrain_ks_shape[0], m, Rel.EQ, "constrain_ks_shape[0]", cls_name)
2530
- if need_pressure:
2531
- validator.check_int(need_pressure[0], 1, Rel.EQ, "need_pressure_shape", self.name)
2532
- return [n, 3], [n, 3], [m,]
2533
-
2534
- def infer_dtype(self, crd, quarter_cof, mass_inverse, scaler_dtype, pair_dr_dtype, atom_i_serials_dtype,
2535
- atom_j_serials_dtype, constant_rs_dtype, constrain_ks_dtype, need_pressure):
2536
- validator.check_tensor_dtype_valid('crd', crd, [mstype.float32], self.name)
2537
- validator.check_tensor_dtype_valid('quarter_cof', quarter_cof, [mstype.float32], self.name)
2538
- validator.check_tensor_dtype_valid('mass_inverse', mass_inverse, [mstype.float32], self.name)
2539
- validator.check_tensor_dtype_valid('scaler', scaler_dtype, [mstype.float32], self.name)
2540
- validator.check_tensor_dtype_valid('pair_dr', pair_dr_dtype, [mstype.float32], self.name)
2541
- validator.check_tensor_dtype_valid('atom_i_serials', atom_i_serials_dtype, [mstype.int32], self.name)
2542
- validator.check_tensor_dtype_valid('atom_j_serials', atom_j_serials_dtype, [mstype.int32], self.name)
2543
- validator.check_tensor_dtype_valid('constant_rs', constant_rs_dtype, [mstype.float32], self.name)
2544
- validator.check_tensor_dtype_valid('constrain_ks', constrain_ks_dtype, [mstype.float32], self.name)
2545
- validator.check_tensor_dtype_valid('need_pressure', need_pressure, [mstype.int32], self.name)
2546
- return mstype.uint32, mstype.float32, mstype.float32