mindspore 1.10.0__cp37-cp37m-win_amd64.whl → 2.0.0rc1__cp37-cp37m-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (966) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/ConcurrencyCheck.dll +0 -0
  3. mindspore/CppBuildInsights.dll +0 -0
  4. mindspore/CppCoreCheck.dll +0 -0
  5. mindspore/EnumIndex.dll +0 -0
  6. mindspore/EspXEngine.dll +0 -0
  7. mindspore/HResultCheck.dll +0 -0
  8. mindspore/KernelTraceControl.dll +0 -0
  9. mindspore/LocalESPC.dll +0 -0
  10. mindspore/Microsoft.Diagnostics.Tracing.EventSource.dll +0 -0
  11. mindspore/Microsoft.VisualStudio.RemoteControl.dll +0 -0
  12. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  13. mindspore/Microsoft.VisualStudio.Utilities.Internal.dll +0 -0
  14. mindspore/Newtonsoft.Json.dll +0 -0
  15. mindspore/System.Runtime.CompilerServices.Unsafe.dll +0 -0
  16. mindspore/VariantClear.dll +0 -0
  17. mindspore/__init__.py +9 -4
  18. mindspore/_c_dataengine.cp37-win_amd64.pyd +0 -0
  19. mindspore/_c_expression.cp37-win_amd64.pyd +0 -0
  20. mindspore/_c_mindrecord.cp37-win_amd64.pyd +0 -0
  21. mindspore/_check_jit_forbidden_api.py +102 -0
  22. mindspore/_checkparam.py +1066 -1001
  23. mindspore/_extends/builtin_operations.py +32 -4
  24. mindspore/_extends/graph_kernel/model/graph_split.py +66 -222
  25. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +12 -9
  26. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +119 -26
  27. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +50 -50
  28. mindspore/_extends/parallel_compile/akg_compiler/util.py +9 -6
  29. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +4 -25
  30. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +9 -4
  31. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -27
  32. mindspore/_extends/parse/__init__.py +5 -3
  33. mindspore/_extends/parse/namespace.py +17 -2
  34. mindspore/_extends/parse/parser.py +193 -34
  35. mindspore/_extends/parse/resources.py +7 -8
  36. mindspore/_extends/parse/standard_method.py +1780 -435
  37. mindspore/_extends/parse/trope.py +3 -1
  38. mindspore/amp.py +53 -58
  39. mindspore/atlprov.dll +0 -0
  40. mindspore/boost/adasum.py +3 -2
  41. mindspore/boost/boost.py +2 -2
  42. mindspore/boost/boost_cell_wrapper.py +46 -26
  43. mindspore/boost/dim_reduce.py +6 -5
  44. mindspore/boost/grad_accumulation.py +2 -1
  45. mindspore/boost/group_loss_scale_manager.py +1 -1
  46. mindspore/c1.dll +0 -0
  47. mindspore/c1xx.dll +0 -0
  48. mindspore/c2.dll +0 -0
  49. mindspore/cfgpersist.dll +0 -0
  50. mindspore/clang_rt.asan_dbg_dynamic-x86_64.dll +0 -0
  51. mindspore/clang_rt.asan_dynamic-x86_64.dll +0 -0
  52. mindspore/common/__init__.py +11 -10
  53. mindspore/common/_decorator.py +2 -0
  54. mindspore/common/_register_for_adapter.py +55 -0
  55. mindspore/common/_stub_tensor.py +201 -0
  56. mindspore/common/_utils.py +57 -0
  57. mindspore/common/api.py +582 -297
  58. mindspore/common/dtype.py +66 -18
  59. mindspore/common/dump.py +2 -2
  60. mindspore/common/initializer.py +38 -1
  61. mindspore/common/jit_config.py +25 -13
  62. mindspore/common/mutable.py +53 -24
  63. mindspore/common/parameter.py +60 -37
  64. mindspore/common/seed.py +8 -24
  65. mindspore/common/sparse_tensor.py +927 -0
  66. mindspore/common/tensor.py +1627 -3900
  67. mindspore/communication/__init__.py +10 -5
  68. mindspore/communication/_comm_helper.py +78 -214
  69. mindspore/communication/_hccl_management.py +2 -1
  70. mindspore/communication/management.py +136 -47
  71. mindspore/config/op_info.config +501 -1008
  72. mindspore/context.py +291 -56
  73. mindspore/d3dcompiler_47.dll +0 -0
  74. mindspore/dataset/__init__.py +12 -8
  75. mindspore/dataset/audio/__init__.py +9 -9
  76. mindspore/dataset/audio/transforms.py +1090 -228
  77. mindspore/dataset/audio/utils.py +87 -39
  78. mindspore/dataset/audio/validators.py +223 -1
  79. mindspore/dataset/callback/ds_callback.py +17 -15
  80. mindspore/dataset/core/config.py +246 -17
  81. mindspore/dataset/core/py_util_helpers.py +4 -3
  82. mindspore/dataset/core/validator_helpers.py +10 -10
  83. mindspore/{parallel/nn/layers.py → dataset/debug/__init__.py} +7 -8
  84. mindspore/dataset/debug/debug_hook.py +65 -0
  85. mindspore/dataset/debug/pre_defined_hook.py +67 -0
  86. mindspore/dataset/engine/__init__.py +7 -3
  87. mindspore/dataset/engine/cache_client.py +9 -9
  88. mindspore/dataset/engine/datasets.py +648 -477
  89. mindspore/dataset/engine/datasets_audio.py +165 -167
  90. mindspore/dataset/engine/datasets_standard_format.py +93 -67
  91. mindspore/dataset/engine/datasets_text.py +492 -342
  92. mindspore/dataset/engine/datasets_user_defined.py +85 -50
  93. mindspore/dataset/engine/datasets_vision.py +1224 -699
  94. mindspore/dataset/engine/graphdata.py +134 -69
  95. mindspore/dataset/engine/iterators.py +50 -9
  96. mindspore/dataset/engine/offload.py +52 -31
  97. mindspore/dataset/engine/samplers.py +27 -24
  98. mindspore/dataset/engine/serializer_deserializer.py +14 -15
  99. mindspore/dataset/engine/validators.py +213 -52
  100. mindspore/dataset/text/__init__.py +10 -8
  101. mindspore/dataset/text/transforms.py +152 -57
  102. mindspore/dataset/text/utils.py +98 -49
  103. mindspore/dataset/text/validators.py +25 -0
  104. mindspore/dataset/transforms/__init__.py +4 -2
  105. mindspore/dataset/transforms/c_transforms.py +11 -13
  106. mindspore/dataset/transforms/py_transforms.py +2 -2
  107. mindspore/dataset/transforms/py_transforms_util.py +10 -0
  108. mindspore/dataset/transforms/transforms.py +13 -15
  109. mindspore/dataset/transforms/validators.py +7 -7
  110. mindspore/dataset/utils/__init__.py +2 -1
  111. mindspore/dataset/utils/browse_dataset.py +13 -13
  112. mindspore/dataset/utils/line_reader.py +121 -0
  113. mindspore/dataset/vision/__init__.py +8 -7
  114. mindspore/dataset/vision/c_transforms.py +125 -126
  115. mindspore/dataset/vision/py_transforms.py +37 -37
  116. mindspore/dataset/vision/py_transforms_util.py +23 -20
  117. mindspore/dataset/vision/transforms.py +316 -315
  118. mindspore/dataset/vision/utils.py +313 -17
  119. mindspore/dataset/vision/validators.py +6 -6
  120. mindspore/default_config.py +0 -1
  121. mindspore/dpcmi.dll +0 -0
  122. mindspore/{compression → experimental}/__init__.py +6 -5
  123. mindspore/experimental/map_parameter.py +275 -0
  124. mindspore/include/OWNERS +0 -1
  125. mindspore/include/api/callback/callback.h +9 -13
  126. mindspore/include/api/callback/ckpt_saver.h +2 -2
  127. mindspore/include/api/callback/loss_monitor.h +2 -2
  128. mindspore/include/api/callback/lr_scheduler.h +5 -5
  129. mindspore/include/api/callback/time_monitor.h +2 -2
  130. mindspore/include/api/callback/train_accuracy.h +4 -6
  131. mindspore/include/api/cfg.h +19 -6
  132. mindspore/include/api/context.h +70 -9
  133. mindspore/include/api/delegate.h +8 -1
  134. mindspore/include/api/dual_abi_helper.h +8 -24
  135. mindspore/include/api/metrics/accuracy.h +2 -2
  136. mindspore/include/api/metrics/metrics.h +4 -3
  137. mindspore/include/api/model.h +9 -4
  138. mindspore/include/api/model_group.h +68 -0
  139. mindspore/include/api/model_parallel_runner.h +17 -17
  140. mindspore/include/api/net.h +12 -11
  141. mindspore/include/api/serialization.h +20 -4
  142. mindspore/include/api/status.h +7 -1
  143. mindspore/include/api/types.h +25 -21
  144. mindspore/include/api/visible.h +4 -0
  145. mindspore/include/c_api/model_c.h +5 -0
  146. mindspore/include/c_api/status_c.h +1 -1
  147. mindspore/include/dataset/config.h +1 -1
  148. mindspore/include/dataset/constants.h +14 -0
  149. mindspore/include/dataset/text.h +59 -0
  150. mindspore/include/dataset/vision.h +56 -117
  151. mindspore/include/dataset/vision_lite.h +102 -0
  152. mindspore/jpeg62.dll +0 -0
  153. mindspore/log.py +28 -28
  154. mindspore/mindrecord/common/exceptions.py +2 -4
  155. mindspore/mindrecord/filereader.py +19 -1
  156. mindspore/mindrecord/filewriter.py +250 -88
  157. mindspore/mindrecord/mindpage.py +13 -13
  158. mindspore/mindrecord/shardheader.py +15 -15
  159. mindspore/mindrecord/shardreader.py +9 -0
  160. mindspore/mindrecord/shardwriter.py +29 -29
  161. mindspore/mindrecord/tools/cifar100_to_mr.py +9 -9
  162. mindspore/mindrecord/tools/cifar10_to_mr.py +9 -9
  163. mindspore/mindrecord/tools/csv_to_mr.py +4 -4
  164. mindspore/mindrecord/tools/imagenet_to_mr.py +70 -65
  165. mindspore/mindrecord/tools/mnist_to_mr.py +41 -41
  166. mindspore/mindrecord/tools/tfrecord_to_mr.py +6 -6
  167. mindspore/{libmindspore_backend.dll → mindspore_backend.dll} +0 -0
  168. mindspore/mindspore_common.dll +0 -0
  169. mindspore/mindspore_core.dll +0 -0
  170. mindspore/mindspore_glog.dll +0 -0
  171. mindspore/mindspore_shared_lib.dll +0 -0
  172. mindspore/msobj140.dll +0 -0
  173. mindspore/mspdb140.dll +0 -0
  174. mindspore/mspdbcore.dll +0 -0
  175. mindspore/mspdbst.dll +0 -0
  176. mindspore/mspft140.dll +0 -0
  177. mindspore/msvcdis140.dll +0 -0
  178. mindspore/msvcp140_1.dll +0 -0
  179. mindspore/msvcp140_2.dll +0 -0
  180. mindspore/msvcp140_atomic_wait.dll +0 -0
  181. mindspore/msvcp140_codecvt_ids.dll +0 -0
  182. mindspore/nn/__init__.py +1 -5
  183. mindspore/nn/cell.py +297 -234
  184. mindspore/nn/dynamic_lr.py +1 -1
  185. mindspore/nn/grad/cell_grad.py +17 -42
  186. mindspore/nn/layer/__init__.py +7 -4
  187. mindspore/nn/layer/activation.py +131 -88
  188. mindspore/nn/layer/basic.py +313 -613
  189. mindspore/nn/layer/channel_shuffle.py +103 -0
  190. mindspore/nn/layer/combined.py +1 -1
  191. mindspore/nn/layer/container.py +52 -6
  192. mindspore/nn/layer/conv.py +112 -43
  193. mindspore/nn/layer/dense.py +10 -9
  194. mindspore/nn/layer/embedding.py +36 -34
  195. mindspore/nn/layer/image.py +123 -27
  196. mindspore/nn/layer/math.py +108 -107
  197. mindspore/nn/layer/normalization.py +212 -366
  198. mindspore/nn/layer/padding.py +370 -42
  199. mindspore/nn/layer/pooling.py +1443 -219
  200. mindspore/nn/layer/rnn_cells.py +11 -16
  201. mindspore/nn/layer/rnns.py +38 -39
  202. mindspore/nn/layer/thor_layer.py +24 -25
  203. mindspore/nn/layer/timedistributed.py +5 -5
  204. mindspore/nn/layer/transformer.py +701 -0
  205. mindspore/nn/learning_rate_schedule.py +8 -8
  206. mindspore/nn/loss/__init__.py +9 -6
  207. mindspore/nn/loss/loss.py +678 -142
  208. mindspore/nn/metrics.py +53 -0
  209. mindspore/nn/optim/_dist_optimizer_registry.py +2 -2
  210. mindspore/nn/optim/ada_grad.py +8 -8
  211. mindspore/nn/optim/adadelta.py +2 -3
  212. mindspore/nn/optim/adafactor.py +18 -14
  213. mindspore/nn/optim/adam.py +429 -87
  214. mindspore/nn/optim/adamax.py +5 -6
  215. mindspore/nn/optim/adasum.py +10 -8
  216. mindspore/nn/optim/asgd.py +7 -7
  217. mindspore/nn/optim/ftrl.py +81 -11
  218. mindspore/nn/optim/lamb.py +7 -8
  219. mindspore/nn/optim/lars.py +4 -4
  220. mindspore/nn/optim/lazyadam.py +82 -7
  221. mindspore/nn/optim/momentum.py +8 -7
  222. mindspore/nn/optim/optimizer.py +19 -10
  223. mindspore/nn/optim/proximal_ada_grad.py +6 -5
  224. mindspore/nn/optim/rmsprop.py +3 -3
  225. mindspore/nn/optim/rprop.py +20 -16
  226. mindspore/nn/optim/sgd.py +21 -15
  227. mindspore/nn/optim/thor.py +23 -21
  228. mindspore/nn/probability/__init__.py +0 -2
  229. mindspore/nn/probability/bijector/bijector.py +7 -6
  230. mindspore/nn/probability/bijector/invert.py +4 -2
  231. mindspore/nn/probability/bijector/softplus.py +2 -2
  232. mindspore/nn/probability/bnn_layers/dense_variational.py +1 -1
  233. mindspore/nn/probability/bnn_layers/layer_distribution.py +2 -2
  234. mindspore/nn/probability/distribution/__init__.py +6 -0
  235. mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -2
  236. mindspore/nn/probability/distribution/_utils/utils.py +11 -17
  237. mindspore/nn/probability/distribution/bernoulli.py +6 -6
  238. mindspore/nn/probability/distribution/beta.py +1 -1
  239. mindspore/nn/probability/distribution/categorical.py +9 -9
  240. mindspore/nn/probability/distribution/cauchy.py +8 -8
  241. mindspore/nn/probability/distribution/distribution.py +12 -6
  242. mindspore/nn/probability/distribution/exponential.py +5 -5
  243. mindspore/nn/probability/distribution/gamma.py +3 -3
  244. mindspore/nn/probability/distribution/geometric.py +6 -5
  245. mindspore/nn/probability/distribution/gumbel.py +5 -5
  246. mindspore/nn/probability/distribution/half_normal.py +133 -0
  247. mindspore/nn/probability/distribution/laplace.py +128 -0
  248. mindspore/nn/probability/distribution/log_normal.py +0 -1
  249. mindspore/nn/probability/distribution/logistic.py +4 -5
  250. mindspore/nn/probability/distribution/normal.py +11 -15
  251. mindspore/nn/probability/distribution/poisson.py +6 -2
  252. mindspore/nn/probability/distribution/student_t.py +150 -0
  253. mindspore/nn/probability/distribution/transformed_distribution.py +4 -4
  254. mindspore/nn/probability/distribution/uniform.py +5 -5
  255. mindspore/nn/reinforcement/_tensors_queue.py +3 -3
  256. mindspore/nn/reinforcement/tensor_array.py +2 -2
  257. mindspore/nn/sparse/sparse.py +8 -1
  258. mindspore/nn/wrap/cell_wrapper.py +55 -27
  259. mindspore/nn/wrap/grad_reducer.py +20 -11
  260. mindspore/nn/wrap/loss_scale.py +47 -30
  261. mindspore/numpy/array_creations.py +33 -22
  262. mindspore/numpy/array_ops.py +46 -42
  263. mindspore/numpy/logic_ops.py +6 -27
  264. mindspore/numpy/math_ops.py +26 -19
  265. mindspore/numpy/utils.py +1 -8
  266. mindspore/numpy/utils_const.py +112 -62
  267. mindspore/opencv_core452.dll +0 -0
  268. mindspore/opencv_imgcodecs452.dll +0 -0
  269. mindspore/opencv_imgproc452.dll +0 -0
  270. mindspore/ops/__init__.py +6 -3
  271. mindspore/ops/_constants.py +0 -6
  272. mindspore/ops/_grad/__init__.py +2 -1
  273. mindspore/ops/_grad/grad_array_ops.py +209 -152
  274. mindspore/ops/_grad/grad_base.py +55 -17
  275. mindspore/ops/_grad/grad_clip_ops.py +11 -3
  276. mindspore/ops/_grad/grad_comm_ops.py +58 -47
  277. mindspore/ops/_grad/grad_implementations.py +21 -61
  278. mindspore/ops/_grad/grad_inner_ops.py +48 -6
  279. mindspore/ops/_grad/grad_math_ops.py +306 -161
  280. mindspore/ops/_grad/grad_nn_ops.py +192 -181
  281. mindspore/ops/_grad/grad_other_ops.py +1 -1
  282. mindspore/ops/_grad/grad_quant_ops.py +5 -5
  283. mindspore/ops/_grad/grad_sequence_ops.py +296 -0
  284. mindspore/ops/_grad/grad_sparse.py +15 -9
  285. mindspore/ops/_grad_experimental/__init__.py +1 -0
  286. mindspore/ops/_grad_experimental/grad_array_ops.py +441 -55
  287. mindspore/ops/_grad_experimental/grad_image_ops.py +25 -7
  288. mindspore/ops/_grad_experimental/grad_inner_ops.py +3 -44
  289. mindspore/ops/_grad_experimental/grad_linalg_ops.py +16 -21
  290. mindspore/ops/_grad_experimental/grad_math_ops.py +979 -49
  291. mindspore/ops/_grad_experimental/grad_nn_ops.py +78 -8
  292. mindspore/ops/_grad_experimental/grad_scalar_ops.py +112 -0
  293. mindspore/ops/_grad_experimental/grad_sparse_ops.py +197 -13
  294. mindspore/ops/_op_impl/__init__.py +3 -3
  295. mindspore/ops/_op_impl/_custom_op/__init__.py +0 -1
  296. mindspore/ops/_op_impl/_custom_op/_basic.py +0 -1
  297. mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +1 -1
  298. mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +4 -2
  299. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +2 -2
  300. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +2 -2
  301. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +5 -5
  302. mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +3 -3
  303. mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +1 -1
  304. mindspore/ops/_op_impl/_custom_op/correction_mul.py +3 -3
  305. mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +2 -2
  306. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +4 -8
  307. mindspore/ops/_op_impl/_custom_op/dsd_impl.py +1 -1
  308. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +2 -2
  309. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +2 -2
  310. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +2 -2
  311. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +2 -2
  312. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +2 -2
  313. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +2 -2
  314. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +2 -2
  315. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +2 -2
  316. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +2 -2
  317. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +2 -2
  318. mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +1 -1
  319. mindspore/ops/_op_impl/_custom_op/img2col_impl.py +1 -1
  320. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
  321. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +1 -1
  322. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +1 -1
  323. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +1 -1
  324. mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +2 -2
  325. mindspore/ops/_op_impl/_custom_op/matmul_dds_grad_impl.py +0 -1
  326. mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +0 -1
  327. mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +1 -1
  328. mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +2 -2
  329. mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +2 -2
  330. mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +1 -1
  331. mindspore/ops/_op_impl/aicpu/__init__.py +238 -3
  332. mindspore/ops/_op_impl/aicpu/abs.py +36 -0
  333. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d.py +34 -0
  334. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
  335. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d.py +39 -0
  336. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d_grad.py +39 -0
  337. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d_grad.py +37 -0
  338. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d.py +42 -0
  339. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d_grad.py +152 -0
  340. mindspore/ops/_op_impl/aicpu/add.py +43 -0
  341. mindspore/ops/_op_impl/aicpu/addcdiv.py +0 -32
  342. mindspore/ops/_op_impl/aicpu/addcmul.py +0 -84
  343. mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
  344. mindspore/ops/_op_impl/aicpu/arg_max.py +75 -0
  345. mindspore/ops/_op_impl/aicpu/arg_min.py +75 -0
  346. mindspore/ops/_op_impl/aicpu/argmin_with_value.py +43 -0
  347. mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -0
  348. mindspore/ops/_op_impl/aicpu/batch_norm_grad_grad.py +49 -0
  349. mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
  350. mindspore/ops/_op_impl/aicpu/bessel_i0.py +31 -0
  351. mindspore/ops/_op_impl/aicpu/bias_add.py +44 -0
  352. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +43 -0
  353. mindspore/ops/_op_impl/aicpu/bincount.py +33 -0
  354. mindspore/{nn/probability/infer/variational/__init__.py → ops/_op_impl/aicpu/cauchy.py} +17 -10
  355. mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
  356. mindspore/ops/_op_impl/aicpu/cholesky.py +1 -1
  357. mindspore/ops/_op_impl/{cpu/bias_add.py → aicpu/choleskygrad.py} +9 -7
  358. mindspore/ops/_op_impl/aicpu/combined_non_max_suppression.py +42 -0
  359. mindspore/ops/_op_impl/aicpu/concat_offset.py +42 -0
  360. mindspore/ops/_op_impl/aicpu/concat_offset_v1.py +31 -0
  361. mindspore/ops/_op_impl/aicpu/conj.py +11 -0
  362. mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_image.py +38 -0
  363. mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +36 -0
  364. mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
  365. mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +2 -2
  366. mindspore/ops/_op_impl/aicpu/dense_to_sparse_set_operation.py +48 -0
  367. mindspore/ops/_op_impl/aicpu/diag.py +36 -0
  368. mindspore/ops/_op_impl/aicpu/diag_part.py +36 -0
  369. mindspore/ops/_op_impl/aicpu/diagonal.py +35 -0
  370. mindspore/ops/_op_impl/{cpu/bias_add_grad.py → aicpu/digamma.py} +9 -7
  371. mindspore/ops/_op_impl/aicpu/eig.py +35 -0
  372. mindspore/ops/_op_impl/aicpu/fft_with_size.py +41 -0
  373. mindspore/ops/_op_impl/aicpu/flatten.py +1 -0
  374. mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
  375. mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
  376. mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +1 -1
  377. mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
  378. mindspore/ops/_op_impl/aicpu/glu.py +33 -0
  379. mindspore/ops/_op_impl/aicpu/glu_grad.py +34 -0
  380. mindspore/ops/_op_impl/aicpu/greater.py +41 -0
  381. mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
  382. mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
  383. mindspore/ops/_op_impl/{tbe/scatter_add_ds.py → aicpu/inplace_index_add.py} +17 -21
  384. mindspore/ops/_op_impl/aicpu/instance_norm_v2.py +41 -0
  385. mindspore/ops/_op_impl/aicpu/instance_norm_v2_grad.py +44 -0
  386. mindspore/ops/_op_impl/aicpu/layer_norm_grad_grad.py +47 -0
  387. mindspore/ops/_op_impl/aicpu/less.py +41 -0
  388. mindspore/ops/_op_impl/aicpu/less_equal.py +41 -0
  389. mindspore/ops/_op_impl/aicpu/lgamma.py +32 -0
  390. mindspore/ops/_op_impl/aicpu/log_normal_reverse.py +33 -0
  391. mindspore/ops/_op_impl/aicpu/logit.py +33 -0
  392. mindspore/ops/_op_impl/aicpu/logit_grad.py +34 -0
  393. mindspore/ops/_op_impl/aicpu/masked_fill.py +42 -0
  394. mindspore/ops/_op_impl/aicpu/masked_scatter.py +39 -0
  395. mindspore/ops/_op_impl/aicpu/matmul.py +39 -0
  396. mindspore/ops/_op_impl/aicpu/matrix_logarithm.py +31 -0
  397. mindspore/ops/_op_impl/aicpu/matrix_power.py +32 -0
  398. mindspore/ops/_op_impl/aicpu/matrix_solve_ls.py +36 -0
  399. mindspore/ops/_op_impl/aicpu/matrix_triangular_solve.py +36 -0
  400. mindspore/ops/_op_impl/aicpu/mirror_pad.py +2 -0
  401. mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +0 -4
  402. mindspore/ops/_op_impl/aicpu/mul.py +3 -1
  403. mindspore/ops/_op_impl/aicpu/multinomial.py +14 -6
  404. mindspore/ops/_op_impl/aicpu/multinomial_with_replacement.py +35 -0
  405. mindspore/ops/_op_impl/aicpu/nan_to_num.py +34 -0
  406. mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
  407. mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
  408. mindspore/ops/_op_impl/aicpu/ones_like.py +0 -2
  409. mindspore/ops/_op_impl/aicpu/polar.py +32 -0
  410. mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
  411. mindspore/ops/_op_impl/aicpu/qr.py +36 -0
  412. mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
  413. mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
  414. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
  415. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_tensor.py +74 -0
  416. mindspore/ops/_op_impl/aicpu/random_shuffle.py +3 -0
  417. mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
  418. mindspore/ops/_op_impl/aicpu/range.py +36 -0
  419. mindspore/ops/_op_impl/aicpu/reciprocal.py +34 -0
  420. mindspore/ops/_op_impl/aicpu/reciprocal_grad.py +35 -0
  421. mindspore/ops/_op_impl/aicpu/reduce_sum.py +57 -0
  422. mindspore/ops/_op_impl/aicpu/resize_bicubic.py +2 -8
  423. mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +1 -1
  424. mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
  425. mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
  426. mindspore/ops/_op_impl/aicpu/scatter_elements.py +4 -0
  427. mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +2 -0
  428. mindspore/ops/_op_impl/aicpu/search_sorted.py +12 -6
  429. mindspore/ops/_op_impl/aicpu/self_adjoint_eig.py +34 -0
  430. mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
  431. mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
  432. mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
  433. mindspore/ops/_op_impl/aicpu/slice_grad.py +76 -0
  434. mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
  435. mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
  436. mindspore/ops/_op_impl/aicpu/sort.py +39 -0
  437. mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +0 -24
  438. mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
  439. mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows.py +63 -0
  440. mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows_grad.py +45 -0
  441. mindspore/ops/_op_impl/aicpu/sparse_matrix_mat_mul.py +56 -0
  442. mindspore/ops/_op_impl/{tbe/slice_ds.py → aicpu/sparse_segment_sum.py} +16 -24
  443. mindspore/ops/_op_impl/aicpu/sparse_segment_sum_with_num_segments.py +68 -0
  444. mindspore/ops/_op_impl/aicpu/sparse_slice.py +63 -0
  445. mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +61 -0
  446. mindspore/ops/_op_impl/aicpu/squared_difference.py +2 -0
  447. mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +93 -0
  448. mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +66 -0
  449. mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
  450. mindspore/ops/_op_impl/{tbe/gather_v2.py → aicpu/tile.py} +24 -24
  451. mindspore/ops/_op_impl/aicpu/tridiagonal_solve.py +35 -0
  452. mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
  453. mindspore/ops/_op_impl/aicpu/triu_indices.py +34 -0
  454. mindspore/ops/_op_impl/aicpu/uniform.py +34 -0
  455. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +1 -0
  456. mindspore/ops/_op_impl/aicpu/unique_consecutive.py +10 -2
  457. mindspore/ops/_op_impl/cpu/__init__.py +1 -2
  458. mindspore/ops/_op_impl/cpu/dynamic_shape.py +5 -1
  459. mindspore/ops/_op_impl/cpu/maximum_grad.py +2 -0
  460. mindspore/{compression/common/__init__.py → ops/_op_impl/cpu/pyexecute.py} +13 -8
  461. mindspore/ops/_op_impl/cpu/reduce_sum.py +8 -0
  462. mindspore/ops/_op_impl/cpu/sparse_slice.py +62 -0
  463. mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +60 -0
  464. mindspore/ops/_op_impl/cpu/tensor_shape.py +5 -1
  465. mindspore/ops/_op_impl/tbe/__init__.py +27 -608
  466. mindspore/ops/_op_impl/tbe/addcdiv_ds.py +42 -0
  467. mindspore/ops/_op_impl/tbe/addcmul_ds.py +44 -0
  468. mindspore/ops/_op_impl/tbe/assign_add_ds.py +1 -0
  469. mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
  470. mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +1 -1
  471. mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad_v2.py +0 -1
  472. mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
  473. mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +1 -1
  474. mindspore/ops/_op_impl/tbe/batch_to_space_nd_v2.py +41 -0
  475. mindspore/ops/_op_impl/tbe/bce_with_logits_loss.py +1 -0
  476. mindspore/ops/_op_impl/tbe/bias_add_grad.py +2 -0
  477. mindspore/ops/_op_impl/tbe/bn_infer_grad.py +4 -2
  478. mindspore/ops/_op_impl/tbe/bn_infer_grad_ds.py +40 -0
  479. mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -1
  480. mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -1
  481. mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +6 -4
  482. mindspore/ops/_op_impl/tbe/cast.py +0 -2
  483. mindspore/ops/_op_impl/tbe/cast_ds.py +3 -3
  484. mindspore/ops/_op_impl/tbe/ctc_loss_v2.py +0 -2
  485. mindspore/ops/_op_impl/tbe/ctc_loss_v2_grad.py +0 -2
  486. mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +1 -0
  487. mindspore/ops/_op_impl/tbe/deformable_offsets.py +1 -0
  488. mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +1 -1
  489. mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +1 -1
  490. mindspore/ops/_op_impl/tbe/gather_nd.py +1 -0
  491. mindspore/ops/_op_impl/tbe/greater.py +2 -0
  492. mindspore/ops/_op_impl/tbe/{index_add.py → inplace_index_add.py} +3 -6
  493. mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2.py +0 -1
  494. mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +35 -0
  495. mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +35 -0
  496. mindspore/ops/_op_impl/tbe/one_hot_ds.py +0 -6
  497. mindspore/ops/_op_impl/tbe/{greater_ds.py → reduce_all_ds.py} +13 -16
  498. mindspore/ops/_op_impl/tbe/reduce_any_ds.py +39 -0
  499. mindspore/ops/_op_impl/tbe/roi_align_ds.py +44 -0
  500. mindspore/ops/_op_impl/tbe/roi_align_grad_ds.py +44 -0
  501. mindspore/ops/_op_impl/tbe/scatter_add.py +2 -0
  502. mindspore/ops/_op_impl/tbe/scatter_nd_add.py +2 -2
  503. mindspore/ops/_op_impl/tbe/slice.py +26 -15
  504. mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
  505. mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +1 -1
  506. mindspore/ops/_op_impl/tbe/strided_slice_grad_d.py +1 -0
  507. mindspore/ops/_op_impl/tbe/trans_data_ds.py +15 -5
  508. mindspore/ops/_op_impl/tbe/unsorted_segment_sum.py +1 -1
  509. mindspore/ops/_op_impl/tbe/unsorted_segment_sum_ds.py +2 -0
  510. mindspore/ops/_primitive_cache.py +3 -2
  511. mindspore/ops/_register_for_op.py +11 -0
  512. mindspore/ops/_utils/__init__.py +1 -1
  513. mindspore/ops/_utils/utils.py +20 -41
  514. mindspore/ops/_vmap/__init__.py +2 -2
  515. mindspore/ops/_vmap/vmap_array_ops.py +170 -78
  516. mindspore/ops/_vmap/vmap_base.py +24 -10
  517. mindspore/ops/_vmap/vmap_convolution_ops.py +7 -10
  518. mindspore/ops/_vmap/vmap_grad_math_ops.py +4 -4
  519. mindspore/ops/_vmap/vmap_grad_nn_ops.py +41 -9
  520. mindspore/ops/_vmap/vmap_image_ops.py +52 -0
  521. mindspore/ops/_vmap/vmap_math_ops.py +77 -6
  522. mindspore/ops/_vmap/vmap_nn_ops.py +78 -29
  523. mindspore/ops/_vmap/vmap_other_ops.py +3 -1
  524. mindspore/ops/_vmap/vmap_random_ops.py +55 -3
  525. mindspore/ops/_vmap/vmap_sparse_ops.py +1 -0
  526. mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
  527. mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
  528. mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +18 -19
  529. mindspore/ops/bprop_mindir/Argmax_bprop.mindir +13 -12
  530. mindspore/ops/bprop_mindir/Argmin_bprop.mindir +14 -13
  531. mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +17 -18
  532. mindspore/ops/bprop_mindir/Assign_bprop.mindir +16 -16
  533. mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +150 -0
  534. mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +66 -0
  535. mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
  536. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +13 -12
  537. mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
  538. mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +28 -0
  539. mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
  540. mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +33 -0
  541. mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +306 -0
  542. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +12 -8
  543. mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
  544. mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
  545. mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +240 -0
  546. mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +247 -0
  547. mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +247 -0
  548. mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +315 -0
  549. mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +278 -0
  550. mindspore/ops/bprop_mindir/DType_bprop.mindir +12 -12
  551. mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +58 -0
  552. mindspore/ops/bprop_mindir/Depend_bprop.mindir +12 -13
  553. mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +23 -0
  554. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +138 -0
  555. mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +15 -0
  556. mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
  557. mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
  558. mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +22 -24
  559. mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +16 -14
  560. mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +27 -0
  561. mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
  562. mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
  563. mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
  564. mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +12 -12
  565. mindspore/ops/bprop_mindir/Elu_bprop.mindir +16 -0
  566. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  567. mindspore/ops/bprop_mindir/Equal_bprop.mindir +18 -19
  568. mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +58 -0
  569. mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +16 -0
  570. mindspore/ops/bprop_mindir/Flatten_bprop.mindir +54 -0
  571. mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +18 -15
  572. mindspore/ops/bprop_mindir/GatherD_bprop.mindir +26 -0
  573. mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +57 -0
  574. mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
  575. mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +17 -18
  576. mindspore/ops/bprop_mindir/Greater_bprop.mindir +18 -19
  577. mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +16 -0
  578. mindspore/ops/bprop_mindir/HSwish_bprop.mindir +16 -0
  579. mindspore/ops/bprop_mindir/IOU_bprop.mindir +18 -19
  580. mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
  581. mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +13 -12
  582. mindspore/ops/bprop_mindir/IsInf_bprop.mindir +13 -10
  583. mindspore/ops/bprop_mindir/IsNan_bprop.mindir +14 -11
  584. mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +126 -0
  585. mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +15 -0
  586. mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +30 -0
  587. mindspore/ops/bprop_mindir/LRN_bprop.mindir +43 -0
  588. mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
  589. mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +18 -19
  590. mindspore/ops/bprop_mindir/Less_bprop.mindir +17 -18
  591. mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +22 -19
  592. mindspore/ops/bprop_mindir/Load_bprop.mindir +12 -13
  593. mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +23 -0
  594. mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +17 -18
  595. mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +14 -13
  596. mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +21 -0
  597. mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +74 -0
  598. mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +74 -0
  599. mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +75 -0
  600. mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +65 -0
  601. mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
  602. mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
  603. mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
  604. mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +27 -0
  605. mindspore/ops/bprop_mindir/Mish_bprop.mindir +35 -0
  606. mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
  607. mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
  608. mindspore/ops/bprop_mindir/NonZero_bprop.mindir +14 -0
  609. mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +18 -19
  610. mindspore/ops/bprop_mindir/OneHot_bprop.mindir +25 -23
  611. mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +13 -13
  612. mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
  613. mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
  614. mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
  615. mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +29 -0
  616. mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +82 -0
  617. mindspore/ops/bprop_mindir/Range_bprop.mindir +21 -19
  618. mindspore/ops/bprop_mindir/Rank_bprop.mindir +11 -11
  619. mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +16 -0
  620. mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
  621. mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +18 -17
  622. mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +18 -17
  623. mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +19 -23
  624. mindspore/ops/bprop_mindir/Reshape_bprop.mindir +60 -0
  625. mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +29 -0
  626. mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +89 -0
  627. mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +52 -0
  628. mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +22 -0
  629. mindspore/ops/bprop_mindir/Round_bprop.mindir +14 -13
  630. mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
  631. mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
  632. mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +22 -0
  633. mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +24 -0
  634. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +22 -0
  635. mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
  636. mindspore/ops/bprop_mindir/SeLU_bprop.mindir +21 -0
  637. mindspore/ops/bprop_mindir/Select_bprop.mindir +30 -34
  638. mindspore/ops/bprop_mindir/Shape_bprop.mindir +12 -12
  639. mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +21 -0
  640. mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
  641. mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +16 -0
  642. mindspore/ops/bprop_mindir/Sign_bprop.mindir +13 -12
  643. mindspore/ops/bprop_mindir/Slice_bprop.mindir +26 -0
  644. mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +36 -0
  645. mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  646. mindspore/ops/bprop_mindir/Softplus_bprop.mindir +16 -0
  647. mindspore/ops/bprop_mindir/Softsign_bprop.mindir +33 -0
  648. mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
  649. mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +28 -0
  650. mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +23 -0
  651. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  652. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  653. mindspore/ops/bprop_mindir/Split_bprop.mindir +22 -0
  654. mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +54 -0
  655. mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +95 -0
  656. mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +98 -0
  657. mindspore/ops/bprop_mindir/Switch_bprop.mindir +28 -32
  658. mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
  659. mindspore/ops/bprop_mindir/Tanh_bprop.mindir +66 -0
  660. mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +22 -0
  661. mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +29 -0
  662. mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +14 -0
  663. mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
  664. mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
  665. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +23 -0
  666. mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +18 -15
  667. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +11 -13
  668. mindspore/ops/bprop_mindir/Unique_bprop.mindir +16 -0
  669. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +22 -0
  670. mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +32 -0
  671. mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +38 -0
  672. mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +13 -12
  673. mindspore/ops/bprop_mindir/__init__.py +1 -4
  674. mindspore/ops/bprop_mindir/generate_mindir.py +32 -20
  675. mindspore/ops/composite/__init__.py +12 -13
  676. mindspore/ops/composite/base.py +261 -254
  677. mindspore/ops/composite/env_ops.py +41 -0
  678. mindspore/ops/composite/math_ops.py +197 -156
  679. mindspore/ops/composite/multitype_ops/_compile_utils.py +428 -176
  680. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +188 -87
  681. mindspore/ops/composite/multitype_ops/add_impl.py +23 -1
  682. mindspore/ops/composite/multitype_ops/div_impl.py +3 -3
  683. mindspore/ops/composite/multitype_ops/equal_impl.py +1 -0
  684. mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -1
  685. mindspore/ops/composite/multitype_ops/getitem_impl.py +52 -5
  686. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +31 -0
  687. mindspore/ops/composite/multitype_ops/greater_impl.py +31 -0
  688. mindspore/ops/composite/multitype_ops/in_impl.py +15 -3
  689. mindspore/ops/composite/multitype_ops/less_equal_impl.py +33 -2
  690. mindspore/ops/composite/multitype_ops/less_impl.py +33 -0
  691. mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -2
  692. mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
  693. mindspore/ops/composite/multitype_ops/mod_impl.py +1 -1
  694. mindspore/ops/composite/multitype_ops/mul_impl.py +21 -7
  695. mindspore/ops/composite/multitype_ops/not_in_impl.py +15 -3
  696. mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -4
  697. mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
  698. mindspore/ops/composite/multitype_ops/setitem_impl.py +62 -70
  699. mindspore/ops/composite/multitype_ops/sub_impl.py +3 -3
  700. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +41 -4
  701. mindspore/ops/function/__init__.py +323 -8
  702. mindspore/ops/function/array_func.py +3511 -780
  703. mindspore/ops/function/clip_func.py +329 -0
  704. mindspore/ops/function/debug_func.py +6 -6
  705. mindspore/ops/function/grad/__init__.py +5 -1
  706. mindspore/ops/function/grad/grad_func.py +736 -65
  707. mindspore/ops/function/image_func.py +270 -0
  708. mindspore/ops/function/linalg_func.py +268 -8
  709. mindspore/ops/function/math_func.py +8032 -3164
  710. mindspore/ops/function/nn_func.py +5619 -1855
  711. mindspore/ops/function/other_func.py +115 -0
  712. mindspore/ops/function/parameter_func.py +11 -10
  713. mindspore/ops/function/random_func.py +939 -77
  714. mindspore/ops/function/sparse_func.py +249 -84
  715. mindspore/ops/function/sparse_unary_func.py +2303 -0
  716. mindspore/ops/function/spectral_func.py +146 -0
  717. mindspore/ops/function/vmap_func.py +114 -0
  718. mindspore/ops/functional.py +182 -254
  719. mindspore/ops/op_info_register.py +79 -34
  720. mindspore/ops/operations/__init__.py +210 -118
  721. mindspore/ops/operations/_csr_ops.py +7 -7
  722. mindspore/ops/operations/_embedding_cache_ops.py +25 -15
  723. mindspore/ops/operations/_grad_ops.py +447 -322
  724. mindspore/ops/operations/_inner_ops.py +547 -176
  725. mindspore/ops/operations/_map_tensor_ops.py +112 -0
  726. mindspore/ops/operations/_ms_kernel.py +29 -27
  727. mindspore/ops/operations/_ocr_ops.py +11 -11
  728. mindspore/ops/operations/_opaque_predicate_registry.py +41 -0
  729. mindspore/ops/operations/_quant_ops.py +186 -101
  730. mindspore/ops/operations/_rl_inner_ops.py +122 -61
  731. mindspore/ops/operations/_scalar_ops.py +466 -0
  732. mindspore/ops/operations/_sequence_ops.py +1047 -0
  733. mindspore/ops/operations/_tensor_array.py +10 -11
  734. mindspore/ops/operations/_thor_ops.py +4 -4
  735. mindspore/ops/operations/array_ops.py +1428 -1226
  736. mindspore/ops/operations/comm_ops.py +180 -117
  737. mindspore/ops/operations/control_ops.py +4 -2
  738. mindspore/ops/operations/custom_ops.py +185 -98
  739. mindspore/ops/operations/debug_ops.py +92 -54
  740. mindspore/ops/operations/image_ops.py +406 -211
  741. mindspore/ops/operations/inner_ops.py +42 -53
  742. mindspore/ops/operations/linalg_ops.py +32 -29
  743. mindspore/ops/operations/math_ops.py +2076 -897
  744. mindspore/ops/operations/nn_ops.py +1282 -1252
  745. mindspore/ops/operations/other_ops.py +124 -278
  746. mindspore/ops/operations/random_ops.py +345 -178
  747. mindspore/ops/operations/rl_ops.py +8 -9
  748. mindspore/ops/operations/sparse_ops.py +502 -157
  749. mindspore/ops/operations/spectral_ops.py +107 -0
  750. mindspore/ops/primitive.py +192 -15
  751. mindspore/ops/vm_impl_registry.py +23 -2
  752. mindspore/parallel/__init__.py +6 -1
  753. mindspore/parallel/_auto_parallel_context.py +199 -92
  754. mindspore/parallel/_cell_wrapper.py +4 -2
  755. mindspore/parallel/_cost_model_context.py +3 -0
  756. mindspore/parallel/_dp_allreduce_fusion.py +2 -1
  757. mindspore/parallel/_offload_context.py +185 -0
  758. mindspore/parallel/_parallel_serialization.py +167 -28
  759. mindspore/parallel/_ps_context.py +9 -5
  760. mindspore/parallel/_recovery_context.py +1 -1
  761. mindspore/parallel/_tensor.py +9 -1
  762. mindspore/{nn/transformer → parallel/_transformer}/__init__.py +6 -6
  763. mindspore/{nn/transformer → parallel/_transformer}/layers.py +59 -37
  764. mindspore/{nn/transformer → parallel/_transformer}/loss.py +4 -7
  765. mindspore/{nn/transformer → parallel/_transformer}/moe.py +160 -35
  766. mindspore/{nn/transformer → parallel/_transformer}/op_parallel_config.py +3 -3
  767. mindspore/{nn/transformer → parallel/_transformer}/transformer.py +235 -196
  768. mindspore/parallel/_utils.py +47 -7
  769. mindspore/parallel/algo_parameter_config.py +5 -1
  770. mindspore/parallel/checkpoint_transform.py +329 -0
  771. mindspore/parallel/shard.py +229 -0
  772. mindspore/perf_msvcbuildinsights.dll +0 -0
  773. mindspore/pgodb140.dll +0 -0
  774. mindspore/pgort140.dll +0 -0
  775. mindspore/profiler/__init__.py +2 -1
  776. mindspore/profiler/common/util.py +4 -3
  777. mindspore/profiler/common/validator/validate_path.py +2 -2
  778. mindspore/profiler/envprofiling.py +249 -0
  779. mindspore/profiler/parser/aicpu_data_parser.py +38 -39
  780. mindspore/profiler/parser/ascend_timeline_generator.py +497 -0
  781. mindspore/profiler/parser/base_timeline_generator.py +471 -0
  782. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +684 -0
  783. mindspore/profiler/parser/framework_parser.py +42 -16
  784. mindspore/profiler/parser/hccl_parser.py +158 -158
  785. mindspore/profiler/parser/hwts_log_parser.py +7 -6
  786. mindspore/profiler/parser/integrator.py +18 -1579
  787. mindspore/profiler/parser/minddata_analyzer.py +8 -8
  788. mindspore/profiler/parser/msadvisor_analyzer.py +14 -27
  789. mindspore/profiler/parser/msadvisor_parser.py +2 -4
  790. mindspore/profiler/parser/optime_parser.py +17 -18
  791. mindspore/profiler/parser/profiler_info.py +108 -0
  792. mindspore/profiler/parser/step_trace_parser.py +1 -1
  793. mindspore/profiler/profiling.py +396 -194
  794. mindspore/rewrite/__init__.py +6 -2
  795. mindspore/rewrite/api/node.py +51 -110
  796. mindspore/rewrite/api/node_type.py +10 -6
  797. mindspore/rewrite/api/pattern_engine.py +51 -7
  798. mindspore/rewrite/api/scoped_value.py +64 -53
  799. mindspore/rewrite/api/symbol_tree.py +108 -61
  800. mindspore/rewrite/api/tree_node_helper.py +2 -3
  801. mindspore/{compression/quant/__init__.py → rewrite/ast_creator_register.py} +20 -11
  802. mindspore/rewrite/ast_helpers/__init__.py +6 -3
  803. mindspore/rewrite/ast_helpers/ast_creator.py +115 -0
  804. mindspore/rewrite/ast_helpers/ast_finder.py +99 -1
  805. mindspore/rewrite/ast_helpers/ast_modifier.py +17 -4
  806. mindspore/rewrite/ast_helpers/ast_replacer.py +1 -1
  807. mindspore/rewrite/ast_transformers/__init__.py +0 -1
  808. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +46 -5
  809. mindspore/rewrite/ast_transformers/remove_return_out_of_if.py +6 -3
  810. mindspore/rewrite/common/__init__.py +2 -0
  811. mindspore/rewrite/common/event.py +1 -1
  812. mindspore/rewrite/common/observable.py +1 -1
  813. mindspore/rewrite/common/observer.py +1 -1
  814. mindspore/rewrite/common/rewrite_elog.py +35 -0
  815. mindspore/rewrite/namer.py +2 -2
  816. mindspore/rewrite/namespace.py +14 -4
  817. mindspore/rewrite/node.py +161 -13
  818. mindspore/rewrite/parser.py +0 -1
  819. mindspore/rewrite/parser_register.py +0 -1
  820. mindspore/rewrite/parsers/arguments_parser.py +3 -2
  821. mindspore/rewrite/parsers/assign_parser.py +267 -67
  822. mindspore/rewrite/parsers/attribute_parser.py +56 -0
  823. mindspore/rewrite/parsers/class_def_parser.py +191 -108
  824. mindspore/rewrite/parsers/constant_parser.py +101 -0
  825. mindspore/rewrite/parsers/container_parser.py +88 -0
  826. mindspore/rewrite/parsers/for_parser.py +28 -15
  827. mindspore/rewrite/parsers/function_def_parser.py +21 -5
  828. mindspore/rewrite/parsers/if_parser.py +11 -28
  829. mindspore/rewrite/parsers/module_parser.py +9 -6
  830. mindspore/rewrite/parsers/return_parser.py +3 -2
  831. mindspore/rewrite/sparsify/__init__.py +0 -0
  832. mindspore/rewrite/sparsify/sparse_transformer.py +448 -0
  833. mindspore/rewrite/sparsify/sparsify.py +109 -0
  834. mindspore/rewrite/sparsify/utils.py +173 -0
  835. mindspore/rewrite/symbol_tree.py +322 -109
  836. mindspore/rewrite/symbol_tree_builder.py +45 -8
  837. mindspore/rewrite/symbol_tree_dumper.py +0 -1
  838. mindspore/rewrite/topological_manager.py +1 -2
  839. mindspore/run_check/_check_version.py +209 -112
  840. mindspore/run_check/run_check.py +2 -1
  841. mindspore/tbbmalloc.dll +0 -0
  842. mindspore/tinyxml2.dll +0 -0
  843. mindspore/train/__init__.py +6 -4
  844. mindspore/train/_utils.py +28 -5
  845. mindspore/train/amp.py +321 -50
  846. mindspore/train/callback/__init__.py +3 -1
  847. mindspore/train/callback/_backup_and_restore.py +120 -0
  848. mindspore/train/callback/_callback.py +8 -8
  849. mindspore/train/callback/_checkpoint.py +12 -9
  850. mindspore/train/callback/_early_stop.py +13 -7
  851. mindspore/train/callback/_history.py +8 -8
  852. mindspore/train/callback/_lambda_callback.py +6 -6
  853. mindspore/train/callback/_landscape.py +36 -38
  854. mindspore/train/callback/_loss_monitor.py +12 -6
  855. mindspore/train/callback/_lr_scheduler_callback.py +2 -4
  856. mindspore/train/callback/_on_request_exit.py +212 -0
  857. mindspore/train/callback/_reduce_lr_on_plateau.py +13 -7
  858. mindspore/train/callback/_summary_collector.py +27 -19
  859. mindspore/train/callback/_time_monitor.py +13 -7
  860. mindspore/train/checkpoint_pb2.py +68 -8
  861. mindspore/train/data_sink.py +122 -33
  862. mindspore/train/dataset_helper.py +28 -87
  863. mindspore/train/loss_scale_manager.py +4 -7
  864. mindspore/{nn → train}/metrics/__init__.py +20 -20
  865. mindspore/{nn → train}/metrics/accuracy.py +12 -10
  866. mindspore/{nn → train}/metrics/auc.py +4 -4
  867. mindspore/{nn → train}/metrics/bleu_score.py +4 -4
  868. mindspore/{nn → train}/metrics/confusion_matrix.py +10 -8
  869. mindspore/{nn → train}/metrics/cosine_similarity.py +4 -4
  870. mindspore/{nn → train}/metrics/dice.py +6 -5
  871. mindspore/{nn → train}/metrics/error.py +7 -5
  872. mindspore/{nn → train}/metrics/fbeta.py +9 -7
  873. mindspore/{nn → train}/metrics/hausdorff_distance.py +8 -6
  874. mindspore/{nn → train}/metrics/loss.py +4 -3
  875. mindspore/{nn → train}/metrics/mean_surface_distance.py +6 -5
  876. mindspore/{nn → train}/metrics/metric.py +6 -5
  877. mindspore/{nn → train}/metrics/occlusion_sensitivity.py +4 -3
  878. mindspore/{nn → train}/metrics/perplexity.py +5 -4
  879. mindspore/{nn → train}/metrics/precision.py +5 -4
  880. mindspore/{nn → train}/metrics/recall.py +5 -4
  881. mindspore/{nn → train}/metrics/roc.py +7 -6
  882. mindspore/{nn → train}/metrics/root_mean_square_surface_distance.py +6 -5
  883. mindspore/{nn → train}/metrics/topk.py +7 -5
  884. mindspore/train/mind_ir_pb2.py +339 -32
  885. mindspore/train/model.py +113 -84
  886. mindspore/train/serialization.py +547 -167
  887. mindspore/train/summary/_summary_adapter.py +1 -1
  888. mindspore/train/summary/summary_record.py +43 -12
  889. mindspore/train/train_thor/convert_utils.py +7 -1
  890. mindspore/train/train_thor/dataset_helper.py +3 -3
  891. mindspore/train/train_thor/model_thor.py +0 -4
  892. mindspore/turbojpeg.dll +0 -0
  893. mindspore/vcmeta.dll +0 -0
  894. mindspore/vcruntime140.dll +0 -0
  895. mindspore/vcruntime140_1.dll +0 -0
  896. mindspore/version.py +1 -1
  897. {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/METADATA +4 -3
  898. {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/RECORD +901 -660
  899. mindspore/compression/common/constant.py +0 -124
  900. mindspore/compression/export/__init__.py +0 -19
  901. mindspore/compression/export/quant_export.py +0 -514
  902. mindspore/compression/quant/qat.py +0 -636
  903. mindspore/compression/quant/quant_utils.py +0 -462
  904. mindspore/compression/quant/quantizer.py +0 -68
  905. mindspore/libatomic-1.dll +0 -0
  906. mindspore/libgcc_s_seh-1.dll +0 -0
  907. mindspore/libgfortran-4.dll +0 -0
  908. mindspore/libgomp-1.dll +0 -0
  909. mindspore/libjpeg-62.dll +0 -0
  910. mindspore/libmindspore.dll +0 -0
  911. mindspore/libmindspore_common.dll +0 -0
  912. mindspore/libmindspore_core.dll +0 -0
  913. mindspore/libmindspore_glog.dll +0 -0
  914. mindspore/libnnacl.dll +0 -0
  915. mindspore/libopencv_core452.dll +0 -0
  916. mindspore/libopencv_imgcodecs452.dll +0 -0
  917. mindspore/libopencv_imgproc452.dll +0 -0
  918. mindspore/libquadmath-0.dll +0 -0
  919. mindspore/libsqlite3.dll +0 -0
  920. mindspore/libssp-0.dll +0 -0
  921. mindspore/libstdc++-6.dll +0 -0
  922. mindspore/libtinyxml2.dll +0 -0
  923. mindspore/libturbojpeg.dll +0 -0
  924. mindspore/libwinpthread-1.dll +0 -0
  925. mindspore/nn/layer/quant.py +0 -1868
  926. mindspore/nn/layer/rnn_utils.py +0 -90
  927. mindspore/nn/probability/dpn/__init__.py +0 -22
  928. mindspore/nn/probability/dpn/vae/__init__.py +0 -25
  929. mindspore/nn/probability/dpn/vae/cvae.py +0 -138
  930. mindspore/nn/probability/dpn/vae/vae.py +0 -122
  931. mindspore/nn/probability/infer/__init__.py +0 -22
  932. mindspore/nn/probability/infer/variational/elbo.py +0 -70
  933. mindspore/nn/probability/infer/variational/svi.py +0 -84
  934. mindspore/nn/probability/toolbox/__init__.py +0 -22
  935. mindspore/nn/probability/toolbox/anomaly_detection.py +0 -99
  936. mindspore/nn/probability/toolbox/uncertainty_evaluation.py +0 -363
  937. mindspore/nn/probability/transforms/__init__.py +0 -22
  938. mindspore/nn/probability/transforms/transform_bnn.py +0 -262
  939. mindspore/nn/probability/zhusuan/__init__.py +0 -18
  940. mindspore/nn/probability/zhusuan/framework/__init__.py +0 -18
  941. mindspore/nn/probability/zhusuan/framework/bn.py +0 -95
  942. mindspore/nn/probability/zhusuan/variational/__init__.py +0 -18
  943. mindspore/nn/probability/zhusuan/variational/elbo.py +0 -46
  944. mindspore/ops/_op_impl/tbe/bias_add_grad_ds.py +0 -52
  945. mindspore/ops/_op_impl/tbe/scatter_nd_add_ds.py +0 -43
  946. mindspore/ops/bprop_mindir/AssignAdd_bprop.mindir +0 -20
  947. mindspore/ops/bprop_mindir/Identity_bprop.mindir +0 -9
  948. mindspore/ops/bprop_mindir/LogicalOr_bprop.mindir +0 -20
  949. mindspore/ops/bprop_mindir/ReLU_bprop.mindir +0 -16
  950. mindspore/ops/bprop_mindir/UpdateState_bprop.mindir +0 -17
  951. mindspore/ops/bprop_mindir/stop_gradient_bprop.mindir +0 -12
  952. mindspore/ops/composite/array_ops.py +0 -210
  953. mindspore/ops/composite/clip_ops.py +0 -238
  954. mindspore/ops/composite/random_ops.py +0 -426
  955. mindspore/ops/composite/vmap_ops.py +0 -38
  956. mindspore/ops/operations/sponge_ops.py +0 -3531
  957. mindspore/ops/operations/sponge_update_ops.py +0 -2546
  958. mindspore/parallel/nn/__init__.py +0 -42
  959. mindspore/parallel/nn/loss.py +0 -22
  960. mindspore/parallel/nn/moe.py +0 -21
  961. mindspore/parallel/nn/op_parallel_config.py +0 -22
  962. mindspore/parallel/nn/transformer.py +0 -31
  963. mindspore/run_check/_check_deps_version.py +0 -84
  964. {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/WHEEL +0 -0
  965. {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/entry_points.txt +0 -0
  966. {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/top_level.txt +0 -0
@@ -1,3531 +0,0 @@
1
- # Copyright 2021 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ============================================================================
15
-
16
- """
17
- Note:
18
- SPONGE operators. This is an experimental interface that is subject to change and/or deletion.
19
- """
20
-
21
- import math
22
-
23
- from ..primitive import PrimitiveWithInfer, prim_attr_register
24
- from ..._checkparam import Rel
25
- from ..._checkparam import Validator as validator
26
- from ...common import dtype as mstype
27
-
28
-
29
- class BondForce(PrimitiveWithInfer):
30
- """
31
- Calculate the force exerted by the simple harmonic bond on the corresponding atoms.
32
- Assume the number of harmonic bonds is m and the number of atoms is n.
33
-
34
- Because there is a large amount of inputs and each of them are related,
35
- there is no way to construct `Examples` using random methods. For details, refer the webpage `SPONGE in MindSpore
36
- <https://gitee.com/mindspore/mindscience/blob/master/MindSPONGE/docs/simple_formula.md>`_.
37
-
38
- .. math::
39
-
40
- dr = (x_1-x_2, y_1-y_2, z_1-z_2)
41
-
42
- .. math::
43
-
44
- F = (F_x, F_y, F_z) = 2*k*(1 - r_0/|dr|)*dr
45
-
46
- Args:
47
- atom_numbers(int32): the number of atoms n.
48
- bond_numbers(int32): the number of harmonic bonds m.
49
-
50
- Inputs:
51
- - **uint_crd_f** (Tensor) - The unsigned int coordinate value of each atom.
52
- The data type is uint32 and the shape is :math:`(n, 3)`.
53
- - **scaler_f** (Tensor) - The 3-D scale factor (x, y, z),
54
- between the real space float coordinates and the unsigned int coordinates.
55
- The data type is float32 and the shape is :math:`(3,)`.
56
- - **atom_a** (Tensor) - The first atom index of each bond.
57
- The data type is int32 and the shape is :math:`(m,)`.
58
- - **atom_b** (Tensor) - The second atom index of each bond.
59
- The data type is int32 and the shape is :math:`(m,)`.
60
- - **bond_k** (Tensor) - The force constant of each bond.
61
- The data type is float32 and the shape is :math:`(m,)`.
62
- - **bond_r0** (Tensor) - The equlibrium length of each bond.
63
- The data type is float32 and the shape is :math:`(m,)`.
64
-
65
- Outputs:
66
- - **frc_f** (Tensor) - The force felt by each atom.
67
- The data type is float32 and the shape is :math:`(n, 3)`.
68
-
69
- Supported Platforms:
70
- ``GPU``
71
- """
72
-
73
- @prim_attr_register
74
- def __init__(self, bond_numbers, atom_numbers):
75
- """Initialize BondForce."""
76
- validator.check_value_type('bond_numbers', bond_numbers, int, self.name)
77
- validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
78
- self.bond_numbers = bond_numbers
79
- self.atom_numbers = atom_numbers
80
- self.add_prim_attr('bond_numbers', self.bond_numbers)
81
- self.add_prim_attr('atom_numbers', self.atom_numbers)
82
- self.init_prim_io_names(inputs=['uint_crd_f', 'scaler_f', 'atom_a', 'atom_b', 'bond_k', 'bond_r0'],
83
- outputs=['frc_f'])
84
-
85
- def infer_shape(self, uint_crd_f_shape, scaler_f_shape, atom_a_shape, atom_b_shape, bond_k_shape, bond_r0_shape):
86
- cls_name = self.name
87
- n = self.atom_numbers
88
- m = self.bond_numbers
89
- validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
90
- validator.check_int(len(scaler_f_shape), 1, Rel.EQ, "scaler_f_dim", cls_name)
91
- validator.check_int(len(atom_a_shape), 1, Rel.EQ, "atom_a_dim", cls_name)
92
- validator.check_int(len(atom_b_shape), 1, Rel.EQ, "atom_b_dim", cls_name)
93
- validator.check_int(len(bond_k_shape), 1, Rel.EQ, "bond_k_dim", cls_name)
94
- validator.check_int(len(bond_r0_shape), 1, Rel.EQ, "bond_r0_dim", cls_name)
95
- validator.check_int(uint_crd_f_shape[0], n, Rel.EQ, "uint_crd_f_shape[0]", cls_name)
96
- validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
97
- validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name)
98
- validator.check_int(atom_a_shape[0], m, Rel.EQ, "uint_crd_f_shape", cls_name)
99
- validator.check_int(atom_b_shape[0], m, Rel.EQ, "atom_b_shape", cls_name)
100
- validator.check_int(bond_k_shape[0], m, Rel.EQ, "bond_k_shape", cls_name)
101
- validator.check_int(bond_r0_shape[0], m, Rel.EQ, "bond_r0_shape", cls_name)
102
- return uint_crd_f_shape
103
-
104
- def infer_dtype(self, uint_crd_f_dtype, scaler_f_type, atom_a_type, atom_b_type, bond_k_type, bond_r0_type):
105
- validator.check_tensor_dtype_valid('uint_crd_f', uint_crd_f_dtype, [mstype.uint32], self.name)
106
- validator.check_tensor_dtype_valid('scaler_f', scaler_f_type, [mstype.float32], self.name)
107
- validator.check_tensor_dtype_valid('atom_a', atom_a_type, [mstype.int32], self.name)
108
- validator.check_tensor_dtype_valid('atom_b', atom_b_type, [mstype.int32], self.name)
109
- validator.check_tensor_dtype_valid('bond_k', bond_k_type, [mstype.float32], self.name)
110
- validator.check_tensor_dtype_valid('bond_r0', bond_r0_type, [mstype.float32], self.name)
111
- return bond_r0_type
112
-
113
-
114
- class BondEnergy(PrimitiveWithInfer):
115
- """
116
- Calculate the harmonic potential energy between each bonded atom pair.
117
- Assume our system has n atoms and m harmonic bonds.
118
-
119
- Because there is a large amount of inputs and each of them are related,
120
- there is no way to construct `Examples` using random methods. For details, refer the webpage `SPONGE in MindSpore
121
- <https://gitee.com/mindspore/mindscience/blob/master/MindSPONGE/docs/simple_formula.md>`_.
122
-
123
- .. math::
124
-
125
- dr = (x_1-x_2, y_1-y_2, z_1-z_2)
126
-
127
- .. math::
128
-
129
- E = k*(|dr| - r_0)^2
130
-
131
- Args:
132
- atom_numbers(int32): the number of atoms n.
133
- bond_numbers(int32): the number of harmonic bonds m.
134
-
135
- Inputs:
136
- - **uint_crd_f** (Tensor) - The unsigned int coordinate value of each atom.
137
- The data type is uint32 and the shape is :math:`(n, 3)`.
138
- - **scaler_f** (Tensor) - The 3-D scale factor (x, y, z),
139
- between the real space float coordinates and the unsigned int coordinates.
140
- The data type is float32 and the shape is :math:`(3,)`.
141
- - **atom_a** (Tensor) - The first atom index of each bond.
142
- The data type is int32 and the shape is :math:`(m,)`.
143
- - **atom_b** (Tensor) - The second atom index of each bond.
144
- The data type is int32 and the shape is :math:`(m,)`.
145
- - **bond_k** (Tensor) - The force constant of each bond.
146
- The data type is float32 and the shape is :math:`(m,)`.
147
- - **bond_r0** (Tensor) - The equlibrium length of each bond.
148
- The data type is float32 and the shape is :math:`(m,)`.
149
-
150
- Outputs:
151
- - **bond_ene** (Tensor) - The harmonic potential energy for each bond.
152
- The data type is float32 and the shape is :math:`(m,)`.
153
-
154
- Supported Platforms:
155
- ``GPU``
156
- """
157
-
158
- @prim_attr_register
159
- def __init__(self, bond_numbers, atom_numbers):
160
- """Initialize BondEnergy."""
161
- validator.check_value_type('bond_numbers', bond_numbers, int, self.name)
162
- validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
163
- self.bond_numbers = bond_numbers
164
- self.atom_numbers = atom_numbers
165
- self.add_prim_attr('bond_numbers', self.bond_numbers)
166
- self.add_prim_attr('atom_numbers', self.atom_numbers)
167
- self.init_prim_io_names(inputs=['uint_crd_f', 'scaler_f', 'atom_a', 'atom_b', 'bond_k', 'bond_r0'],
168
- outputs=['bond_ene'])
169
-
170
- def infer_shape(self, uint_crd_f_shape, scaler_f_shape, atom_a_shape, atom_b_shape, bond_k_shape, bond_r0_shape):
171
- cls_name = self.name
172
- n = self.atom_numbers
173
- m = self.bond_numbers
174
- validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
175
- validator.check_int(len(scaler_f_shape), 1, Rel.EQ, "scaler_f_dim", cls_name)
176
- validator.check_int(len(atom_a_shape), 1, Rel.EQ, "atom_a_dim", cls_name)
177
- validator.check_int(len(atom_b_shape), 1, Rel.EQ, "atom_b_dim", cls_name)
178
- validator.check_int(len(bond_k_shape), 1, Rel.EQ, "bond_k_dim", cls_name)
179
- validator.check_int(len(bond_r0_shape), 1, Rel.EQ, "bond_r0_dim", cls_name)
180
- validator.check_int(uint_crd_f_shape[0], n, Rel.EQ, "uint_crd_f_shape[0]", cls_name)
181
- validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
182
- validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name)
183
- validator.check_int(atom_a_shape[0], m, Rel.EQ, "uint_crd_f_shape", cls_name)
184
- validator.check_int(atom_b_shape[0], m, Rel.EQ, "atom_b_shape", cls_name)
185
- validator.check_int(bond_k_shape[0], m, Rel.EQ, "bond_k_shape", cls_name)
186
- validator.check_int(bond_r0_shape[0], m, Rel.EQ, "bond_r0_shape", cls_name)
187
-
188
- return bond_k_shape
189
-
190
- def infer_dtype(self, uint_crd_f_dtype, scaler_f_type, atom_a_type, atom_b_type, bond_k_type, bond_r0_type):
191
- validator.check_tensor_dtype_valid('uint_crd_f', uint_crd_f_dtype, [mstype.uint32], self.name)
192
- validator.check_tensor_dtype_valid('scaler_f', scaler_f_type, [mstype.float32], self.name)
193
- validator.check_tensor_dtype_valid('atom_a', atom_a_type, [mstype.int32], self.name)
194
- validator.check_tensor_dtype_valid('atom_b', atom_b_type, [mstype.int32], self.name)
195
- validator.check_tensor_dtype_valid('bond_k', bond_k_type, [mstype.float32], self.name)
196
- validator.check_tensor_dtype_valid('bond_r0', bond_r0_type, [mstype.float32], self.name)
197
- return bond_r0_type
198
-
199
-
200
- class BondAtomEnergy(PrimitiveWithInfer):
201
- """
202
- Add the potential energy caused by simple harmonic bonds to the total
203
- potential energy of each atom.
204
-
205
- The calculation formula is the same as operator BondEnergy().
206
-
207
- Because there is a large amount of inputs and each of them are related,
208
- there is no way to construct `Examples` using random methods. For details, refer the webpage `SPONGE in MindSpore
209
- <https://gitee.com/mindspore/mindscience/blob/master/MindSPONGE/docs/simple_formula.md>`_.
210
-
211
- Args:
212
- atom_numbers(int32): the number of atoms n.
213
- bond_numbers(int32): the number of harmonic bonds m.
214
-
215
- Inputs:
216
- - **uint_crd_f** (Tensor) - The unsigned int coordinate value of each atom.
217
- The data type is uint32 and the shape is :math:`(n, 3)`.
218
- - **scaler_f** (Tensor) - The 3-D scale factor (x, y, z),
219
- between the real space float coordinates and the unsigned int coordinates.
220
- The data type is float32 and the shape is :math:`(3,)`.
221
- - **atom_a** (Tensor) - The first atom index of each bond.
222
- The data type is int32 and the shape is :math:`(m,)`.
223
- - **atom_b** (Tensor) - The second atom index of each bond.
224
- The data type is int32 and the shape is :math:`(m,)`.
225
- - **bond_k** (Tensor) - The force constant of each bond.
226
- The data type is float32 and the shape is :math:`(m,)`.
227
- - **bond_r0** (Tensor) - The equlibrium length of each bond.
228
- The data type is float32 and the shape is :math:`(m,)`.
229
-
230
- Outputs:
231
- - **atom_ene** (Tensor) - The accumulated potential energy for each atom.
232
- The data type is float32 and the shape is :math:`(n,)`.
233
-
234
- Supported Platforms:
235
- ``GPU``
236
- """
237
-
238
- @prim_attr_register
239
- def __init__(self, bond_numbers, atom_numbers):
240
- """Initialize BondAtomEnergy."""
241
- validator.check_value_type('bond_numbers', bond_numbers, int, self.name)
242
- validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
243
- self.bond_numbers = bond_numbers
244
- self.atom_numbers = atom_numbers
245
- self.add_prim_attr('bond_numbers', self.bond_numbers)
246
- self.add_prim_attr('atom_numbers', self.atom_numbers)
247
- self.init_prim_io_names(inputs=['uint_crd_f', 'scaler_f', 'atom_a', 'atom_b', 'bond_k', 'bond_r0'],
248
- outputs=['atom_ene'])
249
-
250
- def infer_shape(self, uint_crd_f_shape, scaler_f_shape, atom_a_shape, atom_b_shape, bond_k_shape, bond_r0_shape):
251
- cls_name = self.name
252
- n = self.atom_numbers
253
- m = self.bond_numbers
254
- validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
255
- validator.check_int(len(scaler_f_shape), 1, Rel.EQ, "scaler_f_dim", cls_name)
256
- validator.check_int(len(atom_a_shape), 1, Rel.EQ, "atom_a_dim", cls_name)
257
- validator.check_int(len(atom_b_shape), 1, Rel.EQ, "atom_b_dim", cls_name)
258
- validator.check_int(len(bond_k_shape), 1, Rel.EQ, "bond_k_dim", cls_name)
259
- validator.check_int(len(bond_r0_shape), 1, Rel.EQ, "bond_r0_dim", cls_name)
260
- validator.check_int(uint_crd_f_shape[0], n, Rel.EQ, "uint_crd_f_shape[0]", cls_name)
261
- validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
262
- validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name)
263
- validator.check_int(atom_a_shape[0], m, Rel.EQ, "uint_crd_f_shape", cls_name)
264
- validator.check_int(atom_b_shape[0], m, Rel.EQ, "atom_b_shape", cls_name)
265
- validator.check_int(bond_k_shape[0], m, Rel.EQ, "bond_k_shape", cls_name)
266
- validator.check_int(bond_r0_shape[0], m, Rel.EQ, "bond_r0_shape", cls_name)
267
- return [n,]
268
-
269
- def infer_dtype(self, uint_crd_f_dtype, scaler_f_type, atom_a_type, atom_b_type, bond_k_type, bond_r0_type):
270
- validator.check_tensor_dtype_valid('uint_crd_f', uint_crd_f_dtype, [mstype.uint32], self.name)
271
- validator.check_tensor_dtype_valid('scaler_f', scaler_f_type, [mstype.float32], self.name)
272
- validator.check_tensor_dtype_valid('atom_a', atom_a_type, [mstype.int32], self.name)
273
- validator.check_tensor_dtype_valid('atom_b', atom_b_type, [mstype.int32], self.name)
274
- validator.check_tensor_dtype_valid('bond_k', bond_k_type, [mstype.float32], self.name)
275
- validator.check_tensor_dtype_valid('bond_r0', bond_r0_type, [mstype.float32], self.name)
276
- return bond_r0_type
277
-
278
-
279
- class BondForceWithAtomEnergy(PrimitiveWithInfer):
280
- """
281
- Calculate bond force and harmonic potential energy together.
282
-
283
- The calculation formula is the same as operator BondForce() and BondEnergy().
284
-
285
- Because there is a large amount of inputs and each of them are related,
286
- there is no way to construct `Examples` using random methods. For details, refer the webpage `SPONGE in MindSpore
287
- <https://gitee.com/mindspore/mindscience/blob/master/MindSPONGE/docs/simple_formula.md>`_.
288
-
289
- Args:
290
- atom_numbers(int32): the number of atoms n.
291
- bond_numbers(int32): the number of harmonic bonds m.
292
-
293
- Inputs:
294
- - **uint_crd_f** (Tensor) - The unsigned int coordinate value of each atom.
295
- The data type is uint32 and the shape is :math:`(n, 3)`.
296
- - **scaler_f** (Tensor) - The 3-D scale factor (x, y, z),
297
- between the real space float coordinates and the unsigned int coordinates.
298
- The data type is float32 and the shape is :math:`(3,)`.
299
- - **atom_a** (Tensor) - The first atom index of each bond.
300
- The data type is int32 and the shape is :math:`(m,)`.
301
- - **atom_b** (Tensor) - The second atom index of each bond.
302
- The data type is int32 and the shape is :math:`(m,)`.
303
- - **bond_k** (Tensor) - The force constant of each bond.
304
- The data type is float32 and the shape is :math:`(m,)`.
305
- - **bond_r0** (Tensor) - The equlibrium length of each bond.
306
- The data type is float32 and the shape is :math:`(m,)`.
307
-
308
- Outputs:
309
- - **frc_f** (Tensor) - The force felt by each atom.
310
- The data type is float32 and the shape is :math:`(n, 3)`.
311
- - **atom_e** (Tensor) - The accumulated potential energy for each atom.
312
- The data type is float32 and the shape is :math:`(n,)`.
313
-
314
- Supported Platforms:
315
- ``GPU``
316
- """
317
-
318
- @prim_attr_register
319
- def __init__(self, bond_numbers, atom_numbers):
320
- """Initialize BondForceWithAtomEnergy."""
321
- validator.check_value_type('bond_numbers', bond_numbers, int, self.name)
322
- validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
323
- self.bond_numbers = bond_numbers
324
- self.atom_numbers = atom_numbers
325
- self.add_prim_attr('bond_numbers', self.bond_numbers)
326
- self.add_prim_attr('atom_numbers', self.atom_numbers)
327
- self.init_prim_io_names(inputs=['uint_crd_f', 'scaler_f', 'atom_a', 'atom_b', 'bond_k', 'bond_r0'],
328
- outputs=['frc_f', 'atom_e'])
329
-
330
- def infer_shape(self, uint_crd_f_shape, scaler_f_shape, atom_a_shape, atom_b_shape, bond_k_shape, bond_r0_shape):
331
- cls_name = self.name
332
- n = self.atom_numbers
333
- m = self.bond_numbers
334
- validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
335
- validator.check_int(len(scaler_f_shape), 1, Rel.EQ, "scaler_f_dim", cls_name)
336
- validator.check_int(len(atom_a_shape), 1, Rel.EQ, "atom_a_dim", cls_name)
337
- validator.check_int(len(atom_b_shape), 1, Rel.EQ, "atom_b_dim", cls_name)
338
- validator.check_int(len(bond_k_shape), 1, Rel.EQ, "bond_k_dim", cls_name)
339
- validator.check_int(len(bond_r0_shape), 1, Rel.EQ, "bond_r0_dim", cls_name)
340
- validator.check_int(uint_crd_f_shape[0], n, Rel.EQ, "uint_crd_f_shape[0]", cls_name)
341
- validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
342
- validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name)
343
- validator.check_int(atom_a_shape[0], m, Rel.EQ, "uint_crd_f_shape", cls_name)
344
- validator.check_int(atom_b_shape[0], m, Rel.EQ, "atom_b_shape", cls_name)
345
- validator.check_int(bond_k_shape[0], m, Rel.EQ, "bond_k_shape", cls_name)
346
- validator.check_int(bond_r0_shape[0], m, Rel.EQ, "bond_r0_shape", cls_name)
347
-
348
- return uint_crd_f_shape, [n,]
349
-
350
- def infer_dtype(self, uint_crd_f_dtype, scaler_f_type, atom_a_type, atom_b_type, bond_k_type, bond_r0_type):
351
- validator.check_tensor_dtype_valid('uint_crd_f', uint_crd_f_dtype, [mstype.uint32], self.name)
352
- validator.check_tensor_dtype_valid('scaler_f', scaler_f_type, [mstype.float32], self.name)
353
-
354
- validator.check_tensor_dtype_valid('atom_a', atom_a_type, [mstype.int32], self.name)
355
- validator.check_tensor_dtype_valid('atom_b', atom_b_type, [mstype.int32], self.name)
356
-
357
- validator.check_tensor_dtype_valid('bond_k', bond_k_type, [mstype.float32], self.name)
358
- validator.check_tensor_dtype_valid('bond_r0', bond_r0_type, [mstype.float32], self.name)
359
- return bond_r0_type, bond_r0_type
360
-
361
-
362
- class BondForceWithAtomVirial(PrimitiveWithInfer):
363
- """
364
- Calculate bond force and the virial coefficient caused by simple harmonic
365
- bond for each atom together.
366
-
367
- The calculation formula of the force part is the same as operator BondForce().
368
-
369
- Because there is a large amount of inputs and each of them are related,
370
- there is no way to construct `Examples` using random methods. For details, refer the webpage `SPONGE in MindSpore
371
- <https://gitee.com/mindspore/mindscience/blob/master/MindSPONGE/docs/simple_formula.md>`_.
372
-
373
- The Virial part is as follows:
374
-
375
- .. math::
376
-
377
- dr = (x_1-x_2, y_1-y_2, z_1-z_2)
378
-
379
- .. math::
380
-
381
- virial = |dr|*(|dr| - r_0)*k
382
-
383
- Args:
384
- atom_numbers(int32): the number of atoms n.
385
- bond_numbers(int32): the number of harmonic bonds m.
386
-
387
- Inputs:
388
- - **uint_crd_f** (Tensor) - The unsigned int coordinate value of each atom.
389
- The data type is uint32 and the shape is :math:`(n, 3)`.
390
- - **scaler_f** (Tensor) - The 3-D scale factor (x, y, z),
391
- between the real space float coordinates and the unsigned int coordinates.
392
- The data type is float32 and the shape is :math:`(3,)`.
393
- - **atom_a** (Tensor) - The first atom index of each bond.
394
- The data type is int32 and the shape is :math:`(m,)`.
395
- - **atom_b** (Tensor) - The second atom index of each bond.
396
- The data type is int32 and the shape is :math:`(m,)`.
397
- - **bond_k** (Tensor) - The force constant of each bond.
398
- The data type is float32 and the shape is :math:`(m,)`.
399
- - **bond_r0** (Tensor) - The equlibrium length of each bond.
400
- The data type is float32 and the shape is :math:`(m,)`.
401
-
402
- Outputs:
403
- - **frc_f** (Tensor) - Same as operator BondForce().
404
- The data type is float32 and the shape is :math:`(n, 3)`.
405
- - **atom_v** (Tensor) - The accumulated virial coefficient for each atom.
406
- The data type is float32 and the shape is :math:`(n,)`.
407
-
408
- Supported Platforms:
409
- ``GPU``
410
- """
411
-
412
- @prim_attr_register
413
- def __init__(self, bond_numbers, atom_numbers):
414
- """Initialize BondForceWithAtomVirial."""
415
- validator.check_value_type('bond_numbers', bond_numbers, int, self.name)
416
- validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
417
- self.bond_numbers = bond_numbers
418
- self.atom_numbers = atom_numbers
419
- self.add_prim_attr('bond_numbers', self.bond_numbers)
420
- self.add_prim_attr('atom_numbers', self.atom_numbers)
421
- self.init_prim_io_names(inputs=['uint_crd_f', 'scaler_f', 'atom_a', 'atom_b', 'bond_k', 'bond_r0'],
422
- outputs=['frc_f', 'atom_v'])
423
-
424
- def infer_shape(self, uint_crd_f_shape, scaler_f_shape, atom_a_shape, atom_b_shape, bond_k_shape, bond_r0_shape):
425
- cls_name = self.name
426
- n = self.atom_numbers
427
- m = self.bond_numbers
428
- validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
429
- validator.check_int(len(scaler_f_shape), 1, Rel.EQ, "scaler_f_dim", cls_name)
430
- validator.check_int(len(atom_a_shape), 1, Rel.EQ, "atom_a_dim", cls_name)
431
- validator.check_int(len(atom_b_shape), 1, Rel.EQ, "atom_b_dim", cls_name)
432
- validator.check_int(len(bond_k_shape), 1, Rel.EQ, "bond_k_dim", cls_name)
433
- validator.check_int(len(bond_r0_shape), 1, Rel.EQ, "bond_r0_dim", cls_name)
434
- validator.check_int(uint_crd_f_shape[0], n, Rel.EQ, "uint_crd_f_shape[0]", cls_name)
435
- validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
436
- validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name)
437
- validator.check_int(atom_a_shape[0], m, Rel.EQ, "uint_crd_f_shape", cls_name)
438
- validator.check_int(atom_b_shape[0], m, Rel.EQ, "atom_b_shape", cls_name)
439
- validator.check_int(bond_k_shape[0], m, Rel.EQ, "bond_k_shape", cls_name)
440
- validator.check_int(bond_r0_shape[0], m, Rel.EQ, "bond_r0_shape", cls_name)
441
-
442
- return uint_crd_f_shape, [n,]
443
-
444
- def infer_dtype(self, uint_crd_f_dtype, scaler_f_type, atom_a_type, atom_b_type, bond_k_type, bond_r0_type):
445
- validator.check_tensor_dtype_valid('uint_crd_f', uint_crd_f_dtype, [mstype.uint32], self.name)
446
- validator.check_tensor_dtype_valid('scaler_f', scaler_f_type, [mstype.float32], self.name)
447
-
448
- validator.check_tensor_dtype_valid('atom_a', atom_a_type, [mstype.int32], self.name)
449
- validator.check_tensor_dtype_valid('atom_b', atom_b_type, [mstype.int32], self.name)
450
-
451
- validator.check_tensor_dtype_valid('bond_k', bond_k_type, [mstype.float32], self.name)
452
- validator.check_tensor_dtype_valid('bond_r0', bond_r0_type, [mstype.float32], self.name)
453
- return bond_r0_type, bond_r0_type
454
-
455
-
456
- class DihedralForce(PrimitiveWithInfer):
457
- """
458
- Calculate the force exerted by the dihedral term which made of 4-atoms
459
- on the corresponding atoms. Assume the number of dihedral terms is m and
460
- the number of atoms is n.
461
-
462
- Because there is a large amount of inputs and each of them are related,
463
- there is no way to construct `Examples` using random methods. For details, refer the webpage `SPONGE in MindSpore
464
- <https://gitee.com/mindspore/mindscience/blob/master/MindSPONGE/docs/simple_formula.md>`_.
465
-
466
- Args:
467
- dihedral_numbers(int32): the number of dihedral terms m.
468
-
469
- Inputs:
470
- - **uint_crd_f** (Tensor) - The unsigned int coordinates
471
- value of each atom. The data type is uint32 and the shape is :math:`(n, 3)`.
472
- - **scaler_f** (Tensor) - The 3-D scale factor between
473
- the real space float coordinates and the unsigned int coordinates.
474
- The data type is float32 and the shape is :math:`(3,)`.
475
- - **atom_a** (Tensor) - The 1st atom index of each dihedral.
476
- The data type is int32 and the shape is :math:`(m,)`.
477
- - **atom_b** (Tensor) - The 2nd atom index of each dihedral.
478
- The data type is int32 and the shape is :math:`(m,)`.
479
- - **atom_c** (Tensor) - The 3rd atom index of each dihedral.
480
- The data type is int32 and the shape is :math:`(m,)`.
481
- - **atom_d** (Tensor) - The 4th atom index of each dihedral.
482
- 4 atoms are connected in the form a-b-c-d.
483
- The data type is int32 and the shape is :math:`(m,)`.
484
- - **ipn** (Tensor) - The period of dihedral angle of each dihedral.
485
- The data type is int32 and the shape is :math:`(m,)`.
486
- - **pk** (Tensor) - The force constant of each dihedral.
487
- The data type is float32 and the shape is :math:`(m,)`.
488
- - **gamc** (Tensor) - k*cos(phi_0) of each dihedral.
489
- The data type is float32 and the shape is :math:`(m,)`.
490
- - **gams** (Tensor) - k*sin(phi_0) of each dihedral.
491
- The data type is float32 and the shape is :math:`(m,)`.
492
- - **pn** (Tensor) - The floating point form of ipn.
493
- The data type is float32 and the shape is :math:`(m,)`.
494
-
495
- Outputs:
496
- - **frc_f** (Tensor) - The force felt by each atom.
497
- The data type is float32 and the shape is :math:`(n, 3)`.
498
-
499
- Supported Platforms:
500
- ``GPU``
501
- """
502
-
503
- @prim_attr_register
504
- def __init__(self, dihedral_numbers):
505
- """Initialize DihedralForce."""
506
- validator.check_value_type('dihedral_numbers', dihedral_numbers, int, self.name)
507
- self.dihedral_numbers = dihedral_numbers
508
- self.init_prim_io_names(inputs=['uint_crd_f', 'scaler_f', 'atom_a', 'atom_b', 'atom_c', 'atom_d', 'ipn', 'pk',
509
- 'gamc', 'gams', 'pn'],
510
- outputs=['frc_f'])
511
- self.add_prim_attr('dihedral_numbers', self.dihedral_numbers)
512
-
513
- def infer_shape(self, uint_crd_f_shape, scaler_f_shape, atom_a_shape, atom_b_shape, atom_c_shape, atom_d_shape,
514
- ipn_shape, pk_shape, gamc_shape, gams_shape, pn_shape):
515
- cls_name = self.name
516
- m = self.dihedral_numbers
517
- validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
518
- validator.check_int(len(scaler_f_shape), 1, Rel.EQ, "scaler_f_dim", cls_name)
519
- validator.check_int(len(atom_a_shape), 1, Rel.EQ, "atom_a_dim", cls_name)
520
- validator.check_int(len(atom_b_shape), 1, Rel.EQ, "atom_b_dim", cls_name)
521
- validator.check_int(len(atom_c_shape), 1, Rel.EQ, "atom_c_dim", cls_name)
522
- validator.check_int(len(atom_d_shape), 1, Rel.EQ, "atom_d_dim", cls_name)
523
- validator.check_int(len(ipn_shape), 1, Rel.EQ, "ipn_dim", cls_name)
524
- validator.check_int(len(pk_shape), 1, Rel.EQ, "pk_dim", cls_name)
525
- validator.check_int(len(gamc_shape), 1, Rel.EQ, "gamc_dim", cls_name)
526
- validator.check_int(len(gams_shape), 1, Rel.EQ, "gams_dim", cls_name)
527
- validator.check_int(len(pn_shape), 1, Rel.EQ, "pn_dim", cls_name)
528
-
529
- validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
530
- validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name)
531
- validator.check_int(atom_a_shape[0], m, Rel.EQ, "atom_a_shape", cls_name)
532
- validator.check_int(atom_b_shape[0], m, Rel.EQ, "atom_b_shape", cls_name)
533
- validator.check_int(atom_c_shape[0], m, Rel.EQ, "atom_c_shape", cls_name)
534
- validator.check_int(atom_d_shape[0], m, Rel.EQ, "atom_d_shape", cls_name)
535
- validator.check_int(ipn_shape[0], m, Rel.EQ, "ipn_shape", cls_name)
536
- validator.check_int(pk_shape[0], m, Rel.EQ, "pk_shape", cls_name)
537
- validator.check_int(gamc_shape[0], m, Rel.EQ, "gamc_shape", cls_name)
538
- validator.check_int(gams_shape[0], m, Rel.EQ, "gams_shape", cls_name)
539
- validator.check_int(pn_shape[0], m, Rel.EQ, "pn_shape", cls_name)
540
- return uint_crd_f_shape
541
-
542
- def infer_dtype(self, uint_crd_f_dtype, scaler_f_type, atom_a_type, atom_b_type, atom_c_type, atom_d_type,
543
- ipn_type, pk_type, gamc_type, gams_type, pn_type):
544
- validator.check_tensor_dtype_valid('uint_crd_f', uint_crd_f_dtype, [mstype.uint32], self.name)
545
- validator.check_tensor_dtype_valid('scaler_f', scaler_f_type, [mstype.float32], self.name)
546
- validator.check_tensor_dtype_valid('atom_a', atom_a_type, [mstype.int32], self.name)
547
- validator.check_tensor_dtype_valid('atom_b', atom_b_type, [mstype.int32], self.name)
548
- validator.check_tensor_dtype_valid('atom_c', atom_c_type, [mstype.int32], self.name)
549
- validator.check_tensor_dtype_valid('atom_d', atom_d_type, [mstype.int32], self.name)
550
- validator.check_tensor_dtype_valid('ipn', ipn_type, [mstype.int32], self.name)
551
- validator.check_tensor_dtype_valid('pk', pk_type, [mstype.float32], self.name)
552
- validator.check_tensor_dtype_valid('gamc', gamc_type, [mstype.float32], self.name)
553
- validator.check_tensor_dtype_valid('gams', gams_type, [mstype.float32], self.name)
554
- validator.check_tensor_dtype_valid('pn', pn_type, [mstype.float32], self.name)
555
-
556
- return pn_type
557
-
558
-
559
- class DihedralEnergy(PrimitiveWithInfer):
560
- """
561
- Calculate the potential energy caused by dihedral terms for each 4-atom pair.
562
- Assume our system has n atoms and m dihedral terms.
563
-
564
- Because there is a large amount of inputs and each of them are related,
565
- there is no way to construct `Examples` using random methods. For details, refer the webpage `SPONGE in MindSpore
566
- <https://gitee.com/mindspore/mindscience/blob/master/MindSPONGE/docs/simple_formula.md>`_.
567
-
568
- Args:
569
- dihedral_numbers(int32): the number of dihedral terms m.
570
-
571
- Inputs:
572
- - **uint_crd_f** (Tensor) - The unsigned int coordinates
573
- value of each atom.
574
- The data type is uint32 and the shape is :math:`(n, 3)`.
575
- - **scaler_f** (Tensor) - The 3-D scale factor between
576
- the real space float coordinates and the unsigned int coordinates.
577
- The data type is float32 and the shape is :math:`(3,)`.
578
- - **atom_a** (Tensor) - The 1st atom index of each dihedral.
579
- The data type is int32 and the shape is :math:`(m,)`.
580
- - **atom_b** (Tensor) - The 2nd atom index of each dihedral.
581
- The data type is int32 and the shape is :math:`(m,)`.
582
- - **atom_c** (Tensor) - The 3rd atom index of each dihedral.
583
- The data type is int32 and the shape is :math:`(m,)`.
584
- - **atom_d** (Tensor) - The 4th atom index of each dihedral.
585
- 4 atoms are connected in the form a-b-c-d.
586
- The data type is int32 and the shape is :math:`(m,)`.
587
- - **ipn** (Tensor) - The period of dihedral angle of each dihedral.
588
- The data type is int32 and the shape is :math:`(m,)`.
589
- - **pk** (Tensor) - The force constant of each dihedral.
590
- The data type is int32 and the shape is :math:`(m,)`.
591
- - **gamc** (Tensor) - k*cos(phi_0) of each dihedral.
592
- The data type is float32 and the shape is :math:`(m,)`.
593
- - **gams** (Tensor) - k*sin(phi_0) of each dihedral.
594
- The data type is float32 and the shape is :math:`(m,)`.
595
- - **pn** (Tensor) - The floating point form of ipn.
596
- The data type is float32 and the shape is :math:`(m,)`.
597
-
598
- Outputs:
599
- - **ene** (Tensor) - The potential energy for each
600
- dihedral term. The data type is float32 and the shape is :math:`(m,)`.
601
-
602
- Supported Platforms:
603
- ``GPU``
604
- """
605
-
606
- @prim_attr_register
607
- def __init__(self, dihedral_numbers):
608
- """Initialize DihedralEnergy."""
609
- validator.check_value_type('dihedral_numbers', dihedral_numbers, int, self.name)
610
- self.dihedral_numbers = dihedral_numbers
611
- self.init_prim_io_names(inputs=['uint_crd_f', 'scaler_f', 'atom_a', 'atom_b', 'atom_c', 'atom_d', 'ipn', 'pk',
612
- 'gamc', 'gams', 'pn'],
613
- outputs=['ene'])
614
- self.add_prim_attr('dihedral_numbers', self.dihedral_numbers)
615
-
616
- def infer_shape(self, uint_crd_f_shape, scaler_f_shape, atom_a_shape, atom_b_shape, atom_c_shape, atom_d_shape,
617
- ipn_shape, pk_shape, gamc_shape, gams_shape, pn_shape):
618
- cls_name = self.name
619
- m = self.dihedral_numbers
620
- validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
621
- validator.check_int(len(scaler_f_shape), 1, Rel.EQ, "scaler_f_dim", cls_name)
622
- validator.check_int(len(atom_a_shape), 1, Rel.EQ, "atom_a_dim", cls_name)
623
- validator.check_int(len(atom_b_shape), 1, Rel.EQ, "atom_b_dim", cls_name)
624
- validator.check_int(len(atom_c_shape), 1, Rel.EQ, "atom_c_dim", cls_name)
625
- validator.check_int(len(atom_d_shape), 1, Rel.EQ, "atom_d_dim", cls_name)
626
- validator.check_int(len(ipn_shape), 1, Rel.EQ, "ipn_dim", cls_name)
627
- validator.check_int(len(pk_shape), 1, Rel.EQ, "pk_dim", cls_name)
628
- validator.check_int(len(gamc_shape), 1, Rel.EQ, "gamc_dim", cls_name)
629
- validator.check_int(len(gams_shape), 1, Rel.EQ, "gams_dim", cls_name)
630
- validator.check_int(len(pn_shape), 1, Rel.EQ, "pn_dim", cls_name)
631
-
632
- validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
633
- validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name)
634
- validator.check_int(atom_a_shape[0], m, Rel.EQ, "atom_a_shape", cls_name)
635
- validator.check_int(atom_b_shape[0], m, Rel.EQ, "atom_b_shape", cls_name)
636
- validator.check_int(atom_c_shape[0], m, Rel.EQ, "atom_c_shape", cls_name)
637
- validator.check_int(atom_d_shape[0], m, Rel.EQ, "atom_d_shape", cls_name)
638
- validator.check_int(ipn_shape[0], m, Rel.EQ, "ipn_shape", cls_name)
639
- validator.check_int(pk_shape[0], m, Rel.EQ, "pk_shape", cls_name)
640
- validator.check_int(gamc_shape[0], m, Rel.EQ, "gamc_shape", cls_name)
641
- validator.check_int(gams_shape[0], m, Rel.EQ, "gams_shape", cls_name)
642
- validator.check_int(pn_shape[0], m, Rel.EQ, "pn_shape", cls_name)
643
- return [m,]
644
-
645
- def infer_dtype(self, uint_crd_f_dtype, scaler_f_type, atom_a_type, atom_b_type, atom_c_type, atom_d_type,
646
- ipn_type, pk_type, gamc_type, gams_type, pn_type):
647
- validator.check_tensor_dtype_valid('uint_crd_f', uint_crd_f_dtype, [mstype.uint32], self.name)
648
- validator.check_tensor_dtype_valid('scaler_f', scaler_f_type, [mstype.float32], self.name)
649
- validator.check_tensor_dtype_valid('atom_a', atom_a_type, [mstype.int32], self.name)
650
- validator.check_tensor_dtype_valid('atom_b', atom_b_type, [mstype.int32], self.name)
651
- validator.check_tensor_dtype_valid('atom_c', atom_c_type, [mstype.int32], self.name)
652
- validator.check_tensor_dtype_valid('atom_d', atom_d_type, [mstype.int32], self.name)
653
- validator.check_tensor_dtype_valid('ipn', ipn_type, [mstype.int32], self.name)
654
- validator.check_tensor_dtype_valid('pk', pk_type, [mstype.float32], self.name)
655
- validator.check_tensor_dtype_valid('gamc', gamc_type, [mstype.float32], self.name)
656
- validator.check_tensor_dtype_valid('gams', gams_type, [mstype.float32], self.name)
657
- validator.check_tensor_dtype_valid('pn', pn_type, [mstype.float32], self.name)
658
-
659
- return pn_type
660
-
661
-
662
- class DihedralAtomEnergy(PrimitiveWithInfer):
663
- """
664
- Add the potential energy caused by dihedral terms to the total potential
665
- energy of each atom.
666
-
667
- Because there is a large amount of inputs and each of them are related,
668
- there is no way to construct `Examples` using random methods. For details, refer the webpage `SPONGE in MindSpore
669
- <https://gitee.com/mindspore/mindscience/blob/master/MindSPONGE/docs/simple_formula.md>`_.
670
-
671
- The calculation formula is the same as operator DihedralEnergy().
672
-
673
- Args:
674
- dihedral_numbers(int32): the number of dihedral terms m.
675
-
676
- Inputs:
677
- - **uint_crd_f** (Tensor) - The unsigned int coordinates
678
- value of each atom. The data type is uint32 and the shape is :math:`(n, 3)`.
679
- - **scaler_f** (Tensor) - The 3-D scale factor between
680
- the real space float coordinates and the unsigned int coordinates.
681
- The data type is float32 and the shape is :math:`(3,)`.
682
- - **atom_a** (Tensor) - The 1st atom index of each dihedral.
683
- The data type is int32 and the shape is :math:`(m,)`.
684
- - **atom_b** (Tensor) - The 2nd atom index of each dihedral.
685
- The data type is int32 and the shape is :math:`(m,)`.
686
- - **atom_c** (Tensor) - The 3rd atom index of each dihedral.
687
- The data type is int32 and the shape is :math:`(m,)`.
688
- - **atom_d** (Tensor) - The 4th atom index of each dihedral.
689
- 4 atoms are connected in the form a-b-c-d. The data type is int32 and the shape is :math:`(m,)`.
690
- - **ipn** (Tensor) - The period of dihedral angle of each dihedral.
691
- The data type is int32 and the shape is :math:`(m,)`.
692
- - **pk** (Tensor) - The force constant of each dihedral.
693
- The data type is float32 and the shape is :math:`(m,)`.
694
- - **gamc** (Tensor) - k*cos(phi_0) of each dihedral.
695
- The data type is float32 and the shape is :math:`(m,)`.
696
- - **gams** (Tensor) - k*sin(phi_0) of each dihedral.
697
- The data type is float32 and the shape is :math:`(m,)`.
698
- - **pn** (Tensor) - The floating point form of ipn.
699
- The data type is float32 and the shape is :math:`(m,)`.
700
-
701
- Outputs:
702
- - **ene** (Tensor) - The accumulated potential
703
- energy for each atom. The data type is float32 and the shape is :math:`(n,)`.
704
-
705
- Supported Platforms:
706
- ``GPU``
707
- """
708
-
709
- @prim_attr_register
710
- def __init__(self, dihedral_numbers):
711
- """Initialize DihedralAtomEnergy."""
712
- validator.check_value_type('dihedral_numbers', dihedral_numbers, int, self.name)
713
- self.dihedral_numbers = dihedral_numbers
714
- self.init_prim_io_names(inputs=['uint_crd_f', 'scaler_f', 'atom_a', 'atom_b', 'atom_c', 'atom_d', 'ipn', 'pk',
715
- 'gamc', 'gams', 'pn'],
716
- outputs=['ene'])
717
- self.add_prim_attr('dihedral_numbers', self.dihedral_numbers)
718
-
719
- def infer_shape(self, uint_crd_f_shape, scaler_f_shape, atom_a_shape, atom_b_shape, atom_c_shape, atom_d_shape,
720
- ipn_shape, pk_shape, gamc_shape, gams_shape, pn_shape):
721
- cls_name = self.name
722
- n = uint_crd_f_shape[0]
723
- m = self.dihedral_numbers
724
- validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
725
- validator.check_int(len(scaler_f_shape), 1, Rel.EQ, "scaler_f_dim", cls_name)
726
- validator.check_int(len(atom_a_shape), 1, Rel.EQ, "atom_a_dim", cls_name)
727
- validator.check_int(len(atom_b_shape), 1, Rel.EQ, "atom_b_dim", cls_name)
728
- validator.check_int(len(atom_c_shape), 1, Rel.EQ, "atom_c_dim", cls_name)
729
- validator.check_int(len(atom_d_shape), 1, Rel.EQ, "atom_d_dim", cls_name)
730
- validator.check_int(len(ipn_shape), 1, Rel.EQ, "ipn_dim", cls_name)
731
- validator.check_int(len(pk_shape), 1, Rel.EQ, "pk_dim", cls_name)
732
- validator.check_int(len(gamc_shape), 1, Rel.EQ, "gamc_dim", cls_name)
733
- validator.check_int(len(gams_shape), 1, Rel.EQ, "gams_dim", cls_name)
734
- validator.check_int(len(pn_shape), 1, Rel.EQ, "pn_dim", cls_name)
735
-
736
- validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
737
- validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name)
738
- validator.check_int(atom_a_shape[0], m, Rel.EQ, "atom_a_shape", cls_name)
739
- validator.check_int(atom_b_shape[0], m, Rel.EQ, "atom_b_shape", cls_name)
740
- validator.check_int(atom_c_shape[0], m, Rel.EQ, "atom_c_shape", cls_name)
741
- validator.check_int(atom_d_shape[0], m, Rel.EQ, "atom_d_shape", cls_name)
742
- validator.check_int(ipn_shape[0], m, Rel.EQ, "ipn_shape", cls_name)
743
- validator.check_int(pk_shape[0], m, Rel.EQ, "pk_shape", cls_name)
744
- validator.check_int(gamc_shape[0], m, Rel.EQ, "gamc_shape", cls_name)
745
- validator.check_int(gams_shape[0], m, Rel.EQ, "gams_shape", cls_name)
746
- validator.check_int(pn_shape[0], m, Rel.EQ, "pn_shape", cls_name)
747
- return [n,]
748
-
749
- def infer_dtype(self, uint_crd_f_dtype, scaler_f_type, atom_a_type, atom_b_type, atom_c_type, atom_d_type,
750
- ipn_type, pk_type, gamc_type, gams_type, pn_type):
751
- validator.check_tensor_dtype_valid('uint_crd_f', uint_crd_f_dtype, [mstype.uint32], self.name)
752
- validator.check_tensor_dtype_valid('scaler_f', scaler_f_type, [mstype.float32], self.name)
753
- validator.check_tensor_dtype_valid('atom_a', atom_a_type, [mstype.int32], self.name)
754
- validator.check_tensor_dtype_valid('atom_b', atom_b_type, [mstype.int32], self.name)
755
- validator.check_tensor_dtype_valid('atom_c', atom_c_type, [mstype.int32], self.name)
756
- validator.check_tensor_dtype_valid('atom_d', atom_d_type, [mstype.int32], self.name)
757
- validator.check_tensor_dtype_valid('ipn', ipn_type, [mstype.int32], self.name)
758
- validator.check_tensor_dtype_valid('pk', pk_type, [mstype.float32], self.name)
759
- validator.check_tensor_dtype_valid('gamc', gamc_type, [mstype.float32], self.name)
760
- validator.check_tensor_dtype_valid('gams', gams_type, [mstype.float32], self.name)
761
- validator.check_tensor_dtype_valid('pn', pn_type, [mstype.float32], self.name)
762
-
763
- return pn_type
764
-
765
-
766
- class DihedralForceWithAtomEnergy(PrimitiveWithInfer):
767
- """
768
- Calculate dihedral force and potential energy together.
769
-
770
- The calculation formula is the same as operator DihedralForce() and DihedralEnergy().
771
-
772
- Because there is a large amount of inputs and each of them are related,
773
- there is no way to construct `Examples` using random methods. For details, refer the webpage `SPONGE in MindSpore
774
- <https://gitee.com/mindspore/mindscience/blob/master/MindSPONGE/docs/simple_formula.md>`_.
775
-
776
- Args:
777
- dihedral_numbers(int32): the number of dihedral terms m.
778
-
779
- Inputs:
780
- - **uint_crd_f** (Tensor) - The unsigned int coordinates
781
- value of each atom. The data type is uint32 and the shape is :math:`(n, 3)`.
782
- - **scaler_f** (Tensor) - The 3-D scale factor between
783
- the real space float coordinates and the unsigned int coordinates.
784
- The data type is float32 and the shape is :math:`(3,)`.
785
- - **atom_a** (Tensor) - The 1st atom index of each dihedral.
786
- The data type is int32 and the shape is :math:`(m,)`.
787
- - **atom_b** (Tensor) - The 2nd atom index of each dihedral.
788
- The data type is int32 and the shape is :math:`(m,)`.
789
- - **atom_c** (Tensor) - The 3rd atom index of each dihedral.
790
- The data type is int32 and the shape is :math:`(m,)`.
791
- - **atom_d** (Tensor) - The 4th atom index of each dihedral.
792
- 4 atoms are connected in the form a-b-c-d. The data type is int32 and the shape is :math:`(m,)`.
793
- - **ipn** (Tensor) - The period of dihedral angle of each dihedral.
794
- The data type is int32 and the shape is :math:`(m,)`.
795
- - **pk** (Tensor) - The force constant of each dihedral.
796
- The data type is float32 and the shape is :math:`(m,)`.
797
- - **gamc** (Tensor) - k*cos(phi_0) of each dihedral.
798
- The data type is float32 and the shape is :math:`(m,)`.
799
- - **gams** (Tensor) - k*sin(phi_0) of each dihedral.
800
- The data type is float32 and the shape is :math:`(m,)`.
801
- - **pn** (Tensor) - The floating point form of ipn.
802
- The data type is float32 and the shape is :math:`(m,)`.
803
-
804
- Outputs:
805
- - **frc_f** (Tensor) - Same as operator DihedralForce().
806
- The data type is float32 and the shape is :math:`(n, 3)`.
807
- - **ene** (Tensor) - Same as operator DihedralAtomEnergy().
808
- The data type is float32 and the shape is :math:`(n,)`.
809
-
810
- Supported Platforms:
811
- ``GPU``
812
- """
813
-
814
- @prim_attr_register
815
- def __init__(self, dihedral_numbers):
816
- """Initialize DihedralForceWithAtomEnergy."""
817
- validator.check_value_type('dihedral_numbers', dihedral_numbers, int, self.name)
818
- self.dihedral_numbers = dihedral_numbers
819
- self.init_prim_io_names(inputs=['uint_crd_f', 'scaler_f', 'atom_a', 'atom_b', 'atom_c', 'atom_d', 'ipn', 'pk',
820
- 'gamc', 'gams', 'pn'],
821
- outputs=['frc_f', 'ene'])
822
- self.add_prim_attr('dihedral_numbers', self.dihedral_numbers)
823
-
824
- def infer_shape(self, uint_crd_f_shape, scaler_f_shape, atom_a_shape, atom_b_shape, atom_c_shape, atom_d_shape,
825
- ipn_shape, pk_shape, gamc_shape, gams_shape, pn_shape):
826
- cls_name = self.name
827
- n = uint_crd_f_shape[0]
828
- m = self.dihedral_numbers
829
- validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
830
- validator.check_int(len(scaler_f_shape), 1, Rel.EQ, "scaler_f_dim", cls_name)
831
- validator.check_int(len(atom_a_shape), 1, Rel.EQ, "atom_a_dim", cls_name)
832
- validator.check_int(len(atom_b_shape), 1, Rel.EQ, "atom_b_dim", cls_name)
833
- validator.check_int(len(atom_c_shape), 1, Rel.EQ, "atom_c_dim", cls_name)
834
- validator.check_int(len(atom_d_shape), 1, Rel.EQ, "atom_d_dim", cls_name)
835
- validator.check_int(len(ipn_shape), 1, Rel.EQ, "ipn_dim", cls_name)
836
- validator.check_int(len(pk_shape), 1, Rel.EQ, "pk_dim", cls_name)
837
- validator.check_int(len(gamc_shape), 1, Rel.EQ, "gamc_dim", cls_name)
838
- validator.check_int(len(gams_shape), 1, Rel.EQ, "gams_dim", cls_name)
839
- validator.check_int(len(pn_shape), 1, Rel.EQ, "pn_dim", cls_name)
840
-
841
- validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
842
- validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name)
843
- validator.check_int(atom_a_shape[0], m, Rel.EQ, "atom_a_shape", cls_name)
844
- validator.check_int(atom_b_shape[0], m, Rel.EQ, "atom_b_shape", cls_name)
845
- validator.check_int(atom_c_shape[0], m, Rel.EQ, "atom_c_shape", cls_name)
846
- validator.check_int(atom_d_shape[0], m, Rel.EQ, "atom_d_shape", cls_name)
847
- validator.check_int(ipn_shape[0], m, Rel.EQ, "ipn_shape", cls_name)
848
- validator.check_int(pk_shape[0], m, Rel.EQ, "pk_shape", cls_name)
849
- validator.check_int(gamc_shape[0], m, Rel.EQ, "gamc_shape", cls_name)
850
- validator.check_int(gams_shape[0], m, Rel.EQ, "gams_shape", cls_name)
851
- validator.check_int(pn_shape[0], m, Rel.EQ, "pn_shape", cls_name)
852
- return uint_crd_f_shape, [n,]
853
-
854
- def infer_dtype(self, uint_crd_f_dtype, scaler_f_type, atom_a_type, atom_b_type, atom_c_type, atom_d_type,
855
- ipn_type, pk_type, gamc_type, gams_type, pn_type):
856
- validator.check_tensor_dtype_valid('uint_crd_f', uint_crd_f_dtype, [mstype.uint32], self.name)
857
- validator.check_tensor_dtype_valid('scaler_f', scaler_f_type, [mstype.float32], self.name)
858
- validator.check_tensor_dtype_valid('atom_a', atom_a_type, [mstype.int32], self.name)
859
- validator.check_tensor_dtype_valid('atom_b', atom_b_type, [mstype.int32], self.name)
860
- validator.check_tensor_dtype_valid('atom_c', atom_c_type, [mstype.int32], self.name)
861
- validator.check_tensor_dtype_valid('atom_d', atom_d_type, [mstype.int32], self.name)
862
- validator.check_tensor_dtype_valid('ipn', ipn_type, [mstype.int32], self.name)
863
- validator.check_tensor_dtype_valid('pk', pk_type, [mstype.float32], self.name)
864
- validator.check_tensor_dtype_valid('gamc', gamc_type, [mstype.float32], self.name)
865
- validator.check_tensor_dtype_valid('gams', gams_type, [mstype.float32], self.name)
866
- validator.check_tensor_dtype_valid('pn', pn_type, [mstype.float32], self.name)
867
-
868
- return pn_type, pn_type
869
-
870
-
871
- class AngleForce(PrimitiveWithInfer):
872
- """
873
- Calculate the force exerted by angles made of 3 atoms on the
874
- corresponding atoms. Assume the number of angles is m and the
875
- number of atoms is n.
876
-
877
- Because there is a large amount of inputs and each of them are related,
878
- there is no way to construct `Examples` using random methods. For details, refer the webpage `SPONGE in MindSpore
879
- <https://gitee.com/mindspore/mindscience/blob/master/MindSPONGE/docs/simple_formula.md>`_.
880
-
881
- .. math::
882
- dr_{ab} = (x_b-x_a, y_b-y_a, z_b-z_a)
883
- .. math::
884
- dr_{cb} = (x_b-x_c, y_b-y_c, z_b-z_c)
885
- .. math::
886
- theta = arccos(inner_product(dr_{ab}, dr_{cb})/|dr_{ab}|/|dr_{cb}|)
887
- .. math::
888
- F_a = -2*k*(theta-theta_0)/sin(theta)*[cos(theta)/|dr_{ab}|^2*dr_{ab}
889
- - 1/|dr_{ab}|/|dr_{cb}|*dr_{cb}]
890
- .. math::
891
- F_c = -2*k*(theta-theta_0)/sin(theta)*[cos(theta)/|dr_{cb}|^2*dr_{cb}
892
- - 1/|dr_{cb}|/|dr_{ab}|*dr_{ab}]
893
- .. math::
894
- F_b = -F_a - F_c
895
-
896
- Args:
897
- angle_numbers(int32): the number of angles m.
898
-
899
- Inputs:
900
- - **uint_crd_f** (Tensor) - The unsigned int coordinate value of each atom.
901
- The data type is uint32 and the shape is :math:`(n, 3)`.
902
- - **scaler_f** (Tensor) - The 3-D scale factor between
903
- the real space float coordinates and the unsigned int coordinates.
904
- The data type is float32 and the shape is :math:`(3,)`.
905
- - **atom_a** (Tensor) - The 1st atom index of each angle.
906
- The data type is int32 and the shape is :math:`(m,)`.
907
- - **atom_b** (Tensor) - The 2nd and the central atom index of each angle.
908
- The data type is int32 and the shape is :math:`(m,)`.
909
- - **atom_c** (Tensor) - The 3rd atom index of each angle.
910
- The data type is int32 and the shape is :math:`(m,)`.
911
- - **angle_k** (Tensor) - The force constant for each angle.
912
- The data type is float32 and the shape is :math:`(m,)`.
913
- - **angle_theta0** (Tensor) - The equilibrium position value for each angle.
914
- The data type is float32 and the shape is :math:`(m,)`.
915
-
916
- Outputs:
917
- - **frc_f** (Tensor) - The force felt by each atom.
918
- The data type is float32 and the shape is :math:`(n, 3)`.
919
-
920
- Supported Platforms:
921
- ``GPU``
922
- """
923
-
924
- @prim_attr_register
925
- def __init__(self, angle_numbers):
926
- """Initialize AngleForce."""
927
- validator.check_value_type('angle_numbers', angle_numbers, int, self.name)
928
- self.angle_numbers = angle_numbers
929
- self.init_prim_io_names(inputs=['uint_crd_f', 'scaler_f', 'atom_a', 'atom_b', 'atom_c', 'angle_k',
930
- 'angle_theta0'],
931
- outputs=['frc_f'])
932
- self.add_prim_attr('angle_numbers', self.angle_numbers)
933
-
934
- def infer_shape(self, uint_crd_f_shape, scaler_f_shape, atom_a_shape, atom_b_shape, atom_c_shape, angle_k_shape,
935
- angle_theta0_shape):
936
- cls_name = self.name
937
- m = self.angle_numbers
938
- validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
939
- validator.check_int(len(scaler_f_shape), 1, Rel.EQ, "scaler_f_dim", cls_name)
940
- validator.check_int(len(atom_a_shape), 1, Rel.EQ, "atom_a_dim", cls_name)
941
- validator.check_int(len(atom_b_shape), 1, Rel.EQ, "atom_b_dim", cls_name)
942
- validator.check_int(len(atom_c_shape), 1, Rel.EQ, "atom_c_dim", cls_name)
943
- validator.check_int(len(angle_k_shape), 1, Rel.EQ, "angle_k_dim", cls_name)
944
- validator.check_int(len(angle_theta0_shape), 1, Rel.EQ, "angle_theta0_dim", cls_name)
945
-
946
- validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
947
- validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name)
948
- validator.check_int(atom_a_shape[0], m, Rel.EQ, "atom_a_shape", cls_name)
949
- validator.check_int(atom_b_shape[0], m, Rel.EQ, "atom_b_shape", cls_name)
950
- validator.check_int(atom_c_shape[0], m, Rel.EQ, "atom_c_shape", cls_name)
951
- validator.check_int(angle_k_shape[0], m, Rel.EQ, "angle_k_shape", cls_name)
952
- validator.check_int(angle_theta0_shape[0], m, Rel.EQ, "angle_theta0_shape", cls_name)
953
- return uint_crd_f_shape
954
-
955
- def infer_dtype(self, uint_crd_f_dtype, scaler_f_type, atom_a_type, atom_b_type, atom_c_type, angle_k_type,
956
- angle_theta0_type):
957
- validator.check_tensor_dtype_valid('uint_crd_f', uint_crd_f_dtype, [mstype.uint32], self.name)
958
- validator.check_tensor_dtype_valid('scaler_f', scaler_f_type, [mstype.float32], self.name)
959
- validator.check_tensor_dtype_valid('atom_a', atom_a_type, [mstype.int32], self.name)
960
- validator.check_tensor_dtype_valid('atom_b', atom_b_type, [mstype.int32], self.name)
961
- validator.check_tensor_dtype_valid('atom_c', atom_c_type, [mstype.int32], self.name)
962
- validator.check_tensor_dtype_valid('angle_k', angle_k_type, [mstype.float32], self.name)
963
- validator.check_tensor_dtype_valid('angle_theta0', angle_theta0_type, [mstype.float32], self.name)
964
- return angle_k_type
965
-
966
-
967
- class AngleEnergy(PrimitiveWithInfer):
968
- """
969
- Calculate the energy caused by 3-atoms angle term. Assume the number of angles is m and the
970
- number of atoms is n.
971
-
972
- Because there is a large amount of inputs and each of them are related,
973
- there is no way to construct `Examples` using random methods. For details, refer the webpage `SPONGE in MindSpore
974
- <https://gitee.com/mindspore/mindscience/blob/master/MindSPONGE/docs/simple_formula.md>`_.
975
-
976
- .. math::
977
- dr_{ab} = (x_b-x_a, y_b-y_a, z_b-z_a)
978
- .. math::
979
- dr_{cb} = (x_b-x_c, y_b-y_c, z_b-z_c)
980
- .. math::
981
- theta = arccos(inner_product(dr_{ab}, dr_{cb})/|dr_{ab}|/|dr_{cb}|)
982
- .. math::
983
- E = k*(theta - theta_0)^2
984
-
985
- Args:
986
- angle_numbers(int32): the number of angles m.
987
-
988
- Inputs:
989
- - **uint_crd_f** (Tensor) - The unsigned int coordinate value of each atom.
990
- The data type is uint32 and the shape is :math:`(n, 3)`.
991
- - **scaler_f** (Tensor) - The 3-D scale factor between
992
- the real space float coordinates and the unsigned int coordinates.
993
- The data type is float32 and the shape is :math:`(3,)`.
994
- - **atom_a** (Tensor) - The 1st atom index of each angle.
995
- The data type is int32 and the shape is :math:`(m,)`.
996
- - **atom_b** (Tensor) - The 2nd and the central atom index of each angle.
997
- The data type is int32 and the shape is :math:`(m,)`.
998
- - **atom_c** (Tensor) - The 3rd atom index of each angle.
999
- The data type is int32 and the shape is :math:`(m,)`.
1000
- - **angle_k** (Tensor) - The force constant for each angle.
1001
- The data type is float32 and the shape is :math:`(m,)`.
1002
- - **angle_theta0** (Tensor) - The equilibrium position value for each angle.
1003
- The data type is float32 and the shape is :math:`(m,)`.
1004
-
1005
- Outputs:
1006
- - **ene** (Tensor) - The potential energy for each angle term.
1007
- The data type is float32 and the shape is :math:`(m,)`.
1008
-
1009
- Supported Platforms:
1010
- ``GPU``
1011
- """
1012
-
1013
- @prim_attr_register
1014
- def __init__(self, angle_numbers):
1015
- """Initialize AngleEnergy."""
1016
- validator.check_value_type('angle_numbers', angle_numbers, int, self.name)
1017
- self.angle_numbers = angle_numbers
1018
- self.init_prim_io_names(inputs=['uint_crd_f', 'scaler_f', 'atom_a', 'atom_b', 'atom_c', 'angle_k',
1019
- 'angle_theta0'],
1020
- outputs=['ene'])
1021
- self.add_prim_attr('angle_numbers', self.angle_numbers)
1022
-
1023
- def infer_shape(self, uint_crd_f_shape, scaler_f_shape, atom_a_shape, atom_b_shape, atom_c_shape, angle_k_shape,
1024
- angle_theta0_shape):
1025
- cls_name = self.name
1026
- m = self.angle_numbers
1027
- validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
1028
- validator.check_int(len(scaler_f_shape), 1, Rel.EQ, "scaler_f_dim", cls_name)
1029
- validator.check_int(len(atom_a_shape), 1, Rel.EQ, "atom_a_dim", cls_name)
1030
- validator.check_int(len(atom_b_shape), 1, Rel.EQ, "atom_b_dim", cls_name)
1031
- validator.check_int(len(atom_c_shape), 1, Rel.EQ, "atom_c_dim", cls_name)
1032
- validator.check_int(len(angle_k_shape), 1, Rel.EQ, "angle_k_dim", cls_name)
1033
- validator.check_int(len(angle_theta0_shape), 1, Rel.EQ, "angle_theta0_dim", cls_name)
1034
-
1035
- validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
1036
- validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name)
1037
- validator.check_int(atom_a_shape[0], m, Rel.EQ, "atom_a_shape", cls_name)
1038
- validator.check_int(atom_b_shape[0], m, Rel.EQ, "atom_b_shape", cls_name)
1039
- validator.check_int(atom_c_shape[0], m, Rel.EQ, "atom_c_shape", cls_name)
1040
- validator.check_int(angle_k_shape[0], m, Rel.EQ, "angle_k_shape", cls_name)
1041
- validator.check_int(angle_theta0_shape[0], m, Rel.EQ, "angle_theta0_shape", cls_name)
1042
- return [m,]
1043
-
1044
- def infer_dtype(self, uint_crd_f_dtype, scaler_f_type, atom_a_type, atom_b_type, atom_c_type, angle_k_type,
1045
- angle_theta0_type):
1046
- validator.check_tensor_dtype_valid('uint_crd_f', uint_crd_f_dtype, [mstype.uint32], self.name)
1047
- validator.check_tensor_dtype_valid('scaler_f', scaler_f_type, [mstype.float32], self.name)
1048
- validator.check_tensor_dtype_valid('atom_a', atom_a_type, [mstype.int32], self.name)
1049
- validator.check_tensor_dtype_valid('atom_b', atom_b_type, [mstype.int32], self.name)
1050
- validator.check_tensor_dtype_valid('atom_c', atom_c_type, [mstype.int32], self.name)
1051
- validator.check_tensor_dtype_valid('angle_k', angle_k_type, [mstype.float32], self.name)
1052
- validator.check_tensor_dtype_valid('angle_theta0', angle_theta0_type, [mstype.float32], self.name)
1053
- return angle_k_type
1054
-
1055
-
1056
- class AngleAtomEnergy(PrimitiveWithInfer):
1057
- """
1058
- Add the potential energy caused by angle terms to the total potential
1059
- energy of each atom. Assume the number of angles is m and the
1060
- number of atoms is n.
1061
-
1062
- The calculation formula is the same as operator AngleEnergy().
1063
-
1064
- Because there is a large amount of inputs and each of them are related,
1065
- there is no way to construct `Examples` using random methods. For details, refer the webpage `SPONGE in MindSpore
1066
- <https://gitee.com/mindspore/mindscience/blob/master/MindSPONGE/docs/simple_formula.md>`_.
1067
-
1068
- Args:
1069
- angle_numbers(int32): the number of angles m.
1070
-
1071
- Inputs:
1072
- - **uint_crd_f** (Tensor) - The unsigned int coordinate value of each atom.
1073
- The data type is uint32 and the shape is :math:`(n, 3)`.
1074
- - **scaler_f** (Tensor) - The 3-D scale factor between
1075
- the real space float coordinates and the unsigned int coordinates.
1076
- The data type is float32 and the shape is :math:`(3,)`.
1077
- - **atom_a** (Tensor) - The 1st atom index of each angle.
1078
- The data type is int32 and the shape is :math:`(m,)`.
1079
- - **atom_b** (Tensor) - The 2nd and the central atom index of each angle.
1080
- The data type is int32 and the shape is :math:`(m,)`.
1081
- - **atom_c** (Tensor) - The 3rd atom index of each angle.
1082
- The data type is int32 and the shape is :math:`(m,)`.
1083
- - **angle_k** (Tensor) - The force constant for each angle.
1084
- The data type is float32 and the shape is :math:`(m,)`.
1085
- - **angle_theta0** (Tensor) - The equilibrium position value for each angle.
1086
- The data type is float32 and the shape is :math:`(m,)`.
1087
-
1088
- Outputs:
1089
- - **ene** (Tensor) - The accumulated potential energy for each atom.
1090
- The data type is float32 and the shape is :math:`(n,)`.
1091
-
1092
- Supported Platforms:
1093
- ``GPU``
1094
- """
1095
-
1096
- @prim_attr_register
1097
- def __init__(self, angle_numbers):
1098
- """Initialize AngleAtomEnergy."""
1099
- validator.check_value_type('angle_numbers', angle_numbers, int, self.name)
1100
- self.angle_numbers = angle_numbers
1101
- self.init_prim_io_names(inputs=['uint_crd_f', 'scaler_f', 'atom_a', 'atom_b', 'atom_c', 'angle_k',
1102
- 'angle_theta0'],
1103
- outputs=['ene'])
1104
- self.add_prim_attr('angle_numbers', self.angle_numbers)
1105
-
1106
- def infer_shape(self, uint_crd_f_shape, scaler_f_shape, atom_a_shape, atom_b_shape, atom_c_shape, angle_k_shape,
1107
- angle_theta0_shape):
1108
- cls_name = self.name
1109
- n = uint_crd_f_shape[0]
1110
- m = self.angle_numbers
1111
- validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
1112
- validator.check_int(len(scaler_f_shape), 1, Rel.EQ, "scaler_f_dim", cls_name)
1113
- validator.check_int(len(atom_a_shape), 1, Rel.EQ, "atom_a_dim", cls_name)
1114
- validator.check_int(len(atom_b_shape), 1, Rel.EQ, "atom_b_dim", cls_name)
1115
- validator.check_int(len(atom_c_shape), 1, Rel.EQ, "atom_c_dim", cls_name)
1116
- validator.check_int(len(angle_k_shape), 1, Rel.EQ, "angle_k_dim", cls_name)
1117
- validator.check_int(len(angle_theta0_shape), 1, Rel.EQ, "angle_theta0_dim", cls_name)
1118
-
1119
- validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
1120
- validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name)
1121
- validator.check_int(atom_a_shape[0], m, Rel.EQ, "atom_a_shape", cls_name)
1122
- validator.check_int(atom_b_shape[0], m, Rel.EQ, "atom_b_shape", cls_name)
1123
- validator.check_int(atom_c_shape[0], m, Rel.EQ, "atom_c_shape", cls_name)
1124
- validator.check_int(angle_k_shape[0], m, Rel.EQ, "angle_k_shape", cls_name)
1125
- validator.check_int(angle_theta0_shape[0], m, Rel.EQ, "angle_theta0_shape", cls_name)
1126
- return [n,]
1127
-
1128
- def infer_dtype(self, uint_crd_f_dtype, scaler_f_type, atom_a_type, atom_b_type, atom_c_type, angle_k_type,
1129
- angle_theta0_type):
1130
- validator.check_tensor_dtype_valid('uint_crd_f', uint_crd_f_dtype, [mstype.uint32], self.name)
1131
- validator.check_tensor_dtype_valid('scaler_f', scaler_f_type, [mstype.float32], self.name)
1132
- validator.check_tensor_dtype_valid('atom_a', atom_a_type, [mstype.int32], self.name)
1133
- validator.check_tensor_dtype_valid('atom_b', atom_b_type, [mstype.int32], self.name)
1134
- validator.check_tensor_dtype_valid('atom_c', atom_c_type, [mstype.int32], self.name)
1135
- validator.check_tensor_dtype_valid('angle_k', angle_k_type, [mstype.float32], self.name)
1136
- validator.check_tensor_dtype_valid('angle_theta0', angle_theta0_type, [mstype.float32], self.name)
1137
- return angle_k_type
1138
-
1139
-
1140
- class AngleForceWithAtomEnergy(PrimitiveWithInfer):
1141
- """
1142
- Calculate angle force and potential energy together. Assume the number of angles is m and the
1143
- number of atoms is n.
1144
-
1145
- The calculation formula is the same as operator AngleForce() and AngleEnergy().
1146
-
1147
- Because there is a large amount of inputs and each of them are related,
1148
- there is no way to construct `Examples` using random methods. For details, refer the webpage `SPONGE in MindSpore
1149
- <https://gitee.com/mindspore/mindscience/blob/master/MindSPONGE/docs/simple_formula.md>`_.
1150
-
1151
- Args:
1152
- angle_numbers(int32): the number of angles m.
1153
-
1154
- Inputs:
1155
- - **uint_crd_f** (Tensor) - The unsigned int coordinate value of each atom.
1156
- The data type is uint32 and the shape is :math:`(n, 3)`.
1157
- - **scaler_f** (Tensor) - The 3-D scale factor between
1158
- the real space float coordinates and the unsigned int coordinates.
1159
- The data type is float and the shape is :math:`(3,)`.
1160
- - **atom_a** (Tensor) - The 1st atom index of each angle.
1161
- The data type is int32 and the shape is :math:`(m,)`.
1162
- - **atom_b** (Tensor) - The 2nd and the central atom index of each angle.
1163
- The data type is int32 and the shape is :math:`(m,)`.
1164
- - **atom_c** (Tensor) - The 3rd atom index of each angle.
1165
- The data type is int32 and the shape is :math:`(m,)`.
1166
- - **angle_k** (Tensor) - The force constant for each angle.
1167
- The data type is float32 and the shape is :math:`(m,)`.
1168
- - **angle_theta0** (Tensor) - The equilibrium position value for each angle.
1169
- The data type is float32 and the shape is :math:`(m,)`.
1170
-
1171
- Outputs:
1172
- - **frc_f** (Tensor) - same as operator AngleForce().
1173
- The data type is float32 and the shape is :math:`(n, 3)`.
1174
- - **ene** (Tensor) - same as operator AngleAtomEnergy().
1175
- The data type is float and the shape is :math:`(n,)`.
1176
-
1177
- Supported Platforms:
1178
- ``GPU``
1179
- """
1180
-
1181
- @prim_attr_register
1182
- def __init__(self, angle_numbers):
1183
- """Initialize AngleForceWithAtomEnergy."""
1184
- validator.check_value_type('angle_numbers', angle_numbers, int, self.name)
1185
- self.angle_numbers = angle_numbers
1186
- self.init_prim_io_names(inputs=['uint_crd_f', 'scaler_f', 'atom_a', 'atom_b', 'atom_c', 'angle_k',
1187
- 'angle_theta0'],
1188
- outputs=['frc_f', 'ene'])
1189
- self.add_prim_attr('angle_numbers', self.angle_numbers)
1190
-
1191
- def infer_shape(self, uint_crd_f_shape, scaler_f_shape, atom_a_shape, atom_b_shape, atom_c_shape, angle_k_shape,
1192
- angle_theta0_shape):
1193
- cls_name = self.name
1194
- n = uint_crd_f_shape[0]
1195
- m = self.angle_numbers
1196
- validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
1197
- validator.check_int(len(scaler_f_shape), 1, Rel.EQ, "scaler_f_dim", cls_name)
1198
- validator.check_int(len(atom_a_shape), 1, Rel.EQ, "atom_a_dim", cls_name)
1199
- validator.check_int(len(atom_b_shape), 1, Rel.EQ, "atom_b_dim", cls_name)
1200
- validator.check_int(len(atom_c_shape), 1, Rel.EQ, "atom_c_dim", cls_name)
1201
- validator.check_int(len(angle_k_shape), 1, Rel.EQ, "angle_k_dim", cls_name)
1202
- validator.check_int(len(angle_theta0_shape), 1, Rel.EQ, "angle_theta0_dim", cls_name)
1203
-
1204
- validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
1205
- validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name)
1206
- validator.check_int(atom_a_shape[0], m, Rel.EQ, "atom_a_shape", cls_name)
1207
- validator.check_int(atom_b_shape[0], m, Rel.EQ, "atom_b_shape", cls_name)
1208
- validator.check_int(atom_c_shape[0], m, Rel.EQ, "atom_c_shape", cls_name)
1209
- validator.check_int(angle_k_shape[0], m, Rel.EQ, "angle_k_shape", cls_name)
1210
- validator.check_int(angle_theta0_shape[0], m, Rel.EQ, "angle_theta0_shape", cls_name)
1211
- return uint_crd_f_shape, [n,]
1212
-
1213
- def infer_dtype(self, uint_crd_f_dtype, scaler_f_type, atom_a_type, atom_b_type, atom_c_type, angle_k_type,
1214
- angle_theta0_type):
1215
- validator.check_tensor_dtype_valid('uint_crd_f', uint_crd_f_dtype, [mstype.uint32], self.name)
1216
- validator.check_tensor_dtype_valid('scaler_f', scaler_f_type, [mstype.float32], self.name)
1217
- validator.check_tensor_dtype_valid('atom_a', atom_a_type, [mstype.int32], self.name)
1218
- validator.check_tensor_dtype_valid('atom_b', atom_b_type, [mstype.int32], self.name)
1219
- validator.check_tensor_dtype_valid('atom_c', atom_c_type, [mstype.int32], self.name)
1220
- validator.check_tensor_dtype_valid('angle_k', angle_k_type, [mstype.float32], self.name)
1221
- validator.check_tensor_dtype_valid('angle_theta0', angle_theta0_type, [mstype.float32], self.name)
1222
- return angle_k_type, angle_k_type
1223
-
1224
-
1225
- class Dihedral14LJForce(PrimitiveWithInfer):
1226
- """
1227
- Calculate the Lennard-Jones part of 1,4 dihedral force correction
1228
- for each necessary dihedral terms on the corresponding atoms.
1229
-
1230
- Assume the number of necessary dihedral 1,4 terms is m, the number of atoms is n,
1231
- and the number of Lennard-Jones types for all atoms is P, which means
1232
- there will be q = P*(P+1)/2 types of possible Lennard-Jones interactions
1233
- for all kinds of atom pairs.
1234
-
1235
- Because there is a large amount of inputs and each of them are related,
1236
- there is no way to construct `Examples` using random methods. For details, refer the webpage `SPONGE in MindSpore
1237
- <https://gitee.com/mindspore/mindscience/blob/master/MindSPONGE/docs/simple_formula.md>`_.
1238
-
1239
- .. math::
1240
- dr = (x_a-x_b, y_a-y_b, z_a-z_b)
1241
- .. math::
1242
- F = k*(-12*A/|dr|^{14} + 6*B/|dr|^{8})*dr
1243
-
1244
- Args:
1245
- nb14_numbers (int32): the number of necessary dihedral 1,4 terms m.
1246
- atom_numbers (int32): the number of atoms n.
1247
-
1248
- Inputs:
1249
- - **uint_crd_f** (Tensor) - The unsigned int coordinate value of each atom.
1250
- The data type is uint32 and the shape is :math:`(n, 3)`.
1251
- - **LJ_type** (Tensor) - The Lennard-Jones type of each atom.
1252
- The data type is int32 and the shape is :math:`(n,)`.
1253
- - **charge** (Tensor) - The charge of each atom.
1254
- The data type is float32 and the shape is :math:`(n,)`.
1255
- - **boxlength_f** (Tensor) - The length of molecular simulation box in 3 dimensions.
1256
- The data type is float32 and the shape is :math:`(3,)`.
1257
- - **a_14** (Tensor) - The first atom index of each dihedral 1,4 term.
1258
- The data type is int32 and the shape is :math:`(m,)`.
1259
- - **b_14** (Tensor) - The second atom index of each dihedral 1,4 term.
1260
- The data type is int32 and the shape is :math:`(m,)`.
1261
- - **lj_scale_factor** (Tensor) - The scale factor for the
1262
- Lennard-Jones part of force correction of each dihedral 1,4 term.
1263
- The data type is float32 and the shape is :math:`(m,)`.
1264
- - **LJ_type_A** (Tensor) - The A parameter in Lennard-Jones scheme of each atom pair type.
1265
- q is the number of atom pair. The data type is float32 and the shape is :math:`(q,)`.
1266
- - **LJ_type_B** (Tensor) - The B parameter in Lennard-Jones scheme of each atom pair type.
1267
- q is the number of atom pair. The data type is float32 and the shape is :math:`(q,)`.
1268
-
1269
- Outputs:
1270
- - **frc_f** (Tensor) - The force felt by each atom.
1271
- The data type is float32 and the shape is :math:`(n, 3)`.
1272
-
1273
- Supported Platforms:
1274
- ``GPU``
1275
- """
1276
-
1277
- @prim_attr_register
1278
- def __init__(self, nb14_numbers, atom_numbers):
1279
- """Initialize Dihedral14LJForce."""
1280
- validator.check_value_type('nb14_numbers', nb14_numbers, int, self.name)
1281
- validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
1282
- self.dihedral_14_numbers = nb14_numbers
1283
- self.atom_numbers = atom_numbers
1284
- self.init_prim_io_names(
1285
- inputs=['uint_crd_f', 'LJtype', 'charge', 'boxlength_f', 'a_14', 'b_14', 'lj_scale_factor',
1286
- 'LJ_type_A', 'LJ_type_B'],
1287
- outputs=['frc_f'])
1288
- self.add_prim_attr('dihedral_14_numbers', self.dihedral_14_numbers)
1289
- self.add_prim_attr('atom_numbers', self.atom_numbers)
1290
-
1291
- def infer_shape(self, uint_crd_f_shape, ljtype_shape, charge_shape, boxlength_f_shape, a_14_shape, b_14_shape,
1292
- lj_scale_factor_shape, lj_type_a_shape, lj_type_b_shape):
1293
- cls_name = self.name
1294
- n = self.atom_numbers
1295
- m = self.dihedral_14_numbers
1296
- q = lj_type_a_shape[0]
1297
- validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
1298
- validator.check_int(len(ljtype_shape), 1, Rel.EQ, "LJtype_dim", cls_name)
1299
- validator.check_int(len(charge_shape), 1, Rel.EQ, "charge_dim", cls_name)
1300
- validator.check_int(len(boxlength_f_shape), 1, Rel.EQ, "boxlength_f_dim", cls_name)
1301
- validator.check_int(len(a_14_shape), 1, Rel.EQ, "a_14_dim", cls_name)
1302
- validator.check_int(len(b_14_shape), 1, Rel.EQ, "b_14_dim", cls_name)
1303
- validator.check_int(len(lj_scale_factor_shape), 1, Rel.EQ, "lj_scale_factor_dim", cls_name)
1304
- validator.check_int(len(lj_type_b_shape), 1, Rel.EQ, "LJ_type_B_dim", cls_name)
1305
-
1306
- validator.check_int(uint_crd_f_shape[0], n, Rel.EQ, "uint_crd_f[0]", cls_name)
1307
- validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f[1]", cls_name)
1308
- validator.check_int(ljtype_shape[0], n, Rel.EQ, "LJtype", cls_name)
1309
- validator.check_int(charge_shape[0], n, Rel.EQ, "charge", cls_name)
1310
- validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f", cls_name)
1311
- validator.check_int(lj_type_b_shape[0], q, Rel.EQ, "LJ_type_B", cls_name)
1312
- validator.check_int(a_14_shape[0], m, Rel.EQ, "a_14_shape", cls_name)
1313
- validator.check_int(b_14_shape[0], m, Rel.EQ, "b_14_shape", cls_name)
1314
- validator.check_int(lj_scale_factor_shape[0], m, Rel.EQ, "lj_scale_factor_shape", cls_name)
1315
- return uint_crd_f_shape
1316
-
1317
- def infer_dtype(self, uint_crd_f_dtype, ljtype_dtype, charge_dtype, boxlength_f_type, a_14_type, b_14_type,
1318
- lj_scale_factor_type, lj_type_a_type, lj_type_b_type):
1319
- validator.check_tensor_dtype_valid('uint_crd_f', uint_crd_f_dtype, [mstype.uint32], self.name)
1320
- validator.check_tensor_dtype_valid('LJtype', ljtype_dtype, [mstype.int32], self.name)
1321
- validator.check_tensor_dtype_valid('charge', charge_dtype, [mstype.float32], self.name)
1322
- validator.check_tensor_dtype_valid('boxlength_f', boxlength_f_type, [mstype.float32], self.name)
1323
-
1324
- validator.check_tensor_dtype_valid('a_14', a_14_type, [mstype.int32], self.name)
1325
- validator.check_tensor_dtype_valid('b_14', b_14_type, [mstype.int32], self.name)
1326
-
1327
- validator.check_tensor_dtype_valid('lj_scale_factor', lj_scale_factor_type, [mstype.float32], self.name)
1328
- validator.check_tensor_dtype_valid('LJ_type_A', lj_type_a_type, [mstype.float32], self.name)
1329
- validator.check_tensor_dtype_valid('LJ_type_B', lj_type_b_type, [mstype.float32], self.name)
1330
- return lj_type_b_type
1331
-
1332
-
1333
- class Dihedral14LJEnergy(PrimitiveWithInfer):
1334
- """
1335
- Calculate the Lennard-Jones part of 1,4 dihedral energy correction for
1336
- each necessary dihedral terms on the corresponding atoms.
1337
-
1338
- Because there is a large amount of inputs and each of them are related,
1339
- there is no way to construct `Examples` using random methods. For details, refer the webpage `SPONGE in MindSpore
1340
- <https://gitee.com/mindspore/mindscience/blob/master/MindSPONGE/docs/simple_formula.md>`_.
1341
-
1342
- .. math::
1343
- dr = (x_a-x_b, y_a-y_b, z_a-z-b)
1344
- .. math::
1345
- E = k*(A/|dr|^{12} - B/|dr|^{6})
1346
-
1347
- Args:
1348
- nb14_numbers (int32): the number of necessary dihedral 1,4 terms m.
1349
- atom_numbers (int32): the number of atoms n.
1350
-
1351
- Inputs:
1352
- - **uint_crd_f** (Tensor) - The unsigned int coordinate value of each atom.
1353
- The data type is uint32 and the shape is :math:`(n, 3)`.
1354
- - **LJ_type** (Tensor) - The Lennard-Jones type of each atom.
1355
- The data type is int32 and the shape is :math:`(n,)`.
1356
- - **charge** (Tensor) - The charge of each atom.
1357
- The data type is float32 and the shape is :math:`(n,)`.
1358
- - **boxlength_f** (Tensor) - The length of molecular simulation box in 3 dimensions.
1359
- The data type is float32 and the shape is :math:`(3,)`.
1360
- - **a_14** (Tensor) - The first atom index of each dihedral 1,4 term.
1361
- The data type is int32 and the shape is :math:`(m,)`.
1362
- - **b_14** (Tensor) - The second atom index of each dihedral 1,4 term.
1363
- The data type is int32 and the shape is :math:`(m,)`.
1364
- - **lj_scale_factor** (Tensor) - The scale factor for the
1365
- Lennard-Jones part of force correction of each dihedral 1,4 term.
1366
- The data type is float32 and the shape is :math:`(m,)`.
1367
- - **LJ_type_A** (Tensor) - The A parameter in Lennard-Jones scheme of each atom pair type.
1368
- q is the number of atom pair. The data type is float32 and the shape is :math:`(q,)`.
1369
- - **LJ_type_B** (Tensor) - The B parameter in Lennard-Jones scheme of each atom pair type.
1370
- q is the number of atom pair. The data type is float32 and the shape is :math:`(q,)`.
1371
-
1372
- Outputs:
1373
- - **ene** (Tensor) - The Lennard-Jones potential energy correction.
1374
- The data type is float32 and the shape is :math:`(m,)`.
1375
-
1376
- Supported Platforms:
1377
- ``GPU``
1378
- """
1379
-
1380
- @prim_attr_register
1381
- def __init__(self, nb14_numbers, atom_numbers):
1382
- """Initialize Dihedral14LJEnergy"""
1383
- validator.check_value_type('nb14_numbers', nb14_numbers, int, self.name)
1384
- validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
1385
- self.dihedral_14_numbers = nb14_numbers
1386
- self.atom_numbers = atom_numbers
1387
-
1388
- self.init_prim_io_names(
1389
- inputs=['uint_crd_f', 'LJtype', 'charge', 'boxlength_f', 'a_14', 'b_14', 'lj_scale_factor',
1390
- 'LJ_type_A', 'LJ_type_B'],
1391
- outputs=['ene'])
1392
- self.add_prim_attr('dihedral_14_numbers', self.dihedral_14_numbers)
1393
- self.add_prim_attr('atom_numbers', self.atom_numbers)
1394
-
1395
- def infer_shape(self, uint_crd_f_shape, ljtype_shape, charge_shape, boxlength_f_shape, a_14_shape, b_14_shape,
1396
- lj_scale_factor_shape, lj_type_a_shape, lj_type_b_shape):
1397
- cls_name = self.name
1398
- n = self.atom_numbers
1399
- m = self.dihedral_14_numbers
1400
- q = lj_type_a_shape[0]
1401
- validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
1402
- validator.check_int(len(ljtype_shape), 1, Rel.EQ, "LJtype_dim", cls_name)
1403
- validator.check_int(len(charge_shape), 1, Rel.EQ, "charge_dim", cls_name)
1404
- validator.check_int(len(boxlength_f_shape), 1, Rel.EQ, "boxlength_f_dim", cls_name)
1405
- validator.check_int(len(a_14_shape), 1, Rel.EQ, "a_14_dim", cls_name)
1406
- validator.check_int(len(b_14_shape), 1, Rel.EQ, "b_14_dim", cls_name)
1407
- validator.check_int(len(lj_scale_factor_shape), 1, Rel.EQ, "lj_scale_factor_dim", cls_name)
1408
- validator.check_int(len(lj_type_b_shape), 1, Rel.EQ, "LJ_type_B_dim", cls_name)
1409
-
1410
- validator.check_int(uint_crd_f_shape[0], n, Rel.EQ, "uint_crd_f[0]", cls_name)
1411
- validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f[1]", cls_name)
1412
- validator.check_int(ljtype_shape[0], n, Rel.EQ, "LJtype", cls_name)
1413
- validator.check_int(charge_shape[0], n, Rel.EQ, "charge", cls_name)
1414
- validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f", cls_name)
1415
- validator.check_int(lj_type_b_shape[0], q, Rel.EQ, "LJ_type_B", cls_name)
1416
- validator.check_int(a_14_shape[0], m, Rel.EQ, "a_14_shape", cls_name)
1417
- validator.check_int(b_14_shape[0], m, Rel.EQ, "b_14_shape", cls_name)
1418
- validator.check_int(lj_scale_factor_shape[0], m, Rel.EQ, "lj_scale_factor_shape", cls_name)
1419
- return [self.dihedral_14_numbers,]
1420
-
1421
- def infer_dtype(self, uint_crd_f_dtype, ljtype_dtype, charge_dtype, boxlength_f_type, a_14_type, b_14_type,
1422
- lj_scale_factor_type, lj_type_a_type, lj_type_b_type):
1423
- validator.check_tensor_dtype_valid('uint_crd_f', uint_crd_f_dtype, [mstype.uint32], self.name)
1424
- validator.check_tensor_dtype_valid('LJtype', ljtype_dtype, [mstype.int32], self.name)
1425
- validator.check_tensor_dtype_valid('charge', charge_dtype, [mstype.float32], self.name)
1426
- validator.check_tensor_dtype_valid('boxlength_f', boxlength_f_type, [mstype.float32], self.name)
1427
- validator.check_tensor_dtype_valid('a_14', a_14_type, [mstype.int32], self.name)
1428
- validator.check_tensor_dtype_valid('b_14', b_14_type, [mstype.int32], self.name)
1429
- validator.check_tensor_dtype_valid('lj_scale_factor', lj_scale_factor_type, [mstype.float32], self.name)
1430
- validator.check_tensor_dtype_valid('LJ_type_A', lj_type_a_type, [mstype.float32], self.name)
1431
- validator.check_tensor_dtype_valid('LJ_type_B', lj_type_b_type, [mstype.float32], self.name)
1432
-
1433
- return lj_type_a_type
1434
-
1435
-
1436
- class Dihedral14LJForceWithDirectCF(PrimitiveWithInfer):
1437
- """
1438
- Calculate the Lennard-Jones part and the Coulomb part of force correction
1439
- for each necessary dihedral 1,4 terms.
1440
-
1441
- Because there is a large amount of inputs and each of them are related,
1442
- there is no way to construct `Examples` using random methods. For details, refer the webpage `SPONGE in MindSpore
1443
- <https://gitee.com/mindspore/mindscience/blob/master/MindSPONGE/docs/simple_formula.md>`_.
1444
-
1445
- The calculation formula of the Lennard-Jones part is the same as operator
1446
- Dihedral14LJForce(), and the Coulomb part is as follows:
1447
-
1448
- .. math::
1449
- dr = (x_a-x_b, y_a-y_b, z_a-z_b)
1450
- .. math::
1451
- F = -k*q_a*q_b/|r|^3*dr
1452
-
1453
- Args:
1454
- nb14_numbers (int32): the number of necessary dihedral 1,4 terms m.
1455
- atom_numbers (int32): the number of atoms n.
1456
-
1457
- Inputs:
1458
- - **uint_crd_f** (Tensor) - The unsigned int coordinate value of each atom.
1459
- The data type is uint32 and the shape is :math:`(n, 3)`.
1460
- - **LJ_type** (Tensor) - The Lennard-Jones type of each atom.
1461
- The data type is int32 and the shape is :math:`(n,)`.
1462
- - **charge** (Tensor) - The charge of each atom.
1463
- The data type is float32 and the shape is :math:`(n,)`.
1464
- - **boxlength_f** (Tensor) - The length of molecular simulation box in 3 dimensions.
1465
- The data type is float32 and the shape is :math:`(3,)`.
1466
- - **a_14** (Tensor) - The first atom index of each dihedral 1,4 term.
1467
- The data type is int32 and the shape is :math:`(m,)`.
1468
- - **b_14** (Tensor) - The second atom index of each dihedral 1,4 term.
1469
- The data type is int32 and the shape is :math:`(m,)`.
1470
- - **lj_scale_factor** (Tensor) - The scale factor for the
1471
- Lennard-Jones part of force correction of each dihedral 1,4 term.
1472
- The data type is float32 and the shape is :math:`(m,)`.
1473
- - **cf_scale_factor** (Tensor) - The scale factor for the
1474
- Coulomb part of force correction for each dihedral 1,4 terms.
1475
- The data type is float32 and the shape is :math:`(m,)`.
1476
- - **LJ_type_A** (Tensor) - The A parameter in Lennard-Jones scheme of each atom pair type.
1477
- q is the number of atom pair. The data type is float32 and the shape is :math:`(q,)`.
1478
- - **LJ_type_B** (Tensor) - The B parameter in Lennard-Jones shceme of each atom pair type.
1479
- q is the number of atom pair. The data type is float32 and the shape is :math:`(q,)`.
1480
-
1481
- Outputs:
1482
- - **frc_f** (Tensor) - The force felt by each atom.
1483
- The data type is float32 and the shape is :math:`(n, 3)`
1484
-
1485
- Supported Platforms:
1486
- ``GPU``
1487
- """
1488
-
1489
- @prim_attr_register
1490
- def __init__(self, nb14_numbers, atom_numbers):
1491
- """Initialize Dihedral14LJForceWithDirectCF."""
1492
- validator.check_value_type('nb14_numbers', nb14_numbers, int, self.name)
1493
- validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
1494
- self.dihedral_14_numbers = nb14_numbers
1495
- self.atom_numbers = atom_numbers
1496
-
1497
- self.init_prim_io_names(
1498
- inputs=['uint_crd_f', 'LJtype', 'charge', 'boxlength_f', 'a_14', 'b_14', 'lj_scale_factor',
1499
- 'cf_scale_factor',
1500
- 'LJ_type_A', 'LJ_type_B'],
1501
- outputs=['frc_f'])
1502
- self.add_prim_attr('dihedral_14_numbers', self.dihedral_14_numbers)
1503
- self.add_prim_attr('atom_numbers', self.atom_numbers)
1504
-
1505
- def infer_shape(self, uint_crd_f_shape, ljtype_shape, charge_shape, boxlength_f_shape, a_14_shape, b_14_shape,
1506
- lj_scale_factor_shape, cf_scale_factor_shape, lj_type_a_shape, lj_type_b_shape):
1507
- cls_name = self.name
1508
- n = self.atom_numbers
1509
- m = self.dihedral_14_numbers
1510
- q = lj_type_a_shape[0]
1511
- validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
1512
- validator.check_int(len(ljtype_shape), 1, Rel.EQ, "LJtype_dim", cls_name)
1513
- validator.check_int(len(charge_shape), 1, Rel.EQ, "charge_dim", cls_name)
1514
- validator.check_int(len(boxlength_f_shape), 1, Rel.EQ, "boxlength_f_dim", cls_name)
1515
- validator.check_int(len(a_14_shape), 1, Rel.EQ, "a_14_dim", cls_name)
1516
- validator.check_int(len(b_14_shape), 1, Rel.EQ, "b_14_dim", cls_name)
1517
- validator.check_int(len(lj_scale_factor_shape), 1, Rel.EQ, "lj_scale_factor_dim", cls_name)
1518
- validator.check_int(len(cf_scale_factor_shape), 1, Rel.EQ, "cf_scale_factor_dim", cls_name)
1519
- validator.check_int(len(lj_type_b_shape), 1, Rel.EQ, "LJ_type_B_dim", cls_name)
1520
-
1521
- validator.check_int(uint_crd_f_shape[0], n, Rel.EQ, "uint_crd_f_shape[0]", cls_name)
1522
- validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
1523
- validator.check_int(ljtype_shape[0], n, Rel.EQ, "LJtype_shape", cls_name)
1524
- validator.check_int(charge_shape[0], n, Rel.EQ, "charge_shape", cls_name)
1525
- validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f_shape", cls_name)
1526
- validator.check_int(lj_type_b_shape[0], q, Rel.EQ, "LJ_type_B_shape", cls_name)
1527
- validator.check_int(a_14_shape[0], m, Rel.EQ, "a_14_shape", cls_name)
1528
- validator.check_int(b_14_shape[0], m, Rel.EQ, "b_14_shape", cls_name)
1529
- validator.check_int(lj_scale_factor_shape[0], m, Rel.EQ, "lj_scale_factor_shape", cls_name)
1530
- validator.check_int(cf_scale_factor_shape[0], m, Rel.EQ, "cf_scale_factor_shape", cls_name)
1531
- return [self.atom_numbers, 3]
1532
-
1533
- def infer_dtype(self, uint_crd_f_dtype, ljtype_dtype, charge_dtype, boxlength_f_type, a_14_type, b_14_type,
1534
- lj_scale_factor_type, cf_scale_factor_type, lj_type_a_type, lj_type_b_type):
1535
- validator.check_tensor_dtype_valid('uint_crd_f', uint_crd_f_dtype, [mstype.uint32], self.name)
1536
- validator.check_tensor_dtype_valid('LJtype', ljtype_dtype, [mstype.int32], self.name)
1537
- validator.check_tensor_dtype_valid('charge', charge_dtype, [mstype.float32], self.name)
1538
- validator.check_tensor_dtype_valid('boxlength_f', boxlength_f_type, [mstype.float32], self.name)
1539
- validator.check_tensor_dtype_valid('a_14', a_14_type, [mstype.int32], self.name)
1540
- validator.check_tensor_dtype_valid('b_14', b_14_type, [mstype.int32], self.name)
1541
- validator.check_tensor_dtype_valid('lj_scale_factor', lj_scale_factor_type, [mstype.float32], self.name)
1542
- validator.check_tensor_dtype_valid('cf_scale_factor', cf_scale_factor_type, [mstype.float32], self.name)
1543
- validator.check_tensor_dtype_valid('LJ_type_A', lj_type_a_type, [mstype.float32], self.name)
1544
- validator.check_tensor_dtype_valid('LJ_type_B', lj_type_b_type, [mstype.float32], self.name)
1545
-
1546
- return lj_type_a_type
1547
-
1548
-
1549
- class Dihedral14LJCFForceWithAtomEnergy(PrimitiveWithInfer):
1550
- """
1551
- Calculate the Lennard-Jones and Coulumb energy correction and force correction
1552
- for each necessary dihedral 1,4 terms together and add them to the total force
1553
- and potential energy for each atom.
1554
-
1555
- The calculation formula of force correction is the same as operator
1556
- :class:`Dihedral14LJForceWithDirectCF`, and the energy correction part is the same
1557
- as operator :class:`Dihedral14LJEnergy` and :class:`Dihedral14CFEnergy`.
1558
-
1559
- Because there is a large amount of inputs and each of them are related,
1560
- there is no way to construct `Examples` using random methods. For details, refer the webpage `SPONGE in MindSpore
1561
- <https://gitee.com/mindspore/mindscience/blob/master/MindSPONGE/docs/simple_formula.md>`_.
1562
-
1563
- Args:
1564
- nb14_numbers (int32): the number of necessary dihedral 1,4 terms m.
1565
- atom_numbers (int32): the number of atoms n.
1566
-
1567
- Inputs:
1568
- - **uint_crd_f** (Tensor) - The unsigned int coordinate value of each atom.
1569
- The data type is uint32 and the shape is :math:`(n, 3)`.
1570
- - **LJ_type** (Tensor) - The Lennard-Jones type of each atom.
1571
- The data type is int32 and the shape is :math:`(n,)`.
1572
- - **charge** (Tensor) - The charge of each atom.
1573
- The data type is float32 and the shape is :math:`(n,)`.
1574
- - **boxlength_f** (Tensor) - The length of molecular simulation box in 3 dimensions.
1575
- The data type is float32 and the shape is :math:`(3,)`.
1576
- - **a_14** (Tensor) - The first atom index of each dihedral 1,4 term.
1577
- The data type is int32 and the shape is :math:`(m,)`.
1578
- - **b_14** (Tensor) - The second atom index of each dihedral 1,4 term.
1579
- The data type is int32 and the shape is :math:`(m,)`.
1580
- - **lj_scale_factor** (Tensor) - The scale factor for the
1581
- Lennard-Jones part of force correction of each dihedral 1,4 term.
1582
- The data type is float32 and the shape is :math:`(m,)`.
1583
- - **cf_scale_factor** (Tensor) - The scale factor for the
1584
- Coulomb part of force correction for each dihedral 1,4 terms.
1585
- The data type is float32 and the shape is :math:`(m,)`.
1586
- - **LJ_type_A** (Tensor) - The A parameter in Lennard-Jones scheme of each atom pair type.
1587
- q is the number of atom pair. The data type is float32 and the shape is :math:`(q,)`.
1588
- - **LJ_type_B** (Tensor) - The B parameter in Lennard-Jones scheme of each atom pair type.
1589
- q is the number of atom pair. The data type is float32 and the shape is :math:`(q,)`.
1590
-
1591
- Outputs:
1592
- - **frc_f** (Tensor) - The force felt by each atom.
1593
- The data type is float32 and the shape is :math:`(n, 3)`.
1594
- - **atom_energy** (Tensor) - The accumulated potential energy for each atom.
1595
- The data type is float32 and the shape is :math:`(n,)`.
1596
-
1597
- Supported Platforms:
1598
- ``GPU``
1599
- """
1600
-
1601
- @prim_attr_register
1602
- def __init__(self, nb14_numbers, atom_numbers):
1603
- """Initialize Dihedral14LJCFForceWithAtomEnergy."""
1604
- validator.check_value_type('nb14_numbers', nb14_numbers, int, self.name)
1605
- validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
1606
- self.dihedral_14_numbers = nb14_numbers
1607
- self.atom_numbers = atom_numbers
1608
-
1609
- self.init_prim_io_names(
1610
- inputs=['uint_crd_f', 'LJtype', 'charge', 'boxlength_f', 'a_14', 'b_14', 'lj_scale_factor',
1611
- 'cf_scale_factor',
1612
- 'LJ_type_A', 'LJ_type_B'],
1613
- outputs=['frc_f', 'atom_energy'])
1614
- self.add_prim_attr('dihedral_14_numbers', self.dihedral_14_numbers)
1615
- self.add_prim_attr('atom_numbers', self.atom_numbers)
1616
-
1617
- def infer_shape(self, uint_crd_f_shape, ljtype_shape, charge_shape, boxlength_f_shape, a_14_shape, b_14_shape,
1618
- lj_scale_factor_shape, cf_scale_factor_shape, lj_type_a_shape, lj_type_b_shape):
1619
- cls_name = self.name
1620
- n = self.atom_numbers
1621
- m = self.dihedral_14_numbers
1622
- q = lj_type_a_shape[0]
1623
- validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
1624
- validator.check_int(len(ljtype_shape), 1, Rel.EQ, "LJtype_dim", cls_name)
1625
- validator.check_int(len(charge_shape), 1, Rel.EQ, "charge_dim", cls_name)
1626
- validator.check_int(len(boxlength_f_shape), 1, Rel.EQ, "boxlength_f_dim", cls_name)
1627
- validator.check_int(len(a_14_shape), 1, Rel.EQ, "a_14_dim", cls_name)
1628
- validator.check_int(len(b_14_shape), 1, Rel.EQ, "b_14_dim", cls_name)
1629
- validator.check_int(len(lj_scale_factor_shape), 1, Rel.EQ, "lj_scale_factor_dim", cls_name)
1630
- validator.check_int(len(cf_scale_factor_shape), 1, Rel.EQ, "cf_scale_factor_dim", cls_name)
1631
- validator.check_int(len(lj_type_b_shape), 1, Rel.EQ, "LJ_type_B_dim", cls_name)
1632
-
1633
- validator.check_int(uint_crd_f_shape[0], n, Rel.EQ, "uint_crd_f_shape[0]", cls_name)
1634
- validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
1635
- validator.check_int(ljtype_shape[0], n, Rel.EQ, "LJtype_shape", cls_name)
1636
- validator.check_int(charge_shape[0], n, Rel.EQ, "charge_shape", cls_name)
1637
- validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f_shape", cls_name)
1638
- validator.check_int(lj_type_b_shape[0], q, Rel.EQ, "LJ_type_B_shape", cls_name)
1639
- validator.check_int(a_14_shape[0], m, Rel.EQ, "a_14_shape", cls_name)
1640
- validator.check_int(b_14_shape[0], m, Rel.EQ, "b_14_shape", cls_name)
1641
- validator.check_int(lj_scale_factor_shape[0], m, Rel.EQ, "lj_scale_factor_shape", cls_name)
1642
- validator.check_int(cf_scale_factor_shape[0], m, Rel.EQ, "cf_scale_factor_shape", cls_name)
1643
- return uint_crd_f_shape, charge_shape
1644
-
1645
- def infer_dtype(self, uint_crd_f_dtype, ljtype_dtype, charge_dtype, boxlength_f_type, a_14_type, b_14_type,
1646
- lj_scale_factor_type, cf_scale_factor_type, lj_type_a_type, lj_type_b_type):
1647
- validator.check_tensor_dtype_valid('uint_crd_f', uint_crd_f_dtype, [mstype.uint32], self.name)
1648
- validator.check_tensor_dtype_valid('LJtype', ljtype_dtype, [mstype.int32], self.name)
1649
- validator.check_tensor_dtype_valid('charge', charge_dtype, [mstype.float32], self.name)
1650
- validator.check_tensor_dtype_valid('boxlength_f', boxlength_f_type, [mstype.float32], self.name)
1651
- validator.check_tensor_dtype_valid('a_14', a_14_type, [mstype.int32], self.name)
1652
- validator.check_tensor_dtype_valid('b_14', b_14_type, [mstype.int32], self.name)
1653
- validator.check_tensor_dtype_valid('lj_scale_factor', lj_scale_factor_type, [mstype.float32], self.name)
1654
- validator.check_tensor_dtype_valid('cf_scale_factor', cf_scale_factor_type, [mstype.float32], self.name)
1655
- validator.check_tensor_dtype_valid('LJ_type_A', lj_type_a_type, [mstype.float32], self.name)
1656
- validator.check_tensor_dtype_valid('LJ_type_B', lj_type_b_type, [mstype.float32], self.name)
1657
-
1658
- return charge_dtype, charge_dtype
1659
-
1660
-
1661
- class Dihedral14LJAtomEnergy(PrimitiveWithInfer):
1662
- """
1663
- Add the potential energy caused by Lennard-Jones energy correction for each
1664
- necessary dihedral 1,4 terms to the total potential energy of each atom.
1665
-
1666
- The calculation formula is the same as operator Dihedral14LJEnergy().
1667
-
1668
- Because there is a large amount of inputs and each of them are related,
1669
- there is no way to construct `Examples` using random methods. For details, refer the webpage `SPONGE in MindSpore
1670
- <https://gitee.com/mindspore/mindscience/blob/master/MindSPONGE/docs/simple_formula.md>`_.
1671
-
1672
- Args:
1673
- nb14_numbers (int32): the number of necessary dihedral 1,4 terms m.
1674
- atom_numbers (int32): the number of atoms n.
1675
-
1676
- Inputs:
1677
- - **uint_crd_f** (Tensor) - The unsigned int coordinate value of each atom.
1678
- The data type is uint32 and the shape is :math:`(n, 3)`.
1679
- - **LJ_type** (Tensor) - The Lennard-Jones type of each atom.
1680
- The data type is int32 and the shape is :math:`(n,)`.
1681
- - **charge** (Tensor) - The charge of each atom.
1682
- The data type is float32 and the shape is :math:`(n,)`.
1683
- - **boxlength_f** (Tensor) - The length of molecular simulation box in 3 dimensions.
1684
- The data type is float32 and the shape is :math:`(3,)`.
1685
- - **a_14** (Tensor) - The first atom index of each dihedral 1,4 term.
1686
- The data type is int32 and the shape is :math:`(m,)`.
1687
- - **b_14** (Tensor) - The second atom index of each dihedral 1,4 term.
1688
- The data type is int32 and the shape is :math:`(m,)`.
1689
- - **lj_scale_factor** (Tensor) - The scale factor for the
1690
- Lennard-Jones part of force correction of each dihedral 1,4 term.
1691
- The data type is float32 and the shape is :math:`(m,)`.
1692
- - **cf_scale_factor** (Tensor) - The scale factor for the
1693
- Coulomb part of force correction for each dihedral 1,4 terms.
1694
- The data type is float32 and the shape is :math:`(m,)`.
1695
- - **LJ_type_A** (Tensor) - The A parameter in Lennard-Jones scheme of each atom pair type.
1696
- q is the number of atom pair. The data type is float32 and the shape is :math:`(q,)`.
1697
- - **LJ_type_B** (Tensor) - The B parameter in Lennard-Jones scheme of each atom pair type.
1698
- q is the number of atom pair. The data type is float32 and the shape is :math:`(q,)`.
1699
-
1700
- Outputs:
1701
- - **ene** (Tensor) - The accumulated potential energy of each atom.
1702
- The data type is float32 and the shape is :math:`(n,)`.
1703
-
1704
- Supported Platforms:
1705
- ``GPU``
1706
- """
1707
-
1708
- @prim_attr_register
1709
- def __init__(self, nb14_numbers, atom_numbers):
1710
- """Initialize Dihedral14LJAtomEnergy."""
1711
- validator.check_value_type('nb14_numbers', nb14_numbers, int, self.name)
1712
- validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
1713
- self.dihedral_14_numbers = nb14_numbers
1714
- self.atom_numbers = atom_numbers
1715
-
1716
- self.init_prim_io_names(
1717
- inputs=['uint_crd_f', 'LJtype', 'charge', 'boxlength_f', 'a_14', 'b_14', 'lj_scale_factor',
1718
- 'LJ_type_A', 'LJ_type_B'],
1719
- outputs=['ene'])
1720
- self.add_prim_attr('dihedral_14_numbers', self.dihedral_14_numbers)
1721
- self.add_prim_attr('atom_numbers', self.atom_numbers)
1722
-
1723
- def infer_shape(self, uint_crd_f_shape, ljtype_shape, charge_shape, boxlength_f_shape, a_14_shape, b_14_shape,
1724
- lj_scale_factor_shape, lj_type_a_shape, lj_type_b_shape):
1725
- cls_name = self.name
1726
- n = self.atom_numbers
1727
- q = lj_type_a_shape[0]
1728
- validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
1729
- validator.check_int(len(ljtype_shape), 1, Rel.EQ, "LJtype_dim", cls_name)
1730
- validator.check_int(len(charge_shape), 1, Rel.EQ, "charge_dim", cls_name)
1731
- validator.check_int(len(boxlength_f_shape), 1, Rel.EQ, "boxlength_f_dim", cls_name)
1732
- validator.check_int(len(a_14_shape), 1, Rel.EQ, "a_14_dim", cls_name)
1733
- validator.check_int(len(b_14_shape), 1, Rel.EQ, "b_14_dim", cls_name)
1734
- validator.check_int(len(lj_scale_factor_shape), 1, Rel.EQ, "lj_scale_factor_dim", cls_name)
1735
- validator.check_int(len(lj_type_b_shape), 1, Rel.EQ, "LJ_type_B_dim", cls_name)
1736
-
1737
- validator.check_int(uint_crd_f_shape[0], n, Rel.EQ, "uint_crd_f_shape[0]", cls_name)
1738
- validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
1739
- validator.check_int(ljtype_shape[0], n, Rel.EQ, "LJtype_shape", cls_name)
1740
- validator.check_int(charge_shape[0], n, Rel.EQ, "charge_shape", cls_name)
1741
- validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f_shape", cls_name)
1742
- validator.check_int(lj_type_b_shape[0], q, Rel.EQ, "LJ_type_B_shape", cls_name)
1743
- m = self.dihedral_14_numbers
1744
- validator.check_int(a_14_shape[0], m, Rel.EQ, "a_14_shape", cls_name)
1745
- validator.check_int(b_14_shape[0], m, Rel.EQ, "b_14_shape", cls_name)
1746
- validator.check_int(lj_scale_factor_shape[0], m, Rel.EQ, "lj_scale_factor_shape", cls_name)
1747
- return ljtype_shape
1748
-
1749
- def infer_dtype(self, uint_crd_f_dtype, ljtype_dtype, charge_dtype, boxlength_f_type, a_14_type, b_14_type,
1750
- lj_scale_factor_type, lj_type_a_type, lj_type_b_type):
1751
- validator.check_tensor_dtype_valid('uint_crd_f', uint_crd_f_dtype, [mstype.uint32], self.name)
1752
- validator.check_tensor_dtype_valid('LJtype', ljtype_dtype, [mstype.int32], self.name)
1753
- validator.check_tensor_dtype_valid('charge', charge_dtype, [mstype.float32], self.name)
1754
- validator.check_tensor_dtype_valid('boxlength_f', boxlength_f_type, [mstype.float32], self.name)
1755
- validator.check_tensor_dtype_valid('a_14', a_14_type, [mstype.int32], self.name)
1756
- validator.check_tensor_dtype_valid('b_14', b_14_type, [mstype.int32], self.name)
1757
- validator.check_tensor_dtype_valid('lj_scale_factor', lj_scale_factor_type, [mstype.float32],
1758
- self.name)
1759
- validator.check_tensor_dtype_valid('LJ_type_A', lj_type_a_type, [mstype.float32], self.name)
1760
- validator.check_tensor_dtype_valid('LJ_type_B', lj_type_b_type, [mstype.float32], self.name)
1761
-
1762
- return lj_type_a_type
1763
-
1764
-
1765
- class Dihedral14CFEnergy(PrimitiveWithInfer):
1766
- """
1767
- Calculate the Coulumb part of 1,4 dihedral energy correction for
1768
- each necessary dihedral terms on the corresponding atoms.
1769
-
1770
- Because there is a large amount of inputs and each of them are related,
1771
- there is no way to construct `Examples` using random methods. For details, refer the webpage `SPONGE in MindSpore
1772
- <https://gitee.com/mindspore/mindscience/blob/master/MindSPONGE/docs/simple_formula.md>`_.
1773
-
1774
- .. math::
1775
-
1776
- dr = (x_a-x_b, y_a-y_b, z_a-z_b)
1777
-
1778
- .. math::
1779
- E = k*q_a*q_b/|dr|
1780
-
1781
- Args:
1782
- nb14_numbers (int32): the number of necessary dihedral 1,4 terms m.
1783
- atom_numbers (int32): the number of atoms n.
1784
-
1785
- Inputs:
1786
- - **uint_crd_f** (Tensor) - The unsigned int coordinate value of each atom.
1787
- The data type is uint32 and the shape is :math:`(n, 3)`.
1788
- - **LJ_type** (Tensor) - The Lennard-Jones type of each atom.
1789
- The data type is int32 and the shape is :math:`(n,)`.
1790
- - **charge** (Tensor) - The charge of each atom.
1791
- The data type is float32 and the shape is :math:`(n,)`.
1792
- - **boxlength_f** (Tensor) - The length of molecular simulation box in 3 dimensions.
1793
- The data type is float32 and the shape is :math:`(3,)`.
1794
- - **a_14** (Tensor) - The first atom index of each dihedral 1,4 term.
1795
- The data type is int32 and the shape is :math:`(m,)`.
1796
- - **b_14** (Tensor) - The second atom index of each dihedral 1,4 term.
1797
- The data type is int32 and the shape is :math:`(m,)`.
1798
- - **cf_scale_factor** (Tensor) - The scale factor for the
1799
- Coulomb part of force correction for each dihedral 1,4 terms.
1800
- The data type is float32 and the shape is :math:`(m,)`.
1801
-
1802
- Outputs:
1803
- - **ene** (Tensor) - The accumulated potential energy of each atom.
1804
- The data type is float32 and the shape is :math:`(m,)`.
1805
-
1806
- Supported Platforms:
1807
- ``GPU``
1808
- """
1809
-
1810
- @prim_attr_register
1811
- def __init__(self, nb14_numbers, atom_numbers):
1812
- """Initialize Dihedral14CFEnergy."""
1813
- validator.check_value_type('nb14_numbers', nb14_numbers, int, self.name)
1814
- validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
1815
- self.dihedral_14_numbers = nb14_numbers
1816
- self.atom_numbers = atom_numbers
1817
-
1818
- self.init_prim_io_names(
1819
- inputs=['uint_crd_f', 'LJtype', 'charge', 'boxlength_f', 'a_14', 'b_14', 'cj_scale_factor'],
1820
- outputs=['ene'])
1821
- self.add_prim_attr('dihedral_14_numbers', self.dihedral_14_numbers)
1822
- self.add_prim_attr('atom_numbers', self.atom_numbers)
1823
-
1824
- def infer_shape(self, uint_crd_f_shape, ljtype_shape, charge_shape, boxlength_f_shape, a_14_shape, b_14_shape,
1825
- cf_scale_factor_shape):
1826
- cls_name = self.name
1827
- n = self.atom_numbers
1828
- validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
1829
- validator.check_int(len(ljtype_shape), 1, Rel.EQ, "LJtype_dim", cls_name)
1830
- validator.check_int(len(charge_shape), 1, Rel.EQ, "charge_dim", cls_name)
1831
- validator.check_int(len(boxlength_f_shape), 1, Rel.EQ, "boxlength_f_dim", cls_name)
1832
- validator.check_int(len(a_14_shape), 1, Rel.EQ, "a_14_dim", cls_name)
1833
- validator.check_int(len(b_14_shape), 1, Rel.EQ, "b_14_dim", cls_name)
1834
- validator.check_int(len(cf_scale_factor_shape), 1, Rel.EQ, "cf_scale_factor_dim", cls_name)
1835
-
1836
- validator.check_int(uint_crd_f_shape[0], n, Rel.EQ, "uint_crd_f_shape[0]", cls_name)
1837
- validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
1838
- validator.check_int(ljtype_shape[0], n, Rel.EQ, "LJtype_shape", cls_name)
1839
- validator.check_int(charge_shape[0], n, Rel.EQ, "charge_shape", cls_name)
1840
- validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f_shape", cls_name)
1841
- m = self.dihedral_14_numbers
1842
- validator.check_int(a_14_shape[0], m, Rel.EQ, "a_14_shape", cls_name)
1843
- validator.check_int(b_14_shape[0], m, Rel.EQ, "b_14_shape", cls_name)
1844
- validator.check_int(cf_scale_factor_shape[0], m, Rel.EQ, "cf_scale_factor_shape", cls_name)
1845
- return [self.dihedral_14_numbers,]
1846
-
1847
- def infer_dtype(self, uint_crd_f_dtype, ljtype_dtype, charge_dtype, boxlength_f_type, a_14_type, b_14_type,
1848
- cf_scale_factor_type):
1849
- validator.check_tensor_dtype_valid('uint_crd_f', uint_crd_f_dtype, [mstype.uint32], self.name)
1850
- validator.check_tensor_dtype_valid('LJtype', ljtype_dtype, [mstype.int32], self.name)
1851
- validator.check_tensor_dtype_valid('charge', charge_dtype, [mstype.float32], self.name)
1852
- validator.check_tensor_dtype_valid('boxlength_f', boxlength_f_type, [mstype.float32], self.name)
1853
- validator.check_tensor_dtype_valid('a_14', a_14_type, [mstype.int32], self.name)
1854
- validator.check_tensor_dtype_valid('b_14', b_14_type, [mstype.int32], self.name)
1855
- validator.check_tensor_dtype_valid('lj_scale_factor', cf_scale_factor_type, [mstype.float32],
1856
- self.name)
1857
-
1858
- return charge_dtype
1859
-
1860
-
1861
- class Dihedral14CFAtomEnergy(PrimitiveWithInfer):
1862
- """
1863
- Add the potential energy caused by Coulumb energy correction for each
1864
- necessary dihedral 1,4 terms to the total potential energy of each atom.
1865
-
1866
- The calculation formula is the same as operator :class:`Dihedral14CFEnergy`.
1867
-
1868
- Because there is a large amount of inputs and each of them are related,
1869
- there is no way to construct `Examples` using random methods. For details, refer the webpage `SPONGE in MindSpore
1870
- <https://gitee.com/mindspore/mindscience/blob/master/MindSPONGE/docs/simple_formula.md>`_.
1871
-
1872
- Args:
1873
- nb14_numbers (int32): the number of necessary dihedral 1,4 terms m.
1874
- atom_numbers (int32): the number of atoms n.
1875
-
1876
- Inputs:
1877
- - **uint_crd_f** (Tensor) - The unsigned int coordinate value of each atom.
1878
- The data type is uint32 and the shape is :math:`(n, 3)`.
1879
- - **LJ_type** (Tensor) - The Lennard-Jones type of each atom.
1880
- The data type is int32 and the shape is :math:`(n,)`.
1881
- - **charge** (Tensor) - The charge of each atom.
1882
- The data type is float32 and the shape is :math:`(n,)`.
1883
- - **boxlength_f** (Tensor) - The length of molecular simulation box in 3 dimensions.
1884
- The data type is float32 and the shape is :math:`(3,)`.
1885
- - **a_14** (Tensor) - The first atom index of each dihedral 1,4 term.
1886
- The data type is int32 and the shape is :math:`(m,)`.
1887
- - **b_14** (Tensor) - The second atom index of each dihedral 1,4 term.
1888
- The data type is int32 and the shape is :math:`(m,)`.
1889
- - **cf_scale_factor** (Tensor) - The scale factor for the
1890
- Coulomb part of force correction for each dihedral 1,4 terms.
1891
- The data type is float32 and the shape is :math:`(m,)`.
1892
-
1893
- Outputs:
1894
- - **ene** (Tensor) - The accumulated potential energy of each atom.
1895
- The data type is float32 and the shape is :math:`(n,)`
1896
-
1897
-
1898
- Supported Platforms:
1899
- ``GPU``
1900
- """
1901
-
1902
- @prim_attr_register
1903
- def __init__(self, nb14_numbers, atom_numbers):
1904
- """Initialize Dihedral14CFAtomEnergy."""
1905
- validator.check_value_type('nb14_numbers', nb14_numbers, int, self.name)
1906
- validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
1907
- self.dihedral_14_numbers = nb14_numbers
1908
- self.atom_numbers = atom_numbers
1909
-
1910
- self.init_prim_io_names(
1911
- inputs=['uint_crd_f', 'LJtype', 'charge', 'boxlength_f', 'a_14', 'b_14', 'cf_scale_factor'],
1912
- outputs=['ene'])
1913
- self.add_prim_attr('dihedral_14_numbers', self.dihedral_14_numbers)
1914
- self.add_prim_attr('atom_numbers', self.atom_numbers)
1915
-
1916
- def infer_shape(self, uint_crd_f_shape, ljtype_shape, charge_shape, boxlength_f_shape, a_14_shape, b_14_shape,
1917
- cf_scale_factor_shape):
1918
- cls_name = self.name
1919
- n = self.atom_numbers
1920
- validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
1921
- validator.check_int(len(ljtype_shape), 1, Rel.EQ, "LJtype_dim", cls_name)
1922
- validator.check_int(len(charge_shape), 1, Rel.EQ, "charge_dim", cls_name)
1923
- validator.check_int(len(boxlength_f_shape), 1, Rel.EQ, "boxlength_f_dim", cls_name)
1924
- validator.check_int(len(a_14_shape), 1, Rel.EQ, "a_14_dim", cls_name)
1925
- validator.check_int(len(b_14_shape), 1, Rel.EQ, "b_14_dim", cls_name)
1926
- validator.check_int(len(cf_scale_factor_shape), 1, Rel.EQ, "cf_scale_factor_dim", cls_name)
1927
-
1928
- validator.check_int(uint_crd_f_shape[0], n, Rel.EQ, "uint_crd_f_shape[0]", cls_name)
1929
- validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
1930
- validator.check_int(ljtype_shape[0], n, Rel.EQ, "LJtype_shape", cls_name)
1931
- validator.check_int(charge_shape[0], n, Rel.EQ, "charge_shape", cls_name)
1932
- validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f_shape", cls_name)
1933
- m = self.dihedral_14_numbers
1934
- validator.check_int(a_14_shape[0], m, Rel.EQ, "a_14_shape", cls_name)
1935
- validator.check_int(b_14_shape[0], m, Rel.EQ, "b_14_shape", cls_name)
1936
- validator.check_int(cf_scale_factor_shape[0], m, Rel.EQ, "cf_scale_factor_shape", cls_name)
1937
- return ljtype_shape
1938
-
1939
- def infer_dtype(self, uint_crd_f_dtype, ljtype_dtype, charge_dtype, boxlength_f_type, a_14_type, b_14_type,
1940
- cf_scale_factor_type):
1941
- validator.check_tensor_dtype_valid('uint_crd_f', uint_crd_f_dtype, [mstype.uint32], self.name)
1942
- validator.check_tensor_dtype_valid('LJtype', ljtype_dtype, [mstype.int32], self.name)
1943
- validator.check_tensor_dtype_valid('charge', charge_dtype, [mstype.float32], self.name)
1944
- validator.check_tensor_dtype_valid('boxlength_f', boxlength_f_type, [mstype.float32], self.name)
1945
- validator.check_tensor_dtype_valid('a_14', a_14_type, [mstype.int32], self.name)
1946
- validator.check_tensor_dtype_valid('b_14', b_14_type, [mstype.int32], self.name)
1947
- validator.check_tensor_dtype_valid('cf_scale_factor', cf_scale_factor_type, [mstype.float32],
1948
- self.name)
1949
-
1950
- return charge_dtype
1951
-
1952
-
1953
- class PMEReciprocalForce(PrimitiveWithInfer):
1954
- """
1955
- Calculate the reciprocal part of long-range Coulumb force using
1956
- PME(Particle Meshed Ewald) method. Assume the number of atoms is n.
1957
-
1958
- The detailed calculation formula of PME(Particle Meshed Ewald) method
1959
- can be found in this paper: A Smooth Particle Mesh Ewald Method. DOI:
1960
- 10.1063/1.470117.
1961
-
1962
- Because there is a large amount of inputs and each of them are related,
1963
- there is no way to construct `Examples` using random methods. For details, refer the webpage `SPONGE in MindSpore
1964
- <https://gitee.com/mindspore/mindscience/blob/master/MindSPONGE/docs/simple_formula.md>`_.
1965
-
1966
- Args:
1967
- atom_numbers(int32): the number of atoms, n.
1968
- beta(float32): the PME beta parameter, determined by the
1969
- non-bond cutoff value and simulation precision tolerance.
1970
- fftx(int32): the number of points for Fourier transform in dimension X.
1971
- ffty(int32): the number of points for Fourier transform in dimension Y.
1972
- fftz(int32): the number of points for Fourier transform in dimension Z.
1973
- box_length_0(float32): the value of boxlength idx 0
1974
- box_length_1(float32): the value of boxlength idx 1
1975
- box_length_2(float32): the value of boxlength idx 2
1976
-
1977
- Inputs:
1978
- - **uint_crd** (Tensor) - The unsigned int coordinates value of each atom.
1979
- The data type is uint32 and the shape is :math:`(n, 3)`
1980
- - **charge** (Tensor) - The charge carried by each atom.
1981
- The data type is float32 and the shape is :math:`(n,)`
1982
-
1983
- Outputs:
1984
- - **force** (Tensor) - The force felt by each atom.
1985
- The data type is float32 and the shape is :math:`(n, 3)`
1986
-
1987
- Supported Platforms:
1988
- ``GPU``
1989
- """
1990
-
1991
- @prim_attr_register
1992
- def __init__(self, atom_numbers, beta, fftx, ffty, fftz, box_length_0, box_length_1, box_length_2):
1993
- """Initialize PMEReciprocalForce."""
1994
- validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
1995
- validator.check_value_type('beta', beta, float, self.name)
1996
- validator.check_value_type('fftx', fftx, int, self.name)
1997
- validator.check_value_type('ffty', ffty, int, self.name)
1998
- validator.check_value_type('fftz', fftz, int, self.name)
1999
- validator.check_value_type('box_length_0', box_length_0, float, self.name)
2000
- validator.check_value_type('box_length_1', box_length_1, float, self.name)
2001
- validator.check_value_type('box_length_2', box_length_2, float, self.name)
2002
- self.atom_numbers = atom_numbers
2003
- self.beta = beta
2004
- self.fftx = fftx
2005
- self.ffty = ffty
2006
- self.fftz = fftz
2007
- self.box_length_0 = box_length_0
2008
- self.box_length_1 = box_length_1
2009
- self.box_length_2 = box_length_2
2010
-
2011
- self.init_prim_io_names(inputs=['boxlength', 'uint_crd', 'charge'],
2012
- outputs=['force'])
2013
- self.add_prim_attr('atom_numbers', self.atom_numbers)
2014
- self.add_prim_attr('beta', self.beta)
2015
- self.add_prim_attr('fftx', self.fftx)
2016
- self.add_prim_attr('ffty', self.ffty)
2017
- self.add_prim_attr('fftz', self.fftz)
2018
- self.add_prim_attr('box_length_0', self.box_length_0)
2019
- self.add_prim_attr('box_length_1', self.box_length_1)
2020
- self.add_prim_attr('box_length_2', self.box_length_2)
2021
-
2022
- def infer_shape(self, uint_crd_shape, charge_shape):
2023
- cls_name = self.name
2024
- n = self.atom_numbers
2025
- validator.check_int(len(uint_crd_shape), 2, Rel.EQ, "uint_crd_dim", cls_name)
2026
- validator.check_int(len(charge_shape), 1, Rel.EQ, "charge_dim", cls_name)
2027
-
2028
- validator.check_int(uint_crd_shape[0], n, Rel.EQ, "uint_crd_shape[0]", cls_name)
2029
- validator.check_int(uint_crd_shape[1], 3, Rel.EQ, "uint_crd_shape[1]", cls_name)
2030
- validator.check_int(charge_shape[0], n, Rel.EQ, "charge_shape", cls_name)
2031
- return uint_crd_shape
2032
-
2033
- def infer_dtype(self, uint_crd_type, charge_type):
2034
- validator.check_tensor_dtype_valid('uint_crd', uint_crd_type, [mstype.uint32], self.name)
2035
- validator.check_tensor_dtype_valid('charge', charge_type, [mstype.float32], self.name)
2036
- return charge_type
2037
-
2038
-
2039
- class PMEExcludedForce(PrimitiveWithInfer):
2040
- """
2041
- Calculate the excluded part of long-range Coulumb force using
2042
- PME(Particle Meshed Ewald) method. Assume the number of atoms is
2043
- n, and the length of excluded list is E.
2044
-
2045
- Because there is a large amount of inputs and each of them are related,
2046
- there is no way to construct `Examples` using random methods. For details, refer the webpage `SPONGE in MindSpore
2047
- <https://gitee.com/mindspore/mindscience/blob/master/MindSPONGE/docs/simple_formula.md>`_.
2048
-
2049
- Args:
2050
- atom_numbers(int32): the number of atoms, n.
2051
- excluded_numbers(int32): the length of excluded list, E.
2052
- beta(float32): the PME beta parameter, determined by the
2053
- non-bond cutoff value and simulation precision tolerance.
2054
-
2055
- Inputs:
2056
- - **uint_crd** (Tensor) - The unsigned int coordinates value of each atom.
2057
- The data type is uint32 and the shape is :math:`(n, 3)`
2058
- - **scaler** (Tensor) - The scale factor between real space
2059
- coordinates and its unsigned int value. The data type is float32 and the shape is :math:`(3,)`
2060
- - **charge** (Tensor) - The charge carried by each atom.
2061
- The data type is float32 and the shape is :math:`(n,)`
2062
- - **excluded_list_start** (Tensor) - The start excluded index
2063
- in excluded list for each atom. The data type is int32 and the shape is :math:`(n,)`
2064
- - **excluded_list** (Tensor) - The contiguous join of excluded
2065
- list of each atom. E is the number of excluded atoms. The data type is int32 and the shape is :math:`(E,)`
2066
- - **excluded_atom_numbers** (Tensor) - The number of atom excluded
2067
- in excluded list for each atom. The data type is int32 and the shape is :math:`(n,)`
2068
-
2069
- Outputs:
2070
- - **force** (Tensor) - The force felt by each atom.
2071
- The data type is float32 and the shape is :math:`(n, 3)`
2072
-
2073
- Supported Platforms:
2074
- ``GPU``
2075
- """
2076
-
2077
- @prim_attr_register
2078
- def __init__(self, atom_numbers, excluded_numbers, beta):
2079
- """Initialize PMEExcludedForce."""
2080
- validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
2081
- validator.check_value_type('excluded_numbers', excluded_numbers, int, self.name)
2082
- validator.check_value_type('beta', beta, float, self.name)
2083
- self.atom_numbers = atom_numbers
2084
- self.excluded_numbers = excluded_numbers
2085
- self.beta = beta
2086
- self.init_prim_io_names(
2087
- inputs=['uint_crd', 'sacler', 'charge', 'excluded_list_start', 'excluded_list', 'excluded_atom_numbers'],
2088
- outputs=['force'])
2089
- self.add_prim_attr('atom_numbers', self.atom_numbers)
2090
- self.add_prim_attr('excluded_numbers', self.excluded_numbers)
2091
- self.add_prim_attr('beta', self.beta)
2092
-
2093
- def infer_shape(self, uint_crd_shape, sacler_shape, charge_shape, excluded_list_start_shape, excluded_list_shape,
2094
- excluded_atom_numbers_shape):
2095
- cls_name = self.name
2096
- n = self.atom_numbers
2097
- validator.check_int(len(uint_crd_shape), 2, Rel.EQ, "uint_crd_dim", cls_name)
2098
- validator.check_int(len(sacler_shape), 1, Rel.EQ, "sacler_dim", cls_name)
2099
- validator.check_int(len(charge_shape), 1, Rel.EQ, "charge_dim", cls_name)
2100
- validator.check_int(len(excluded_list_start_shape), 1, Rel.EQ, "excluded_list_start_dim", cls_name)
2101
- validator.check_int(len(excluded_atom_numbers_shape), 1, Rel.EQ, "excluded_atom_numbers_dim", cls_name)
2102
-
2103
- validator.check_int(uint_crd_shape[0], n, Rel.EQ, "uint_crd_shape[0]", cls_name)
2104
- validator.check_int(uint_crd_shape[1], 3, Rel.EQ, "uint_crd_shape[1]", cls_name)
2105
- validator.check_int(sacler_shape[0], 3, Rel.EQ, "sacler_shape", cls_name)
2106
- validator.check_int(charge_shape[0], n, Rel.EQ, "charge_shape", cls_name)
2107
- validator.check_int(excluded_list_start_shape[0], n, Rel.EQ, "excluded_list_start_shape", cls_name)
2108
- validator.check_int(excluded_atom_numbers_shape[0], n, Rel.EQ, "excluded_atom_numbers_shape", cls_name)
2109
- return uint_crd_shape
2110
-
2111
- def infer_dtype(self, uint_crd_type, sacler_type, charge_type, excluded_list_start_type, excluded_list_type,
2112
- excluded_atom_numbers_type):
2113
- validator.check_tensor_dtype_valid('sacler', sacler_type, [mstype.float32], self.name)
2114
- validator.check_tensor_dtype_valid('uint_crd', uint_crd_type, [mstype.uint32], self.name)
2115
- validator.check_tensor_dtype_valid('charge', charge_type, [mstype.float32], self.name)
2116
- validator.check_tensor_dtype_valid('excluded_list_start', excluded_list_start_type, [mstype.int32],
2117
- self.name)
2118
- validator.check_tensor_dtype_valid('excluded_list', excluded_list_type, [mstype.int32],
2119
- self.name)
2120
- validator.check_tensor_dtype_valid('excluded_atom_numbers', excluded_atom_numbers_type, [mstype.int32],
2121
- self.name)
2122
- return charge_type
2123
-
2124
-
2125
- class PMEEnergy(PrimitiveWithInfer):
2126
- """
2127
- Calculate the Coulumb energy of the system using PME method.
2128
-
2129
- Because there is a large amount of inputs and each of them are related,
2130
- there is no way to construct `Examples` using random methods. For details, refer the webpage `SPONGE in MindSpore
2131
- <https://gitee.com/mindspore/mindscience/blob/master/MindSPONGE/docs/simple_formula.md>`_.
2132
-
2133
- .. math::
2134
-
2135
- E = sum_{ij} q_iq_j/r_{ij}
2136
-
2137
- Args:
2138
- atom_numbers(int32): the number of atoms, n.
2139
- excluded_numbers(int32): the length of excluded list, E.
2140
- beta(float32): the PME beta parameter, determined by the
2141
- non-bond cutoff value and simulation precision tolerance.
2142
- fftx(int32): the number of points for Fourier transform in dimension X.
2143
- ffty(int32): the number of points for Fourier transform in dimension Y.
2144
- fftz(int32): the number of points for Fourier transform in dimension Z.
2145
- box_length_0(float32): the value of boxlength idx 0
2146
- box_length_1(float32): the value of boxlength idx 1
2147
- box_length_2(float32): the value of boxlength idx 2
2148
-
2149
-
2150
- Inputs:
2151
- - **uint_crd** (Tensor) - The unsigned int coordinates value of each atom.
2152
- The data type is uint32 and the shape is :math:`(n, 3)`
2153
- - **charge** (Tensor) - The charge carried by each atom.
2154
- The data type is float32 and the shape is :math:`(n,)`
2155
- - **nl_numbers** - (Tensor) - The each atom.
2156
- The data type is int32 and the shape is :math:`(n, 3)`
2157
- - **nl_serial** - (Tensor) - The neighbor list of each atom, the max number is 800.
2158
- The data type is int32 and the shape is :math:`(n, 800)`
2159
- - **scaler** (Tensor) - The scale factor between real space
2160
- coordinates and its unsigned int value. The data type is float32 and the shape is :math:`(3,)`
2161
- - **excluded_list_start** (Tensor) - The start excluded index
2162
- in excluded list for each atom. The data type is int32 and the shape is :math:`(n,)`
2163
- - **excluded_list** (Tensor) - The contiguous join of excluded
2164
- list of each atom. E is the number of excluded atoms. The data type is int32 and the shape is :math:`(E,)`
2165
- - **excluded_atom_numbers** (Tensor) - The number of atom excluded
2166
- in excluded list for each atom. The data type is int32 and the shape is :math:`(n,)`
2167
-
2168
- Outputs:
2169
- - **reciprocal_ene** (Tensor) - The reciprocal term of PME energy.
2170
- The data type is float32 and the the shape is :math:`(1,)`.
2171
- - **self_ene** (Tensor) - The self term of PME energy.
2172
- The data type is float32 and the the shape is :math:`(1,)`.
2173
- - **direct_ene** (Tensor) - The direct term of PME energy.
2174
- The data type is float32 and the the shape is :math:`(1,)`.
2175
- - **correction_ene** (Tensor) - The correction term of PME energy.
2176
- The data type is float32 and the the shape is :math:`(1,)`.
2177
-
2178
- Supported Platforms:
2179
- ``GPU``
2180
- """
2181
-
2182
- @prim_attr_register
2183
- def __init__(self, atom_numbers, excluded_numbers, beta, fftx, ffty, fftz, box_length_0, box_length_1,
2184
- box_length_2):
2185
- """Initialize PMEEnergy."""
2186
- validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
2187
- validator.check_value_type('excluded_numbers', excluded_numbers, int, self.name)
2188
- validator.check_value_type('beta', beta, float, self.name)
2189
- validator.check_value_type('fftx', fftx, int, self.name)
2190
- validator.check_value_type('ffty', ffty, int, self.name)
2191
- validator.check_value_type('fftz', fftz, int, self.name)
2192
- validator.check_value_type('box_length_0', box_length_0, float, self.name)
2193
- validator.check_value_type('box_length_1', box_length_1, float, self.name)
2194
- validator.check_value_type('box_length_2', box_length_2, float, self.name)
2195
- self.atom_numbers = atom_numbers
2196
- self.excluded_numbers = excluded_numbers
2197
- self.beta = beta
2198
- self.fftx = fftx
2199
- self.ffty = ffty
2200
- self.fftz = fftz
2201
- self.box_length_0 = box_length_0
2202
- self.box_length_1 = box_length_1
2203
- self.box_length_2 = box_length_2
2204
- self.init_prim_io_names(
2205
- inputs=['box_length', 'uint_crd', 'charge', 'nl_numbers', 'nl_serial', 'scaler', 'excluded_list_start',
2206
- 'excluded_list', 'excluded_atom_numbers'],
2207
- outputs=['reciprocal_ene', 'self_ene', 'direct_ene', 'correction_ene'])
2208
- self.add_prim_attr('atom_numbers', self.atom_numbers)
2209
- self.add_prim_attr('excluded_numbers', self.excluded_numbers)
2210
- self.add_prim_attr('beta', self.beta)
2211
- self.add_prim_attr('fftx', self.fftx)
2212
- self.add_prim_attr('ffty', self.ffty)
2213
- self.add_prim_attr('fftz', self.fftz)
2214
- self.add_prim_attr('box_length_0', self.box_length_0)
2215
- self.add_prim_attr('box_length_1', self.box_length_1)
2216
- self.add_prim_attr('box_length_2', self.box_length_2)
2217
-
2218
- def infer_shape(self, uint_crd, charge, nl_numbers, nl_serial, scaler, excluded_list_start,
2219
- excluded_list, excluded_atom_numbers):
2220
- cls_name = self.name
2221
- n = self.atom_numbers
2222
- validator.check_int(len(uint_crd), 2, Rel.EQ, "uint_crd_dim", cls_name)
2223
- validator.check_int(len(charge), 1, Rel.EQ, "charge_dim", cls_name)
2224
- validator.check_int(len(nl_numbers), 1, Rel.EQ, "nl_numbers_dim", cls_name)
2225
- validator.check_int(len(nl_serial), 2, Rel.LE, "nl_serial_dim", cls_name)
2226
- validator.check_int(len(excluded_list_start), 1, Rel.EQ, "excluded_list_start_dim", cls_name)
2227
- validator.check_int(len(excluded_atom_numbers), 1, Rel.EQ, "excluded_atom_numbers_dim", cls_name)
2228
- validator.check_int(len(excluded_list), 1, Rel.GE, "excluded_list", cls_name)
2229
-
2230
- validator.check_int(uint_crd[0], n, Rel.EQ, "uint_crd_shape[0]", cls_name)
2231
- validator.check_int(uint_crd[1], 3, Rel.EQ, "uint_crd_shape[1]", cls_name)
2232
- validator.check_int(charge[0], n, Rel.EQ, "charge_shape", cls_name)
2233
- validator.check_int(nl_numbers[0], n, Rel.EQ, "nl_numbers_shape[0]", cls_name)
2234
- validator.check_int(nl_serial[0], n, Rel.LE, "nl_serial_shape[0]", cls_name)
2235
- validator.check_int(nl_serial[1], 800, Rel.LE, "nl_serial_shape[1]", cls_name)
2236
- validator.check_int(excluded_list_start[0], n, Rel.EQ, "excluded_list_start_shape", cls_name)
2237
- validator.check_int(excluded_atom_numbers[0], n, Rel.EQ, "excluded_atom_numbers_shape", cls_name)
2238
- validator.check_int(excluded_list[0], 0, Rel.GE, "excluded_list_shape", cls_name)
2239
- return (1,), (1,), (1,), (1,)
2240
-
2241
- def infer_dtype(self, uint_crd, charge, nl_numbers, nl_serial, scaler, excluded_list_start,
2242
- excluded_list, excluded_atom_numbers):
2243
- validator.check_tensor_dtype_valid('uint_crd', uint_crd, [mstype.uint32], self.name)
2244
- validator.check_tensor_dtype_valid('charge', charge, [mstype.float32], self.name)
2245
- validator.check_tensor_dtype_valid('nl_numbers', nl_numbers, [mstype.int32], self.name)
2246
- validator.check_tensor_dtype_valid('nl_serial', nl_serial, [mstype.int32], self.name)
2247
- validator.check_tensor_dtype_valid('scaler', scaler, [mstype.float32], self.name)
2248
- validator.check_tensor_dtype_valid('excluded_list_start', excluded_list_start, [mstype.int32],
2249
- self.name)
2250
- validator.check_tensor_dtype_valid('excluded_list', excluded_list, [mstype.int32],
2251
- self.name)
2252
- validator.check_tensor_dtype_valid('excluded_atom_numbers', excluded_atom_numbers, [mstype.int32],
2253
- self.name)
2254
- return charge, charge, charge, charge
2255
-
2256
-
2257
- class LJEnergy(PrimitiveWithInfer):
2258
- """
2259
- Calculate the Van der Waals interaction energy described by Lennard-Jones
2260
- potential for each atom. Assume the number of atoms is n, and the number
2261
- of Lennard-Jones types for all atoms is P, which means there will be
2262
- q = P*(P+1)/2 types of possible Lennard-Jones interactions for all kinds
2263
- of atom pairs.
2264
-
2265
- Because there is a large amount of inputs and each of them are related,
2266
- there is no way to construct `Examples` using random methods. For details, refer the webpage `SPONGE in MindSpore
2267
- <https://gitee.com/mindspore/mindscience/blob/master/MindSPONGE/docs/simple_formula.md>`_.
2268
-
2269
- .. math::
2270
-
2271
- dr = (x_a-x_b, y_a-y_b, z_a-z_b)
2272
-
2273
- .. math::
2274
- E = A/|dr|^{12} - B/|dr|^{6}
2275
-
2276
- Args:
2277
- atom_numbers(int32): the number of atoms, n.
2278
- cutoff_square(float32): the square value of cutoff.
2279
-
2280
- Inputs:
2281
- - **uint_crd** (Tensor) - The unsigned int coordinate value of each atom.
2282
- The data type is uint32 and the shape is :math:`(n, 3)`
2283
- - **LJtype** (Tensor) - The Lennard-Jones type of each atom.
2284
- The data type is int32 and the shape is :math:`(n,)`
2285
- - **charge** (Tensor) - The charge carried by each atom.
2286
- The data type is float32 and the shape is :math:`(n,)`
2287
- - **scaler** (Tensor) - The scale factor between real
2288
- space coordinate and its unsigned int value. The data type is float32 and the shape is :math:`(3,)`
2289
- - **nl_numbers** - (Tensor) - The each atom.
2290
- The data type is int32 and the shape is :math:`(n,)`
2291
- - **nl_serial** - (Tensor) - The neighbor list of each atom, the max number is 800.
2292
- The data type is int32 and the shape is :math:`(n, 800)`.
2293
- - **d_LJ_A** (Tensor) - The Lennard-Jones A coefficient of each kind of atom pair.
2294
- q is the number of atom pair. The data type is float32 and the shape is :math:`(q,)`.
2295
- - **d_LJ_B** (Tensor) - The Lennard-Jones B coefficient of each kind of atom pair.
2296
- q is the number of atom pair. The data type is float32 and the shape is :math:`(q,)`.
2297
-
2298
- Outputs:
2299
- - **d_LJ_energy_atom** (Tensor) - The Lennard-Jones potential energy of each atom.
2300
- The data type is float32 and the shape is :math:`(n,)`.
2301
- - **d_LJ_energy_sum** (Scalar), the sum of Lennard-Jones potential energy of each atom.
2302
- The data type is float32.
2303
-
2304
- Supported Platforms:
2305
- ``GPU``
2306
- """
2307
-
2308
- @prim_attr_register
2309
- def __init__(self, atom_numbers, cutoff_square):
2310
- """Initialize LJEnergy."""
2311
- validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
2312
- validator.check_value_type('cutoff_square', cutoff_square, float, self.name)
2313
- self.atom_numbers = atom_numbers
2314
- self.cutoff_square = cutoff_square
2315
- self.init_prim_io_names(
2316
- inputs=['uint_crd', 'LJtype', 'charge', 'scaler', 'nl_numbers', 'nl_serial', 'd_LJ_A', 'd_LJ_B'],
2317
- outputs=['d_LJ_energy_atom'])
2318
- self.add_prim_attr('atom_numbers', self.atom_numbers)
2319
- self.add_prim_attr('cutoff_square', self.cutoff_square)
2320
-
2321
- def infer_shape(self, uint_crd, ljtype, charge, scaler, nl_numbers, nl_serial, d_lj_a, d_lj_b):
2322
- cls_name = self.name
2323
- n = self.atom_numbers
2324
- q = d_lj_a[0]
2325
- validator.check_int(len(uint_crd), 2, Rel.EQ, "uint_crd_dim", cls_name)
2326
- validator.check_int(len(ljtype), 1, Rel.EQ, "LJtype_dim", cls_name)
2327
- validator.check_int(len(charge), 1, Rel.EQ, "charge_dim", cls_name)
2328
- validator.check_int(len(scaler), 1, Rel.EQ, "scaler_dim", cls_name)
2329
- validator.check_int(len(nl_numbers), 1, Rel.EQ, "nl_numbers_dim", cls_name)
2330
- validator.check_int(len(nl_serial), 2, Rel.EQ, "nl_serial_dim", cls_name)
2331
- validator.check_int(len(d_lj_b), 1, Rel.EQ, "d_LJ_B_dim", cls_name)
2332
-
2333
- validator.check_int(uint_crd[0], n, Rel.EQ, "uint_crd_shape[0]", cls_name)
2334
- validator.check_int(uint_crd[1], 3, Rel.EQ, "uint_crd_shape[1]", cls_name)
2335
- validator.check_int(ljtype[0], n, Rel.EQ, "LJtype_shape", cls_name)
2336
- validator.check_int(charge[0], n, Rel.EQ, "charge_shape", cls_name)
2337
- validator.check_int(scaler[0], 3, Rel.EQ, "scaler_shape", cls_name)
2338
- validator.check_int(nl_numbers[0], n, Rel.EQ, "nl_numbers_shape", cls_name)
2339
- validator.check_int(nl_serial[0], n, Rel.EQ, "nl_serial_shape[0]", cls_name)
2340
- validator.check_int(nl_serial[1], 800, Rel.LE, "nl_serial_shape[1]", cls_name)
2341
- validator.check_int(len(d_lj_a), 1, Rel.EQ, "d_LJ_A_dim", cls_name)
2342
- validator.check_int(d_lj_a[0], q, Rel.EQ, "d_LJ_A_shape[0]", cls_name)
2343
- validator.check_int(d_lj_b[0], q, Rel.EQ, "d_LJ_B_shape[0]", cls_name)
2344
- return charge
2345
-
2346
- def infer_dtype(self, uint_crd, ljtype, charge, scaler, nl_numbers, nl_serial, d_lj_a, d_lj_b):
2347
- validator.check_tensor_dtype_valid('uint_crd', uint_crd, [mstype.uint32], self.name)
2348
- validator.check_tensor_dtype_valid('LJtype', ljtype, [mstype.int32], self.name)
2349
- validator.check_tensor_dtype_valid('charge', charge, [mstype.float32], self.name)
2350
- validator.check_tensor_dtype_valid('scaler', scaler, [mstype.float32], self.name)
2351
- validator.check_tensor_dtype_valid('nl_numbers', nl_numbers, [mstype.int32], self.name)
2352
- validator.check_tensor_dtype_valid('nl_serial', nl_serial, [mstype.int32], self.name)
2353
- validator.check_tensor_dtype_valid('d_LJ_A', d_lj_a, [mstype.float32], self.name)
2354
- validator.check_tensor_dtype_valid('d_LJ_B', d_lj_b, [mstype.float32], self.name)
2355
- return charge
2356
-
2357
-
2358
- class LJForce(PrimitiveWithInfer):
2359
- """
2360
- Calculate the Van der Waals interaction force described by Lennard-Jones
2361
- potential energy for each atom.
2362
-
2363
- Because there is a large amount of inputs and each of them are related,
2364
- there is no way to construct `Examples` using random methods. For details, refer the webpage `SPONGE in MindSpore
2365
- <https://gitee.com/mindspore/mindscience/blob/master/MindSPONGE/docs/simple_formula.md>`_.
2366
-
2367
- .. math::
2368
-
2369
- dr = (x_a-x_b, y_a-y_b, z_a-z_b)
2370
-
2371
- .. math::
2372
-
2373
- F = (-12*A/|dr|^{14} + 6*B/|dr|^{8}) * dr
2374
-
2375
- Args:
2376
- atom_numbers(int32): the number of atoms, n.
2377
- cutoff_square(float32): the square value of cutoff.
2378
-
2379
- Inputs:
2380
- - **uint_crd** (Tensor) - The unsigned int coordinates value of each atom.
2381
- The data type is uint32 and the shape is :math:`(n, 3)`
2382
- - **LJtype** (Tensor) - The Lennard-Jones type of each atom.
2383
- The data type is int32 and the shape is :math:`(n,)`
2384
- - **charge** (Tensor) - The charge carried by each atom.
2385
- The data type is float32 and the shape is :math:`(n,)`
2386
- - **scaler** (Tensor) - The scale factor between real space
2387
- coordinates and its unsigned int value. The data type is float32 and the shape is :math:`(3,)`
2388
- - **nl_numbers** - (Tensor) - The each atom.
2389
- The data type is int32 and the shape is :math:`(n,)`
2390
- - **nl_serial** - (Tensor) - The neighbor list of each atom, the max number is 800.
2391
- The data type is int32 and the shape is :math:`(n, 800)`.
2392
- - **d_LJ_A** (Tensor) - The Lennard-Jones A coefficient of each kind of atom pair.
2393
- q is the number of atom pair. The data type is float32 and the shape is :math:`(q,)`.
2394
- - **d_LJ_B** (Tensor) - The Lennard-Jones B coefficient of each kind of atom pair.
2395
- q is the number of atom pair. The data type is float32 and the shape is :math:`(q,)`.
2396
-
2397
- outputs:
2398
- - **frc** (Tensor) - The force felt by each atom.
2399
- The data type is float32 and the shape is :math:`(n, 3)`.
2400
-
2401
- Supported Platforms:
2402
- ``GPU``
2403
- """
2404
-
2405
- @prim_attr_register
2406
- def __init__(self, atom_numbers, cutoff_square):
2407
- """Initialize LJForce."""
2408
- validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
2409
- validator.check_value_type('cutoff_square', cutoff_square, float, self.name)
2410
- self.atom_numbers = atom_numbers
2411
- self.cutoff_square = cutoff_square
2412
- self.init_prim_io_names(
2413
- inputs=['uint_crd', 'LJtype', 'charge', 'scaler', 'nl_numbers', 'nl_serial', 'd_LJ_A', 'd_LJ_B'],
2414
- outputs=['frc'])
2415
- self.add_prim_attr('atom_numbers', self.atom_numbers)
2416
- self.add_prim_attr('cutoff_square', self.cutoff_square)
2417
-
2418
- def infer_shape(self, uint_crd, ljtype, charge, scaler, nl_numbers, nl_serial, d_lj_a, d_lj_b):
2419
- cls_name = self.name
2420
- n = self.atom_numbers
2421
- q = d_lj_a[0]
2422
- validator.check_int(len(uint_crd), 2, Rel.EQ, "uint_crd_dim", cls_name)
2423
- validator.check_int(len(ljtype), 1, Rel.EQ, "LJtype_dim", cls_name)
2424
- validator.check_int(len(charge), 1, Rel.EQ, "charge_dim", cls_name)
2425
- validator.check_int(len(scaler), 1, Rel.EQ, "scaler_dim", cls_name)
2426
- validator.check_int(len(nl_numbers), 1, Rel.EQ, "nl_numbers_dim", cls_name)
2427
- validator.check_int(len(nl_serial), 2, Rel.EQ, "nl_serial_dim", cls_name)
2428
- validator.check_int(len(d_lj_b), 1, Rel.EQ, "d_LJ_B_dim", cls_name)
2429
-
2430
- validator.check_int(uint_crd[0], n, Rel.EQ, "uint_crd_shape[0]", cls_name)
2431
- validator.check_int(uint_crd[1], 3, Rel.EQ, "uint_crd_shape[1]", cls_name)
2432
- validator.check_int(ljtype[0], n, Rel.EQ, "LJtype_shape", cls_name)
2433
- validator.check_int(charge[0], n, Rel.EQ, "charge_shape", cls_name)
2434
- validator.check_int(scaler[0], 3, Rel.EQ, "scaler_shape", cls_name)
2435
- validator.check_int(nl_numbers[0], n, Rel.EQ, "nl_numbers_shape", cls_name)
2436
- validator.check_int(nl_serial[0], n, Rel.EQ, "nl_serial_shape[0]", cls_name)
2437
- validator.check_int(nl_serial[1], 800, Rel.EQ, "nl_serial_shape[1]", cls_name)
2438
- validator.check_int(len(d_lj_a), 1, Rel.EQ, "d_LJ_A_dim", cls_name)
2439
- validator.check_int(d_lj_a[0], q, Rel.EQ, "d_LJ_A_shape[0]", cls_name)
2440
- validator.check_int(d_lj_b[0], q, Rel.EQ, "d_LJ_B_shape[0]", cls_name)
2441
- return uint_crd
2442
-
2443
- def infer_dtype(self, uint_crd, ljtype, charge, scaler, nl_numbers, nl_serial, d_lj_a, d_lj_b):
2444
- validator.check_tensor_dtype_valid('uint_crd', uint_crd, [mstype.uint32], self.name)
2445
- validator.check_tensor_dtype_valid('LJtype', ljtype, [mstype.int32], self.name)
2446
- validator.check_tensor_dtype_valid('charge', charge, [mstype.float32], self.name)
2447
- validator.check_tensor_dtype_valid('scaler', scaler, [mstype.float32], self.name)
2448
- validator.check_tensor_dtype_valid('nl_numbers', nl_numbers, [mstype.int32], self.name)
2449
- validator.check_tensor_dtype_valid('nl_serial', nl_serial, [mstype.int32], self.name)
2450
- validator.check_tensor_dtype_valid('d_LJ_A', d_lj_a, [mstype.float32], self.name)
2451
- validator.check_tensor_dtype_valid('d_LJ_B', d_lj_b, [mstype.float32], self.name)
2452
- return charge
2453
-
2454
-
2455
- class LJForceWithPMEDirectForce(PrimitiveWithInfer):
2456
- """
2457
- Calculate the Lennard-Jones force and PME direct force together.
2458
-
2459
- The calculation formula of Lennard-Jones part is the same as operator
2460
- LJForce(), and the PME direct part is within PME method.
2461
-
2462
- Because there is a large amount of inputs and each of them are related,
2463
- there is no way to construct `Examples` using random methods. For details, refer the webpage `SPONGE in MindSpore
2464
- <https://gitee.com/mindspore/mindscience/blob/master/MindSPONGE/docs/simple_formula.md>`_.
2465
-
2466
- Args:
2467
- atom_numbers(int32): the number of atoms, n.
2468
- cutoff_square(float32): the square value of cutoff.
2469
- pme_beta(float32): PME beta parameter, same as operator PMEReciprocalForce().
2470
-
2471
- Inputs:
2472
- - **uint_crd** (Tensor) - The unsigned int coordinate value of each atom.
2473
- The data type is uint32 and the shape is :math:`(n, 3)`.
2474
- - **LJtype** (Tensor) - The Lennard-Jones type of each atom.
2475
- The data type is int32 and the shape is :math:`(n,)`.
2476
- - **charge** (Tensor) - The charge carried by each atom.
2477
- The data type is float32 and the shape is :math:`(n,)`.
2478
- - **scaler** (Tensor) - The scale factor between real
2479
- space coordinate and its unsigned int value.
2480
- The data type is float32 and the shape is :math:`(3,)`.
2481
- - **nl_numbers** - (Tensor) - The each atom.
2482
- The data type is int32 and the shape is :math:`(n,)`.
2483
- - **nl_serial** - (Tensor) - The neighbor list of each atom, the max number is 800.
2484
- The data type is int32 and the shape is :math:`(n, 800)`.
2485
- - **d_LJ_A** (Tensor) - The Lennard-Jones A coefficient of each kind of atom pair.
2486
- q is the number of atom pair. The data type is float32 and the shape is :math:`(q,)`.
2487
- - **d_LJ_B** (Tensor) - The Lennard-Jones B coefficient of each kind of atom pair.
2488
- q is the number of atom pair. The data type is float32 and the shape is :math:`(q,)`.
2489
-
2490
- Outputs:
2491
- - **frc** (Tensor), The force felt by each atom.
2492
- The data type is float32 and the shape is :math:`(n, 3)`.
2493
-
2494
- Supported Platforms:
2495
- ``GPU``
2496
- """
2497
-
2498
- @prim_attr_register
2499
- def __init__(self, atom_numbers, cutoff, pme_beta):
2500
- """Initialize LJForceWithPMEDirectForce."""
2501
- validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
2502
- validator.check_value_type('cutoff', cutoff, float, self.name)
2503
- validator.check_value_type('pme_beta', pme_beta, float, self.name)
2504
- self.atom_numbers = atom_numbers
2505
- self.cutoff = cutoff
2506
- self.pme_beta = pme_beta
2507
- self.init_prim_io_names(
2508
- inputs=['uint_crd', 'LJtype', 'charge', 'scaler', 'nl_numbers', 'nl_serial', 'd_LJ_A', 'd_LJ_B'],
2509
- outputs=['frc'])
2510
- self.add_prim_attr('atom_numbers', self.atom_numbers)
2511
- self.add_prim_attr('cutoff', self.cutoff)
2512
- self.add_prim_attr('pme_beta', self.pme_beta)
2513
-
2514
- def infer_shape(self, uint_crd, ljtype, charge, scaler, nl_numbers, nl_serial, d_lj_a, d_lj_b):
2515
- cls_name = self.name
2516
- n = self.atom_numbers
2517
- q = d_lj_a[0]
2518
- validator.check_int(len(uint_crd), 2, Rel.EQ, "uint_crd_dim", cls_name)
2519
- validator.check_int(len(ljtype), 1, Rel.EQ, "LJtype_dim", cls_name)
2520
- validator.check_int(len(charge), 1, Rel.EQ, "charge_dim", cls_name)
2521
- validator.check_int(len(scaler), 1, Rel.EQ, "scaler_dim", cls_name)
2522
- validator.check_int(len(nl_numbers), 1, Rel.EQ, "nl_numbers_dim", cls_name)
2523
- validator.check_int(len(nl_serial), 2, Rel.EQ, "nl_serial_dim", cls_name)
2524
- validator.check_int(len(d_lj_b), 1, Rel.EQ, "d_LJ_B_dim", cls_name)
2525
-
2526
- validator.check_int(uint_crd[0], n, Rel.EQ, "uint_crd_shape[0]", cls_name)
2527
- validator.check_int(uint_crd[1], 3, Rel.EQ, "uint_crd_shape[1]", cls_name)
2528
- validator.check_int(ljtype[0], n, Rel.EQ, "LJtype_shape", cls_name)
2529
- validator.check_int(charge[0], n, Rel.EQ, "charge_shape", cls_name)
2530
- validator.check_int(scaler[0], 3, Rel.EQ, "scaler_shape", cls_name)
2531
- validator.check_int(nl_numbers[0], n, Rel.EQ, "nl_numbers_shape", cls_name)
2532
- validator.check_int(nl_serial[0], n, Rel.EQ, "nl_serial_shape[0]", cls_name)
2533
- validator.check_int(nl_serial[1], 800, Rel.EQ, "nl_serial_shape[1]", cls_name)
2534
- validator.check_int(len(d_lj_a), 1, Rel.EQ, "d_LJ_A_dim", cls_name)
2535
- validator.check_int(d_lj_a[0], q, Rel.EQ, "d_LJ_A_shape[0]", cls_name)
2536
- validator.check_int(d_lj_b[0], q, Rel.EQ, "d_LJ_B_shape[0]", cls_name)
2537
- return uint_crd
2538
-
2539
- def infer_dtype(self, uint_crd, ljtype, charge, scaler, nl_numbers, nl_serial, d_lj_a, d_lj_b):
2540
- validator.check_tensor_dtype_valid('uint_crd', uint_crd, [mstype.uint32], self.name)
2541
- validator.check_tensor_dtype_valid('LJtype', ljtype, [mstype.int32], self.name)
2542
- validator.check_tensor_dtype_valid('charge', charge, [mstype.float32], self.name)
2543
- validator.check_tensor_dtype_valid('scaler', scaler, [mstype.float32], self.name)
2544
- validator.check_tensor_dtype_valid('nl_numbers', nl_numbers, [mstype.int32], self.name)
2545
- validator.check_tensor_dtype_valid('nl_serial', nl_serial, [mstype.int32], self.name)
2546
- validator.check_tensor_dtype_valid('d_LJ_A', d_lj_a, [mstype.float32], self.name)
2547
- validator.check_tensor_dtype_valid('d_LJ_B', d_lj_b, [mstype.float32], self.name)
2548
- return charge
2549
-
2550
-
2551
- class MDTemperature(PrimitiveWithInfer):
2552
- """
2553
- Compute the MD temperature.
2554
-
2555
- Because there is a large amount of inputs and each of them are related,
2556
- there is no way to construct `Examples` using random methods. For details, refer the webpage `SPONGE in MindSpore
2557
- <https://gitee.com/mindspore/mindscience/blob/master/MindSPONGE/docs/simple_formula.md>`_.
2558
-
2559
- Args:
2560
- residue_numbers (int32): the number of residues m.
2561
- atom_numbers (int32): the number of atoms n.
2562
-
2563
- Inputs:
2564
- - **start** (Tensor) - The start atom index of each residue.
2565
- The data type is int32 and the shape is :math:`(m,)`.
2566
- - **end** (Tensor) - The end atom index of each residue.
2567
- The data type is int32 and the shape is :math:`(m,)`.
2568
- - **atom_vel_f** (Tensor) - The velocity of each atom.
2569
- The data type is float32 and the shape is :math:`(n, 3)`.
2570
- - **atom_mass** (Tensor) - The mass of each atom.
2571
- The data type is float32 and the shape is :math:`(n,)`.
2572
-
2573
- Outputs:
2574
- - **ek** (Tensor) - The temperature of each atom.
2575
- The data type is float32 and the shape is :math:`(n,)`.
2576
-
2577
- Supported Platforms:
2578
- ``GPU``
2579
- """
2580
-
2581
- @prim_attr_register
2582
- def __init__(self, residue_numbers, atom_numbers):
2583
- """Initialize MDTemperature."""
2584
- validator.check_value_type('residue_numbers', residue_numbers, int, self.name)
2585
- validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
2586
- self.residue_numbers = residue_numbers
2587
- self.atom_numbers = atom_numbers
2588
- self.add_prim_attr('residue_numbers', self.residue_numbers)
2589
- self.add_prim_attr('atom_numbers', self.atom_numbers)
2590
- self.init_prim_io_names(
2591
- inputs=['start', 'end', 'atom_vel_f', 'atom_mass'],
2592
- outputs=['ek'])
2593
-
2594
- def infer_shape(self, start_shape, end_shape, atom_vel_f_shape, atom_mass_shape):
2595
- cls_name = self.name
2596
- n = self.residue_numbers
2597
- m = self.atom_numbers
2598
- validator.check_int(len(start_shape), 1, Rel.EQ, "start", cls_name)
2599
- validator.check_int(start_shape[0], n, Rel.EQ, "end", cls_name)
2600
- validator.check_int(len(end_shape), 1, Rel.EQ, "start", cls_name)
2601
- validator.check_int(end_shape[0], n, Rel.EQ, "end", cls_name)
2602
- validator.check_int(atom_vel_f_shape[0], m, Rel.EQ, "atom_vel_f", cls_name)
2603
- validator.check_int(atom_vel_f_shape[1], 3, Rel.EQ, "atom_vel_f", cls_name)
2604
- validator.check_int(len(atom_mass_shape), 1, Rel.EQ, "atom_mass", cls_name)
2605
- validator.check_int(atom_mass_shape[0], m, Rel.EQ, "atom_mass", cls_name)
2606
- return [n,]
2607
-
2608
- def infer_dtype(self, start_dtype, end_dtype, atom_vel_f_dtype, atom_mass_dtype):
2609
- validator.check_tensor_dtype_valid('start', start_dtype, [mstype.int32], self.name)
2610
- validator.check_tensor_dtype_valid('end', end_dtype, [mstype.int32], self.name)
2611
- validator.check_tensor_dtype_valid('atom_vel_f', atom_vel_f_dtype, [mstype.float32], self.name)
2612
- validator.check_tensor_dtype_valid('atom_mass', atom_mass_dtype, [mstype.float32], self.name)
2613
- return atom_mass_dtype
2614
-
2615
-
2616
- class MDIterationLeapFrogWithRF(PrimitiveWithInfer):
2617
- """
2618
- One step of classical leap frog algorithm to solve the finite difference
2619
- Hamiltonian equations of motion for certain system, using Langevin dynamics
2620
- with Liu's thermostat scheme. Assume the number of atoms is n and the target
2621
- control temperature is T.
2622
-
2623
- Detailed iteration formula can be found in this paper: A unified thermostat
2624
- scheme for efficient configurational sampling for classical/quantum canonical
2625
- ensembles via molecular dynamics. DOI: 10.1063/1.4991621.
2626
-
2627
- Because there is a large amount of inputs and each of them are related,
2628
- there is no way to construct `Examples` using random methods. For details, refer the webpage `SPONGE in MindSpore
2629
- <https://gitee.com/mindspore/mindscience/blob/master/MindSPONGE/docs/simple_formula.md>`_.
2630
-
2631
- Inputs:
2632
- - **float4_numbers** (Scalar) - total length to store random numbers.
2633
- The data type is int32.
2634
- - **atom_numbers** (Scalar) - The number of atoms n.
2635
- The data type is int32.
2636
- - **dt** (Scalar) - time step for finite difference. The data type is float32.
2637
- - **half_dt** (Scalar) - half of time step for finite difference.
2638
- The data type is float32.
2639
- - **exp_gamma** (Scalar) - parameter in Liu's dynamic, equals
2640
- exp(-gamma_ln * dt), where gamma_ln is the firction factor in Langvin
2641
- dynamics. The data type is float32.
2642
- - **max_velocity** (Scalar) - The upper limit of velocity, when the
2643
- velocity overflows, scale it to the upper limit. The data type is float32.
2644
- - **is_max_velocity** (Scalar) - whether the max velocity control is
2645
- open or not. The data type is int32.
2646
- - **mass_inverse** (Tensor) - The inverse value of
2647
- mass of each atom. The data type is float32 and the shape is :math:`(n,)`.
2648
- - **sqrt_mass** (Tensor) - The inverse square root value
2649
- of effect mass in Liu's dynamics of each atom.
2650
- The data type is float32 and the shape is :math:`(n,)`.
2651
- - **vel** (Tensor) - The velocity of each atom.
2652
- The data type is float32 and the shape is :math:`(n, 3)`.
2653
- - **crd** (Tensor) - The coordinate of each atom.
2654
- The data type is float32 and the shape is :math:`(n, 3)`.
2655
- - **frc** (Tensor) - The force felt by each atom.
2656
- The data type is float32 and the shape is :math:`(n, 3)`.
2657
- - **acc** (Tensor) - The acceleration of each atom.
2658
- The data type is float32 and the shape is :math:`(n, 3)`.
2659
- - **random force** (Tensor) - The random forces.
2660
- The data type is float32 and the shape is :math:`(n, 3)`.
2661
-
2662
- Outputs:
2663
- - **res** (Scalar) - The data type is float32.
2664
-
2665
- Supported Platforms:
2666
- ``GPU``
2667
- Examples:
2668
- """
2669
-
2670
- @prim_attr_register
2671
- def __init__(self, float4_numbers, atom_numbers, half_dt, dt, exp_gamma, is_max_velocity, max_velocity):
2672
- """Initialize MDIterationLeapFrogWithRF."""
2673
- validator.check_value_type('float4_numbers', float4_numbers, int, self.name)
2674
- validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
2675
- validator.check_value_type('half_dt', half_dt, float, self.name)
2676
- validator.check_value_type('dt', dt, float, self.name)
2677
- validator.check_value_type('exp_gamma', exp_gamma, float, self.name)
2678
- validator.check_value_type('is_max_velocity', is_max_velocity, int, self.name)
2679
- validator.check_value_type('max_velocity', max_velocity, float, self.name)
2680
- self.float4_numbers = float4_numbers
2681
- self.atom_numbers = atom_numbers
2682
- self.half_dt = half_dt
2683
- self.dt = dt
2684
- self.exp_gamma = exp_gamma
2685
- self.is_max_velocity = is_max_velocity
2686
- self.max_velocity = max_velocity
2687
-
2688
- self.init_prim_io_names(
2689
- inputs=['mass_inverse', 'sqrt_mass', 'vel_in', 'crd_in', 'frc_in', 'acc_in', 'random_force'],
2690
- outputs=['res'])
2691
- self.add_prim_attr('float4_numbers', self.float4_numbers)
2692
- self.add_prim_attr('atom_numbers', self.atom_numbers)
2693
- self.add_prim_attr('half_dt', self.half_dt)
2694
- self.add_prim_attr('dt', self.dt)
2695
- self.add_prim_attr('exp_gamma', self.exp_gamma)
2696
- self.add_prim_attr('is_max_velocity', self.is_max_velocity)
2697
- self.add_prim_attr('max_velocity', self.max_velocity)
2698
-
2699
- def infer_shape(self, mass_inverse_shape, sqrt_mass_shape, vel_in_shape, crd_in_shape, frc_in_shape, acc_in_shape,
2700
- random_force_shape):
2701
- n = self.atom_numbers
2702
- validator.check_int(len(mass_inverse_shape), 1, Rel.EQ, "mass_inverse", self.name)
2703
- validator.check_int(len(sqrt_mass_shape), 1, Rel.EQ, "mass_inverse", self.name)
2704
- validator.check_int(mass_inverse_shape[0], n, Rel.EQ, "mass_inverse", self.name)
2705
- validator.check_int(sqrt_mass_shape[0], n, Rel.EQ, "mass_inverse", self.name)
2706
- validator.check_int(vel_in_shape[0], n, Rel.EQ, "vel_in", self.name)
2707
- validator.check_int(vel_in_shape[1], 3, Rel.EQ, "vel_in", self.name)
2708
- validator.check_int(crd_in_shape[0], n, Rel.EQ, "crd_in", self.name)
2709
- validator.check_int(crd_in_shape[1], 3, Rel.EQ, "crd_in", self.name)
2710
- validator.check_int(frc_in_shape[0], n, Rel.EQ, "frc_in", self.name)
2711
- validator.check_int(frc_in_shape[1], 3, Rel.EQ, "frc_in", self.name)
2712
- validator.check_int(acc_in_shape[0], n, Rel.EQ, "acc_in", self.name)
2713
- validator.check_int(acc_in_shape[1], 3, Rel.EQ, "acc_in", self.name)
2714
- validator.check_int(random_force_shape[0], n, Rel.EQ, "random_force", self.name)
2715
- validator.check_int(random_force_shape[1], 3, Rel.EQ, "random_force", self.name)
2716
-
2717
- return [1,]
2718
-
2719
- def infer_dtype(self, mass_inverse_dtype, sqrt_mass_dtype, vel_in_dtype, crd_in_dtype, frc_in_dtype, acc_in_dtype,
2720
- rf_dtype):
2721
- validator.check_tensor_dtype_valid('mass_inverse', mass_inverse_dtype, [mstype.float32], self.name)
2722
- validator.check_tensor_dtype_valid('sqrt_mass', sqrt_mass_dtype, [mstype.float32], self.name)
2723
- validator.check_tensor_dtype_valid('vel_in', vel_in_dtype, [mstype.float32], self.name)
2724
- validator.check_tensor_dtype_valid('crd_in', crd_in_dtype, [mstype.float32], self.name)
2725
- validator.check_tensor_dtype_valid('frc_in', frc_in_dtype, [mstype.float32], self.name)
2726
- validator.check_tensor_dtype_valid('acc_in', acc_in_dtype, [mstype.float32], self.name)
2727
- validator.check_tensor_dtype_valid('rf', rf_dtype, [mstype.float32], self.name)
2728
- return mstype.float32
2729
-
2730
-
2731
- class MDIterationLeapFrogLiujian(PrimitiveWithInfer):
2732
- """
2733
- One step of classical leap frog algorithm to solve the finite difference
2734
- Hamiltonian equations of motion for certain system, using Langevin dynamics
2735
- with Liu's thermostat scheme. Assume the number of atoms is n and the target
2736
- control temperature is T.
2737
-
2738
- Detailed iteration formula can be found in this paper: A unified thermostat
2739
- scheme for efficient configurational sampling for classical/quantum canonical
2740
- ensembles via molecular dynamics. DOI: 10.1063/1.4991621.
2741
-
2742
- Because there is a large amount of inputs and each of them are related,
2743
- there is no way to construct `Examples` using random methods. For details, refer the webpage `SPONGE in MindSpore
2744
- <https://gitee.com/mindspore/mindscience/blob/master/MindSPONGE/docs/simple_formula.md>`_.
2745
-
2746
- Args:
2747
- atom_numbers(int32): the number of atoms n.
2748
- dt(float32): time step for finite difference.
2749
- half_dt(float32): half of time step for finite difference.
2750
- exp_gamma(float32): parameter in Liu's dynamic.
2751
-
2752
- Inputs:
2753
- - **inverse_mass** (Tensor) - The inverse value of
2754
- mass of each atom. The data type is float32 and the shape is :math:`(n)`.
2755
- - **sqrt_mass_inverse** (Tensor) - The inverse square root value
2756
- of effect mass in Liu's dynamics of each atom.
2757
- The data type is float32 and the shape is :math:`(n,)`.
2758
- - **vel** (Tensor) - The velocity of each atom.
2759
- The data type is float32 and the shape is :math:`(n, 3)`.
2760
- - **crd** (Tensor) - The coordinate of each atom.
2761
- The data type is float32 and the shape is :math:`(n, 3)`.
2762
- - **frc** (Tensor) - The force felt by each atom.
2763
- The data type is float32 and the shape is :math:`(n, 3)`.
2764
- - **acc** (Tensor) - The acceleration of each atom.
2765
- The data type is float32 and the shape is :math:`(n, 3)`.
2766
- - **rand_state** (Tensor) - Random state to generate
2767
- random force. The data type is float32 and the shape is :math:`(math.ceil(n * 3.0 / 4.0) * 16, )`.
2768
- - **rand_frc** (Tensor) - The random forces.
2769
- The data type is float32 and the shape is :math:`(n, 3)`.
2770
-
2771
- Outputs:
2772
- - **output** (Tensor) - The output coordinates.
2773
- The data type is float32, and the shape is :math:`(n, 3)`.
2774
-
2775
- Supported Platforms:
2776
- ``GPU``
2777
- """
2778
-
2779
- @prim_attr_register
2780
- def __init__(self, atom_numbers, half_dt, dt, exp_gamma):
2781
- """Initialize MDIterationLeapFrogLiujian."""
2782
- self.atom_numbers = atom_numbers
2783
- self.half_dt = half_dt
2784
- self.dt = dt
2785
- self.exp_gamma = exp_gamma
2786
-
2787
- self.add_prim_attr('atom_numbers', self.atom_numbers)
2788
- self.add_prim_attr('half_dt', self.half_dt)
2789
- self.add_prim_attr('dt', self.dt)
2790
- self.add_prim_attr('exp_gamma', self.exp_gamma)
2791
- self.init_prim_io_names(
2792
- inputs=['inverse_mass', 'sqrt_mass_inverse', 'vel', 'crd', 'frc', 'acc', 'rand_state', 'rand_frc'],
2793
- outputs=['output'])
2794
- self.add_prim_attr('side_effect_mem', True)
2795
-
2796
- def infer_shape(self, inverse_mass, sqrt_mass_inverse, vel, crd, frc, acc, rand_state, rand_frc):
2797
- n = self.atom_numbers
2798
- validator.check_int(len(inverse_mass), 1, Rel.EQ, "inverse_mass", self.name)
2799
- validator.check_int(len(sqrt_mass_inverse), 1, Rel.EQ, "sqrt_mass_inverse", self.name)
2800
- validator.check_int(inverse_mass[0], n, Rel.EQ, "inverse_mass", self.name)
2801
- validator.check_int(sqrt_mass_inverse[0], n, Rel.EQ, "sqrt_mass_inverse", self.name)
2802
- validator.check_int(len(rand_state), 1, Rel.EQ, "rand_state_dim", self.name)
2803
- validator.check_int(len(rand_frc), 2, Rel.EQ, "rand_frc_dim", self.name)
2804
- validator.check_int(len(vel), 2, Rel.EQ, "vel_dim", self.name)
2805
- validator.check_int(len(crd), 2, Rel.EQ, "crd_dim", self.name)
2806
- validator.check_int(len(frc), 2, Rel.EQ, "frc_dim", self.name)
2807
- validator.check_int(len(acc), 2, Rel.EQ, "acc_dim", self.name)
2808
- validator.check_int(vel[0], n, Rel.EQ, "vel_shape[0]", self.name)
2809
- validator.check_int(vel[1], 3, Rel.EQ, "vel_shape[1]", self.name)
2810
- validator.check_int(crd[0], n, Rel.EQ, "crd_shape[0]", self.name)
2811
- validator.check_int(crd[1], 3, Rel.EQ, "crd_shape[1]", self.name)
2812
- validator.check_int(frc[0], n, Rel.EQ, "frc_shape[0]", self.name)
2813
- validator.check_int(frc[1], 3, Rel.EQ, "frc_shape[1]", self.name)
2814
- validator.check_int(acc[0], n, Rel.EQ, "acc_shape[0]", self.name)
2815
- validator.check_int(acc[1], 3, Rel.EQ, "acc_shape[1]", self.name)
2816
- validator.check_int(rand_frc[0], n, Rel.EQ, "rand_frc_shape[0]", self.name)
2817
- validator.check_int(rand_frc[1], 3, Rel.EQ, "rand_frc_shape[1]", self.name)
2818
- validator.check_int(rand_state[0], math.ceil(self.atom_numbers * 3 / 4.0) * 16, Rel.EQ, "rand_state", self.name)
2819
- return [self.atom_numbers, 3]
2820
-
2821
- def infer_dtype(self, inverse_mass, sqrt_mass_inverse, vel, crd, frc, acc, rand_state, rand_frc):
2822
- validator.check_tensor_dtype_valid('inverse_mass', inverse_mass, [mstype.float32], self.name)
2823
- validator.check_tensor_dtype_valid('sqrt_mass_inverse', sqrt_mass_inverse, [mstype.float32], self.name)
2824
- validator.check_tensor_dtype_valid('vel', vel, [mstype.float32], self.name)
2825
- validator.check_tensor_dtype_valid('crd', crd, [mstype.float32], self.name)
2826
- validator.check_tensor_dtype_valid('frc', frc, [mstype.float32], self.name)
2827
- validator.check_tensor_dtype_valid('acc', acc, [mstype.float32], self.name)
2828
- validator.check_tensor_dtype_valid('rand_frc', rand_frc, [mstype.float32], self.name)
2829
- validator.check_tensor_dtype_valid('rand_state', rand_state, [mstype.float32], self.name)
2830
- return mstype.float32
2831
-
2832
-
2833
- class CrdToUintCrd(PrimitiveWithInfer):
2834
- """
2835
- Convert FP32 coordinate to Uint32 coordinate.
2836
-
2837
- Because there is a large amount of inputs and each of them are related,
2838
- there is no way to construct `Examples` using random methods. For details, refer the webpage `SPONGE in MindSpore
2839
- <https://gitee.com/mindspore/mindscience/blob/master/MindSPONGE/docs/simple_formula.md>`_.
2840
-
2841
- Args:
2842
- atom_numbers(int32): the number of atoms n.
2843
-
2844
- Inputs:
2845
- - **crd_to_uint_crd_cof** (Tensor) - The scale factor
2846
- between the unsigned int value and the real space coordinates.
2847
- The data type is float32 and the shape is :math:`(3,)`.
2848
- - **crd** (Tensor) - The coordinate of each atom.
2849
- The data type is float32 and the shape is :math:`(n, 3)`.
2850
-
2851
- Outputs:
2852
- - **output** (Scalar) - The data type is uint32.
2853
-
2854
- Supported Platforms:
2855
- ``GPU``
2856
- """
2857
-
2858
- @prim_attr_register
2859
- def __init__(self, atom_numbers):
2860
- """Initialize CrdToUintCrd."""
2861
- self.atom_numbers = atom_numbers
2862
- self.add_prim_attr('atom_numbers', self.atom_numbers)
2863
- self.init_prim_io_names(
2864
- inputs=['crd_to_uint_crd_cof', 'crd'],
2865
- outputs=['output'])
2866
-
2867
- def infer_shape(self, crd_to_uint_crd_cof, crd):
2868
- validator.check_int(crd_to_uint_crd_cof[0], 3, Rel.EQ, "crd_to_uint_crd_cof_shape", self.name)
2869
- validator.check_int(crd[0], self.atom_numbers, Rel.EQ, "crd[0]", self.name)
2870
- validator.check_int(crd[1], 3, Rel.EQ, "crd[1]", self.name)
2871
- return crd
2872
-
2873
- def infer_dtype(self, crd_to_uint_crd_cof, crd):
2874
- validator.check_tensor_dtype_valid('crd_to_uint_crd_cof', crd_to_uint_crd_cof, [mstype.float32], self.name)
2875
- validator.check_tensor_dtype_valid('crd', crd, [mstype.float32], self.name)
2876
- return mstype.uint32
2877
-
2878
-
2879
- class MDIterationSetupRandState(PrimitiveWithInfer):
2880
- """
2881
- Compute the random state of the iteration.
2882
-
2883
- Because there is a large amount of inputs and each of them are related,
2884
- there is no way to construct `Examples` using random methods. For details, refer the webpage `SPONGE in MindSpore
2885
- <https://gitee.com/mindspore/mindscience/blob/master/MindSPONGE/docs/simple_formula.md>`_.
2886
-
2887
- Args:
2888
- atom_numbers(int32): the number of atoms n.
2889
- seed(int32): random seed.
2890
-
2891
- Outputs:
2892
- - **output** (Tensor) random state.
2893
- The data type is float32 and the shape is :math:`(ceil(n * 3 / 4),)`.
2894
-
2895
- Supported Platforms:
2896
- ``GPU``
2897
- """
2898
-
2899
- @prim_attr_register
2900
- def __init__(self, atom_numbers, seed):
2901
- """Initialize MDIterationSetupRandState."""
2902
- self.atom_numbers = atom_numbers
2903
- self.seed = seed
2904
- self.add_prim_attr('atom_numbers', self.atom_numbers)
2905
- self.add_prim_attr('seed', self.seed)
2906
- self.init_prim_io_names(
2907
- inputs=[],
2908
- outputs=['output'])
2909
-
2910
- def infer_shape(self):
2911
- float4_numbers = math.ceil(self.atom_numbers * 3 / 4.0)
2912
- curandsize = 64 / 4
2913
- return [float4_numbers * int(curandsize),]
2914
-
2915
- def infer_dtype(self):
2916
- return mstype.float32
2917
-
2918
-
2919
- class TransferCrd(PrimitiveWithInfer):
2920
- """
2921
- Transfer the coordinates to angular and radial.
2922
-
2923
- Because there is a large amount of inputs and each of them are related,
2924
- there is no way to construct `Examples` using random methods. For details, refer the webpage `SPONGE in MindSpore
2925
- <https://gitee.com/mindspore/mindscience/blob/master/MindSPONGE/docs/simple_formula.md>`_.
2926
-
2927
- Args:
2928
- start_serial(int32): the index start position.
2929
- end_serial(int32): the index end position.
2930
- number(int32): the length of angular and radial.
2931
-
2932
- Inputs:
2933
- - **crd** (Tensor) - The coordinate of each atom.
2934
- n is the number of atoms. The data type is float32 and the shape is :math:`(n, 3)`.
2935
- - **old_crd** (Tensor) - The last coordinate of each atom.
2936
- n is the number of atoms. The data type is float32 and the shape is :math:`(n, 3)`.
2937
- - **box** (Tensor) - The length of 3 dimensions of the simulation box.
2938
- The data type is float32 and the shape is :math:`(3,)`.
2939
-
2940
- Outputs:
2941
- - **radial** (Tensor) - The array of radial transferred from coordinates.
2942
- The data type is float32 and the shape is :math:`(number,)`.
2943
- - **angular** (Tensor) - The array of angular transferred from coordinates.
2944
- The data type is float32 and the shape is :math:`(number,)`.
2945
- - **nowarp_crd** (Tensor) - The modified coordinate of each atom for
2946
- computing radial and angular. The data type is float32 and the shape is :math:`(n, 3)`.
2947
- - **box_map_times** (Tensor) - The box map times for radial and angular.
2948
- The data type is int32 and the shape is :math:`(n, 3)`.
2949
-
2950
- Supported Platforms:
2951
- ``GPU``
2952
- """
2953
-
2954
- @prim_attr_register
2955
- def __init__(self, start_serial, end_serial, number, atom_numbers):
2956
- """Initialize TransferCrd."""
2957
- validator.check_value_type('start_serial', start_serial, int, self.name)
2958
- validator.check_value_type('end_serial', end_serial, int, self.name)
2959
- validator.check_value_type('number', number, int, self.name)
2960
- validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
2961
- self.start_serial = start_serial
2962
- self.end_serial = end_serial
2963
- self.number = number
2964
- self.atom_numbers = atom_numbers
2965
- self.add_prim_attr('start_serial', self.start_serial)
2966
- self.add_prim_attr('end_serial', self.end_serial)
2967
- self.add_prim_attr('number', self.number)
2968
- self.add_prim_attr('atom_numbers', self.atom_numbers)
2969
- self.init_prim_io_names(
2970
- inputs=['crd', 'old_crd', 'box'],
2971
- outputs=['radial', 'angular', 'nowarp_crd', 'box_map_times'])
2972
-
2973
- def infer_shape(self, crd_shape, old_crd_shape, box_shape):
2974
- n = self.atom_numbers
2975
- validator.check_int(len(crd_shape), 2, Rel.EQ, "crd_dim", self.name)
2976
- validator.check_int(crd_shape[0], n, Rel.EQ, "crd_shape[0]", self.name)
2977
- validator.check_int(crd_shape[1], 3, Rel.EQ, "crd_shape[1]", self.name)
2978
- validator.check_int(len(old_crd_shape), 2, Rel.EQ, "old_crd_dim", self.name)
2979
- validator.check_int(old_crd_shape[0], n, Rel.EQ, "old_crd_shape[0]", self.name)
2980
- validator.check_int(old_crd_shape[1], 3, Rel.EQ, "old_crd_shape[1]", self.name)
2981
- validator.check_int(len(box_shape), 1, Rel.EQ, "box_dim", self.name)
2982
- validator.check_int(box_shape[0], 3, Rel.EQ, "box_shape[0]", self.name)
2983
- return [self.number,], [self.number,], [n, 3], [n, 3]
2984
-
2985
- def infer_dtype(self, crd_dtype, old_crd_dtype, box_dtype):
2986
- validator.check_tensor_dtype_valid('crd', crd_dtype, [mstype.float32], self.name)
2987
- validator.check_tensor_dtype_valid('old_crd', old_crd_dtype, [mstype.float32], self.name)
2988
- validator.check_tensor_dtype_valid('box', box_dtype, [mstype.float32], self.name)
2989
- return mstype.float32, mstype.float32, mstype.float32, mstype.int32
2990
-
2991
-
2992
- class PMEBatchedFFT2D(PrimitiveWithInfer):
2993
- """
2994
- Forward FFT with N-Dimensional Input. currently this is only batched ifft2.
2995
-
2996
- .. warning::
2997
- This is an experimental prototype that is subject to change and/or deletion.
2998
-
2999
- Inputs:
3000
- - **input_tensor** (Tensor) - Three dimentsional tensor, supported
3001
- data type is complex64.
3002
-
3003
- Outputs:
3004
- - **output_tensor** (Tensor) - The tensor after undergoing fast Fourier
3005
- transform, the data type is complex64.
3006
-
3007
- Supported Platforms:
3008
- ``GPU``
3009
- """
3010
-
3011
- @prim_attr_register
3012
- def __init__(self, direction):
3013
- self.init_prim_io_names(
3014
- inputs=['input_tensor'],
3015
- outputs=['output_tensor'])
3016
-
3017
- validator.check_value_type('direction', direction, int, self.name)
3018
- self.direction = direction
3019
- self.add_prim_attr('direction', self.direction)
3020
-
3021
- def infer_shape(self, input_shape):
3022
- self.add_prim_attr("fftx", input_shape[0])
3023
- self.add_prim_attr("ffty", input_shape[1])
3024
- self.add_prim_attr("fftz", input_shape[2])
3025
- return [input_shape[0], input_shape[1], input_shape[2]]
3026
-
3027
- def infer_dtype(self, input_dtype):
3028
- validator.check_tensor_dtype_valid('input_tensor', input_dtype, [mstype.complex64], self.name)
3029
- return mstype.complex64
3030
-
3031
-
3032
- class PMEFFT1D(PrimitiveWithInfer):
3033
- """
3034
- Forward FFT with One-Dimensional Input.
3035
-
3036
- .. warning::
3037
- This is an experimental prototype that is subject to change and/or deletion.
3038
-
3039
- Inputs:
3040
- - **input_tensor** (Tensor) - Three dimentsional tensor, supported
3041
- data type is complex64.
3042
-
3043
- Outputs:
3044
- - **output_tensor** (Tensor) - The tensor after undergoing fast Fourier
3045
- transform, the data type is complex64.
3046
-
3047
- Supported Platforms:
3048
- ``GPU``
3049
- """
3050
-
3051
- @prim_attr_register
3052
- def __init__(self):
3053
- self.init_prim_io_names(
3054
- inputs=['input_tensor'],
3055
- outputs=['output_tensor'])
3056
-
3057
- def infer_shape(self, input_shape):
3058
- self.add_prim_attr('fftx', input_shape[0])
3059
- return [input_shape[0]]
3060
-
3061
- def infer_dtype(self, input_dtype):
3062
- validator.check_tensor_dtype_valid('input_tensor', input_dtype, [mstype.complex64], self.name)
3063
- return mstype.complex64
3064
-
3065
-
3066
- class PMEIFFT1D(PrimitiveWithInfer):
3067
- """
3068
- Inverse FFT with One-Dimensional Input.
3069
-
3070
- .. warning::
3071
- This is an experimental prototype that is subject to change and/or deletion.
3072
-
3073
- Inputs:
3074
- - **input_tensor** (Tensor) - Three dimentsional input tensor, supported data
3075
- type is complex64.
3076
-
3077
- Outputs:
3078
- - **output_tensor** (Tensor) - Returns the tensor after undergoing
3079
- inverse Fourier transform, the data type is complex64.
3080
-
3081
- Supported Platforms:
3082
- ``GPU``
3083
- """
3084
-
3085
- @prim_attr_register
3086
- def __init__(self):
3087
- self.init_prim_io_names(
3088
- inputs=['input_tensor'],
3089
- outputs=['output_tensor'])
3090
-
3091
- def infer_shape(self, input_shape):
3092
- self.add_prim_attr('fftx', input_shape[0])
3093
- return [input_shape[0]]
3094
-
3095
- def infer_dtype(self, input_dtype):
3096
- validator.check_tensor_dtype_valid('input_tensor', input_dtype, [mstype.complex64], self.name)
3097
- return mstype.complex64
3098
-
3099
-
3100
- class PMEFFT2D(PrimitiveWithInfer):
3101
- """
3102
- Forward FFT with Two-Dimensional Input.
3103
-
3104
- .. warning::
3105
- This is an experimental prototype that is subject to change and/or deletion.
3106
-
3107
- Inputs:
3108
- - **input_tensor** (Tensor) - Three dimentsional tensor, supported
3109
- data type is complex64.
3110
-
3111
- Outputs:
3112
- - **output_tensor** (Tensor) - The tensor after undergoing fast Fourier
3113
- transform, the data type is complex64.
3114
-
3115
- Supported Platforms:
3116
- ``GPU``
3117
- """
3118
-
3119
- @prim_attr_register
3120
- def __init__(self):
3121
- self.init_prim_io_names(
3122
- inputs=['input_tensor'],
3123
- outputs=['output_tensor'])
3124
-
3125
- def infer_shape(self, input_shape):
3126
- self.add_prim_attr('fftx', input_shape[0])
3127
- self.add_prim_attr('ffty', input_shape[1])
3128
- return [input_shape[0], input_shape[1]]
3129
-
3130
- def infer_dtype(self, input_dtype):
3131
- validator.check_tensor_dtype_valid('input_tensor', input_dtype, [mstype.complex64], self.name)
3132
- return mstype.complex64
3133
-
3134
-
3135
- class PMEIFFT2D(PrimitiveWithInfer):
3136
- """
3137
- Inverse FFT with Two-Dimensional Input.
3138
-
3139
- .. warning::
3140
- This is an experimental prototype that is subject to change and/or deletion.
3141
-
3142
- Inputs:
3143
- - **input_tensor** (Tensor) - Three dimentsional input tensor, supported data
3144
- type is complex64.
3145
-
3146
- Outputs:
3147
- - **output_tensor** (Tensor) - Return the tensor after undergoing
3148
- inverse Fourier transform, the data type is complex64.
3149
-
3150
- Supported Platforms:
3151
- ``GPU``
3152
- """
3153
-
3154
- @prim_attr_register
3155
- def __init__(self):
3156
- self.init_prim_io_names(
3157
- inputs=['input_tensor'],
3158
- outputs=['output_tensor'])
3159
-
3160
- def infer_shape(self, input_shape):
3161
- self.add_prim_attr('fftx', input_shape[0])
3162
- self.add_prim_attr('ffty', input_shape[1])
3163
- return [input_shape[0], input_shape[1]]
3164
-
3165
- def infer_dtype(self, input_dtype):
3166
- validator.check_tensor_dtype_valid('input_tensor', input_dtype, [mstype.complex64], self.name)
3167
- return mstype.complex64
3168
-
3169
-
3170
- class PMERFFT2D(PrimitiveWithInfer):
3171
- """
3172
- Forward FFT with Two-Dimensional Input for real -> complex.
3173
-
3174
- .. warning::
3175
- This is an experimental prototype that is subject to change and/or deletion.
3176
-
3177
- Inputs:
3178
- - **input_tensor** (Tensor) - Three dimentsional tensor, supported
3179
- data type is float32.
3180
-
3181
- Outputs:
3182
- - **output_tensor** (Tensor) - The tensor after undergoing fast Fourier
3183
- transform, the data type is complex64.
3184
-
3185
- Supported Platforms:
3186
- ``GPU``
3187
- """
3188
-
3189
- @prim_attr_register
3190
- def __init__(self):
3191
- self.init_prim_io_names(
3192
- inputs=['input_tensor'],
3193
- outputs=['output_tensor'])
3194
-
3195
- def infer_shape(self, input_shape):
3196
- self.add_prim_attr('fftx', input_shape[0])
3197
- self.add_prim_attr('ffty', input_shape[1])
3198
- return [input_shape[0], int(input_shape[1]/2)+1]
3199
-
3200
- def infer_dtype(self, input_dtype):
3201
- validator.check_tensor_dtype_valid('input_tensor', input_dtype, [mstype.float32], self.name)
3202
- return mstype.complex64
3203
-
3204
-
3205
- class PMEIRFFT2D(PrimitiveWithInfer):
3206
- """
3207
- Inverse RFFT with Two-Dimensional Input.
3208
-
3209
- .. warning::
3210
- This is an experimental prototype that is subject to change and/or deletion.
3211
-
3212
- Inputs:
3213
- - **input_tensor** (Tensor) - Three dimentsional input tensor, supported data
3214
- type is complex64.
3215
-
3216
- Outputs:
3217
- - **output_tensor** (Tensor) - Return the tensor after undergoing
3218
- inverse Fourier transform, the data type is float32.
3219
-
3220
- Supported Platforms:
3221
- ``GPU``
3222
- """
3223
-
3224
- @prim_attr_register
3225
- def __init__(self):
3226
- self.init_prim_io_names(
3227
- inputs=['input_tensor'],
3228
- outputs=['output_tensor'])
3229
-
3230
- def infer_shape(self, input_shape):
3231
- self.add_prim_attr('fftx', input_shape[0])
3232
- self.add_prim_attr('ffty', input_shape[1])
3233
- return [input_shape[0], 2*(input_shape[1]-1)]
3234
-
3235
- def infer_dtype(self, input_dtype):
3236
- validator.check_tensor_dtype_valid('input_tensor', input_dtype, [mstype.complex64], self.name)
3237
- return mstype.float32
3238
-
3239
-
3240
- class FFT3D(PrimitiveWithInfer):
3241
- """
3242
- Forward FFT with Three-Dimensional Input.
3243
-
3244
- .. warning::
3245
- This is an experimental prototype that is subject to change and/or deletion.
3246
-
3247
- Inputs:
3248
- - **input_tensor** (Tensor) - Three dimensional tensor, supported
3249
- data type is float32.
3250
-
3251
- Outputs:
3252
- - **output_tensor** (Tensor) - The tensor after undergoing fast Fourier
3253
- transform, the data type is complex64.
3254
-
3255
- Supported Platforms:
3256
- ``GPU``
3257
- """
3258
-
3259
- @prim_attr_register
3260
- def __init__(self):
3261
- self.init_prim_io_names(
3262
- inputs=['input_tensor'],
3263
- outputs=['output_tensor'])
3264
-
3265
- def infer_shape(self, input_shape):
3266
- self.add_prim_attr('fftx', input_shape[0])
3267
- self.add_prim_attr('ffty', input_shape[1])
3268
- self.add_prim_attr('fftz', input_shape[2])
3269
- return [input_shape[0], input_shape[1], int(input_shape[2]/2)+1]
3270
-
3271
- def infer_dtype(self, input_dtype):
3272
- validator.check_tensor_dtype_valid('input_tensor', input_dtype, [mstype.float32], self.name)
3273
- return mstype.complex64
3274
-
3275
-
3276
- class IFFT3D(PrimitiveWithInfer):
3277
- """
3278
- Inverse FFT with Three-Dimensional Input.
3279
-
3280
- .. warning::
3281
- This is an experimental prototype that is subject to change and/or deletion.
3282
-
3283
- Inputs:
3284
- - **input_tensor** (Tensor) - Three dimensional input tensor, supported data
3285
- type is complex64.
3286
-
3287
- Outputs:
3288
- - **output_tensor** (Tensor) - Returns the tensor after undergoing
3289
- inverse Fourier transform, the data type is float32.
3290
-
3291
- Supported Platforms:
3292
- ``GPU``
3293
- """
3294
-
3295
- @prim_attr_register
3296
- def __init__(self):
3297
- self.init_prim_io_names(
3298
- inputs=['input_tensor'],
3299
- outputs=['output_tensor'])
3300
-
3301
- def infer_shape(self, input_shape):
3302
- self.add_prim_attr('fftx', input_shape[0])
3303
- self.add_prim_attr('ffty', input_shape[1])
3304
- self.add_prim_attr('fftz', input_shape[2])
3305
- return [input_shape[0], input_shape[1], (input_shape[2]-1)*2]
3306
-
3307
- def infer_dtype(self, input_dtype):
3308
- validator.check_tensor_dtype_valid('input_tensor', input_dtype, [mstype.complex64], self.name)
3309
- return mstype.float32
3310
-
3311
-
3312
- class NeighborListUpdate(PrimitiveWithInfer):
3313
- """
3314
- Update (or construct if first time) the Verlet neighbor list for the
3315
- calculation of short-ranged force. Assume the number of atoms is N,
3316
- the number of grids divided is G, the maximum number of atoms in one
3317
- grid is m, the maximum number of atoms in single atom's neighbor list
3318
- is L, and the number of total atom in excluded list is E.
3319
-
3320
- Args:
3321
- grid_numbers (int32): the total number of grids divided G.
3322
- atom_numbers (int32): the number of atoms n.
3323
- not_first_time (int32): whether to construct the neighbor list first time or not.
3324
- nxy (int32): the total number of grids divided in xy plane.
3325
- excluded_atom_numbers (int32): the total atom numbers in the excluded list E.
3326
- cutoff_square (float32): the cutoff square distance for short-range force calculation.
3327
- half_skin_square (float32): the maximum square value of the distance atom allowed to move between two updates.
3328
- cutoff_with_skin (float32): cutoff + skin, indicates the radius of the neighbor list for each atom.
3329
- half_cutoff_with_skin (float32): cutoff_with_skin/2.
3330
- cutoff_with_skin_square (float32): the square value of cutoff_with_skin.
3331
- refresh_interval (int32): the number of iteration steps between two updates of neighbor list. Default: 20.
3332
- cutoff (float32): the cutoff distance for short-range force calculation. Default: 10.0.
3333
- skin (float32): the maximum value of the distance atom allowed to move. Default: 2.0.
3334
- max_atom_in_grid_numbers (int32): the maximum number of atoms in one grid k. Default: 64.
3335
- max_neighbor_numbers (int32): The maximum number of neighbors m. Default: 800.
3336
-
3337
- Inputs:
3338
- - **atom_numbers_in_grid_bucket** (Tensor) - The number of atoms in each grid bucket.
3339
- The data type is int32 and the shape is :math:`(G,)`.
3340
- - **bucket** (Tensor) - (Tensor) - The atom indices in each grid bucket.
3341
- The data type is int32 and the shape is :math:`(G, k)`.
3342
- - **crd** (Tensor) - The coordinates of each atom.
3343
- The data type is float32 and the shape is :math:`(n, 3)`.
3344
- - **box_length** (Tensor) - The box length of the simulation box.
3345
- The data type is float32 and the shape is :math:`(3,)`.
3346
- - **grid_N** (Tensor) - The number of grids divided of 3 dimensions of the simulation box.
3347
- The data type is int32 and the shape is :math:`(3,)`.
3348
- - **grid_length_inverse** (Tensor) - The inverse value of grid length.
3349
- The data type is float32 and the shape is :math:`(3,)`.
3350
- - **atom_in_grid_serial** (Tensor) - The grid index for each atom.
3351
- The data type is int32 and the shape is :math:`(n,)`.
3352
- - **old_crd** (Tensor) - The coordinates before update of each atom.
3353
- The data type is float32 and the shape is :math:`(n, 3)`.
3354
- - **crd_to_uint_crd_cof** (Tensor) - The scale factor between the unsigned int coordinate and the real one.
3355
- The data type is float32 and the shape is :math:`(3,)`.
3356
- - **uint_crd** (Tensor) - The unsigned int coordinates value fo each atom.
3357
- The data type is unsigned int32 and the shape is :math:`(n, 3)`.
3358
- - **gpointer** (Tensor) - The nearest neighbor grids (including self) of each grid.
3359
- The data type is int32 and the shape is :math:`(G, 125)`.
3360
- - **nl_atom_numbers** (Tensor) - The number of atoms in neighbor list of each atom.
3361
- The data type is int32 and the shape is :math:`(n,)`.
3362
- - **nl_atom_serial** (Tensor) - The indices of atoms in neighbor list of each atom.
3363
- The data type is int32 and the shape is :math:`(n, m)`.
3364
- - **uint_dr_to_dr_cof** (Tensor) - The scale factor.
3365
- The data type is float32 and the shape is :math:`(3,)`.
3366
- - **excluded_list_start** (Tensor) - The start excluded index in excluded list for each atom.
3367
- The data type is int32 and the shape is :math:`(n,)`.
3368
- - **excluded_list** (Tensor) - The contiguous join of excluded list of each atom.
3369
- The data type is int32 and the shape is :math:`(E,)`.
3370
- - **excluded_numbers** (Tensor) - The number of atom excluded in excluded list for each atom.
3371
- The data type is int32 and the shape is :math:`(n,)`.
3372
- - **need_refresh_flag** (Tensor) - Whether the neighbor list of each atom need update or not.
3373
- The data type is int32 and the shape is :math:`(1,)`.
3374
- - **refresh_count** (Tensor) - Count how many iteration steps have passed since last update.
3375
- The data type is int32 and the shape is :math:`(1,)` or :math:`()`.
3376
-
3377
- Outputs:
3378
- - **res** (Tensor) - The return value after updating successfully.
3379
- The data type is float32 and the shape is :math:`(1,)`.
3380
-
3381
- Supported Platforms:
3382
- ``GPU``
3383
- """
3384
-
3385
- @prim_attr_register
3386
- def __init__(self, grid_numbers, atom_numbers, not_first_time, nxy, excluded_atom_numbers,
3387
- cutoff_square, half_skin_square, cutoff_with_skin, half_cutoff_with_skin, cutoff_with_skin_square,
3388
- refresh_interval=20, cutoff=10.0, skin=2.0, max_atom_in_grid_numbers=64, max_neighbor_numbers=800):
3389
- self.grid_numbers = grid_numbers
3390
- self.atom_numbers = atom_numbers
3391
- self.refresh_interval = refresh_interval
3392
- self.not_first_time = not_first_time
3393
- self.cutoff = cutoff
3394
- self.skin = skin
3395
- self.max_atom_in_grid_numbers = max_atom_in_grid_numbers
3396
- self.nxy = nxy
3397
- self.excluded_atom_numbers = excluded_atom_numbers
3398
- self.cutoff_square = cutoff_square
3399
- self.half_skin_square = half_skin_square
3400
- self.cutoff_with_skin = cutoff_with_skin
3401
- self.half_cutoff_with_skin = half_cutoff_with_skin
3402
- self.cutoff_with_skin_square = cutoff_with_skin_square
3403
- self.max_neighbor_numbers = max_neighbor_numbers
3404
- validator.check_value_type('grid_numbers', grid_numbers, int, self.name)
3405
- validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
3406
- validator.check_value_type('refresh_interval', refresh_interval, int, self.name)
3407
- validator.check_value_type('not_first_time', not_first_time, int, self.name)
3408
- validator.check_value_type('cutoff', cutoff, float, self.name)
3409
- validator.check_value_type('skin', skin, float, self.name)
3410
- validator.check_value_type('max_atom_in_grid_numbers', max_atom_in_grid_numbers, int, self.name)
3411
- validator.check_value_type('nxy', nxy, int, self.name)
3412
- validator.check_value_type('excluded_atom_numbers', excluded_atom_numbers, int, self.name)
3413
- validator.check_value_type('cutoff_square', cutoff_square, float, self.name)
3414
- validator.check_value_type('half_skin_square', half_skin_square, float, self.name)
3415
- validator.check_value_type('cutoff_with_skin', cutoff_with_skin, float, self.name)
3416
- validator.check_value_type('half_cutoff_with_skin', half_cutoff_with_skin, float, self.name)
3417
- validator.check_value_type('cutoff_with_skin_square', cutoff_with_skin_square, float, self.name)
3418
- validator.check_value_type('max_neighbor_numbers', max_neighbor_numbers, int, self.name)
3419
- self.init_prim_io_names(
3420
- inputs=['atom_numbers_in_grid_bucket', 'bucket', 'crd', 'box_length', 'grid_N', 'grid_length_inverse',
3421
- 'atom_in_grid_serial', 'old_crd', 'crd_to_uint_crd_cof', 'uint_crd', 'gpointer', 'nl_atom_numbers',
3422
- 'nl_atom_serial', 'uint_dr_to_dr_cof', 'excluded_list_start', 'excluded_list', 'excluded_numbers',
3423
- 'need_refresh_flag', 'refresh_count'], outputs=['res'])
3424
-
3425
- self.add_prim_attr('grid_numbers', self.grid_numbers)
3426
- self.add_prim_attr('atom_numbers', self.atom_numbers)
3427
- self.add_prim_attr('refresh_interval', self.refresh_interval)
3428
- self.add_prim_attr('not_first_time', self.not_first_time)
3429
- self.add_prim_attr('cutoff', self.cutoff)
3430
- self.add_prim_attr('skin', self.skin)
3431
- self.add_prim_attr('max_atom_in_grid_numbers', self.max_atom_in_grid_numbers)
3432
- self.add_prim_attr('nxy', self.nxy)
3433
- self.add_prim_attr('excluded_atom_numbers', self.excluded_atom_numbers)
3434
- self.add_prim_attr('cutoff_square', self.cutoff_square)
3435
- self.add_prim_attr('half_skin_square', self.half_skin_square)
3436
- self.add_prim_attr('cutoff_with_skin', self.cutoff_with_skin)
3437
- self.add_prim_attr('half_cutoff_with_skin', self.half_cutoff_with_skin)
3438
- self.add_prim_attr('cutoff_with_skin_square', self.cutoff_with_skin_square)
3439
- self.add_prim_attr('side_effect_mem', True)
3440
-
3441
- def infer_shape(self, atom_numbers_in_grid_bucket_shape, bucket_shape, crd_shape, box_length_shape, grid_n_shape,
3442
- grid_length_inverse_shape, atom_in_grid_serial_shape, old_crd_shape, crd_to_uint_crd_cof_shape,
3443
- uint_crd_shape, gpointer_shape, nl_atom_numbers_shape, nl_atom_serial_shape,
3444
- uint_dr_to_dr_cof_shape, excluded_list_start_shape, excluded_list_shape, excluded_numbers_shape,
3445
- need_refresh_flag_shape, refresh_count_shape):
3446
- validator.check_int(len(atom_numbers_in_grid_bucket_shape), 1, Rel.EQ,
3447
- "atom_numbers_in_grid_bucket_dim", self.name)
3448
- validator.check_int(len(bucket_shape), 2, Rel.EQ, "bucket_dim", self.name)
3449
- validator.check_int(len(crd_shape), 2, Rel.EQ, "crd_dim", self.name)
3450
- validator.check_int(len(box_length_shape), 1, Rel.EQ, "box_length_dim", self.name)
3451
- validator.check_int(len(grid_n_shape), 1, Rel.EQ, "grid_n_dim", self.name)
3452
- validator.check_int(len(grid_length_inverse_shape), 1, Rel.EQ, "grid_length_inverse_dim", self.name)
3453
- validator.check_int(len(atom_in_grid_serial_shape), 1, Rel.EQ, "atom_in_grid_serial_dim", self.name)
3454
- validator.check_int(len(old_crd_shape), 2, Rel.EQ, "old_crd_dim", self.name)
3455
- validator.check_int(len(crd_to_uint_crd_cof_shape), 1, Rel.EQ, "crd_to_uint_crd_cof_dim", self.name)
3456
- validator.check_int(len(uint_crd_shape), 2, Rel.EQ, "uint_crd_dim", self.name)
3457
- validator.check_int(len(gpointer_shape), 2, Rel.EQ, "gpointer_dim", self.name)
3458
- validator.check_int(len(nl_atom_numbers_shape), 1, Rel.EQ, "nl_atom_numbers_dim", self.name)
3459
- validator.check_int(len(nl_atom_serial_shape), 2, Rel.EQ, "nl_atom_serial_dim", self.name)
3460
- validator.check_int(len(uint_dr_to_dr_cof_shape), 1, Rel.EQ, "uint_dr_to_dr_cof_dim", self.name)
3461
- validator.check_int(len(excluded_list_start_shape), 1, Rel.EQ, "excluded_list_start_dim", self.name)
3462
- validator.check_int(len(excluded_list_shape), 1, Rel.EQ, "excluded_list_dim", self.name)
3463
- validator.check_int(len(excluded_numbers_shape), 1, Rel.EQ, "excluded_numbers_dim", self.name)
3464
- validator.check_int(len(need_refresh_flag_shape), 1, Rel.EQ, "need_refresh_flag_dim", self.name)
3465
- validator.check_int(len(refresh_count_shape), 1, Rel.LE, "refresh_count_dim", self.name)
3466
- validator.check_int(atom_numbers_in_grid_bucket_shape[0], self.grid_numbers, Rel.EQ,
3467
- "atom_numbers_in_grid_bucket", self.name)
3468
- validator.check_int(bucket_shape[0], self.grid_numbers, Rel.EQ, "bucket", self.name)
3469
- validator.check_int(bucket_shape[1], self.max_atom_in_grid_numbers, Rel.EQ, "bucket", self.name)
3470
- validator.check_int(crd_shape[0], self.atom_numbers, Rel.EQ, "crd", self.name)
3471
- validator.check_int(crd_shape[1], 3, Rel.EQ, "crd", self.name)
3472
- validator.check_int(box_length_shape[0], 3, Rel.EQ, "box_length", self.name)
3473
- validator.check_int(grid_n_shape[0], 3, Rel.EQ, "grid_N", self.name)
3474
- validator.check_int(grid_length_inverse_shape[0], 3, Rel.EQ, "grid_length_inverse", self.name)
3475
- validator.check_int(atom_in_grid_serial_shape[0], self.atom_numbers, Rel.EQ, "atom_in_grid_serial",
3476
- self.name)
3477
- validator.check_int(old_crd_shape[0], self.atom_numbers, Rel.EQ, "old_crd", self.name)
3478
- validator.check_int(old_crd_shape[1], 3, Rel.EQ, "old_crd", self.name)
3479
- validator.check_int(crd_to_uint_crd_cof_shape[0], 3, Rel.EQ, "crd_to_uint_crd_cof", self.name)
3480
- validator.check_int(uint_crd_shape[0], self.atom_numbers, Rel.EQ, "uint_crd", self.name)
3481
- validator.check_int(uint_crd_shape[1], 3, Rel.EQ, "uint_crd", self.name)
3482
- validator.check_int(gpointer_shape[0], self.grid_numbers, Rel.EQ, "gpointer", self.name)
3483
- validator.check_int(gpointer_shape[1], 125, Rel.EQ, "gpointer", self.name)
3484
- validator.check_int(nl_atom_numbers_shape[0], self.atom_numbers, Rel.EQ, "nl_atom_numbers", self.name)
3485
- validator.check_int(nl_atom_serial_shape[0], self.atom_numbers, Rel.EQ, "nl_atom_serial", self.name)
3486
- validator.check_int(nl_atom_serial_shape[1], self.max_neighbor_numbers, Rel.EQ, "nl_atom_serial",
3487
- self.name)
3488
- validator.check_int(uint_dr_to_dr_cof_shape[0], 3, Rel.EQ, "uint_dr_to_dr_cof", self.name)
3489
- validator.check_int(excluded_list_start_shape[0], self.atom_numbers, Rel.EQ, "excluded_list_start",
3490
- self.name)
3491
- validator.check_int(excluded_list_shape[0], self.excluded_atom_numbers, Rel.EQ, "excluded_list",
3492
- self.name)
3493
- validator.check_int(excluded_numbers_shape[0], self.atom_numbers, Rel.EQ, "excluded_numbers", self.name)
3494
- validator.check_int(need_refresh_flag_shape[0], 1, Rel.EQ, "need_refresh_flag", self.name)
3495
- if refresh_count_shape:
3496
- validator.check_int(refresh_count_shape[0], 1, Rel.EQ, "refresh_count_shape", self.name)
3497
- return [1,]
3498
-
3499
- def infer_dtype(self, atom_numbers_in_grid_bucket_dtype, bucket_dtype, crd_dtype, box_length_dtype, grid_n_dtype,
3500
- grid_length_inverse_dtype, atom_in_grid_serial_dtype, old_crd_dtype, crd_to_uint_crd_cof_dtype,
3501
- uint_crd_dtype, gpointer_dtype, nl_atom_numbers_dtype, nl_atom_serial_dtype,
3502
- uint_dr_to_dr_cof_dtype, excluded_list_start_dtype, excluded_list_dtype, excluded_numbers_dtype,
3503
- need_refresh_flag_dtype, refresh_count_dtype):
3504
- validator.check_tensor_dtype_valid('atom_numbers_in_grid_bucket', atom_numbers_in_grid_bucket_dtype,
3505
- [mstype.int32], self.name)
3506
- validator.check_tensor_dtype_valid('bucket', bucket_dtype, [mstype.int32], self.name)
3507
- validator.check_tensor_dtype_valid('crd', crd_dtype, [mstype.float32], self.name)
3508
- validator.check_tensor_dtype_valid('box_length', box_length_dtype, [mstype.float32], self.name)
3509
- validator.check_tensor_dtype_valid('grid_N', grid_n_dtype, [mstype.int32], self.name)
3510
- validator.check_tensor_dtype_valid('grid_length_inverse', grid_length_inverse_dtype, [mstype.float32],
3511
- self.name)
3512
- validator.check_tensor_dtype_valid('atom_in_grid_serial', atom_in_grid_serial_dtype, [mstype.int32],
3513
- self.name)
3514
- validator.check_tensor_dtype_valid('old_crd', old_crd_dtype, [mstype.float32], self.name)
3515
- validator.check_tensor_dtype_valid('crd_to_uint_crd_cof', crd_to_uint_crd_cof_dtype, [mstype.float32],
3516
- self.name)
3517
- validator.check_tensor_dtype_valid('uint_crd', uint_crd_dtype, [mstype.uint32], self.name)
3518
- validator.check_tensor_dtype_valid('gpointer', gpointer_dtype, [mstype.int32], self.name)
3519
- validator.check_tensor_dtype_valid('nl_atom_numbers', nl_atom_numbers_dtype, [mstype.int32], self.name)
3520
- validator.check_tensor_dtype_valid('nl_atom_serial', nl_atom_serial_dtype, [mstype.int32], self.name)
3521
- validator.check_tensor_dtype_valid('uint_dr_to_dr_cof', uint_dr_to_dr_cof_dtype, [mstype.float32],
3522
- self.name)
3523
- validator.check_tensor_dtype_valid('excluded_list_start', excluded_list_start_dtype, [mstype.int32],
3524
- self.name)
3525
- validator.check_tensor_dtype_valid('excluded_list', excluded_list_dtype, [mstype.int32], self.name)
3526
- validator.check_tensor_dtype_valid('excluded_numbers', excluded_numbers_dtype, [mstype.int32], self.name)
3527
- validator.check_tensor_dtype_valid('need_refresh_flag', need_refresh_flag_dtype, [mstype.int32],
3528
- self.name)
3529
- validator.check_tensor_dtype_valid('refresh_count', refresh_count_dtype, [mstype.int32],
3530
- self.name)
3531
- return mstype.float32