mindspore 2.3.0__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (1400) hide show
  1. mindspore/.commit_id +1 -0
  2. mindspore/ConcurrencyCheck.dll +0 -0
  3. mindspore/CppBuildInsights.dll +0 -0
  4. mindspore/CppCoreCheck.dll +0 -0
  5. mindspore/EnumIndex.dll +0 -0
  6. mindspore/EspXEngine.dll +0 -0
  7. mindspore/HResultCheck.dll +0 -0
  8. mindspore/KernelTraceControl.dll +0 -0
  9. mindspore/LocalESPC.dll +0 -0
  10. mindspore/Microsoft.Diagnostics.Tracing.EventSource.dll +0 -0
  11. mindspore/Microsoft.VisualStudio.RemoteControl.dll +0 -0
  12. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  13. mindspore/Microsoft.VisualStudio.Utilities.Internal.dll +0 -0
  14. mindspore/Newtonsoft.Json.dll +0 -0
  15. mindspore/System.Runtime.CompilerServices.Unsafe.dll +0 -0
  16. mindspore/VariantClear.dll +0 -0
  17. mindspore/__init__.py +51 -0
  18. mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
  19. mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
  20. mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
  21. mindspore/_check_jit_forbidden_api.py +106 -0
  22. mindspore/_checkparam.py +1378 -0
  23. mindspore/_extends/__init__.py +23 -0
  24. mindspore/_extends/builtin_operations.py +224 -0
  25. mindspore/_extends/graph_kernel/__init__.py +17 -0
  26. mindspore/_extends/graph_kernel/model/__init__.py +19 -0
  27. mindspore/_extends/graph_kernel/model/graph_parallel.py +311 -0
  28. mindspore/_extends/graph_kernel/model/graph_split.py +1348 -0
  29. mindspore/_extends/graph_kernel/model/model.py +553 -0
  30. mindspore/_extends/graph_kernel/model/model_builder.py +216 -0
  31. mindspore/_extends/graph_kernel/parallel_estimate.py +60 -0
  32. mindspore/_extends/graph_kernel/splitter.py +140 -0
  33. mindspore/_extends/graph_kernel/utils.py +28 -0
  34. mindspore/_extends/parallel_compile/__init__.py +19 -0
  35. mindspore/_extends/parallel_compile/akg_compiler/__init__.py +19 -0
  36. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +269 -0
  37. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +529 -0
  38. mindspore/_extends/parallel_compile/akg_compiler/compiler.py +56 -0
  39. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +96 -0
  40. mindspore/_extends/parallel_compile/akg_compiler/get_file_path.py +36 -0
  41. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +556 -0
  42. mindspore/_extends/parallel_compile/akg_compiler/util.py +159 -0
  43. mindspore/_extends/parse/__init__.py +49 -0
  44. mindspore/_extends/parse/compile_config.py +258 -0
  45. mindspore/_extends/parse/namespace.py +136 -0
  46. mindspore/_extends/parse/parser.py +1446 -0
  47. mindspore/_extends/parse/resources.py +213 -0
  48. mindspore/_extends/parse/standard_method.py +4437 -0
  49. mindspore/_extends/parse/trope.py +97 -0
  50. mindspore/_extends/pijit/__init__.py +23 -0
  51. mindspore/_extends/pijit/pijit_func_white_list.py +343 -0
  52. mindspore/_extends/remote/__init__.py +19 -0
  53. mindspore/_extends/remote/kernel_build_server.py +199 -0
  54. mindspore/_extends/remote/kernel_build_server_akg.py +55 -0
  55. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  56. mindspore/_extends/remote/kernel_build_server_ascend.py +75 -0
  57. mindspore/_extends/utils.py +68 -0
  58. mindspore/_install_custom.py +43 -0
  59. mindspore/_profiler.py +30 -0
  60. mindspore/amp.py +419 -0
  61. mindspore/atlprov.dll +0 -0
  62. mindspore/avcodec-59.dll +0 -0
  63. mindspore/avdevice-59.dll +0 -0
  64. mindspore/avfilter-8.dll +0 -0
  65. mindspore/avformat-59.dll +0 -0
  66. mindspore/avutil-57.dll +0 -0
  67. mindspore/boost/__init__.py +42 -0
  68. mindspore/boost/adasum.py +319 -0
  69. mindspore/boost/base.py +535 -0
  70. mindspore/boost/boost.py +400 -0
  71. mindspore/boost/boost_cell_wrapper.py +790 -0
  72. mindspore/boost/dim_reduce.py +323 -0
  73. mindspore/boost/grad_accumulation.py +79 -0
  74. mindspore/boost/grad_freeze.py +382 -0
  75. mindspore/boost/group_loss_scale_manager.py +166 -0
  76. mindspore/boost/less_batch_normalization.py +174 -0
  77. mindspore/c1.dll +0 -0
  78. mindspore/c1xx.dll +0 -0
  79. mindspore/c2.dll +0 -0
  80. mindspore/cfgpersist.dll +0 -0
  81. mindspore/clang_rt.asan_dbg_dynamic-x86_64.dll +0 -0
  82. mindspore/clang_rt.asan_dynamic-x86_64.dll +0 -0
  83. mindspore/common/__init__.py +84 -0
  84. mindspore/common/_auto_dynamic.py +68 -0
  85. mindspore/common/_decorator.py +50 -0
  86. mindspore/common/_jit_fallback_utils.py +110 -0
  87. mindspore/common/_monad.py +25 -0
  88. mindspore/common/_register_for_adapter.py +74 -0
  89. mindspore/common/_register_for_recompute.py +48 -0
  90. mindspore/common/_register_for_tensor.py +45 -0
  91. mindspore/common/_stub_tensor.py +210 -0
  92. mindspore/common/_utils.py +122 -0
  93. mindspore/common/api.py +2049 -0
  94. mindspore/common/auto_dynamic_shape.py +507 -0
  95. mindspore/common/dtype.py +422 -0
  96. mindspore/common/dump.py +131 -0
  97. mindspore/common/file_system.py +48 -0
  98. mindspore/common/generator.py +260 -0
  99. mindspore/common/hook_handle.py +155 -0
  100. mindspore/common/initializer.py +880 -0
  101. mindspore/common/jit_config.py +98 -0
  102. mindspore/common/lazy_inline.py +240 -0
  103. mindspore/common/mindir_util.py +111 -0
  104. mindspore/common/mutable.py +234 -0
  105. mindspore/common/no_inline.py +54 -0
  106. mindspore/common/np_dtype.py +25 -0
  107. mindspore/common/parameter.py +1048 -0
  108. mindspore/common/recompute.py +262 -0
  109. mindspore/common/seed.py +260 -0
  110. mindspore/common/sparse_tensor.py +1171 -0
  111. mindspore/common/symbol.py +122 -0
  112. mindspore/common/tensor.py +4859 -0
  113. mindspore/communication/__init__.py +37 -0
  114. mindspore/communication/_comm_helper.py +466 -0
  115. mindspore/communication/_hccl_management.py +297 -0
  116. mindspore/communication/comm_func.py +1140 -0
  117. mindspore/communication/management.py +673 -0
  118. mindspore/config/op_info.config +533 -0
  119. mindspore/context.py +1976 -0
  120. mindspore/d3dcompiler_47.dll +0 -0
  121. mindspore/dataset/__init__.py +90 -0
  122. mindspore/dataset/audio/__init__.py +61 -0
  123. mindspore/dataset/audio/transforms.py +3690 -0
  124. mindspore/dataset/audio/utils.py +386 -0
  125. mindspore/dataset/audio/validators.py +1172 -0
  126. mindspore/dataset/callback/__init__.py +20 -0
  127. mindspore/dataset/callback/ds_callback.py +368 -0
  128. mindspore/dataset/callback/validators.py +32 -0
  129. mindspore/dataset/core/__init__.py +13 -0
  130. mindspore/dataset/core/config.py +1088 -0
  131. mindspore/dataset/core/datatypes.py +101 -0
  132. mindspore/dataset/core/py_util_helpers.py +65 -0
  133. mindspore/dataset/core/validator_helpers.py +774 -0
  134. mindspore/dataset/debug/__init__.py +21 -0
  135. mindspore/dataset/debug/debug_hook.py +97 -0
  136. mindspore/dataset/debug/pre_defined_hook.py +67 -0
  137. mindspore/dataset/engine/__init__.py +124 -0
  138. mindspore/dataset/engine/cache_admin.py +47 -0
  139. mindspore/dataset/engine/cache_client.py +129 -0
  140. mindspore/dataset/engine/datasets.py +4554 -0
  141. mindspore/dataset/engine/datasets_audio.py +911 -0
  142. mindspore/dataset/engine/datasets_standard_format.py +493 -0
  143. mindspore/dataset/engine/datasets_text.py +2161 -0
  144. mindspore/dataset/engine/datasets_user_defined.py +1114 -0
  145. mindspore/dataset/engine/datasets_vision.py +4816 -0
  146. mindspore/dataset/engine/iterators.py +342 -0
  147. mindspore/dataset/engine/obs/__init__.py +23 -0
  148. mindspore/dataset/engine/obs/config_loader.py +68 -0
  149. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +508 -0
  150. mindspore/dataset/engine/obs/util.py +475 -0
  151. mindspore/dataset/engine/offload.py +596 -0
  152. mindspore/dataset/engine/queue.py +250 -0
  153. mindspore/dataset/engine/samplers.py +895 -0
  154. mindspore/dataset/engine/serializer_deserializer.py +159 -0
  155. mindspore/dataset/engine/validators.py +2875 -0
  156. mindspore/dataset/text/__init__.py +54 -0
  157. mindspore/dataset/text/transforms.py +1703 -0
  158. mindspore/dataset/text/utils.py +715 -0
  159. mindspore/dataset/text/validators.py +642 -0
  160. mindspore/dataset/transforms/__init__.py +48 -0
  161. mindspore/dataset/transforms/c_transforms.py +638 -0
  162. mindspore/dataset/transforms/py_transforms.py +393 -0
  163. mindspore/dataset/transforms/py_transforms_util.py +255 -0
  164. mindspore/dataset/transforms/transforms.py +1260 -0
  165. mindspore/dataset/transforms/validators.py +410 -0
  166. mindspore/dataset/utils/__init__.py +19 -0
  167. mindspore/dataset/utils/browse_dataset.py +190 -0
  168. mindspore/dataset/utils/line_reader.py +124 -0
  169. mindspore/dataset/vision/__init__.py +68 -0
  170. mindspore/dataset/vision/c_transforms.py +2641 -0
  171. mindspore/dataset/vision/py_transforms.py +2120 -0
  172. mindspore/dataset/vision/py_transforms_util.py +1660 -0
  173. mindspore/dataset/vision/transforms.py +7295 -0
  174. mindspore/dataset/vision/utils.py +863 -0
  175. mindspore/dataset/vision/validators.py +1482 -0
  176. mindspore/default_config.py +2 -0
  177. mindspore/dnnl.dll +0 -0
  178. mindspore/dpcmi.dll +0 -0
  179. mindspore/experimental/__init__.py +20 -0
  180. mindspore/experimental/map_parameter.py +309 -0
  181. mindspore/experimental/optim/__init__.py +40 -0
  182. mindspore/experimental/optim/adadelta.py +161 -0
  183. mindspore/experimental/optim/adagrad.py +168 -0
  184. mindspore/experimental/optim/adam.py +193 -0
  185. mindspore/experimental/optim/adamax.py +170 -0
  186. mindspore/experimental/optim/adamw.py +205 -0
  187. mindspore/experimental/optim/asgd.py +153 -0
  188. mindspore/experimental/optim/lr_scheduler.py +1371 -0
  189. mindspore/experimental/optim/nadam.py +157 -0
  190. mindspore/experimental/optim/optimizer.py +259 -0
  191. mindspore/experimental/optim/radam.py +194 -0
  192. mindspore/experimental/optim/rmsprop.py +154 -0
  193. mindspore/experimental/optim/rprop.py +164 -0
  194. mindspore/experimental/optim/sgd.py +156 -0
  195. mindspore/hal/__init__.py +40 -0
  196. mindspore/hal/_ascend.py +57 -0
  197. mindspore/hal/_base.py +57 -0
  198. mindspore/hal/_cpu.py +56 -0
  199. mindspore/hal/_gpu.py +57 -0
  200. mindspore/hal/device.py +356 -0
  201. mindspore/hal/event.py +179 -0
  202. mindspore/hal/memory.py +326 -0
  203. mindspore/hal/stream.py +339 -0
  204. mindspore/include/OWNERS +7 -0
  205. mindspore/include/api/allocator.h +97 -0
  206. mindspore/include/api/callback/callback.h +93 -0
  207. mindspore/include/api/callback/ckpt_saver.h +41 -0
  208. mindspore/include/api/callback/loss_monitor.h +33 -0
  209. mindspore/include/api/callback/lr_scheduler.h +51 -0
  210. mindspore/include/api/callback/time_monitor.h +34 -0
  211. mindspore/include/api/callback/train_accuracy.h +37 -0
  212. mindspore/include/api/cell.h +90 -0
  213. mindspore/include/api/cfg.h +82 -0
  214. mindspore/include/api/context.h +602 -0
  215. mindspore/include/api/data_type.h +47 -0
  216. mindspore/include/api/delegate.h +178 -0
  217. mindspore/include/api/delegate_api.h +75 -0
  218. mindspore/include/api/dual_abi_helper.h +208 -0
  219. mindspore/include/api/format.h +28 -0
  220. mindspore/include/api/graph.h +46 -0
  221. mindspore/include/api/kernel.h +58 -0
  222. mindspore/include/api/kernel_api.h +168 -0
  223. mindspore/include/api/metrics/accuracy.h +36 -0
  224. mindspore/include/api/metrics/metrics.h +41 -0
  225. mindspore/include/api/model.h +438 -0
  226. mindspore/include/api/model_group.h +79 -0
  227. mindspore/include/api/model_parallel_runner.h +168 -0
  228. mindspore/include/api/serialization.h +185 -0
  229. mindspore/include/api/status.h +192 -0
  230. mindspore/include/api/types.h +431 -0
  231. mindspore/include/api/visible.h +41 -0
  232. mindspore/include/c_api/context_c.h +179 -0
  233. mindspore/include/c_api/data_type_c.h +52 -0
  234. mindspore/include/c_api/format_c.h +46 -0
  235. mindspore/include/c_api/model_c.h +347 -0
  236. mindspore/include/c_api/ms/abstract.h +67 -0
  237. mindspore/include/c_api/ms/attribute.h +197 -0
  238. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  239. mindspore/include/c_api/ms/base/macros.h +32 -0
  240. mindspore/include/c_api/ms/base/status.h +33 -0
  241. mindspore/include/c_api/ms/base/types.h +283 -0
  242. mindspore/include/c_api/ms/context.h +102 -0
  243. mindspore/include/c_api/ms/graph.h +160 -0
  244. mindspore/include/c_api/ms/node.h +606 -0
  245. mindspore/include/c_api/ms/tensor.h +161 -0
  246. mindspore/include/c_api/ms/value.h +84 -0
  247. mindspore/include/c_api/status_c.h +79 -0
  248. mindspore/include/c_api/tensor_c.h +146 -0
  249. mindspore/include/c_api/types_c.h +67 -0
  250. mindspore/include/dataset/config.h +163 -0
  251. mindspore/include/dataset/constants.h +363 -0
  252. mindspore/include/dataset/execute.h +196 -0
  253. mindspore/include/dataset/text.h +1092 -0
  254. mindspore/include/dataset/transforms.h +638 -0
  255. mindspore/include/dataset/vision.h +2125 -0
  256. mindspore/include/dataset/vision_ascend.h +206 -0
  257. mindspore/include/dataset/vision_lite.h +625 -0
  258. mindspore/jpeg62.dll +0 -0
  259. mindspore/log.py +633 -0
  260. mindspore/mindrecord/__init__.py +43 -0
  261. mindspore/mindrecord/common/__init__.py +17 -0
  262. mindspore/mindrecord/common/constant.py +20 -0
  263. mindspore/mindrecord/common/enums.py +44 -0
  264. mindspore/mindrecord/common/exceptions.py +311 -0
  265. mindspore/mindrecord/config.py +809 -0
  266. mindspore/mindrecord/filereader.py +174 -0
  267. mindspore/mindrecord/filewriter.py +705 -0
  268. mindspore/mindrecord/mindpage.py +210 -0
  269. mindspore/mindrecord/shardheader.py +141 -0
  270. mindspore/mindrecord/shardindexgenerator.py +74 -0
  271. mindspore/mindrecord/shardreader.py +117 -0
  272. mindspore/mindrecord/shardsegment.py +128 -0
  273. mindspore/mindrecord/shardutils.py +185 -0
  274. mindspore/mindrecord/shardwriter.py +237 -0
  275. mindspore/mindrecord/tools/__init__.py +17 -0
  276. mindspore/mindrecord/tools/cifar10.py +140 -0
  277. mindspore/mindrecord/tools/cifar100.py +153 -0
  278. mindspore/mindrecord/tools/cifar100_to_mr.py +185 -0
  279. mindspore/mindrecord/tools/cifar10_to_mr.py +177 -0
  280. mindspore/mindrecord/tools/csv_to_mr.py +200 -0
  281. mindspore/mindrecord/tools/imagenet_to_mr.py +206 -0
  282. mindspore/mindrecord/tools/mnist_to_mr.py +259 -0
  283. mindspore/mindrecord/tools/tfrecord_to_mr.py +360 -0
  284. mindspore/mindspore_backend.dll +0 -0
  285. mindspore/mindspore_common.dll +0 -0
  286. mindspore/mindspore_core.dll +0 -0
  287. mindspore/mindspore_glog.dll +0 -0
  288. mindspore/mindspore_np_dtype.dll +0 -0
  289. mindspore/mindspore_shared_lib.dll +0 -0
  290. mindspore/mint/__init__.py +1137 -0
  291. mindspore/mint/linalg/__init__.py +22 -0
  292. mindspore/mint/nn/__init__.py +512 -0
  293. mindspore/mint/nn/functional.py +573 -0
  294. mindspore/mint/optim/__init__.py +24 -0
  295. mindspore/mint/optim/adamw.py +185 -0
  296. mindspore/msobj140.dll +0 -0
  297. mindspore/mspdb140.dll +0 -0
  298. mindspore/mspdbcore.dll +0 -0
  299. mindspore/mspdbst.dll +0 -0
  300. mindspore/mspft140.dll +0 -0
  301. mindspore/msvcdis140.dll +0 -0
  302. mindspore/msvcp140.dll +0 -0
  303. mindspore/msvcp140_1.dll +0 -0
  304. mindspore/msvcp140_2.dll +0 -0
  305. mindspore/msvcp140_atomic_wait.dll +0 -0
  306. mindspore/msvcp140_codecvt_ids.dll +0 -0
  307. mindspore/multiprocessing/__init__.py +72 -0
  308. mindspore/nn/__init__.py +48 -0
  309. mindspore/nn/cell.py +2605 -0
  310. mindspore/nn/dynamic_lr.py +482 -0
  311. mindspore/nn/extend/__init__.py +29 -0
  312. mindspore/nn/extend/basic.py +140 -0
  313. mindspore/nn/extend/embedding.py +143 -0
  314. mindspore/nn/extend/layer/__init__.py +27 -0
  315. mindspore/nn/extend/layer/normalization.py +109 -0
  316. mindspore/nn/extend/pooling.py +117 -0
  317. mindspore/nn/grad/__init__.py +21 -0
  318. mindspore/nn/grad/cell_grad.py +196 -0
  319. mindspore/nn/layer/__init__.py +63 -0
  320. mindspore/nn/layer/activation.py +1655 -0
  321. mindspore/nn/layer/basic.py +1519 -0
  322. mindspore/nn/layer/channel_shuffle.py +90 -0
  323. mindspore/nn/layer/combined.py +248 -0
  324. mindspore/nn/layer/container.py +734 -0
  325. mindspore/nn/layer/conv.py +1505 -0
  326. mindspore/nn/layer/dense.py +204 -0
  327. mindspore/nn/layer/embedding.py +751 -0
  328. mindspore/nn/layer/embedding_service.py +531 -0
  329. mindspore/nn/layer/embedding_service_layer.py +393 -0
  330. mindspore/nn/layer/image.py +661 -0
  331. mindspore/nn/layer/math.py +1069 -0
  332. mindspore/nn/layer/normalization.py +1177 -0
  333. mindspore/nn/layer/padding.py +894 -0
  334. mindspore/nn/layer/pooling.py +2148 -0
  335. mindspore/nn/layer/rnn_cells.py +388 -0
  336. mindspore/nn/layer/rnns.py +849 -0
  337. mindspore/nn/layer/thor_layer.py +963 -0
  338. mindspore/nn/layer/timedistributed.py +155 -0
  339. mindspore/nn/layer/transformer.py +823 -0
  340. mindspore/nn/learning_rate_schedule.py +512 -0
  341. mindspore/nn/loss/__init__.py +36 -0
  342. mindspore/nn/loss/loss.py +2846 -0
  343. mindspore/nn/metrics.py +53 -0
  344. mindspore/nn/optim/__init__.py +44 -0
  345. mindspore/nn/optim/_dist_optimizer_registry.py +111 -0
  346. mindspore/nn/optim/ada_grad.py +217 -0
  347. mindspore/nn/optim/adadelta.py +206 -0
  348. mindspore/nn/optim/adafactor.py +448 -0
  349. mindspore/nn/optim/adam.py +1297 -0
  350. mindspore/nn/optim/adamax.py +220 -0
  351. mindspore/nn/optim/adasum.py +548 -0
  352. mindspore/nn/optim/asgd.py +216 -0
  353. mindspore/nn/optim/ftrl.py +401 -0
  354. mindspore/nn/optim/lamb.py +296 -0
  355. mindspore/nn/optim/lars.py +202 -0
  356. mindspore/nn/optim/lazyadam.py +533 -0
  357. mindspore/nn/optim/momentum.py +239 -0
  358. mindspore/nn/optim/optimizer.py +1034 -0
  359. mindspore/nn/optim/proximal_ada_grad.py +242 -0
  360. mindspore/nn/optim/rmsprop.py +264 -0
  361. mindspore/nn/optim/rprop.py +251 -0
  362. mindspore/nn/optim/sgd.py +237 -0
  363. mindspore/nn/optim/thor.py +1310 -0
  364. mindspore/nn/probability/__init__.py +22 -0
  365. mindspore/nn/probability/bijector/__init__.py +35 -0
  366. mindspore/nn/probability/bijector/bijector.py +337 -0
  367. mindspore/nn/probability/bijector/exp.py +65 -0
  368. mindspore/nn/probability/bijector/gumbel_cdf.py +144 -0
  369. mindspore/nn/probability/bijector/invert.py +126 -0
  370. mindspore/nn/probability/bijector/power_transform.py +196 -0
  371. mindspore/nn/probability/bijector/scalar_affine.py +167 -0
  372. mindspore/nn/probability/bijector/softplus.py +189 -0
  373. mindspore/nn/probability/bnn_layers/__init__.py +29 -0
  374. mindspore/nn/probability/bnn_layers/_util.py +46 -0
  375. mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +112 -0
  376. mindspore/nn/probability/bnn_layers/conv_variational.py +267 -0
  377. mindspore/nn/probability/bnn_layers/dense_variational.py +302 -0
  378. mindspore/nn/probability/bnn_layers/layer_distribution.py +123 -0
  379. mindspore/nn/probability/distribution/__init__.py +56 -0
  380. mindspore/nn/probability/distribution/_utils/__init__.py +34 -0
  381. mindspore/nn/probability/distribution/_utils/custom_ops.py +96 -0
  382. mindspore/nn/probability/distribution/_utils/utils.py +362 -0
  383. mindspore/nn/probability/distribution/bernoulli.py +334 -0
  384. mindspore/nn/probability/distribution/beta.py +391 -0
  385. mindspore/nn/probability/distribution/categorical.py +435 -0
  386. mindspore/nn/probability/distribution/cauchy.py +383 -0
  387. mindspore/nn/probability/distribution/distribution.py +827 -0
  388. mindspore/nn/probability/distribution/exponential.py +350 -0
  389. mindspore/nn/probability/distribution/gamma.py +391 -0
  390. mindspore/nn/probability/distribution/geometric.py +335 -0
  391. mindspore/nn/probability/distribution/gumbel.py +257 -0
  392. mindspore/nn/probability/distribution/half_normal.py +133 -0
  393. mindspore/nn/probability/distribution/laplace.py +128 -0
  394. mindspore/nn/probability/distribution/log_normal.py +272 -0
  395. mindspore/nn/probability/distribution/logistic.py +379 -0
  396. mindspore/nn/probability/distribution/normal.py +336 -0
  397. mindspore/nn/probability/distribution/poisson.py +288 -0
  398. mindspore/nn/probability/distribution/student_t.py +149 -0
  399. mindspore/nn/probability/distribution/transformed_distribution.py +235 -0
  400. mindspore/nn/probability/distribution/uniform.py +375 -0
  401. mindspore/nn/reinforcement/__init__.py +24 -0
  402. mindspore/nn/reinforcement/_batch_read_write.py +142 -0
  403. mindspore/nn/reinforcement/_tensors_queue.py +152 -0
  404. mindspore/nn/reinforcement/tensor_array.py +145 -0
  405. mindspore/nn/sparse/__init__.py +23 -0
  406. mindspore/nn/sparse/sparse.py +147 -0
  407. mindspore/nn/wrap/__init__.py +49 -0
  408. mindspore/nn/wrap/cell_wrapper.py +979 -0
  409. mindspore/nn/wrap/grad_reducer.py +608 -0
  410. mindspore/nn/wrap/loss_scale.py +680 -0
  411. mindspore/numpy/__init__.py +121 -0
  412. mindspore/numpy/array_creations.py +2734 -0
  413. mindspore/numpy/array_ops.py +2625 -0
  414. mindspore/numpy/dtypes.py +185 -0
  415. mindspore/numpy/fft.py +431 -0
  416. mindspore/numpy/logic_ops.py +935 -0
  417. mindspore/numpy/math_ops.py +5910 -0
  418. mindspore/numpy/utils.py +214 -0
  419. mindspore/numpy/utils_const.py +565 -0
  420. mindspore/opencv_core452.dll +0 -0
  421. mindspore/opencv_imgcodecs452.dll +0 -0
  422. mindspore/opencv_imgproc452.dll +0 -0
  423. mindspore/ops/__init__.py +54 -0
  424. mindspore/ops/_constants.py +30 -0
  425. mindspore/ops/_grad_experimental/__init__.py +31 -0
  426. mindspore/ops/_grad_experimental/grad_array_ops.py +830 -0
  427. mindspore/ops/_grad_experimental/grad_base.py +143 -0
  428. mindspore/ops/_grad_experimental/grad_comm_ops.py +670 -0
  429. mindspore/ops/_grad_experimental/grad_debug_ops.py +31 -0
  430. mindspore/ops/_grad_experimental/grad_implementations.py +203 -0
  431. mindspore/ops/_grad_experimental/grad_inner_ops.py +79 -0
  432. mindspore/ops/_grad_experimental/grad_math_ops.py +824 -0
  433. mindspore/ops/_grad_experimental/grad_nn_ops.py +231 -0
  434. mindspore/ops/_grad_experimental/grad_quant_ops.py +238 -0
  435. mindspore/ops/_grad_experimental/grad_sparse.py +342 -0
  436. mindspore/ops/_grad_experimental/grad_sparse_ops.py +399 -0
  437. mindspore/ops/_grad_experimental/taylor_rule.py +220 -0
  438. mindspore/ops/_op_impl/__init__.py +23 -0
  439. mindspore/ops/_op_impl/_custom_op/__init__.py +39 -0
  440. mindspore/ops/_op_impl/_custom_op/_basic.py +158 -0
  441. mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +279 -0
  442. mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +156 -0
  443. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +109 -0
  444. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +125 -0
  445. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +105 -0
  446. mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +124 -0
  447. mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +116 -0
  448. mindspore/ops/_op_impl/_custom_op/correction_mul.py +89 -0
  449. mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +196 -0
  450. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +366 -0
  451. mindspore/ops/_op_impl/_custom_op/dsd_impl.py +162 -0
  452. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +136 -0
  453. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +206 -0
  454. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +88 -0
  455. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +128 -0
  456. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +199 -0
  457. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +88 -0
  458. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +156 -0
  459. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +184 -0
  460. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +143 -0
  461. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +169 -0
  462. mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +548 -0
  463. mindspore/ops/_op_impl/_custom_op/img2col_impl.py +881 -0
  464. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +278 -0
  465. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +200 -0
  466. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +334 -0
  467. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +255 -0
  468. mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +222 -0
  469. mindspore/ops/_op_impl/_custom_op/matmul_dds_grad_impl.py +644 -0
  470. mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +488 -0
  471. mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +87 -0
  472. mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +129 -0
  473. mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +121 -0
  474. mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +352 -0
  475. mindspore/ops/_op_impl/aicpu/__init__.py +441 -0
  476. mindspore/ops/_op_impl/aicpu/abs.py +36 -0
  477. mindspore/ops/_op_impl/aicpu/acos.py +32 -0
  478. mindspore/ops/_op_impl/aicpu/acos_grad.py +33 -0
  479. mindspore/ops/_op_impl/aicpu/acosh.py +34 -0
  480. mindspore/ops/_op_impl/aicpu/acosh_grad.py +35 -0
  481. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d.py +34 -0
  482. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
  483. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d.py +39 -0
  484. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d_grad.py +39 -0
  485. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
  486. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d_grad.py +37 -0
  487. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d.py +42 -0
  488. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d_grad.py +152 -0
  489. mindspore/ops/_op_impl/aicpu/add.py +43 -0
  490. mindspore/ops/_op_impl/aicpu/add_n.py +41 -0
  491. mindspore/ops/_op_impl/aicpu/add_v2.py +40 -0
  492. mindspore/ops/_op_impl/aicpu/addcdiv.py +41 -0
  493. mindspore/ops/_op_impl/aicpu/addcmul.py +47 -0
  494. mindspore/ops/_op_impl/aicpu/adjust_contrastv2.py +32 -0
  495. mindspore/ops/_op_impl/aicpu/adjust_hue.py +31 -0
  496. mindspore/ops/_op_impl/aicpu/adjust_saturation.py +32 -0
  497. mindspore/ops/_op_impl/aicpu/affine_grid.py +33 -0
  498. mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
  499. mindspore/ops/_op_impl/aicpu/angle.py +31 -0
  500. mindspore/ops/_op_impl/aicpu/arg_max.py +75 -0
  501. mindspore/ops/_op_impl/aicpu/arg_min.py +75 -0
  502. mindspore/ops/_op_impl/aicpu/argmax_with_value.py +43 -0
  503. mindspore/ops/_op_impl/aicpu/argmin_with_value.py +43 -0
  504. mindspore/ops/_op_impl/aicpu/asin.py +32 -0
  505. mindspore/ops/_op_impl/aicpu/asin_grad.py +33 -0
  506. mindspore/ops/_op_impl/aicpu/asinh.py +34 -0
  507. mindspore/ops/_op_impl/aicpu/asinh_grad.py +35 -0
  508. mindspore/ops/_op_impl/aicpu/atanh.py +34 -0
  509. mindspore/ops/_op_impl/aicpu/avgpool_grad_v1.py +37 -0
  510. mindspore/ops/_op_impl/aicpu/avgpool_v1.py +36 -0
  511. mindspore/ops/_op_impl/aicpu/bartlett_window.py +36 -0
  512. mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -0
  513. mindspore/ops/_op_impl/aicpu/batch_norm_grad_grad.py +49 -0
  514. mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
  515. mindspore/ops/_op_impl/aicpu/bessel_i0.py +31 -0
  516. mindspore/ops/_op_impl/aicpu/betainc.py +31 -0
  517. mindspore/ops/_op_impl/aicpu/bias_add.py +44 -0
  518. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +42 -0
  519. mindspore/ops/_op_impl/aicpu/bincount.py +33 -0
  520. mindspore/ops/_op_impl/aicpu/blackman_window.py +36 -0
  521. mindspore/ops/_op_impl/aicpu/broadcast_to.py +58 -0
  522. mindspore/ops/_op_impl/aicpu/bucketize.py +34 -0
  523. mindspore/ops/_op_impl/aicpu/cache_swap_table.py +102 -0
  524. mindspore/ops/_op_impl/aicpu/cast.py +225 -0
  525. mindspore/ops/_op_impl/aicpu/cauchy.py +33 -0
  526. mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
  527. mindspore/ops/_op_impl/aicpu/check_numerics.py +33 -0
  528. mindspore/ops/_op_impl/aicpu/cholesky.py +32 -0
  529. mindspore/ops/_op_impl/aicpu/cholesky_inverse.py +31 -0
  530. mindspore/ops/_op_impl/aicpu/cholesky_solve.py +33 -0
  531. mindspore/ops/_op_impl/aicpu/choleskygrad.py +32 -0
  532. mindspore/ops/_op_impl/aicpu/coalesce.py +37 -0
  533. mindspore/ops/_op_impl/aicpu/col2im.py +38 -0
  534. mindspore/ops/_op_impl/aicpu/combined_non_max_suppression.py +42 -0
  535. mindspore/ops/_op_impl/aicpu/compare_and_bitpack.py +37 -0
  536. mindspore/ops/_op_impl/aicpu/complex.py +32 -0
  537. mindspore/ops/_op_impl/aicpu/complex_abs.py +31 -0
  538. mindspore/ops/_op_impl/aicpu/compute_accidental_hits.py +44 -0
  539. mindspore/ops/_op_impl/aicpu/concat.py +57 -0
  540. mindspore/ops/_op_impl/aicpu/concat_offset.py +42 -0
  541. mindspore/ops/_op_impl/aicpu/concat_offset_v1.py +31 -0
  542. mindspore/ops/_op_impl/aicpu/conj.py +42 -0
  543. mindspore/ops/_op_impl/aicpu/conjugate_transpose.py +58 -0
  544. mindspore/ops/_op_impl/aicpu/cos.py +34 -0
  545. mindspore/ops/_op_impl/aicpu/cosh.py +34 -0
  546. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  547. mindspore/ops/_op_impl/aicpu/crop_and_resize.py +69 -0
  548. mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_boxes.py +68 -0
  549. mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_image.py +38 -0
  550. mindspore/ops/_op_impl/aicpu/cross.py +42 -0
  551. mindspore/ops/_op_impl/aicpu/csr_sparse_matrix_to_dense.py +48 -0
  552. mindspore/ops/_op_impl/aicpu/csr_sparse_matrix_to_sparse_tensor.py +51 -0
  553. mindspore/ops/_op_impl/aicpu/ctc_greedy_decoder.py +35 -0
  554. mindspore/ops/_op_impl/aicpu/ctc_loss_v2.py +43 -0
  555. mindspore/ops/_op_impl/aicpu/ctc_loss_v2_grad.py +45 -0
  556. mindspore/ops/_op_impl/aicpu/ctcloss.py +38 -0
  557. mindspore/ops/_op_impl/aicpu/cummax.py +41 -0
  558. mindspore/ops/_op_impl/aicpu/cumprod.py +58 -0
  559. mindspore/ops/_op_impl/aicpu/cumsum.py +58 -0
  560. mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +36 -0
  561. mindspore/ops/_op_impl/aicpu/data_format_vec_permute.py +32 -0
  562. mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
  563. mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +43 -0
  564. mindspore/ops/_op_impl/aicpu/dense_to_csr_sparse_matrix.py +49 -0
  565. mindspore/ops/_op_impl/aicpu/dense_to_dense_set_operation.py +45 -0
  566. mindspore/ops/_op_impl/aicpu/dense_to_sparse_set_operation.py +48 -0
  567. mindspore/ops/_op_impl/aicpu/depth_to_space.py +44 -0
  568. mindspore/ops/_op_impl/aicpu/diag.py +36 -0
  569. mindspore/ops/_op_impl/aicpu/diag_part.py +36 -0
  570. mindspore/ops/_op_impl/aicpu/diagonal.py +35 -0
  571. mindspore/ops/_op_impl/aicpu/digamma.py +31 -0
  572. mindspore/ops/_op_impl/aicpu/div.py +41 -0
  573. mindspore/ops/_op_impl/aicpu/div_no_nan.py +35 -0
  574. mindspore/ops/_op_impl/aicpu/dropout2d.py +42 -0
  575. mindspore/ops/_op_impl/aicpu/dropout3d.py +42 -0
  576. mindspore/ops/_op_impl/aicpu/dropout_genmask.py +41 -0
  577. mindspore/ops/_op_impl/aicpu/dropout_genmask_v3.py +32 -0
  578. mindspore/ops/_op_impl/aicpu/dynamic_stitch.py +42 -0
  579. mindspore/ops/_op_impl/aicpu/edit_distance.py +56 -0
  580. mindspore/ops/_op_impl/aicpu/eig.py +35 -0
  581. mindspore/ops/_op_impl/aicpu/embedding_lookup.py +102 -0
  582. mindspore/ops/_op_impl/aicpu/end_of_sequence.py +30 -0
  583. mindspore/ops/_op_impl/aicpu/environ_create.py +28 -0
  584. mindspore/ops/_op_impl/aicpu/environ_destroy_all.py +28 -0
  585. mindspore/ops/_op_impl/aicpu/environ_get.py +41 -0
  586. mindspore/ops/_op_impl/aicpu/environ_set.py +40 -0
  587. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  588. mindspore/ops/_op_impl/aicpu/equal.py +41 -0
  589. mindspore/ops/_op_impl/aicpu/exp.py +37 -0
  590. mindspore/ops/_op_impl/aicpu/expand.py +45 -0
  591. mindspore/ops/_op_impl/aicpu/expand_dims.py +42 -0
  592. mindspore/ops/_op_impl/aicpu/expm1.py +34 -0
  593. mindspore/ops/_op_impl/aicpu/extract_glimpse.py +35 -0
  594. mindspore/ops/_op_impl/aicpu/eye.py +44 -0
  595. mindspore/ops/_op_impl/aicpu/fft_with_size.py +47 -0
  596. mindspore/ops/_op_impl/aicpu/fill_diagonal.py +39 -0
  597. mindspore/ops/_op_impl/aicpu/fill_v2.py +58 -0
  598. mindspore/ops/_op_impl/aicpu/flatten.py +43 -0
  599. mindspore/ops/_op_impl/aicpu/floor_div.py +38 -0
  600. mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
  601. mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
  602. mindspore/ops/_op_impl/aicpu/fractional_avg_pool.py +41 -0
  603. mindspore/ops/_op_impl/aicpu/fractional_avg_pool_grad.py +41 -0
  604. mindspore/ops/_op_impl/aicpu/fractional_max_pool.py +41 -0
  605. mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_grad_with_fixed_ksize.py +43 -0
  606. mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +65 -0
  607. mindspore/ops/_op_impl/aicpu/fractional_max_pool_grad.py +42 -0
  608. mindspore/ops/_op_impl/aicpu/fractional_max_pool_grad_with_fixed_ksize.py +42 -0
  609. mindspore/ops/_op_impl/aicpu/fractional_max_pool_with_fixed_ksize.py +49 -0
  610. mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
  611. mindspore/ops/_op_impl/aicpu/fused_sparse_adam.py +46 -0
  612. mindspore/ops/_op_impl/aicpu/fused_sparse_ftrl.py +41 -0
  613. mindspore/ops/_op_impl/aicpu/fused_sparse_lazy_adam.py +46 -0
  614. mindspore/ops/_op_impl/aicpu/fused_sparse_proximal_adagrad.py +39 -0
  615. mindspore/ops/_op_impl/aicpu/gamma.py +38 -0
  616. mindspore/ops/_op_impl/aicpu/gather.py +46 -0
  617. mindspore/ops/_op_impl/aicpu/gather_d.py +79 -0
  618. mindspore/ops/_op_impl/aicpu/gather_d_grad_v2.py +79 -0
  619. mindspore/ops/_op_impl/aicpu/gather_grad.py +54 -0
  620. mindspore/ops/_op_impl/aicpu/gather_nd.py +56 -0
  621. mindspore/ops/_op_impl/aicpu/gcd.py +32 -0
  622. mindspore/ops/_op_impl/aicpu/generate_eod_mask.py +38 -0
  623. mindspore/ops/_op_impl/aicpu/geqrf.py +32 -0
  624. mindspore/ops/_op_impl/aicpu/get_next.py +39 -0
  625. mindspore/ops/_op_impl/aicpu/glu.py +33 -0
  626. mindspore/ops/_op_impl/aicpu/glu_grad.py +34 -0
  627. mindspore/ops/_op_impl/aicpu/greater.py +41 -0
  628. mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
  629. mindspore/ops/_op_impl/aicpu/grid_sampler_2d.py +35 -0
  630. mindspore/ops/_op_impl/aicpu/grid_sampler_2d_grad.py +38 -0
  631. mindspore/ops/_op_impl/aicpu/grid_sampler_3d.py +34 -0
  632. mindspore/ops/_op_impl/aicpu/grid_sampler_3d_grad.py +38 -0
  633. mindspore/ops/_op_impl/aicpu/hamming_window.py +57 -0
  634. mindspore/ops/_op_impl/aicpu/hard_sigmoid.py +32 -0
  635. mindspore/ops/_op_impl/aicpu/hard_sigmoid_grad.py +33 -0
  636. mindspore/ops/_op_impl/aicpu/heaviside.py +40 -0
  637. mindspore/ops/_op_impl/aicpu/histogram.py +35 -0
  638. mindspore/ops/_op_impl/aicpu/hsv_to_rgb.py +32 -0
  639. mindspore/ops/_op_impl/aicpu/hypot.py +32 -0
  640. mindspore/ops/_op_impl/aicpu/identity.py +42 -0
  641. mindspore/ops/_op_impl/aicpu/identity_n.py +41 -0
  642. mindspore/ops/_op_impl/aicpu/igamma.py +30 -0
  643. mindspore/ops/_op_impl/aicpu/igammac.py +30 -0
  644. mindspore/ops/_op_impl/aicpu/igammagrada.py +30 -0
  645. mindspore/ops/_op_impl/aicpu/im2col.py +43 -0
  646. mindspore/ops/_op_impl/aicpu/imag.py +31 -0
  647. mindspore/ops/_op_impl/aicpu/index_fill.py +54 -0
  648. mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
  649. mindspore/ops/_op_impl/aicpu/init_data_set_queue.py +27 -0
  650. mindspore/ops/_op_impl/aicpu/inplace_index_add.py +39 -0
  651. mindspore/ops/_op_impl/aicpu/instance_norm_v2.py +41 -0
  652. mindspore/ops/_op_impl/aicpu/instance_norm_v2_grad.py +44 -0
  653. mindspore/ops/_op_impl/aicpu/is_finite.py +40 -0
  654. mindspore/ops/_op_impl/aicpu/is_inf.py +31 -0
  655. mindspore/ops/_op_impl/aicpu/is_nan.py +31 -0
  656. mindspore/ops/_op_impl/aicpu/kldivloss.py +34 -0
  657. mindspore/ops/_op_impl/aicpu/kldivlossgrad.py +35 -0
  658. mindspore/ops/_op_impl/aicpu/layer_norm_grad_grad.py +47 -0
  659. mindspore/ops/_op_impl/aicpu/lcm.py +32 -0
  660. mindspore/ops/_op_impl/aicpu/left_shift.py +38 -0
  661. mindspore/ops/_op_impl/aicpu/less.py +41 -0
  662. mindspore/ops/_op_impl/aicpu/less_equal.py +41 -0
  663. mindspore/ops/_op_impl/aicpu/lgamma.py +33 -0
  664. mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +57 -0
  665. mindspore/ops/_op_impl/aicpu/linspace.py +33 -0
  666. mindspore/ops/_op_impl/aicpu/list_diff.py +50 -0
  667. mindspore/ops/_op_impl/aicpu/log.py +37 -0
  668. mindspore/ops/_op_impl/aicpu/log1p.py +34 -0
  669. mindspore/ops/_op_impl/aicpu/log_matrix_determinant.py +31 -0
  670. mindspore/ops/_op_impl/aicpu/log_normal_reverse.py +33 -0
  671. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +37 -0
  672. mindspore/ops/_op_impl/aicpu/logical_xor.py +30 -0
  673. mindspore/ops/_op_impl/aicpu/logit.py +33 -0
  674. mindspore/ops/_op_impl/aicpu/logit_grad.py +34 -0
  675. mindspore/ops/_op_impl/aicpu/logspace.py +36 -0
  676. mindspore/ops/_op_impl/aicpu/lower_bound.py +47 -0
  677. mindspore/ops/_op_impl/aicpu/lstsq.py +34 -0
  678. mindspore/ops/_op_impl/aicpu/lu.py +39 -0
  679. mindspore/ops/_op_impl/aicpu/lu_solve.py +32 -0
  680. mindspore/ops/_op_impl/aicpu/lu_unpack.py +114 -0
  681. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +49 -0
  682. mindspore/ops/_op_impl/aicpu/masked_fill.py +42 -0
  683. mindspore/ops/_op_impl/aicpu/masked_scatter.py +40 -0
  684. mindspore/ops/_op_impl/aicpu/masked_select.py +31 -0
  685. mindspore/ops/_op_impl/aicpu/masked_select_grad.py +35 -0
  686. mindspore/ops/_op_impl/aicpu/matmul.py +39 -0
  687. mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
  688. mindspore/ops/_op_impl/aicpu/matrix_determinant.py +30 -0
  689. mindspore/ops/_op_impl/aicpu/matrix_diag_part_v3.py +54 -0
  690. mindspore/ops/_op_impl/aicpu/matrix_diag_v3.py +56 -0
  691. mindspore/ops/_op_impl/aicpu/matrix_exp.py +34 -0
  692. mindspore/ops/_op_impl/aicpu/matrix_inverse.py +31 -0
  693. mindspore/ops/_op_impl/aicpu/matrix_logarithm.py +31 -0
  694. mindspore/ops/_op_impl/aicpu/matrix_power.py +37 -0
  695. mindspore/ops/_op_impl/aicpu/matrix_set_diag_v3.py +54 -0
  696. mindspore/ops/_op_impl/aicpu/matrix_solve.py +35 -0
  697. mindspore/ops/_op_impl/aicpu/matrix_solve_ls.py +36 -0
  698. mindspore/ops/_op_impl/aicpu/matrix_triangular_solve.py +36 -0
  699. mindspore/ops/_op_impl/aicpu/max_pool3d_grad_with_argmax.py +60 -0
  700. mindspore/ops/_op_impl/aicpu/max_pool3d_with_argmax.py +59 -0
  701. mindspore/ops/_op_impl/aicpu/max_unpool2d.py +57 -0
  702. mindspore/ops/_op_impl/aicpu/max_unpool2d_grad.py +58 -0
  703. mindspore/ops/_op_impl/aicpu/max_unpool3d.py +57 -0
  704. mindspore/ops/_op_impl/aicpu/max_unpool3d_grad.py +58 -0
  705. mindspore/ops/_op_impl/aicpu/maximum_grad_grad.py +40 -0
  706. mindspore/ops/_op_impl/aicpu/maxpool_grad_v1.py +46 -0
  707. mindspore/ops/_op_impl/aicpu/maxpool_v1.py +42 -0
  708. mindspore/ops/_op_impl/aicpu/median.py +39 -0
  709. mindspore/ops/_op_impl/aicpu/median_grad.py +45 -0
  710. mindspore/ops/_op_impl/aicpu/meshgrid.py +41 -0
  711. mindspore/ops/_op_impl/aicpu/minimum_grad_grad.py +40 -0
  712. mindspore/ops/_op_impl/aicpu/mirror_pad.py +50 -0
  713. mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +48 -0
  714. mindspore/ops/_op_impl/aicpu/mul.py +43 -0
  715. mindspore/ops/_op_impl/aicpu/mul_no_nan.py +42 -0
  716. mindspore/ops/_op_impl/aicpu/multi_margin_loss.py +37 -0
  717. mindspore/ops/_op_impl/aicpu/multi_margin_loss_grad.py +41 -0
  718. mindspore/ops/_op_impl/aicpu/multilabel_margin_loss_grad.py +37 -0
  719. mindspore/ops/_op_impl/aicpu/multinomial.py +47 -0
  720. mindspore/ops/_op_impl/aicpu/multinomial_with_replacement.py +35 -0
  721. mindspore/ops/_op_impl/aicpu/mvlgamma.py +32 -0
  722. mindspore/ops/_op_impl/aicpu/mvlgamma_grad.py +33 -0
  723. mindspore/ops/_op_impl/aicpu/nan_to_num.py +34 -0
  724. mindspore/ops/_op_impl/aicpu/neg.py +36 -0
  725. mindspore/ops/_op_impl/aicpu/nextafter.py +32 -0
  726. mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
  727. mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
  728. mindspore/ops/_op_impl/aicpu/no_repeat_ngram.py +34 -0
  729. mindspore/ops/_op_impl/aicpu/non_deterministic_ints.py +33 -0
  730. mindspore/ops/_op_impl/aicpu/non_max_suppression.py +36 -0
  731. mindspore/ops/_op_impl/aicpu/non_max_suppression_with_overlaps.py +35 -0
  732. mindspore/ops/_op_impl/aicpu/non_zero.py +43 -0
  733. mindspore/ops/_op_impl/aicpu/not_equal.py +39 -0
  734. mindspore/ops/_op_impl/aicpu/nth_element.py +39 -0
  735. mindspore/ops/_op_impl/aicpu/nuclear_norm.py +33 -0
  736. mindspore/ops/_op_impl/aicpu/one_hot.py +116 -0
  737. mindspore/ops/_op_impl/aicpu/ones_like.py +39 -0
  738. mindspore/ops/_op_impl/aicpu/orgqr.py +34 -0
  739. mindspore/ops/_op_impl/aicpu/pad_and_shift.py +33 -0
  740. mindspore/ops/_op_impl/aicpu/pad_v3.py +61 -0
  741. mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +59 -0
  742. mindspore/ops/_op_impl/aicpu/padding.py +41 -0
  743. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +54 -0
  744. mindspore/ops/_op_impl/aicpu/pdist_grad.py +33 -0
  745. mindspore/ops/_op_impl/aicpu/poisson.py +37 -0
  746. mindspore/ops/_op_impl/aicpu/polar.py +32 -0
  747. mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
  748. mindspore/ops/_op_impl/aicpu/pow.py +39 -0
  749. mindspore/ops/_op_impl/aicpu/print_tensor.py +39 -0
  750. mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +113 -0
  751. mindspore/ops/_op_impl/aicpu/qr.py +36 -0
  752. mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
  753. mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
  754. mindspore/ops/_op_impl/aicpu/ragged_range.py +49 -0
  755. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
  756. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_tensor.py +74 -0
  757. mindspore/ops/_op_impl/aicpu/random_categorical.py +68 -0
  758. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +36 -0
  759. mindspore/ops/_op_impl/aicpu/random_gamma.py +38 -0
  760. mindspore/ops/_op_impl/aicpu/random_poisson.py +134 -0
  761. mindspore/ops/_op_impl/aicpu/random_shuffle.py +47 -0
  762. mindspore/ops/_op_impl/aicpu/randperm.py +38 -0
  763. mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
  764. mindspore/ops/_op_impl/aicpu/range.py +36 -0
  765. mindspore/ops/_op_impl/aicpu/range_v2.py +35 -0
  766. mindspore/ops/_op_impl/aicpu/real.py +31 -0
  767. mindspore/ops/_op_impl/aicpu/real_div.py +40 -0
  768. mindspore/ops/_op_impl/aicpu/reciprocal.py +34 -0
  769. mindspore/ops/_op_impl/aicpu/reciprocal_grad.py +35 -0
  770. mindspore/ops/_op_impl/aicpu/reduce_mean.py +57 -0
  771. mindspore/ops/_op_impl/aicpu/reduce_prod.py +57 -0
  772. mindspore/ops/_op_impl/aicpu/reduce_sum.py +57 -0
  773. mindspore/ops/_op_impl/aicpu/relu_grad_v3.py +41 -0
  774. mindspore/ops/_op_impl/aicpu/relu_v3.py +38 -0
  775. mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +96 -0
  776. mindspore/ops/_op_impl/aicpu/reshape.py +42 -0
  777. mindspore/ops/_op_impl/aicpu/resize_area.py +40 -0
  778. mindspore/ops/_op_impl/aicpu/resize_bicubic.py +20 -0
  779. mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +19 -0
  780. mindspore/ops/_op_impl/aicpu/resize_bilinear.py +32 -0
  781. mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +32 -0
  782. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +36 -0
  783. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +35 -0
  784. mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
  785. mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
  786. mindspore/ops/_op_impl/aicpu/reverse_sequence.py +55 -0
  787. mindspore/ops/_op_impl/aicpu/reversev2.py +54 -0
  788. mindspore/ops/_op_impl/aicpu/rgb_to_hsv.py +32 -0
  789. mindspore/ops/_op_impl/aicpu/right_shift.py +38 -0
  790. mindspore/ops/_op_impl/aicpu/rnnt_loss.py +35 -0
  791. mindspore/ops/_op_impl/aicpu/round.py +34 -0
  792. mindspore/ops/_op_impl/aicpu/rsqrt.py +33 -0
  793. mindspore/ops/_op_impl/aicpu/rsqrt_grad.py +36 -0
  794. mindspore/ops/_op_impl/aicpu/sample_distorted_bounding_box_v2.py +49 -0
  795. mindspore/ops/_op_impl/aicpu/scale_and_translate.py +52 -0
  796. mindspore/ops/_op_impl/aicpu/scale_and_translate_grad.py +36 -0
  797. mindspore/ops/_op_impl/aicpu/scatter.py +79 -0
  798. mindspore/ops/_op_impl/aicpu/scatter_add_with_axis.py +53 -0
  799. mindspore/ops/_op_impl/aicpu/scatter_elements.py +39 -0
  800. mindspore/ops/_op_impl/aicpu/scatter_nd.py +59 -0
  801. mindspore/ops/_op_impl/aicpu/scatter_nd_max.py +54 -0
  802. mindspore/ops/_op_impl/aicpu/scatter_nd_min.py +54 -0
  803. mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +59 -0
  804. mindspore/ops/_op_impl/aicpu/search_sorted.py +44 -0
  805. mindspore/ops/_op_impl/aicpu/segment_max.py +52 -0
  806. mindspore/ops/_op_impl/aicpu/segment_mean.py +56 -0
  807. mindspore/ops/_op_impl/aicpu/segment_min.py +52 -0
  808. mindspore/ops/_op_impl/aicpu/segment_prod.py +56 -0
  809. mindspore/ops/_op_impl/aicpu/segment_sum.py +56 -0
  810. mindspore/ops/_op_impl/aicpu/select.py +45 -0
  811. mindspore/ops/_op_impl/aicpu/self_adjoint_eig.py +34 -0
  812. mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
  813. mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
  814. mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
  815. mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
  816. mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
  817. mindspore/ops/_op_impl/aicpu/set_size.py +38 -0
  818. mindspore/ops/_op_impl/aicpu/sign.py +36 -0
  819. mindspore/ops/_op_impl/aicpu/sin.py +34 -0
  820. mindspore/ops/_op_impl/aicpu/sinc.py +43 -0
  821. mindspore/ops/_op_impl/aicpu/sinh.py +34 -0
  822. mindspore/ops/_op_impl/aicpu/slice.py +59 -0
  823. mindspore/ops/_op_impl/aicpu/slice_grad.py +76 -0
  824. mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
  825. mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
  826. mindspore/ops/_op_impl/aicpu/sort.py +39 -0
  827. mindspore/ops/_op_impl/aicpu/space_to_depth.py +44 -0
  828. mindspore/ops/_op_impl/aicpu/sparse_addmm.py +87 -0
  829. mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +80 -0
  830. mindspore/ops/_op_impl/aicpu/sparse_apply_centered_rms_prop.py +105 -0
  831. mindspore/ops/_op_impl/aicpu/sparse_apply_momentum.py +80 -0
  832. mindspore/ops/_op_impl/aicpu/sparse_apply_proximal_gradient_descent.py +79 -0
  833. mindspore/ops/_op_impl/aicpu/sparse_concat.py +59 -0
  834. mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
  835. mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_add.py +58 -0
  836. mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_div.py +58 -0
  837. mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_mul.py +58 -0
  838. mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows.py +63 -0
  839. mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows_grad.py +45 -0
  840. mindspore/ops/_op_impl/aicpu/sparse_matrix_mat_mul.py +56 -0
  841. mindspore/ops/_op_impl/aicpu/sparse_matrix_nnz.py +81 -0
  842. mindspore/ops/_op_impl/aicpu/sparse_matrix_transpose.py +116 -0
  843. mindspore/ops/_op_impl/aicpu/sparse_reorder.py +56 -0
  844. mindspore/ops/_op_impl/aicpu/sparse_reshape.py +34 -0
  845. mindspore/ops/_op_impl/aicpu/sparse_segment_mean_grad.py +36 -0
  846. mindspore/ops/_op_impl/aicpu/sparse_segment_mean_with_num_segments.py +44 -0
  847. mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n.py +43 -0
  848. mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n_grad.py +38 -0
  849. mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n_with_num_segments.py +44 -0
  850. mindspore/ops/_op_impl/aicpu/sparse_segment_sum.py +49 -0
  851. mindspore/ops/_op_impl/aicpu/sparse_segment_sum_with_num_segments.py +68 -0
  852. mindspore/ops/_op_impl/aicpu/sparse_slice.py +63 -0
  853. mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +61 -0
  854. mindspore/ops/_op_impl/aicpu/sparse_softmax.py +33 -0
  855. mindspore/ops/_op_impl/aicpu/sparse_softmax_cross_entropy_with_logits_v2.py +35 -0
  856. mindspore/ops/_op_impl/aicpu/sparse_sparse_maximum.py +53 -0
  857. mindspore/ops/_op_impl/aicpu/sparse_sparse_minimum.py +53 -0
  858. mindspore/ops/_op_impl/aicpu/sparse_tensor_dense_add.py +84 -0
  859. mindspore/ops/_op_impl/aicpu/sparse_tensor_dense_mat_mul.py +190 -0
  860. mindspore/ops/_op_impl/aicpu/sparse_tensor_to_csr_sparse_matrix.py +51 -0
  861. mindspore/ops/_op_impl/aicpu/sparse_to_dense_v2.py +73 -0
  862. mindspore/ops/_op_impl/aicpu/split.py +45 -0
  863. mindspore/ops/_op_impl/aicpu/sqrt.py +34 -0
  864. mindspore/ops/_op_impl/aicpu/sqrt_grad.py +35 -0
  865. mindspore/ops/_op_impl/aicpu/square.py +35 -0
  866. mindspore/ops/_op_impl/aicpu/squared_difference.py +37 -0
  867. mindspore/ops/_op_impl/aicpu/squeeze.py +42 -0
  868. mindspore/ops/_op_impl/aicpu/sspaddmm.py +97 -0
  869. mindspore/ops/_op_impl/aicpu/stack.py +45 -0
  870. mindspore/ops/_op_impl/aicpu/stack_push_pop.py +87 -0
  871. mindspore/ops/_op_impl/aicpu/standard_laplace.py +34 -0
  872. mindspore/ops/_op_impl/aicpu/standard_normal.py +34 -0
  873. mindspore/ops/_op_impl/aicpu/stateless_dropout_genmask.py +37 -0
  874. mindspore/ops/_op_impl/aicpu/stft.py +70 -0
  875. mindspore/ops/_op_impl/aicpu/strided_slice.py +43 -0
  876. mindspore/ops/_op_impl/aicpu/strided_slice_grad.py +50 -0
  877. mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +93 -0
  878. mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +66 -0
  879. mindspore/ops/_op_impl/aicpu/sub.py +41 -0
  880. mindspore/ops/_op_impl/aicpu/sub_and_filter.py +36 -0
  881. mindspore/ops/_op_impl/aicpu/tan.py +34 -0
  882. mindspore/ops/_op_impl/aicpu/tanh.py +34 -0
  883. mindspore/ops/_op_impl/aicpu/tanh_grad.py +35 -0
  884. mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
  885. mindspore/ops/_op_impl/aicpu/tile.py +56 -0
  886. mindspore/ops/_op_impl/aicpu/topk.py +34 -0
  887. mindspore/ops/_op_impl/aicpu/trace.py +40 -0
  888. mindspore/ops/_op_impl/aicpu/tracegrad.py +41 -0
  889. mindspore/ops/_op_impl/aicpu/trans_data.py +35 -0
  890. mindspore/ops/_op_impl/aicpu/transpose.py +58 -0
  891. mindspore/ops/_op_impl/aicpu/tridiagonal_matmul.py +42 -0
  892. mindspore/ops/_op_impl/aicpu/tridiagonal_solve.py +35 -0
  893. mindspore/ops/_op_impl/aicpu/tril.py +42 -0
  894. mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
  895. mindspore/ops/_op_impl/aicpu/triplet_margin_loss.py +62 -0
  896. mindspore/ops/_op_impl/aicpu/triu.py +43 -0
  897. mindspore/ops/_op_impl/aicpu/triu_indices.py +34 -0
  898. mindspore/ops/_op_impl/aicpu/truncated_normal.py +39 -0
  899. mindspore/ops/_op_impl/aicpu/uniform.py +36 -0
  900. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +41 -0
  901. mindspore/ops/_op_impl/aicpu/uniform_int.py +36 -0
  902. mindspore/ops/_op_impl/aicpu/uniform_real.py +33 -0
  903. mindspore/ops/_op_impl/aicpu/unique.py +31 -0
  904. mindspore/ops/_op_impl/aicpu/unique_consecutive.py +47 -0
  905. mindspore/ops/_op_impl/aicpu/unique_with_pad.py +32 -0
  906. mindspore/ops/_op_impl/aicpu/unravel_index.py +32 -0
  907. mindspore/ops/_op_impl/aicpu/unsorted_segment_prod.py +53 -0
  908. mindspore/ops/_op_impl/aicpu/unsorted_segment_sum.py +57 -0
  909. mindspore/ops/_op_impl/aicpu/unstack.py +45 -0
  910. mindspore/ops/_op_impl/aicpu/update_cache.py +44 -0
  911. mindspore/ops/_op_impl/aicpu/upper_bound.py +47 -0
  912. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +42 -0
  913. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +49 -0
  914. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +40 -0
  915. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +50 -0
  916. mindspore/ops/_op_impl/aicpu/xdivy.py +35 -0
  917. mindspore/ops/_op_impl/aicpu/xlogy.py +33 -0
  918. mindspore/ops/_op_impl/aicpu/zeros_like.py +42 -0
  919. mindspore/ops/_op_impl/aicpu/zeta.py +31 -0
  920. mindspore/ops/_op_impl/akg/__init__.py +19 -0
  921. mindspore/ops/_op_impl/akg/ascend/__init__.py +48 -0
  922. mindspore/ops/_op_impl/akg/ascend/abs.py +35 -0
  923. mindspore/ops/_op_impl/akg/ascend/add.py +42 -0
  924. mindspore/ops/_op_impl/akg/ascend/add_n.py +37 -0
  925. mindspore/ops/_op_impl/akg/ascend/batchmatmul.py +33 -0
  926. mindspore/ops/_op_impl/akg/ascend/cast.py +46 -0
  927. mindspore/ops/_op_impl/akg/ascend/equal.py +35 -0
  928. mindspore/ops/_op_impl/akg/ascend/exp.py +35 -0
  929. mindspore/ops/_op_impl/akg/ascend/expand_dims.py +33 -0
  930. mindspore/ops/_op_impl/akg/ascend/greater.py +34 -0
  931. mindspore/ops/_op_impl/akg/ascend/greater_equal.py +35 -0
  932. mindspore/ops/_op_impl/akg/ascend/less.py +31 -0
  933. mindspore/ops/_op_impl/akg/ascend/less_equal.py +35 -0
  934. mindspore/ops/_op_impl/akg/ascend/load_im2col.py +33 -0
  935. mindspore/ops/_op_impl/akg/ascend/log.py +34 -0
  936. mindspore/ops/_op_impl/akg/ascend/maximum.py +36 -0
  937. mindspore/ops/_op_impl/akg/ascend/minimum.py +39 -0
  938. mindspore/ops/_op_impl/akg/ascend/mul.py +41 -0
  939. mindspore/ops/_op_impl/akg/ascend/neg.py +37 -0
  940. mindspore/ops/_op_impl/akg/ascend/pow.py +35 -0
  941. mindspore/ops/_op_impl/akg/ascend/prod_force_se_a.py +33 -0
  942. mindspore/ops/_op_impl/akg/ascend/real_div.py +36 -0
  943. mindspore/ops/_op_impl/akg/ascend/reciprocal.py +32 -0
  944. mindspore/ops/_op_impl/akg/ascend/reduce_max.py +32 -0
  945. mindspore/ops/_op_impl/akg/ascend/reduce_min.py +32 -0
  946. mindspore/ops/_op_impl/akg/ascend/reduce_sum.py +37 -0
  947. mindspore/ops/_op_impl/akg/ascend/rsqrt.py +35 -0
  948. mindspore/ops/_op_impl/akg/ascend/select.py +37 -0
  949. mindspore/ops/_op_impl/akg/ascend/sqrt.py +35 -0
  950. mindspore/ops/_op_impl/akg/ascend/square.py +35 -0
  951. mindspore/ops/_op_impl/akg/ascend/sub.py +42 -0
  952. mindspore/ops/_op_impl/akg/cpu/__init__.py +23 -0
  953. mindspore/ops/_op_impl/akg/cpu/coo2csr.py +29 -0
  954. mindspore/ops/_op_impl/akg/cpu/csr2coo.py +29 -0
  955. mindspore/ops/_op_impl/akg/cpu/csr_gather.py +33 -0
  956. mindspore/ops/_op_impl/akg/cpu/csr_mm.py +34 -0
  957. mindspore/ops/_op_impl/akg/cpu/csr_mul.py +33 -0
  958. mindspore/ops/_op_impl/akg/cpu/csr_mv.py +33 -0
  959. mindspore/ops/_op_impl/akg/cpu/csr_reduce_sum.py +31 -0
  960. mindspore/ops/_op_impl/akg/gpu/__init__.py +24 -0
  961. mindspore/ops/_op_impl/akg/gpu/coo2csr.py +29 -0
  962. mindspore/ops/_op_impl/akg/gpu/csr2coo.py +29 -0
  963. mindspore/ops/_op_impl/akg/gpu/csr_div.py +36 -0
  964. mindspore/ops/_op_impl/akg/gpu/csr_gather.py +33 -0
  965. mindspore/ops/_op_impl/akg/gpu/csr_mm.py +37 -0
  966. mindspore/ops/_op_impl/akg/gpu/csr_mul.py +36 -0
  967. mindspore/ops/_op_impl/akg/gpu/csr_mv.py +36 -0
  968. mindspore/ops/_op_impl/akg/gpu/csr_reduce_sum.py +33 -0
  969. mindspore/ops/_op_impl/cpu/__init__.py +78 -0
  970. mindspore/ops/_op_impl/cpu/adam.py +49 -0
  971. mindspore/ops/_op_impl/cpu/adam_weight_decay.py +47 -0
  972. mindspore/ops/_op_impl/cpu/arg_max.py +30 -0
  973. mindspore/ops/_op_impl/cpu/arg_max_with_value.py +31 -0
  974. mindspore/ops/_op_impl/cpu/arg_min_with_value.py +31 -0
  975. mindspore/ops/_op_impl/cpu/buffer_append.py +28 -0
  976. mindspore/ops/_op_impl/cpu/buffer_get.py +28 -0
  977. mindspore/ops/_op_impl/cpu/buffer_sample.py +28 -0
  978. mindspore/ops/_op_impl/cpu/cast.py +171 -0
  979. mindspore/ops/_op_impl/cpu/concat_offset.py +38 -0
  980. mindspore/ops/_op_impl/cpu/conv2d.py +30 -0
  981. mindspore/ops/_op_impl/cpu/conv3d.py +30 -0
  982. mindspore/ops/_op_impl/cpu/div.py +32 -0
  983. mindspore/ops/_op_impl/cpu/dropout.py +31 -0
  984. mindspore/ops/_op_impl/cpu/dropout_grad.py +30 -0
  985. mindspore/ops/_op_impl/cpu/dynamic_shape.py +42 -0
  986. mindspore/ops/_op_impl/cpu/dynamic_stitch.py +41 -0
  987. mindspore/ops/_op_impl/cpu/equal_count.py +30 -0
  988. mindspore/ops/_op_impl/cpu/gather_d.py +49 -0
  989. mindspore/ops/_op_impl/cpu/gather_d_grad.py +38 -0
  990. mindspore/ops/_op_impl/cpu/gather_d_grad_v2.py +40 -0
  991. mindspore/ops/_op_impl/cpu/gather_v2.py +40 -0
  992. mindspore/ops/_op_impl/cpu/hsigmoid.py +33 -0
  993. mindspore/ops/_op_impl/cpu/hsigmoid_grad.py +34 -0
  994. mindspore/ops/_op_impl/cpu/hswish.py +32 -0
  995. mindspore/ops/_op_impl/cpu/hswish_grad.py +33 -0
  996. mindspore/ops/_op_impl/cpu/identity_n.py +40 -0
  997. mindspore/ops/_op_impl/cpu/is_finite.py +39 -0
  998. mindspore/ops/_op_impl/cpu/l2loss.py +30 -0
  999. mindspore/ops/_op_impl/cpu/layer_norm.py +36 -0
  1000. mindspore/ops/_op_impl/cpu/layer_norm_grad.py +38 -0
  1001. mindspore/ops/_op_impl/cpu/maximum.py +35 -0
  1002. mindspore/ops/_op_impl/cpu/maximum_grad.py +47 -0
  1003. mindspore/ops/_op_impl/cpu/minimum.py +40 -0
  1004. mindspore/ops/_op_impl/cpu/minimum_grad.py +51 -0
  1005. mindspore/ops/_op_impl/cpu/mirror_pad.py +36 -0
  1006. mindspore/ops/_op_impl/cpu/mirror_pad_grad.py +36 -0
  1007. mindspore/ops/_op_impl/cpu/mul.py +32 -0
  1008. mindspore/ops/_op_impl/cpu/one_hot.py +31 -0
  1009. mindspore/ops/_op_impl/cpu/pad.py +32 -0
  1010. mindspore/ops/_op_impl/cpu/pow.py +32 -0
  1011. mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +42 -0
  1012. mindspore/ops/_op_impl/cpu/pyexecute.py +29 -0
  1013. mindspore/ops/_op_impl/cpu/pyfunc.py +29 -0
  1014. mindspore/ops/_op_impl/cpu/range.py +34 -0
  1015. mindspore/ops/_op_impl/cpu/real_div.py +33 -0
  1016. mindspore/ops/_op_impl/cpu/reduce_all.py +29 -0
  1017. mindspore/ops/_op_impl/cpu/reduce_any.py +29 -0
  1018. mindspore/ops/_op_impl/cpu/reduce_max.py +32 -0
  1019. mindspore/ops/_op_impl/cpu/reduce_mean.py +40 -0
  1020. mindspore/ops/_op_impl/cpu/reduce_min.py +32 -0
  1021. mindspore/ops/_op_impl/cpu/reduce_prod.py +40 -0
  1022. mindspore/ops/_op_impl/cpu/reduce_std.py +31 -0
  1023. mindspore/ops/_op_impl/cpu/reduce_sum.py +41 -0
  1024. mindspore/ops/_op_impl/cpu/space_to_batch_nd.py +38 -0
  1025. mindspore/ops/_op_impl/cpu/sparse_slice.py +62 -0
  1026. mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +60 -0
  1027. mindspore/ops/_op_impl/cpu/split.py +34 -0
  1028. mindspore/ops/_op_impl/cpu/sspaddmm.py +95 -0
  1029. mindspore/ops/_op_impl/cpu/stack.py +38 -0
  1030. mindspore/ops/_op_impl/cpu/sub.py +32 -0
  1031. mindspore/ops/_op_impl/cpu/tensor_copy_slices.py +41 -0
  1032. mindspore/ops/_op_impl/cpu/tile.py +37 -0
  1033. mindspore/ops/_op_impl/cpu/top_k.py +31 -0
  1034. mindspore/ops/_op_impl/cpu/transpose.py +39 -0
  1035. mindspore/ops/_primitive_cache.py +90 -0
  1036. mindspore/ops/_register_for_op.py +73 -0
  1037. mindspore/ops/_utils/__init__.py +20 -0
  1038. mindspore/ops/_utils/utils.py +147 -0
  1039. mindspore/ops/_vmap/__init__.py +25 -0
  1040. mindspore/ops/_vmap/vmap_array_ops.py +2151 -0
  1041. mindspore/ops/_vmap/vmap_base.py +533 -0
  1042. mindspore/ops/_vmap/vmap_convolution_ops.py +441 -0
  1043. mindspore/ops/_vmap/vmap_debug_ops.py +50 -0
  1044. mindspore/ops/_vmap/vmap_grad_math_ops.py +274 -0
  1045. mindspore/ops/_vmap/vmap_grad_nn_ops.py +806 -0
  1046. mindspore/ops/_vmap/vmap_image_ops.py +194 -0
  1047. mindspore/ops/_vmap/vmap_math_ops.py +977 -0
  1048. mindspore/ops/_vmap/vmap_nn_ops.py +2209 -0
  1049. mindspore/ops/_vmap/vmap_other_ops.py +105 -0
  1050. mindspore/ops/_vmap/vmap_random_ops.py +122 -0
  1051. mindspore/ops/_vmap/vmap_sparse_ops.py +89 -0
  1052. mindspore/ops/auto_generate/__init__.py +31 -0
  1053. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +231 -0
  1054. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +250 -0
  1055. mindspore/ops/auto_generate/gen_arg_handler.py +197 -0
  1056. mindspore/ops/auto_generate/gen_extend_func.py +980 -0
  1057. mindspore/ops/auto_generate/gen_ops_def.py +6443 -0
  1058. mindspore/ops/auto_generate/gen_ops_prim.py +13167 -0
  1059. mindspore/ops/auto_generate/pyboost_inner_prim.py +429 -0
  1060. mindspore/ops/composite/__init__.py +71 -0
  1061. mindspore/ops/composite/base.py +1281 -0
  1062. mindspore/ops/composite/env_ops.py +41 -0
  1063. mindspore/ops/composite/math_ops.py +125 -0
  1064. mindspore/ops/composite/multitype_ops/__init__.py +77 -0
  1065. mindspore/ops/composite/multitype_ops/_compile_utils.py +1458 -0
  1066. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +897 -0
  1067. mindspore/ops/composite/multitype_ops/add_impl.py +606 -0
  1068. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +56 -0
  1069. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +56 -0
  1070. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +56 -0
  1071. mindspore/ops/composite/multitype_ops/div_impl.py +189 -0
  1072. mindspore/ops/composite/multitype_ops/equal_impl.py +335 -0
  1073. mindspore/ops/composite/multitype_ops/floordiv_impl.py +88 -0
  1074. mindspore/ops/composite/multitype_ops/getitem_impl.py +400 -0
  1075. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +109 -0
  1076. mindspore/ops/composite/multitype_ops/greater_impl.py +110 -0
  1077. mindspore/ops/composite/multitype_ops/in_impl.py +196 -0
  1078. mindspore/ops/composite/multitype_ops/left_shift_impl.py +37 -0
  1079. mindspore/ops/composite/multitype_ops/less_equal_impl.py +111 -0
  1080. mindspore/ops/composite/multitype_ops/less_impl.py +112 -0
  1081. mindspore/ops/composite/multitype_ops/logic_not_impl.py +113 -0
  1082. mindspore/ops/composite/multitype_ops/logical_and_impl.py +60 -0
  1083. mindspore/ops/composite/multitype_ops/logical_or_impl.py +61 -0
  1084. mindspore/ops/composite/multitype_ops/mod_impl.py +86 -0
  1085. mindspore/ops/composite/multitype_ops/mul_impl.py +294 -0
  1086. mindspore/ops/composite/multitype_ops/negative_impl.py +79 -0
  1087. mindspore/ops/composite/multitype_ops/not_equal_impl.py +290 -0
  1088. mindspore/ops/composite/multitype_ops/not_in_impl.py +196 -0
  1089. mindspore/ops/composite/multitype_ops/ones_like_impl.py +96 -0
  1090. mindspore/ops/composite/multitype_ops/pow_impl.py +87 -0
  1091. mindspore/ops/composite/multitype_ops/right_shift_impl.py +37 -0
  1092. mindspore/ops/composite/multitype_ops/setitem_impl.py +884 -0
  1093. mindspore/ops/composite/multitype_ops/sub_impl.py +116 -0
  1094. mindspore/ops/composite/multitype_ops/uadd_impl.py +29 -0
  1095. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +228 -0
  1096. mindspore/ops/deprecated.py +315 -0
  1097. mindspore/ops/extend/__init__.py +53 -0
  1098. mindspore/ops/extend/array_func.py +218 -0
  1099. mindspore/ops/extend/math_func.py +76 -0
  1100. mindspore/ops/extend/nn_func.py +308 -0
  1101. mindspore/ops/function/__init__.py +760 -0
  1102. mindspore/ops/function/array_func.py +6889 -0
  1103. mindspore/ops/function/clip_func.py +384 -0
  1104. mindspore/ops/function/debug_func.py +69 -0
  1105. mindspore/ops/function/fft_func.py +31 -0
  1106. mindspore/ops/function/grad/__init__.py +34 -0
  1107. mindspore/ops/function/grad/grad_func.py +1424 -0
  1108. mindspore/ops/function/image_func.py +292 -0
  1109. mindspore/ops/function/linalg_func.py +416 -0
  1110. mindspore/ops/function/math_func.py +11877 -0
  1111. mindspore/ops/function/nn_func.py +8175 -0
  1112. mindspore/ops/function/other_func.py +114 -0
  1113. mindspore/ops/function/parameter_func.py +134 -0
  1114. mindspore/ops/function/random_func.py +1539 -0
  1115. mindspore/ops/function/reshard_func.py +102 -0
  1116. mindspore/ops/function/sparse_func.py +884 -0
  1117. mindspore/ops/function/sparse_unary_func.py +2422 -0
  1118. mindspore/ops/function/spectral_func.py +150 -0
  1119. mindspore/ops/function/vmap_func.py +116 -0
  1120. mindspore/ops/functional.py +454 -0
  1121. mindspore/ops/op_info_register.py +1572 -0
  1122. mindspore/ops/operations/__init__.py +717 -0
  1123. mindspore/ops/operations/_csr_ops.py +403 -0
  1124. mindspore/ops/operations/_custom_grad.py +181 -0
  1125. mindspore/ops/operations/_embedding_cache_ops.py +307 -0
  1126. mindspore/ops/operations/_grad_ops.py +3052 -0
  1127. mindspore/ops/operations/_infer_ops.py +19 -0
  1128. mindspore/ops/operations/_inner_ops.py +2567 -0
  1129. mindspore/ops/operations/_map_tensor_ops.py +112 -0
  1130. mindspore/ops/operations/_ms_kernel.py +601 -0
  1131. mindspore/ops/operations/_ocr_ops.py +379 -0
  1132. mindspore/ops/operations/_opaque_predicate_registry.py +41 -0
  1133. mindspore/ops/operations/_pyfunc_registry.py +58 -0
  1134. mindspore/ops/operations/_quant_ops.py +1844 -0
  1135. mindspore/ops/operations/_rl_inner_ops.py +1231 -0
  1136. mindspore/ops/operations/_scalar_ops.py +106 -0
  1137. mindspore/ops/operations/_sequence_ops.py +1155 -0
  1138. mindspore/ops/operations/_sparse_grad_ops.py +56 -0
  1139. mindspore/ops/operations/_tensor_array.py +359 -0
  1140. mindspore/ops/operations/_thor_ops.py +807 -0
  1141. mindspore/ops/operations/array_ops.py +6258 -0
  1142. mindspore/ops/operations/comm_ops.py +1996 -0
  1143. mindspore/ops/operations/control_ops.py +127 -0
  1144. mindspore/ops/operations/custom_ops.py +1065 -0
  1145. mindspore/ops/operations/debug_ops.py +646 -0
  1146. mindspore/ops/operations/image_ops.py +1041 -0
  1147. mindspore/ops/operations/inner_ops.py +697 -0
  1148. mindspore/ops/operations/linalg_ops.py +95 -0
  1149. mindspore/ops/operations/manually_defined/__init__.py +24 -0
  1150. mindspore/ops/operations/manually_defined/_inner.py +61 -0
  1151. mindspore/ops/operations/manually_defined/ops_def.py +2016 -0
  1152. mindspore/ops/operations/math_ops.py +5306 -0
  1153. mindspore/ops/operations/nn_ops.py +9669 -0
  1154. mindspore/ops/operations/other_ops.py +871 -0
  1155. mindspore/ops/operations/random_ops.py +1243 -0
  1156. mindspore/ops/operations/reshard_ops.py +53 -0
  1157. mindspore/ops/operations/rl_ops.py +288 -0
  1158. mindspore/ops/operations/sparse_ops.py +2753 -0
  1159. mindspore/ops/operations/spectral_ops.py +111 -0
  1160. mindspore/ops/primitive.py +1034 -0
  1161. mindspore/ops/signature.py +54 -0
  1162. mindspore/ops/silent_check.py +162 -0
  1163. mindspore/ops/vm_impl_registry.py +91 -0
  1164. mindspore/ops_generate/__init__.py +27 -0
  1165. mindspore/ops_generate/arg_dtype_cast.py +250 -0
  1166. mindspore/ops_generate/arg_handler.py +197 -0
  1167. mindspore/ops_generate/gen_aclnn_implement.py +263 -0
  1168. mindspore/ops_generate/gen_ops.py +1084 -0
  1169. mindspore/ops_generate/gen_ops_inner_prim.py +131 -0
  1170. mindspore/ops_generate/gen_pyboost_func.py +968 -0
  1171. mindspore/ops_generate/gen_utils.py +209 -0
  1172. mindspore/ops_generate/op_proto.py +138 -0
  1173. mindspore/ops_generate/pyboost_utils.py +354 -0
  1174. mindspore/ops_generate/template.py +239 -0
  1175. mindspore/parallel/__init__.py +28 -0
  1176. mindspore/parallel/_auto_parallel_context.py +1466 -0
  1177. mindspore/parallel/_cell_wrapper.py +91 -0
  1178. mindspore/parallel/_cost_model_context.py +700 -0
  1179. mindspore/parallel/_dp_allreduce_fusion.py +159 -0
  1180. mindspore/parallel/_offload_context.py +275 -0
  1181. mindspore/parallel/_parallel_serialization.py +533 -0
  1182. mindspore/parallel/_ps_context.py +242 -0
  1183. mindspore/parallel/_recovery_context.py +110 -0
  1184. mindspore/parallel/_tensor.py +660 -0
  1185. mindspore/parallel/_transformer/__init__.py +35 -0
  1186. mindspore/parallel/_transformer/layers.py +765 -0
  1187. mindspore/parallel/_transformer/loss.py +251 -0
  1188. mindspore/parallel/_transformer/moe.py +693 -0
  1189. mindspore/parallel/_transformer/op_parallel_config.py +222 -0
  1190. mindspore/parallel/_transformer/transformer.py +3119 -0
  1191. mindspore/parallel/_utils.py +600 -0
  1192. mindspore/parallel/algo_parameter_config.py +400 -0
  1193. mindspore/parallel/checkpoint_transform.py +643 -0
  1194. mindspore/parallel/cluster/__init__.py +15 -0
  1195. mindspore/parallel/cluster/process_entity/__init__.py +18 -0
  1196. mindspore/parallel/cluster/process_entity/_api.py +344 -0
  1197. mindspore/parallel/cluster/process_entity/_utils.py +126 -0
  1198. mindspore/parallel/cluster/run.py +136 -0
  1199. mindspore/parallel/mpi/__init__.py +14 -0
  1200. mindspore/parallel/mpi/_mpi_config.py +116 -0
  1201. mindspore/parallel/parameter_broadcast.py +152 -0
  1202. mindspore/parallel/shard.py +350 -0
  1203. mindspore/perf_msvcbuildinsights.dll +0 -0
  1204. mindspore/pgodb140.dll +0 -0
  1205. mindspore/pgort140.dll +0 -0
  1206. mindspore/profiler/__init__.py +27 -0
  1207. mindspore/profiler/common/__init__.py +14 -0
  1208. mindspore/profiler/common/exceptions/__init__.py +14 -0
  1209. mindspore/profiler/common/exceptions/error_code.py +83 -0
  1210. mindspore/profiler/common/exceptions/exceptions.py +286 -0
  1211. mindspore/profiler/common/process_pool.py +41 -0
  1212. mindspore/profiler/common/singleton.py +28 -0
  1213. mindspore/profiler/common/struct_type.py +118 -0
  1214. mindspore/profiler/common/util.py +444 -0
  1215. mindspore/profiler/common/validator/__init__.py +14 -0
  1216. mindspore/profiler/common/validator/validate_path.py +84 -0
  1217. mindspore/profiler/envprofiling.py +256 -0
  1218. mindspore/profiler/parser/__init__.py +14 -0
  1219. mindspore/profiler/parser/aicpu_data_parser.py +272 -0
  1220. mindspore/profiler/parser/ascend_analysis/__init__.py +14 -0
  1221. mindspore/profiler/parser/ascend_analysis/constant.py +53 -0
  1222. mindspore/profiler/parser/ascend_analysis/file_manager.py +159 -0
  1223. mindspore/profiler/parser/ascend_analysis/function_event.py +161 -0
  1224. mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +131 -0
  1225. mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +85 -0
  1226. mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +57 -0
  1227. mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +116 -0
  1228. mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +86 -0
  1229. mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +68 -0
  1230. mindspore/profiler/parser/ascend_cluster_generator.py +116 -0
  1231. mindspore/profiler/parser/ascend_communicate_generator.py +314 -0
  1232. mindspore/profiler/parser/ascend_flops_generator.py +116 -0
  1233. mindspore/profiler/parser/ascend_fpbp_generator.py +82 -0
  1234. mindspore/profiler/parser/ascend_hccl_generator.py +271 -0
  1235. mindspore/profiler/parser/ascend_integrate_generator.py +42 -0
  1236. mindspore/profiler/parser/ascend_memory_generator.py +185 -0
  1237. mindspore/profiler/parser/ascend_msprof_exporter.py +281 -0
  1238. mindspore/profiler/parser/ascend_msprof_generator.py +187 -0
  1239. mindspore/profiler/parser/ascend_op_generator.py +334 -0
  1240. mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
  1241. mindspore/profiler/parser/ascend_timeline_generator.py +543 -0
  1242. mindspore/profiler/parser/base_timeline_generator.py +489 -0
  1243. mindspore/profiler/parser/container.py +229 -0
  1244. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +684 -0
  1245. mindspore/profiler/parser/flops_parser.py +531 -0
  1246. mindspore/profiler/parser/framework_enum.py +111 -0
  1247. mindspore/profiler/parser/framework_parser.py +854 -0
  1248. mindspore/profiler/parser/framework_struct.py +61 -0
  1249. mindspore/profiler/parser/hccl_parser.py +573 -0
  1250. mindspore/profiler/parser/hwts_log_parser.py +122 -0
  1251. mindspore/profiler/parser/integrator.py +526 -0
  1252. mindspore/profiler/parser/memory_usage_parser.py +431 -0
  1253. mindspore/profiler/parser/minddata_analyzer.py +800 -0
  1254. mindspore/profiler/parser/minddata_parser.py +186 -0
  1255. mindspore/profiler/parser/minddata_pipeline_parser.py +299 -0
  1256. mindspore/profiler/parser/msadvisor_analyzer.py +82 -0
  1257. mindspore/profiler/parser/msadvisor_parser.py +240 -0
  1258. mindspore/profiler/parser/op_intermediate_parser.py +149 -0
  1259. mindspore/profiler/parser/optime_parser.py +250 -0
  1260. mindspore/profiler/parser/profiler_info.py +141 -0
  1261. mindspore/profiler/parser/step_trace_parser.py +666 -0
  1262. mindspore/profiler/profiling.py +2054 -0
  1263. mindspore/rewrite/__init__.py +29 -0
  1264. mindspore/rewrite/api/__init__.py +17 -0
  1265. mindspore/rewrite/api/node.py +519 -0
  1266. mindspore/rewrite/api/node_type.py +53 -0
  1267. mindspore/rewrite/api/pattern_engine.py +490 -0
  1268. mindspore/rewrite/api/scoped_value.py +181 -0
  1269. mindspore/rewrite/api/symbol_tree.py +497 -0
  1270. mindspore/rewrite/ast_helpers/__init__.py +25 -0
  1271. mindspore/rewrite/ast_helpers/ast_converter.py +143 -0
  1272. mindspore/rewrite/ast_helpers/ast_finder.py +404 -0
  1273. mindspore/rewrite/ast_helpers/ast_flattener.py +268 -0
  1274. mindspore/rewrite/ast_helpers/ast_modifier.py +605 -0
  1275. mindspore/rewrite/ast_helpers/ast_replacer.py +79 -0
  1276. mindspore/rewrite/common/__init__.py +19 -0
  1277. mindspore/rewrite/common/config.py +24 -0
  1278. mindspore/rewrite/common/error_log.py +39 -0
  1279. mindspore/rewrite/common/event.py +28 -0
  1280. mindspore/rewrite/common/namer.py +271 -0
  1281. mindspore/rewrite/common/namespace.py +118 -0
  1282. mindspore/rewrite/common/observable.py +44 -0
  1283. mindspore/rewrite/common/observer.py +54 -0
  1284. mindspore/rewrite/node/__init__.py +22 -0
  1285. mindspore/rewrite/node/call_function.py +95 -0
  1286. mindspore/rewrite/node/cell_container.py +139 -0
  1287. mindspore/rewrite/node/control_flow.py +113 -0
  1288. mindspore/rewrite/node/node.py +1428 -0
  1289. mindspore/rewrite/node/node_manager.py +283 -0
  1290. mindspore/rewrite/node/node_topological_manager.py +223 -0
  1291. mindspore/rewrite/parsers/__init__.py +29 -0
  1292. mindspore/rewrite/parsers/arguments_parser.py +63 -0
  1293. mindspore/rewrite/parsers/assign_parser.py +852 -0
  1294. mindspore/rewrite/parsers/attribute_parser.py +57 -0
  1295. mindspore/rewrite/parsers/class_def_parser.py +289 -0
  1296. mindspore/rewrite/parsers/constant_parser.py +104 -0
  1297. mindspore/rewrite/parsers/container_parser.py +88 -0
  1298. mindspore/rewrite/parsers/expr_parser.py +55 -0
  1299. mindspore/rewrite/parsers/for_parser.py +61 -0
  1300. mindspore/rewrite/parsers/function_def_parser.py +84 -0
  1301. mindspore/rewrite/parsers/if_parser.py +85 -0
  1302. mindspore/rewrite/parsers/module_parser.py +117 -0
  1303. mindspore/rewrite/parsers/parser.py +43 -0
  1304. mindspore/rewrite/parsers/parser_register.py +86 -0
  1305. mindspore/rewrite/parsers/return_parser.py +37 -0
  1306. mindspore/rewrite/parsers/while_parser.py +59 -0
  1307. mindspore/rewrite/sparsify/__init__.py +0 -0
  1308. mindspore/rewrite/sparsify/sparse_transformer.py +457 -0
  1309. mindspore/rewrite/sparsify/sparsify.py +112 -0
  1310. mindspore/rewrite/sparsify/utils.py +179 -0
  1311. mindspore/rewrite/symbol_tree/__init__.py +20 -0
  1312. mindspore/rewrite/symbol_tree/symbol_tree.py +1819 -0
  1313. mindspore/rewrite/symbol_tree/symbol_tree_builder.py +76 -0
  1314. mindspore/rewrite/symbol_tree/symbol_tree_dumper.py +142 -0
  1315. mindspore/run_check/__init__.py +20 -0
  1316. mindspore/run_check/_check_version.py +574 -0
  1317. mindspore/run_check/run_check.py +66 -0
  1318. mindspore/safeguard/__init__.py +18 -0
  1319. mindspore/safeguard/rewrite_obfuscation.py +531 -0
  1320. mindspore/swresample-4.dll +0 -0
  1321. mindspore/swscale-6.dll +0 -0
  1322. mindspore/tbbmalloc.dll +0 -0
  1323. mindspore/tinyxml2.dll +0 -0
  1324. mindspore/train/__init__.py +47 -0
  1325. mindspore/train/_utils.py +439 -0
  1326. mindspore/train/amp.py +817 -0
  1327. mindspore/train/anf_ir_pb2.py +1517 -0
  1328. mindspore/train/callback/__init__.py +44 -0
  1329. mindspore/train/callback/_backup_and_restore.py +117 -0
  1330. mindspore/train/callback/_callback.py +613 -0
  1331. mindspore/train/callback/_checkpoint.py +751 -0
  1332. mindspore/train/callback/_cluster_monitor.py +201 -0
  1333. mindspore/train/callback/_dataset_graph.py +150 -0
  1334. mindspore/train/callback/_early_stop.py +239 -0
  1335. mindspore/train/callback/_flops_collector.py +238 -0
  1336. mindspore/train/callback/_history.py +92 -0
  1337. mindspore/train/callback/_lambda_callback.py +80 -0
  1338. mindspore/train/callback/_landscape.py +1049 -0
  1339. mindspore/train/callback/_loss_monitor.py +107 -0
  1340. mindspore/train/callback/_lr_scheduler_callback.py +76 -0
  1341. mindspore/train/callback/_mindio_ttp.py +443 -0
  1342. mindspore/train/callback/_on_request_exit.py +195 -0
  1343. mindspore/train/callback/_reduce_lr_on_plateau.py +226 -0
  1344. mindspore/train/callback/_summary_collector.py +1184 -0
  1345. mindspore/train/callback/_time_monitor.py +141 -0
  1346. mindspore/train/checkpoint_pb2.py +233 -0
  1347. mindspore/train/data_sink.py +219 -0
  1348. mindspore/train/dataset_helper.py +688 -0
  1349. mindspore/train/lineage_pb2.py +1260 -0
  1350. mindspore/train/loss_scale_manager.py +213 -0
  1351. mindspore/train/memory_profiling_pb2.py +298 -0
  1352. mindspore/train/metrics/__init__.py +175 -0
  1353. mindspore/train/metrics/accuracy.py +133 -0
  1354. mindspore/train/metrics/auc.py +129 -0
  1355. mindspore/train/metrics/bleu_score.py +170 -0
  1356. mindspore/train/metrics/confusion_matrix.py +700 -0
  1357. mindspore/train/metrics/cosine_similarity.py +109 -0
  1358. mindspore/train/metrics/dice.py +116 -0
  1359. mindspore/train/metrics/error.py +175 -0
  1360. mindspore/train/metrics/fbeta.py +167 -0
  1361. mindspore/train/metrics/hausdorff_distance.py +333 -0
  1362. mindspore/train/metrics/loss.py +97 -0
  1363. mindspore/train/metrics/mean_surface_distance.py +189 -0
  1364. mindspore/train/metrics/metric.py +373 -0
  1365. mindspore/train/metrics/occlusion_sensitivity.py +225 -0
  1366. mindspore/train/metrics/perplexity.py +133 -0
  1367. mindspore/train/metrics/precision.py +160 -0
  1368. mindspore/train/metrics/recall.py +159 -0
  1369. mindspore/train/metrics/roc.py +223 -0
  1370. mindspore/train/metrics/root_mean_square_surface_distance.py +191 -0
  1371. mindspore/train/metrics/topk.py +167 -0
  1372. mindspore/train/mind_ir_pb2.py +1903 -0
  1373. mindspore/train/model.py +2176 -0
  1374. mindspore/train/node_strategy_pb2.py +653 -0
  1375. mindspore/train/print_pb2.py +184 -0
  1376. mindspore/train/profiling_parallel_pb2.py +151 -0
  1377. mindspore/train/serialization.py +3101 -0
  1378. mindspore/train/summary/__init__.py +23 -0
  1379. mindspore/train/summary/_lineage_adapter.py +41 -0
  1380. mindspore/train/summary/_summary_adapter.py +496 -0
  1381. mindspore/train/summary/_writer_pool.py +207 -0
  1382. mindspore/train/summary/enums.py +56 -0
  1383. mindspore/train/summary/summary_record.py +581 -0
  1384. mindspore/train/summary/writer.py +167 -0
  1385. mindspore/train/summary_pb2.py +1165 -0
  1386. mindspore/train/train_thor/__init__.py +20 -0
  1387. mindspore/train/train_thor/convert_utils.py +268 -0
  1388. mindspore/train/train_thor/dataset_helper.py +192 -0
  1389. mindspore/train/train_thor/model_thor.py +257 -0
  1390. mindspore/turbojpeg.dll +0 -0
  1391. mindspore/vcmeta.dll +0 -0
  1392. mindspore/vcomp140.dll +0 -0
  1393. mindspore/vcruntime140.dll +0 -0
  1394. mindspore/vcruntime140_1.dll +0 -0
  1395. mindspore/version.py +1 -0
  1396. mindspore-2.3.0.dist-info/METADATA +351 -0
  1397. mindspore-2.3.0.dist-info/RECORD +1400 -0
  1398. mindspore-2.3.0.dist-info/WHEEL +5 -0
  1399. mindspore-2.3.0.dist-info/entry_points.txt +4 -0
  1400. mindspore-2.3.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,2567 @@
1
+ # Copyright 2020-2022 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+
16
+ """Inner operators."""
17
+ from types import FunctionType, MethodType
18
+ from collections.abc import Iterable
19
+ import os
20
+ import numpy as np
21
+
22
+ from mindspore.common import Tensor
23
+ from mindspore.common._stub_tensor import StubTensor
24
+ from mindspore.ops import composite as C
25
+ from mindspore.ops.operations.array_ops import Cast
26
+ from mindspore.ops.operations._scalar_ops import bit_or, bit_and
27
+ from mindspore.ops import signature as sig
28
+ from mindspore.ops.operations.math_ops import _infer_shape_reduce
29
+ from mindspore.ops.primitive import PrimitiveWithCheck, PrimitiveWithInfer, prim_attr_register, Primitive, \
30
+ _run_op, _check_contains_variable
31
+ from mindspore._c_expression import Tensor as Tensor_
32
+ from mindspore._c_expression import typing
33
+ from mindspore import _checkparam as validator
34
+ from mindspore.common import dtype as mstype
35
+ from mindspore.common.parameter import Parameter
36
+ from mindspore.communication.management import GlobalComm, get_rank, _get_group, get_group_size
37
+ from mindspore.common.api import _pynative_executor
38
+ from mindspore.common._register_for_adapter import ms_adapter_registry
39
+ from mindspore import ops
40
+ from ..auto_generate import TensorCopySlices, SiLU, Cummin, TopKRouter, ExtractImagePatches, DecoderKVCache, \
41
+ PromptKVCache, ApplyCamePart1, ApplyCamePart2, ApplyCamePart3, ApplyCamePart4
42
+
43
+ # Bit operation
44
+ bit_and = bit_and()
45
+ bit_or = bit_or()
46
+ bit_xor = Primitive("bit_xor")
47
+ bit_left_shift = Primitive("bit_left_shift")
48
+ bit_right_shift = Primitive("bit_right_shift")
49
+ # String operation
50
+ string_lt = Primitive("string_lt")
51
+ string_gt = Primitive("string_gt")
52
+ string_le = Primitive("string_le")
53
+ string_ge = Primitive("string_ge")
54
+ string_not = Primitive("string_not")
55
+ string_in = Primitive("string_in")
56
+ string_mul = Primitive("string_mul")
57
+ string_getitem = Primitive("string_getitem")
58
+
59
+
60
+ class Generator(Primitive):
61
+ r"""
62
+ Manage the state of random number generation.
63
+
64
+ Inputs:
65
+ - **cmd** (int) : operation to be executed.
66
+ - **inputs** (tuple[tensor]) : inputs for the operation.
67
+
68
+ Outputs:
69
+ - **seed** (Tensor): Seed for the random number generation algorithm.
70
+ - **offset** (Tensor): Offset of the random number sequence.
71
+ - **state** (Tensor): State tensor, can be used to restore current state.
72
+ """
73
+
74
+ @prim_attr_register
75
+ def __init__(self):
76
+ self.add_prim_attr("side_effect_mem", True)
77
+
78
+ def __call__(self, cmd, inputs):
79
+ if cmd == 0: # step cmd
80
+ return inputs[0], inputs[1]
81
+ return super().__call__(cmd, inputs)
82
+
83
+
84
+ class Quant(PrimitiveWithInfer):
85
+ r"""
86
+ Returns the quantized value of input_x.
87
+
88
+ If `sqrt_mode` is False:
89
+
90
+ .. math::
91
+ y = round(scale * x + offset)
92
+
93
+ If `sqrt_mode` is True:
94
+
95
+ .. math::
96
+ y = round(scale * x * scale + offset)
97
+
98
+ Note:
99
+ This operation only support Atlas 200/300/500 inference product.
100
+
101
+ Args:
102
+ scale (float) : Specifies the scaling ratio.
103
+ offset (float): Specifies the offset.
104
+ sqrt_mode (bool) : Specifies whether to perform square root on `scale`. Default: ``False``.
105
+ round_mode (str): Specifies the way to round. Must be one of ["Round", "Floor", "Ceil", "Trunc"].
106
+ Default: "Round".
107
+
108
+ Inputs:
109
+ - **input_x** (Tensor) : Input tensor. Its data type must be mindspore.float16 or mindspore.float32.
110
+
111
+ Outputs:
112
+ - Tensor: The quantized output tensor of type mindspore.int8.
113
+
114
+ Examples:
115
+ >>> input_x = Tensor([100.0, 150.0], mstype.float32)
116
+ >>> quant = ops.Quant(80.0, 0.0, False, "Round")
117
+ >>> y = quant(input_x)
118
+ """
119
+
120
+ @prim_attr_register
121
+ def __init__(self, scale, offset, sqrt_mode=False, round_mode="Round"):
122
+ self.scale = validator.check_value_type("scale", scale, [float], self.name)
123
+ self.offset = validator.check_value_type("offset", offset, [float], self.name)
124
+ self.sqrt_mode = validator.check_value_type("sqrt_mode", sqrt_mode, [bool], self.name)
125
+ self.round_mode = validator.check_string(round_mode, ["Round", "Floor", "Ceil", "Trunc"],
126
+ "round_mode", self.name)
127
+ self.add_prim_attr("dst_type", mstype.int8)
128
+
129
+ def infer_shape(self, x_shape):
130
+ return x_shape
131
+
132
+ def infer_dtype(self, x_type):
133
+ validator.check_subclass("input_x", x_type, mstype.tensor_type, self.name)
134
+ validator.check_type_name("input_x", x_type, [mstype.float16, mstype.float32], self.name)
135
+ return self.get_attr_dict()['dst_type']
136
+
137
+
138
+ class Lamb(PrimitiveWithInfer):
139
+ r"""
140
+ LAMB optimizer algorithm.
141
+
142
+ The Lamb optimizer is proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes
143
+ <https://arxiv.org/abs/1904.00962>`_.
144
+
145
+ Inputs:
146
+ - **var** (Tensor) - Weights to be updated. The shape is :math:`(N, *)` where :math:`*` means,
147
+ any number of additional dimensions. The data type can be float16 or float32.
148
+ - **m** (Tensor) - The 1st moment vector in the updating formula,
149
+ the shape and data type value should be the same as `var`.
150
+ - **v** (Tensor) - the 2nd moment vector in the updating formula,
151
+ the shape and data type value should be the same as `var`. Mean square gradients with the same type as `var`.
152
+ - **lr** (float) - :math:`l` in the updating formula. The paper suggested value is :math:`10^{-8}`,
153
+ the data type value should be the same as `var`.
154
+ - **beta1** (float) - The exponential decay rate for the 1st moment estimations,
155
+ the data type value should be the same as `var`. The paper suggested value is :math:`0.9`
156
+ - **beta2** (float) - The exponential decay rate for the 2nd moment estimations,
157
+ the data type value should be the same as `var`. The paper suggested value is :math:`0.999`
158
+ - **epsilon** (float) - Term added to the denominator to improve numerical stability.
159
+ - **decay** (float) - The weight decay value, must be a scalar tensor with float data type.
160
+ Default: 0.0.
161
+ - **global_step** (Tensor) - Tensor to record current global step.
162
+ - **gradient** (Tensor) - Gradient, has the same shape and data type as `var`.
163
+
164
+ Outputs:
165
+ Tensor, the updated parameters.
166
+
167
+ - **var** (Tensor) - The same shape and data type as `var`.
168
+
169
+ Supported Platforms:
170
+ ``Ascend````GPU``
171
+ """
172
+
173
+ @prim_attr_register
174
+ def __init__(self):
175
+ """Initialize Lamb."""
176
+ self.add_prim_attr('side_effect_mem', True)
177
+
178
+ def infer_shape(self, var_shape, m_shape, v_shape, lr_shape, beta1_shape, beta2_shape,
179
+ epsilon_shape, decay_shape, global_step_shape, gradient_shape):
180
+ validator.check("var_shape", var_shape, "m_shape", m_shape, validator.EQ, self.name)
181
+ validator.check("var_shape", var_shape, "v_shape", v_shape, validator.EQ, self.name)
182
+ validator.check("var_shape", var_shape, "gradient_shape", gradient_shape, validator.EQ, self.name)
183
+ return var_shape
184
+
185
+ def infer_dtype(self, var_dtype, m_dtype, v_dtype, lr_dtype, beta1_dtype, beta2_dtype,
186
+ epsilon_dtype, decay_dtype, global_step_dtype, gradient_dtype):
187
+ args = {"var": var_dtype, "m": m_dtype, "v": v_dtype, "grad": gradient_dtype}
188
+ validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name)
189
+
190
+ args = {"lr": lr_dtype, "decay": decay_dtype, "beta1": beta1_dtype, "beta2": beta2_dtype,
191
+ "epsilon": epsilon_dtype}
192
+ validator.check_scalar_or_tensor_types_same(args, [mstype.float32], self.name, True)
193
+ return var_dtype
194
+
195
+
196
+ class Dequant(PrimitiveWithInfer):
197
+ r"""
198
+ Returns the dequantized value of input_x.
199
+ This operation will do ReLU to the dequantized value if `relu_flag` is True.
200
+
201
+ If `sqrt_mode` is False:
202
+
203
+ .. math::
204
+ y = x * deq\_scale
205
+
206
+ If `sqrt_mode` is True:
207
+
208
+ .. math::
209
+ y = x * deq\_scale * deq\_scale
210
+
211
+ Note:
212
+ This operation only support Atlas 200/300/500 inference product.
213
+
214
+ Args:
215
+ sqrt_mode (bool) : Specifies whether to perform square root on `scale`. Default: ``False``.
216
+ relu_flag (bool): Specifies whether to perform ReLU. Default: ``False``.
217
+
218
+ Inputs:
219
+ - **input_x** (Tensor) : Input tensor. Must be mindspore.int32.
220
+ - **deq_scale** (Tensor) : Specifies the scaling ratio.
221
+ Data type must be mindspore.float16 or mindspore.uint64
222
+
223
+ Outputs:
224
+ - Tensor: The quantized output tensor of type mindspore.float16.
225
+
226
+ Examples:
227
+ >>> input_x = Tensor([100.0, 150.0], mstype.float32)
228
+ >>> dequant = ops.Dequant(False, False)
229
+ >>> y = dequant(input_x)
230
+ """
231
+
232
+ @prim_attr_register
233
+ def __init__(self, sqrt_mode=False, relu_flag=False, dtype=mstype.float16):
234
+ self.sqrt_mode = validator.check_value_type("sqrt_mode", sqrt_mode, [bool], self.name)
235
+ self.relu_flag = validator.check_value_type("relu_flag", relu_flag, [bool], self.name)
236
+ self.dtype = dtype
237
+
238
+ def infer_shape(self, x_shape, deq_scale_shape):
239
+ return x_shape
240
+
241
+ def infer_dtype(self, x_type, deq_scale_type):
242
+ validator.check_subclass("x", x_type, mstype.tensor_type, self.name)
243
+ validator.check_type_name("x", x_type, [mstype.int32], self.name)
244
+ validator.check_type_name("deq_scale", deq_scale_type, [mstype.float16, mstype.uint64], self.name)
245
+ return mstype.float16
246
+
247
+
248
+ class AntiQuant(Primitive):
249
+ r"""
250
+ Returns the antiquantized value of input_x.
251
+
252
+ If `sqrt_mode` is False:
253
+
254
+ .. math::
255
+ y = scale * (x + offset)
256
+
257
+ If `sqrt_mode` is True:
258
+
259
+ .. math::
260
+ y = scale * scale * (x + offset)
261
+
262
+ Note:
263
+ This operation only support Atlas 200/300/500 inference product.
264
+
265
+ Args:
266
+ scale (float) : Specifies the scaling ratio.
267
+ offset (float): Specifies the offset.
268
+ sqrt_mode (bool) : Specifies whether to perform square root on `scale`. Default: ``False``.
269
+
270
+ Inputs:
271
+ - **input_x** (Tensor) : Input tensor. Must be mindspore.int8.
272
+
273
+ Outputs:
274
+ - Tensor: The antiquantized output tensor of type mindspore.float32.
275
+
276
+ Examples:
277
+ >>> from mindspore.ops.operations._inner_ops import AntiQuant
278
+ >>> input_x = Tensor([50.0, 20.0], mstype.int8)
279
+ >>> antiquant = AntiQuant(2.0, 1.0, False)
280
+ >>> y = antiquant(input_x)
281
+ >>> print(y)
282
+ [102. 42.]
283
+ """
284
+
285
+ @prim_attr_register
286
+ def __init__(self, sqrt_mode=False, dtype=mstype.float16):
287
+ super().__init__("AntiQuant")
288
+ self.sqrt_mode = validator.check_value_type("sqrt_mode", sqrt_mode, [bool], self.name)
289
+ self.dtype = dtype
290
+
291
+ self.init_prim_io_names(inputs=['x', 'scale', 'offset'],
292
+ outputs=['y'])
293
+
294
+
295
+ class MatrixDiag(PrimitiveWithInfer):
296
+ """
297
+ Returns a batched diagonal tensor with a given batched diagonal values.
298
+
299
+ Inputs:
300
+ - **x** (Tensor) - A tensor which to be element-wise multi by `assist`. It can be one of the following data
301
+ types: float32, float16, int32, int8, and uint8.
302
+ - **assist** (Tensor) - A eye tensor of the same type as `x`. It's rank must be greater than or equal to 2 and
303
+ it's last dimension must be equal to the second to last dimension.
304
+
305
+ Outputs:
306
+ Tensor, has the same type and shape as input `assist`.
307
+
308
+ Examples:
309
+ >>> x = Tensor(np.array([1, -1]), mstype.float32)
310
+ >>> assist = Tensor(np.arange(-12, 0).reshape(3, 2, 2), mindspore.float32)
311
+ >>> matrix_diag = ops.MatrixDiag()
312
+ >>> result = matrix_diag(x, assist)
313
+ >>> print(result)
314
+ [[[-12. 11.]
315
+ [-10. 9.]]
316
+ [[ -8. 7.]
317
+ [ -6. 5.]]
318
+ [[ -4. 3.]
319
+ [ -2. 1.]]]
320
+ """
321
+
322
+ @prim_attr_register
323
+ def __init__(self):
324
+ """Initialize MatrixDiag"""
325
+
326
+ def infer_dtype(self, x_dtype, assist_dtype):
327
+ valid_type = [mstype.float16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8]
328
+ args = {"x": x_dtype, "assist": assist_dtype}
329
+ validator.check_tensors_dtypes_same_and_valid(args, valid_type, self.name)
330
+ return x_dtype
331
+
332
+ def infer_shape(self, x_shape, assist_shape):
333
+ validator.check_int(len(assist_shape), 2, validator.GE, "assist rank", self.name)
334
+ validator.check('rank of x', len(x_shape) + 1,
335
+ 'rank of assist', len(assist_shape), validator.LE, self.name)
336
+ validator.check('assist\'s penultimate dimension', assist_shape[-2], 'assist\'s last dimension',
337
+ assist_shape[-1], validator.EQ, self.name)
338
+
339
+ r_end_dim = -len(x_shape)
340
+ r_idx = -1
341
+ while r_idx >= r_end_dim:
342
+ if x_shape[r_idx] != 1:
343
+ validator.check("reverse x dim %d" % r_idx, x_shape[r_idx], "reverse assist dim %d" %
344
+ assist_shape[r_idx - 1], assist_shape[r_idx - 1], validator.EQ, self.name)
345
+ r_idx = r_idx - 1
346
+
347
+ return assist_shape
348
+
349
+
350
+ class MatrixDiagPart(PrimitiveWithInfer):
351
+ r"""
352
+ Returns the batched diagonal part of a batched tensor.
353
+
354
+ Inputs:
355
+ - **x** (Tensor) - The batched tensor. It can be one of the following data types:
356
+ float32, float16, int32, int8, uint8.
357
+ - **assist** (Tensor) - A eye tensor of the same type as `x`. With shape same as `x`.
358
+
359
+ Outputs:
360
+ Tensor, data type same as input `x`. The shape must be x.shape[:-2] + [min(x.shape[-2:])].
361
+
362
+ Examples:
363
+ >>> x = Tensor([[[-1, 0], [0, 1]], [[-1, 0], [0, 1]], [[-1, 0], [0, 1]]], mindspore.float32)
364
+ >>> assist = Tensor(np.arange(-12, 0).reshape(3, 2, 2), mindspore.float32)
365
+ >>> matrix_diag_part = ops.MatrixDiagPart()
366
+ >>> result = matrix_diag_part(x, assist)
367
+ >>> print(result)
368
+ [[12., -9.], [8., -5.], [4., -1.]]
369
+ """
370
+
371
+ @prim_attr_register
372
+ def __init__(self):
373
+ """Initialize MatrixDiagPart"""
374
+
375
+ def infer_dtype(self, x_dtype, assist_dtype):
376
+ valid_type = [mstype.float16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8]
377
+ args = {"x": x_dtype, "assist": assist_dtype}
378
+ validator.check_tensors_dtypes_same_and_valid(args, valid_type, self.name)
379
+ return x_dtype
380
+
381
+ def infer_shape(self, x_shape, assist_shape):
382
+ validator.check_int(len(x_shape), 2, validator.GE, "x rank", self.name)
383
+ validator.check("x shape", x_shape, "assist shape", assist_shape, validator.EQ, self.name)
384
+
385
+ if assist_shape[-2] < assist_shape[-1]:
386
+ out_shape = assist_shape[:-1]
387
+ else:
388
+ out_shape = assist_shape[:-2] + assist_shape[-1:]
389
+ return out_shape
390
+
391
+
392
+ class MatrixSetDiag(PrimitiveWithInfer):
393
+ r"""
394
+ Modifies the batched diagonal part of a batched tensor.
395
+
396
+ Inputs:
397
+ - **x** (Tensor) - The batched tensor. Rank k+1, where k >= 1. It can be one of the following data types:
398
+ float32, float16, int32, int8, uint8.
399
+ - **diagonal** (Tensor) - The diagonal values. Must have the same type as input `x`. Rank k, where k >= 1.
400
+ - **assist** (Tensor) - A eye tensor of the same type as `x`. With shape same as `x`.
401
+
402
+ Outputs:
403
+ Tensor, data type same as input `x`. The shape same as `x`.
404
+
405
+ Examples:
406
+ >>> x = Tensor([[[-1, 0], [0, 1]], [[-1, 0], [0, 1]], [[-1, 0], [0, 1]]], mindspore.float32)
407
+ >>> diagonal = Tensor([[-1., 2.], [-1., 1.], [-1., 1.]], mindspore.float32)
408
+ >>> matrix_set_diag = ops.MatrixSetDiag()
409
+ >>> result = matrix_set_diag(x, diagonal)
410
+ >>> print(result)
411
+ [[[-1, 0], [0, 2]], [[-1, 0], [0, 1]], [[-1, 0], [0, 1]]]
412
+
413
+ """
414
+
415
+ @prim_attr_register
416
+ def __init__(self):
417
+ """Initialize MatrixSetDiag"""
418
+
419
+ def infer_dtype(self, x_dtype, diagonal_dtype, assist_dtype):
420
+ valid_type = [mstype.float16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8]
421
+ args = {"x": x_dtype, "diagonal": diagonal_dtype, "assist": assist_dtype}
422
+ validator.check_tensors_dtypes_same_and_valid(args, valid_type, self.name)
423
+ return x_dtype
424
+
425
+ def infer_shape(self, x_shape, diagonal_shape, assist_shape):
426
+ validator.check_int(len(x_shape), 2, validator.GE, "x rank", self.name)
427
+ validator.check("x shape", x_shape, "assist shape", assist_shape, validator.EQ, self.name)
428
+
429
+ if x_shape[-2] < x_shape[-1]:
430
+ validator.check("diagonal shape", diagonal_shape, "x shape excluding the last dimension",
431
+ x_shape[:-1], validator.EQ, self.name)
432
+ else:
433
+ validator.check("diagonal shape", diagonal_shape, "x shape excluding the second last dimension",
434
+ x_shape[:-2] + x_shape[-1:], validator.EQ, self.name)
435
+
436
+ return assist_shape
437
+
438
+
439
+ class ConfusionMulGrad(PrimitiveWithInfer):
440
+ """
441
+ `output0` is the dot product result of input0 and input1.
442
+
443
+ `output1` is the dot product result of input0 and input1, then apply the reducesum operation on it.
444
+
445
+ Args:
446
+ axis (Union[int, tuple[int], list[int]]): The dimensions to reduce.
447
+ Default:(), reduce all dimensions. Only constant value is allowed.
448
+ keep_dims (bool):
449
+
450
+ - If true, keep these reduced dimensions and the length as 1.
451
+ - If false, don't keep these dimensions. Default:False.
452
+
453
+ Inputs:
454
+ - **input_0** (Tensor) - The input Tensor.
455
+ - **input_1** (Tensor) - The input Tensor.
456
+ - **input_2** (Tensor) - The input Tensor.
457
+
458
+ Outputs:
459
+ - **output_0** (Tensor) - The same shape as `input0`.
460
+ - **output_1** (Tensor)
461
+
462
+ - If axis is (), and keep_dims is false, the output is a 0-D array representing
463
+ the sum of all elements in the input array.
464
+ - If axis is int, set as 2, and keep_dims is false,
465
+ the shape of output is :math:`(x_1,x_3,...,x_R)`.
466
+ - If axis is tuple(int), set as (2,3), and keep_dims is false,
467
+ the shape of output is :math:`(x_1,x_4,...x_R)`.
468
+
469
+ Examples:
470
+ >>> confusion_mul_grad = ops.ConfusionMulGrad()
471
+ >>> input_0 = Tensor(np.random.randint(-2, 2, (2, 3)), mindspore.float32)
472
+ >>> input_1 = Tensor(np.random.randint(0, 4, (2, 3)), mindspore.float32)
473
+ >>> input_2 = Tensor(np.random.randint(-4, 0, (2, 3)), mindspore.float32)
474
+ >>> output_0, output_1 = confusion_mul_grad(input_0, input_1, input_2)
475
+ output_0:
476
+ [[ 3. 1. 0.]
477
+ [-6. 2. -2.]]
478
+ output_1:
479
+ -3.0
480
+ """
481
+
482
+ @prim_attr_register
483
+ def __init__(self, axis=(), keep_dims=False):
484
+ self.init_prim_io_names(inputs=["input0", "input1", "input2"], outputs=["output0", "output1"])
485
+ self.axis_ = validator.check_value_type("axis", axis, [int, tuple, list], self.name)
486
+ self.keep_dims_ = validator.check_value_type("keep_dims", keep_dims, [bool], self.name)
487
+
488
+ def infer_shape(self, input0_shape, input1_shape, input2_shape):
489
+ outshape0 = input0_shape
490
+ outshape1 = _infer_shape_reduce(input1_shape, self.axis_, self.keep_dims_, self.name)
491
+ return outshape0, outshape1
492
+
493
+ def infer_dtype(self, input0_dtype, input1_dtype, input2_dtype):
494
+ validator.check_subclass("input0_dtype", input0_dtype, mstype.tensor_type, self.name)
495
+ validator.check_subclass("input1_dtype", input1_dtype, mstype.tensor_type, self.name)
496
+ validator.check_subclass("input2_dtype", input2_dtype, mstype.tensor_type, self.name)
497
+ return input0_dtype, input1_dtype
498
+
499
+
500
+ class ConvertToDynamic(PrimitiveWithCheck):
501
+ """
502
+ This op is used for dynamic rank testing. Its inferred shape will be unknown
503
+ during compile time, so that its output will appear to be dynamically ranked.
504
+ The input will not be altered in any way. Put this operator before the operator
505
+ being tested for dynamic rank support.
506
+
507
+ Args:
508
+ is_dynamic_rank (bool): If true, convert to dynamic rank.
509
+ If false, convert to dynamic shape. Default: ``False``.
510
+
511
+ Inputs:
512
+ - **input** (Tensor) - The tensor used for testing.
513
+
514
+ Outputs:
515
+ - **output** (Tensor) - Same shape, type and value as `input`.
516
+
517
+ Supported Platforms:
518
+ ``CPU``
519
+
520
+ Examples:
521
+ >>> import mindspore as ms
522
+ >>> import mindspore.nn as nn
523
+ >>> from mindspore.ops.operations import _inner_ops as inner
524
+ >>> from mindspore.ops import operations as P
525
+ >>> class TestDynamicNet(nn.Cell):
526
+ >>> def __init__(self):
527
+ >>> super(TestDynamicNet, self).__init__()
528
+ >>> self.convert_to_dynamic = inner.ConvertToDynamic()
529
+ >>> # suppose we are testing Reshape op
530
+ >>> self.reshape = P.Reshape()
531
+ >>>
532
+ >>> def construct(self, input, new_shape):
533
+ >>> dynamic_input = self.convert_to_dynamic(input)
534
+ >>> reshaped_input = self.reshape(dynamic_input, new_shape)
535
+ >>>
536
+ >>> ms.set_context(mode=ms.GRAPH_MODE, device_target="CPU")
537
+ >>> input = Tensor(np.array([0, 1, 2, 3])
538
+ >>> new_shape = (2, 2)
539
+ >>> net = TestDynamicNet()
540
+ >>> output = net(input, new_shape)
541
+ >>> print(output)
542
+ [[0, 1], [2, 3]
543
+ """
544
+
545
+ @prim_attr_register
546
+ def __init__(self, is_dynamic_rank=False):
547
+ validator.check_value_type('is_dynamic_rank', is_dynamic_rank, [bool], self.name)
548
+ self.init_prim_io_names(inputs=["input"], outputs=["output"])
549
+
550
+ def check_shape(self, input_shape):
551
+ validator.check("input_shape rank", len(input_shape), "", 0, validator.GT, self.name)
552
+
553
+ def check_dtype(self, input_dtype):
554
+ validator.check_subclass("input_dtype", input_dtype, mstype.tensor_type, self.name)
555
+
556
+
557
+ class GpuConvertToDynamicShape(PrimitiveWithCheck):
558
+ """
559
+ This op is used for dynamic shape testing. Its inferred shape will be unknown
560
+ during compile time, so that its output will appear to be dynamically shaped.
561
+ The input will not be altered in any way. Put this operator before the operator
562
+ being tested for dynamic shape support.
563
+
564
+ Inputs:
565
+ - **input** (Tensor) - The tensor used for testing.
566
+
567
+ Outputs:
568
+ - **output** (Tensor) - Same shape, type and value as `input`.
569
+
570
+ Examples:
571
+ >>> # make a model, since dynamic shape operators must be in GRAPH_MODE
572
+ >>> import mindspore as ms
573
+ >>> import mindspore.nn as nn
574
+ >>> from mindspore.ops.operations import _inner_ops as inner
575
+ >>> from mindspore.ops import operations as P
576
+ >>> class TestDynamicShapeReshapeNet(nn.Cell):
577
+ >>> def __init__(self):
578
+ >>> super(TestDynamicShapeReshapeNet, self).__init__()
579
+ >>> self.convert_to_dynamic_shape = inner.GpuConvertToDynamicShape()
580
+ >>> # suppose we are testing Reshape op
581
+ >>> self.reshape = P.Reshape()
582
+ >>>
583
+ >>> def construct(self, input, new_shape):
584
+ >>> dynamic_shape_input = self.convert_to_dynamic_shape(input)
585
+ >>> reshaped_input = self.reshape(input, new_shape)
586
+ >>>
587
+ >>> ms.set_context(mode=ms.GRAPH_MODE, device_target="GPU")
588
+ >>> input = Tensor(np.array([0, 1, 2, 3])
589
+ >>> new_shape = (2, 2)
590
+ >>> net = TestDynamicShapeReshapeNet()
591
+ >>> output = net(input, new_shape)
592
+ >>> print(output)
593
+ [[0, 1], [2, 3]
594
+ """
595
+
596
+ @prim_attr_register
597
+ def __init__(self):
598
+ self.init_prim_io_names(inputs=["input"], outputs=["output"])
599
+
600
+ def check_shape(self, input_shape):
601
+ validator.check("input_shape rank", len(input_shape), "", 0, validator.GT, self.name)
602
+
603
+ def check_dtype(self, input_dtype):
604
+ validator.check_subclass("input_dtype", input_dtype, mstype.tensor_type, self.name)
605
+
606
+
607
+ class ErrorOnDynamicShapeInput(PrimitiveWithInfer):
608
+ """
609
+ This op is used for dynamic shape testing. The only purpose of this operator is
610
+ that it will throw a value error if the input is dynamically shaped.
611
+
612
+ Inputs:
613
+ - **input** (Tensor) - The tensor used for testing.
614
+
615
+ Outputs:
616
+ - **output** (Tensor) - Same shape, type and value as `input`.
617
+
618
+ Examples:
619
+ >>> # make a model, since dynamic shape operators must be in GRAPH_MODE
620
+ >>> import mindspore as ms
621
+ >>> import mindspore.nn as nn
622
+ >>> from mindspore.ops.operations import _inner_ops as inner
623
+ >>> from mindspore.ops import operations as P
624
+ >>> class AssertDynamicShapeNet(nn.Cell):
625
+ >>> def __init__(self):
626
+ >>> super(AssertDynamicShapeNet, self).__init__()
627
+ >>> self.convert_to_dynamic_shape = inner.GpuConvertToDynamicShape()
628
+ >>> self.error_on_dynamic_shape_input = inner.ErrorOnDynamicShapeInput()
629
+ >>>
630
+ >>> def construct(self, input, new_shape):
631
+ >>> dynamic_shape_input = self.convert_to_dynamic_shape(input)
632
+ >>> self.error_on_dynamic_shape_input(dynamic_shape_input)
633
+ >>>
634
+ >>> ms.set_context(mode=ms.GRAPH_MODE, device_target="GPU")
635
+ >>> input = Tensor(np.array([0])
636
+ >>> net = TestDynamicShapeReshapeNet()
637
+ >>> output = net(input, new_shape)
638
+ ValueError: Input is dynamically shaped.
639
+ """
640
+
641
+ @prim_attr_register
642
+ def __init__(self):
643
+ self.init_prim_io_names(inputs=["input"], outputs=["output"])
644
+
645
+ def infer_shape(self, input_shape):
646
+ shape = list(input_shape)
647
+
648
+ for dim in shape:
649
+ if dim == -1:
650
+ raise ValueError("Input is dynamically shaped.")
651
+
652
+ return input_shape
653
+
654
+ def infer_type(self, input_dtype):
655
+ """Infer the dtype of input for ErrorOnDynamicShapeInput."""
656
+ validator.check_subclass("input_dtype", input_dtype, mstype.tensor_type, self.name)
657
+ return input_dtype
658
+
659
+ def infer_value(self, input_tensor):
660
+ return input_tensor
661
+
662
+
663
+ class SequenceMask(PrimitiveWithCheck):
664
+ """
665
+ Returns a mask tensor representing the first N positions of each cell.
666
+
667
+ If lengths has shape [d_1, d_2, ..., d_n], then the resulting tensor mask has type and shape
668
+ [d_1, d_2, ..., d_n, maxlen], with mask[i_1, i_2, ..., i_n, j] = (j < lengths[i_1, i_2, ..., i_n])
669
+
670
+ Inputs:
671
+ - **lengths** (Tensor) - Tensor to calculate the mask for. All values in this tensor should be
672
+ less than or equal to `maxlen`. Values greater than `maxlen` will be treated as `maxlen`.
673
+ Must be type int32 or int64.
674
+
675
+ - **maxlen** (int) - size of the last dimension of returned tensor. Must be positive and same
676
+ type as elements in `lengths`.
677
+
678
+ Outputs:
679
+ One mask tensor of shape lengths.shape + (maxlen,).
680
+
681
+ Supported Platforms:
682
+ ``GPU`` ``CPU``
683
+
684
+ Examples:
685
+ >>> from mindspore import ops
686
+ >>> import numpy as np
687
+ >>> x = Tensor(np.array([[1, 3], [2, 0]]))
688
+ >>> sequence_mask = ops.SequenceMask()
689
+ >>> output = sequence_mask(x, 3)
690
+ >>> print(output)
691
+ [[[True False False]
692
+ [True True True]]
693
+ [[True True False]
694
+ [False False False]]]
695
+ """
696
+
697
+ @prim_attr_register
698
+ def __init__(self):
699
+ self.init_prim_io_names(inputs=["lengths", "maxlen"], outputs=["mask"])
700
+
701
+ def check_shape(self, lengths_shape, maxlen_shape):
702
+ validator.check("lengths_shape", len(lengths_shape), "", 0, validator.GT, self.name)
703
+ validator.check("maxlen_shape", len(maxlen_shape), "", 0, validator.EQ, self.name)
704
+
705
+ def check_dtype(self, lengths_dtype, maxlen_dtype):
706
+ validator.check_subclass("lengths_dtype", lengths_dtype, mstype.tensor_type, self.name)
707
+ validator.check_subclass("maxlen", maxlen_dtype, mstype.number, self.name)
708
+
709
+
710
+ class SyncBatchNorm(Primitive):
711
+ r"""
712
+ Sync Batch Normalization for input data and updated parameters.
713
+
714
+ Sync Batch Normalization is cross device synchronized Batch Normalization. Batch Normalization is
715
+ widely used in convolutional neural networks. This operation applies Batch Normalization over input
716
+ to avoid internal covariate shift as described in the paper `Batch Normalization: Accelerating
717
+ Deep Network Training by Reducing Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`_.
718
+ It rescales and recenters the features using a mini-batch of data and the learned parameters which
719
+ can be described in the following formula,
720
+
721
+ .. math::
722
+ y = \frac{x - mean}{\sqrt{variance + \epsilon}} * \gamma + \beta
723
+
724
+ where :math:`\gamma` is scale, :math:`\beta` is bias, :math:`\epsilon` is epsilon.
725
+
726
+ Args:
727
+ epsilon (float): A small value added for numerical stability. Default: 1e-5.
728
+ momentum (float): The hyper parameter to compute moving average for running_mean and running_var
729
+ (e.g. :math:`new\_running\_mean = (1 - momentum) * running\_mean + momentum * current\_mean`).
730
+ Momentum value must be [0, 1]. Default: 0.1.
731
+ group (str): The communication group to work on. Default: "sync_bn_group0".
732
+ device_num (int): The number of devices in each group. Default: 2.
733
+
734
+ Inputs:
735
+ - **input_x** (Tensor) - Tensor of shape :math:`(N, C)`, with float16 or float32 data type.
736
+ - **scale** (Tensor) - Tensor of shape :math:`(C,)`, with float16 or float32 data type.
737
+ - **bias** (Tensor) - Tensor of shape :math:`(C,)`, has the same data type with `scale`.
738
+ - **mean** (Tensor) - Tensor of shape :math:`(C,)`, with float16 or float32 data type.
739
+ - **variance** (Tensor) - Tensor of shape :math:`(C,)`, has the same data type with `mean`.
740
+
741
+ Outputs:
742
+ Tuple of 5 Tensor, the normalized inputs and the updated parameters.
743
+
744
+ - **output_x** (Tensor) - The same type and shape as the input_x. The shape is :math:`(N, C)`.
745
+ - **updated_scale** (Tensor) - Tensor of shape :math:`(C,)`.
746
+ - **updated_bias** (Tensor) - Tensor of shape :math:`(C,)`.
747
+ - **updated_moving_mean** (Tensor) - Tensor of shape :math:`(C,)`.
748
+ - **updated_moving_variance** (Tensor) - Tensor of shape :math:`(C,)`.
749
+
750
+ Supported Platforms:
751
+ ``Ascend``
752
+
753
+ Examples:
754
+ >>> # This example should be run with multiple processes.
755
+ >>> # Please refer to nn.SyncBatchNorm for direct use.
756
+ >>> input_x = Tensor(np.ones([2, 2]), mindspore.float32)
757
+ >>> scale = Tensor(np.ones([2]), mindspore.float32)
758
+ >>> bias = Tensor(np.ones([2]), mindspore.float32)
759
+ >>> mean = Tensor(np.ones([2]), mindspore.float32)
760
+ >>> variance = Tensor(np.ones([2]), mindspore.float32)
761
+ >>> sync_batch_norm = ops._inner_ops.SyncBatchNorm()
762
+ >>> output = sync_batch_norm(input_x, scale, bias, mean, variance)
763
+ >>> print(output)
764
+ (Tensor(shape=[2, 2], dtype=Float32, value=
765
+ [[ 1.00000000e+00, 1.00000000e+00],
766
+ [ 1.00000000e+00, 1.00000000e+00]]), Tensor(shape=[2], dtype=Float32, value=
767
+ [ 1.00000000e+00, 1.00000000e+00]), Tensor(shape=[2], dtype=Float32, value=
768
+ [ 1.00000000e+00, 1.00000000e+00]), Tensor(shape=[2], dtype=Float32, value=
769
+ [ 1.00000000e+00, 1.00000000e+00]), Tensor(shape=[2], dtype=Float32, value=
770
+ [ 1.00000000e+00, 1.00000000e+00]))
771
+ """
772
+
773
+ @prim_attr_register
774
+ def __init__(self, epsilon=1e-5, momentum=0.1, group="sync_bn_group0", device_num=2):
775
+ validator.check_float_range(epsilon, 0, 1, validator.INC_RIGHT, 'epsilon', self.name)
776
+ validator.check_float_range(momentum, 0, 1, validator.INC_BOTH, 'momentum', self.name)
777
+ validator.check_isinstance("group", group, str)
778
+ validator.check_int(device_num, 2, validator.GE, "device_num", self.name)
779
+ self.init_prim_io_names(inputs=['x', 'scale', 'offset', 'mean', 'variance'],
780
+ outputs=['y', 'batch_mean', 'batch_variance', 'reserve_space_1', 'reserve_space_2'])
781
+ self.add_prim_attr('side_effect_mem', True)
782
+ self.add_prim_attr('format', 'NCHW')
783
+
784
+
785
+ class Centralization(PrimitiveWithInfer):
786
+ """
787
+ Computes centralization. y = x - mean(x, axis).
788
+
789
+ Note:
790
+ The dimension index starts at 0 and must be in the range `[-input.ndim, input.ndim)`.
791
+
792
+ Inputs:
793
+ - **input_x** (Tensor) - The input tensor. The data type mast be float16 or float32.
794
+ - **axis** (Union[int, Tuple(int), List(int)]) - The dimensions to reduce. Default: (), reduce all dimensions.
795
+ Only constant value is allowed. Must be in the range [-rank(input_x), rank(input_x)).
796
+
797
+ Outputs:
798
+ Tensor, has the same shape and dtype as the `input_x`.
799
+
800
+ Raises:
801
+ TypeError: If `axis` is not one of the following types: int, list, tuple, NoneType.
802
+ TypeError: If `axis` has non-Int elements.
803
+
804
+ Supported Platforms:
805
+ ``Ascend``
806
+
807
+ Examples:
808
+ >>> mindspore.set_seed(1)
809
+ >>> input_x = Tensor(np.random.randn(2, 2).astype(np.float32))
810
+ >>> centralization = ops.Centralization()
811
+ >>> output = centralization(input_x, -1)
812
+ >>> print(output)
813
+ [[ 1.1180509 -1.1180508]
814
+ [ 0.2723984 -0.2723984]]
815
+ """
816
+
817
+ __mindspore_signature__ = (
818
+ sig.make_sig('input_x'),
819
+ sig.make_sig('axis', default=())
820
+ )
821
+
822
+ @prim_attr_register
823
+ def __init__(self):
824
+ """Initialize Centralization"""
825
+ self.init_prim_io_names(inputs=['input_x', 'axis'], outputs=['output'])
826
+
827
+ def __infer__(self, input_x, axis):
828
+ x_shape = list(input_x['shape'])
829
+ x_dtype = input_x['dtype']
830
+ axis_v = axis['value']
831
+ rank = len(x_shape)
832
+
833
+ args = {'input_x': input_x['dtype']}
834
+ validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name)
835
+
836
+ if axis_v is None:
837
+ raise ValueError(f"For {self.name}, axis must be const.")
838
+ validator.check_value_type('axis', axis_v, [int, list, tuple], self.name)
839
+
840
+ if isinstance(axis_v, int):
841
+ validator.check_int_range(axis_v, -rank, rank, validator.INC_LEFT, 'axis', self.name)
842
+ elif axis:
843
+ for index, one_axis in enumerate(axis_v):
844
+ validator.check_value_type('axis[%d]' % index, one_axis, [int], self.name)
845
+
846
+ out = {'shape': x_shape,
847
+ 'dtype': x_dtype,
848
+ 'value': None}
849
+ return out
850
+
851
+
852
+ class StackInit(PrimitiveWithInfer):
853
+ """
854
+ Create a stack that produces tensors in first-in last-out order.
855
+
856
+ After `StackInit`, a tensor can be pushed onto the stack using `StackPush`, and popped
857
+ at the top of the stack using `StackPop`. Finally, the stack should be destroyed with `StackDestroy`.
858
+
859
+ Args:
860
+ index (int): The index of the stack. Default: 1.
861
+
862
+ Supported Platforms:
863
+ ``Ascend``
864
+
865
+ Examples:
866
+ >>> x = Tensor(np.array([[1, 3], [2, 0]]))
867
+ >>> index = 0
868
+ >>> stack = ops.StackInit(index)
869
+ >>> push = ops.StackPush(index)
870
+ >>> pop = ops.StackPop(index, x.shape, x.dtype)
871
+ >>> destroy = ops.StackDestroy(index)
872
+ >>> stack()
873
+ >>> push(x)
874
+ >>> y = pop()
875
+ >>> destroy()
876
+ >>> print(y)
877
+ [[1 3]
878
+ [2 0]]
879
+ """
880
+
881
+ @prim_attr_register
882
+ def __init__(self, index=1):
883
+ """StackInit"""
884
+ validator.check_value_type("index", index, [int], self.name)
885
+
886
+
887
+ class StackPush(PrimitiveWithInfer):
888
+ """
889
+ Push a tensor onto the stack.
890
+
891
+ Before `StackPush`, the stack should be created using `StackInit`.
892
+ Please refer to the usage in source code of `StackInit`.
893
+
894
+ Args:
895
+ index (int): The index of the stack. Default: 1.
896
+
897
+ Inputs:
898
+ - **input** (Tensor) - A tensor to be pushed onto the stack.
899
+
900
+ Supported Platforms:
901
+ ``Ascend``
902
+
903
+ Examples:
904
+ Please refer to the usage of `StackInit`.
905
+ """
906
+
907
+ @prim_attr_register
908
+ def __init__(self, index=1):
909
+ """StackPush"""
910
+ validator.check_value_type("index", index, [int], self.name)
911
+ self.init_prim_io_names(inputs=['input'], outputs=[])
912
+
913
+
914
+ class StackPop(PrimitiveWithInfer):
915
+ """
916
+ Pop the tensor at the top of the stack.
917
+
918
+ Before `StackPop`, the stack should be created using `StackInit`.
919
+ Please refer to the usage in source code of `StackInit`.
920
+
921
+ Args:
922
+ index (int): The index of the stack. Default: 1.
923
+ shape (tuple): The shape of the tensor at the top of the stack. Default: (1,).
924
+ dtype (mindspore.dtype): The type of the tensor at the top of the stack. Default: mindspore.float32.
925
+
926
+ Outputs:
927
+ - **output** (Tensor) - The tensor at the top of the stack.
928
+
929
+ Supported Platforms:
930
+ ``Ascend``
931
+
932
+ Examples:
933
+ Please refer to the usage of `StackInit`.
934
+ """
935
+
936
+ @prim_attr_register
937
+ def __init__(self, index=1, shape=(1,), dtype=mstype.float32):
938
+ """StackPop"""
939
+ validator.check_value_type("index", index, [int], self.name)
940
+
941
+ validator.check_value_type('shape type', shape, [list, tuple], self.name)
942
+ validator.check_int(len(np.array(shape).shape), 1, validator.EQ, "dim of shape", self.name)
943
+ for elem in shape:
944
+ validator.check_int(elem, 1, validator.GE, 'shape element', self.name)
945
+ validator.check_value_type('type of shape element', elem, [int], self.name)
946
+
947
+ validator.check_type_name("dtype", dtype, (mstype.bool_,) + mstype.number_type, self.name)
948
+ self.shape = shape
949
+ self.dtype = dtype
950
+
951
+ self.init_prim_io_names(inputs=[], outputs=['output'])
952
+
953
+ def __infer__(self):
954
+ return {'shape': (list(self.shape)),
955
+ 'dtype': (self.dtype),
956
+ 'value': None}
957
+
958
+
959
+ class StackDestroy(PrimitiveWithInfer):
960
+ """
961
+ Destroy the stack.
962
+
963
+ Before `StackDestroy`, the stack should be created using `StackInit`.
964
+ Please refer to the usage in source code of `StackInit`.
965
+
966
+ Args:
967
+ index (int): The index of the stack. Default: 1.
968
+
969
+ Supported Platforms:
970
+ ``Ascend``
971
+
972
+ Examples:
973
+ Please refer to the usage of `StackInit`.
974
+ """
975
+
976
+ @prim_attr_register
977
+ def __init__(self, index=1):
978
+ """StackDestroy"""
979
+ validator.check_value_type("index", index, [int], self.name)
980
+
981
+
982
+ class DynamicStitch(PrimitiveWithCheck):
983
+ r"""
984
+ Interleave the values from the data tensors into a single tensor.
985
+
986
+ Inputs:
987
+ - **indices** (Union[tuple, list]) - A Tuple or list of Tensor objects with the same shape and type.
988
+ - **data** (Union[tuple, list]) - A Tuple or list of Tensor objects with the same shape and type.
989
+
990
+ Outputs:
991
+ Tensor. A stacked Tensor with the same type as `data`.
992
+
993
+ Raises:
994
+ TypeError: If the data types of elements in `data` or `indices` are not the same.
995
+ ValueError: If the length of `data` or `indices` is not greater than 1.
996
+
997
+ Supported Platforms:
998
+ ``Ascend``
999
+
1000
+ Examples:
1001
+ >>> x1 = Tensor([6], mstype.int32)
1002
+ >>> x2 = Tensor(np.array([4, 1]), mstype.int32)
1003
+ >>> x3 = Tensor(np.array([[5, 2], [0, 3]]), mstype.int32)
1004
+ >>> y1 = Tensor(np.array([[6, 1]]), mstype.int32)
1005
+ >>> y2 = Tensor(np.array([[41, 42], [11, 12]]), mstype.int32)
1006
+ >>> y3 = Tensor(np.array([[[51, 52], [21, 22]], [[1, 2], [31, 32]]]), mstype.int32)
1007
+ >>> stitch = ops.DynamicStitch()
1008
+ >>> output = stitch([x1, x2, x3], [y1, y2, y3])
1009
+ >>> print(output)
1010
+ [[ 1 2]
1011
+ [11 12]
1012
+ [21 22]
1013
+ [31 32]
1014
+ [41 42]
1015
+ [51 52]
1016
+ [61 62]]
1017
+ """
1018
+
1019
+ @prim_attr_register
1020
+ def __init__(self):
1021
+ """Initialize DynamicStitch"""
1022
+
1023
+ def check_shape(self, indices_shape, data_shape):
1024
+ validator.check_value_type("shape of indices", indices_shape, [tuple, list], self.name)
1025
+ validator.check_int(len(indices_shape), 1, validator.GE, "len of indices_shape", self.name)
1026
+ indices_dim0 = len(indices_shape[0])
1027
+ indices_num = len(indices_shape)
1028
+
1029
+ validator.check_value_type("shape of data", data_shape, [tuple, list], self.name)
1030
+ validator.check_int(len(data_shape), 1, validator.GE, "len of data_shape", self.name)
1031
+ data_dim0 = len(data_shape[0])
1032
+ data_num = len(indices_shape)
1033
+
1034
+ validator.check("size of indices", indices_num, 'size of data', data_num, validator.EQ, self.name)
1035
+
1036
+ # shape of `data` must start with shape of `indices`
1037
+ for i in range(0, indices_num):
1038
+ indices_dim = len(indices_shape[i])
1039
+ data_dim = len(data_shape[i])
1040
+ validator.check(f"dim of indices[{i}]", indices_dim, f"dim of data[{i}]", data_dim, validator.LE, self.name)
1041
+ if data_shape[i][:indices_dim] != data_shape[i][:indices_dim]:
1042
+ raise ValueError(f"data[{i}].shape: {data_shape} does not start with indices[{i}].shape: {data_shape}")
1043
+
1044
+ # the last-(data_dim0-indices_dim0)-dim of data shape must end with same shape.
1045
+ base_extra = data_dim0 - indices_dim0
1046
+ for i in range(0, data_num):
1047
+ indices_dim = len(indices_shape[i])
1048
+ data_dim = len(data_shape[i])
1049
+ extra = data_dim - indices_dim
1050
+ validator.check(f"extra dim of data[{i}]", extra,
1051
+ f"extra dim of data[0]", base_extra, validator.EQ, self.name)
1052
+ validator.check(f"data[0].shape[{indices_dim0}:]", data_shape[0][indices_dim0:],
1053
+ f"data[{i}].shape[{len(indices_shape[i])}:]",
1054
+ data_shape[i][indices_dim:], validator.EQ, self.name)
1055
+
1056
+ out_shape = [-1] + data_shape[0][indices_dim0:]
1057
+ return out_shape
1058
+
1059
+ def check_dtype(self, indices_type, data_type):
1060
+ validator.check_subclass("indices[0]", indices_type[0], mstype.tensor_type, self.name)
1061
+ validator.check_subclass("data[0]", data_type[0], mstype.tensor_type, self.name)
1062
+ indices_num = len(indices_type)
1063
+ for i in range(0, indices_num):
1064
+ validator.check_tensor_dtype_valid(f'indices[{i}]', indices_type[i], mstype.int32, self.name)
1065
+ validator.check_tensor_dtype_valid(f'data[{i}]', data_type[i],
1066
+ mstype.number_type + (mstype.bool_,), self.name)
1067
+ validator.check(f"type of data[{i}]", data_type[i], f"type of data[0]",
1068
+ data_type[0], validator.EQ, self.name)
1069
+ return data_type[0]
1070
+
1071
+
1072
+ class DynamicBroadcastGradientArgs(Primitive):
1073
+ """
1074
+ Broadcast the two input shapes, return the dimensions that each need to be broadcast.
1075
+
1076
+ Input shape `s0` and shape `s1` can be broadcast to a common shape if for each dimension pair they are either equal
1077
+ or input is one or the target dimension is -1. In case of -1 in target shape, it will be replaced by the input
1078
+ shape's value in that dimension.
1079
+
1080
+ Inputs:
1081
+ - **s0** (Tensor) - A `1-D` tensor. The data type should be one of the following types: int32, int64,
1082
+ uint32, uint64.
1083
+ - **s1** (Tensor) - A `1-D` tensor with the same type as `s0`.
1084
+
1085
+ Outputs:
1086
+ Tuple(Tensor), tuple of 2 tensors, r0 and r1. The first one is the index tensor and the other one is the mask
1087
+ tensor.
1088
+
1089
+ - **r0** (Tensor) - The output shape is 1-D with the same type as s0.
1090
+ - **r1** (Tensor) - The output shape is 1-D with the same type as s0.
1091
+
1092
+ Raises:
1093
+ ValueError: if the `s0` and `s1` are incompatible, or if a - 1 in the target shape is in an invalid
1094
+ location.
1095
+
1096
+ Supported Platforms:
1097
+ ``Ascend``
1098
+
1099
+ Examples:
1100
+ >>> shape0 = (4, 2, 1)
1101
+ >>> shape1 = (2, 7)
1102
+ >>> from mindspore.ops.operations import _inner_ops
1103
+ >>> args = _inner_ops.DynamicBroadcastGradientArgs()
1104
+ >>> r0, r1 = args(Tensor(shape0), Tensor(shape1))
1105
+ >>> print(r0, r1)
1106
+ [2], [0]
1107
+ """
1108
+
1109
+ @prim_attr_register
1110
+ def __init__(self):
1111
+ """Init BroadcastGradientArgs"""
1112
+
1113
+
1114
+ class DSDMatmul(PrimitiveWithInfer):
1115
+ """
1116
+ The definition of the CusSquare primitive.
1117
+ """
1118
+
1119
+ @prim_attr_register
1120
+ def __init__(self):
1121
+ self.init_prim_io_names(inputs=['input_w1', 'input_w2', 'input_v'], outputs=['output_y'])
1122
+
1123
+ def infer_shape(self, input_w1_shape, input_w2_shape, input_v_shape):
1124
+ batch_size = input_w1_shape[0]
1125
+ head = input_w1_shape[1]
1126
+ v_embedding = input_v_shape[1] * 16 // head
1127
+ seq_len = input_v_shape[0] * 16 // batch_size
1128
+ return (batch_size, head, v_embedding // 16, seq_len // 16, 16, 16)
1129
+
1130
+ def infer_dtype(self, data_dtype1, data_dtype2, data_dtype3):
1131
+ return data_dtype1
1132
+
1133
+
1134
+ class MatmulDDS(PrimitiveWithInfer):
1135
+ """MatmulDDS definition"""
1136
+
1137
+ @prim_attr_register
1138
+ def __init__(self, bs, heads):
1139
+ """init MatmulDDS"""
1140
+ self.init_prim_io_names(inputs=['q', 'k', 'local_mask', 'global_mask'],
1141
+ outputs=['local_prob', 'global_prob'])
1142
+
1143
+ self.heads = heads
1144
+
1145
+ def infer_shape(self, q, k, local_mask, global_mask):
1146
+ seq_len = local_mask[0] * local_mask[-1]
1147
+ bs = q[1] * q[2] // seq_len
1148
+ global_size = seq_len // 4
1149
+ size_per_head = q[0] * q[-1] // self.heads
1150
+ heads = q[0] * q[-1] // size_per_head
1151
+ block_size = local_mask[1] * local_mask[2] // bs
1152
+ block_num = seq_len // block_size
1153
+ l_size = (bs, heads, block_num, block_size // 16, block_size // 16, 16, 16)
1154
+ g_size = (bs, heads, block_num, global_size // 16, block_size // 16, 16, 16)
1155
+
1156
+ return l_size, g_size
1157
+
1158
+ def infer_dtype(self, q, k, local_mask, global_mask):
1159
+ return q, q
1160
+
1161
+
1162
+ class DSDGrad(PrimitiveWithInfer):
1163
+ """
1164
+ The definition of the CusSquare primitive.
1165
+ """
1166
+
1167
+ @prim_attr_register
1168
+ def __init__(self):
1169
+ self.init_prim_io_names(inputs=['w1_gm', 'w2_gm', 'v_gm', 'a_gm', 'd_a_gm'],
1170
+ outputs=['d_w1_gm', 'd_w2_gm', 'd_v_gm'])
1171
+
1172
+ def infer_shape(self, input_w1_shape, input_w2_shape, input_v_shape, input_a_shape, input_da_shape):
1173
+ return input_w1_shape, input_w2_shape, input_v_shape
1174
+
1175
+ def infer_dtype(self, data_dtype1, data_dtype2, data_dtype3, data_dtype4, data_dtype5):
1176
+ return data_dtype1, data_dtype1, data_dtype1
1177
+
1178
+
1179
+ class MatmulDDSGrad(PrimitiveWithInfer):
1180
+ """MatmulDDS definition"""
1181
+
1182
+ @prim_attr_register
1183
+ def __init__(self):
1184
+ """init MatmulDDS"""
1185
+ self.init_prim_io_names(inputs=['q', 'k', 'local_prob', 'global_prob', 'local_prob_grad', 'global_prob_grad'],
1186
+ outputs=['dq', 'dk'])
1187
+
1188
+ def infer_shape(self, q, k, local_prob, global_prob, local_prob_grad, global_prob_grad):
1189
+ k_size = (q[1], q[0], q[3], q[2])
1190
+
1191
+ return q, k_size
1192
+
1193
+ def infer_dtype(self, q, k, local_prob, global_prob, local_prob_grad, global_prob_grad):
1194
+ return q, k
1195
+
1196
+
1197
+ class NonZeroWithValue(Primitive):
1198
+ """
1199
+ Returns the value of elements that are non-zero (in row-major order - by dimension).
1200
+
1201
+ Inputs:
1202
+ - **x** (Tensor), input array of rank >= 2.
1203
+
1204
+ Outputs:
1205
+ elements that are non-zero.
1206
+
1207
+ Supported Platforms:
1208
+ ``Ascend``
1209
+
1210
+ Examples:
1211
+ >>> op = NonZeroWithValue()
1212
+ >>> data = Tensor(np.array([[1, 0, 0], [0, 0, 1]]), mindspore.float32)
1213
+ >>> value, index, count = op(data)
1214
+ >>> print(value)
1215
+ [1.0, 1.0]
1216
+ """
1217
+
1218
+ @prim_attr_register
1219
+ def __init__(self, transpose=False):
1220
+ """Initialize NonZeroWithValue"""
1221
+ validator.check_value_type("transpose", transpose, [bool], self.name)
1222
+ self.init_prim_io_names(inputs=['x'], outputs=['value', 'index', 'count'])
1223
+
1224
+
1225
+ class NonZeroWithValueShape(Primitive):
1226
+ """
1227
+ Returns the value and index of elements that are non-zero (in row-major order - by dimension).
1228
+
1229
+ Inputs:
1230
+ - **x** (Tensor), input array of rank >= 2.
1231
+
1232
+ Outputs:
1233
+ elements that are non-zero.
1234
+
1235
+ Supported Platforms:
1236
+ ``Ascend``
1237
+
1238
+ Examples:
1239
+ >>> non_zero = NonZeroWithValue()
1240
+ >>> op = NonZeroWithValueShape()
1241
+ >>> data = Tensor(np.array([[1, 0, 0], [0, 0, 1]]), mindspore.float32)
1242
+ >>> value, index, count = non_zero(data)
1243
+ >>> out_value, out_index = op(value, index, count)
1244
+ >>> print(out_index)
1245
+ [[0, 1], [0, 2]]
1246
+ """
1247
+
1248
+ @prim_attr_register
1249
+ def __init__(self):
1250
+ """Initialize NonZeroWithValueShape"""
1251
+ self.init_prim_io_names(inputs=['value', 'index', 'count'], outputs=['out_value', 'out_index'])
1252
+
1253
+
1254
+ class DecodeImage(PrimitiveWithInfer):
1255
+ """
1256
+ Returns image data that parse from string Tensor.
1257
+
1258
+ Inputs:
1259
+ - **x** (Tensor), a Tensor of type string. 0-D. The jPEG, GIF, PNG, BMP-encoded image.
1260
+
1261
+ Outputs:
1262
+ A Tensor of type uint8, uint16, float.
1263
+
1264
+ Supported Platforms:
1265
+ ``Ascend``
1266
+
1267
+ Examples:
1268
+ """
1269
+
1270
+ @prim_attr_register
1271
+ def __init__(self, channels=0, dtype=mstype.uint8, expand_animations=False, _op_max_shape="8192,8192,3",
1272
+ _op_max_size=[8000000]):
1273
+ self.init_prim_io_names(inputs=["contents"], outputs=["image"])
1274
+ self.res_type = dtype
1275
+
1276
+ def infer_shape(self, x):
1277
+ return (-1, -1, 3)
1278
+
1279
+ def infer_dtype(self, x):
1280
+ return self.res_type
1281
+
1282
+
1283
+ class SliceGetItem(Primitive):
1284
+ """
1285
+ using SliceGetItem to get slice's attribute of 'start' 'stop' 'step'
1286
+ """
1287
+
1288
+ @prim_attr_register
1289
+ def __init__(self):
1290
+ """Initialize ScatterElements"""
1291
+ self.init_prim_io_names(inputs=['slice', 'attr'], outputs=['slice_item'])
1292
+
1293
+ def __call__(self, slice_value, value):
1294
+ if not isinstance(slice_value, slice):
1295
+ raise TypeError(
1296
+ "Primitive[SliceGetItem] only support to get a slice type element but got {}".format(slice_value))
1297
+ if value == "start":
1298
+ if hasattr(slice_value.start, "ndim") and slice_value.start.ndim == 1:
1299
+ return slice_value.start.item()
1300
+ return slice_value.start
1301
+ if value == "stop":
1302
+ if hasattr(slice_value.stop, "ndim") and slice_value.stop.ndim == 1:
1303
+ return slice_value.stop.item()
1304
+ return slice_value.stop
1305
+ if value == "step":
1306
+ if hasattr(slice_value.step, "ndim") and slice_value.step.ndim == 1:
1307
+ return slice_value.step.item()
1308
+ return slice_value.step
1309
+ raise AttributeError("\'slice\' object has no attribute {}".format(value))
1310
+
1311
+
1312
+ class DynamicBroadcastTo(Primitive):
1313
+ """
1314
+ Broadcasts input tensor to a given shape.
1315
+
1316
+ Inputs:
1317
+ - **input_x** (Tensor) - The input tensor. The data type should be one of the following types:
1318
+ float16, float32, int32, int8, uint8.
1319
+ The shape is :math:`(N,*)` where :math:`*` means any number of additional dimensions.
1320
+ - **shape** (Tensor): The target shape to broadcast.
1321
+
1322
+ Outputs:
1323
+ Tensor, with the given `shape` and the same data type as `input_x`.
1324
+
1325
+ Raises:
1326
+ ValueError: if the target and input shapes are incompatible.
1327
+
1328
+ Supported Platforms:
1329
+ ``Ascend`` ``GPU`` ``CPU``
1330
+ """
1331
+
1332
+ @prim_attr_register
1333
+ def __init__(self):
1334
+ """Initialize DynamicBroadcastTo"""
1335
+ self.init_prim_io_names(inputs=['x', 'shape'], outputs=['y'])
1336
+
1337
+
1338
+ class DynamicResizeNearestNeighbor(Primitive):
1339
+ r"""
1340
+ Resizes the input tensor by using the nearest neighbor algorithm.
1341
+
1342
+ Resizes the input tensor to a given size by using the nearest neighbor algorithm. The nearest
1343
+ neighbor algorithm selects the value of the nearest point and does not consider the
1344
+ values of neighboring points at all, yielding a piecewise-constant interpolant.
1345
+
1346
+ Note:
1347
+ The operator supports dynamic shape.
1348
+
1349
+ Args:
1350
+ align_corners (bool): Whether the centers of the 4 corner pixels of the input
1351
+ and output tensors are aligned. Default: ``False``.
1352
+
1353
+ Inputs:
1354
+ - **input_x** (Tensor) - The input tensor. The shape of the tensor is :math:`(N, C, H, W)`.
1355
+ - **size** (Union[tuple, list]): The target size. The dimension of size must be 2.
1356
+
1357
+ Outputs:
1358
+ Tensor, the shape of the output tensor is :math:`(N, C, NEW\_H, NEW\_W)`.
1359
+ The data type is the same as the `input_x`.
1360
+ """
1361
+
1362
+ @prim_attr_register
1363
+ def __init__(self, align_corners=False):
1364
+ """Initialize ResizeNearestNeighbor"""
1365
+ validator.check_value_type("align_corners", align_corners, [bool], self.name)
1366
+ self.init_prim_io_names(inputs=['image_in'], outputs=['image_out'])
1367
+
1368
+
1369
+ class PsROIPooling(PrimitiveWithInfer):
1370
+ r"""
1371
+ Position Sensitive ROI-Pooling
1372
+ Inputs:
1373
+ - feature(Tensor)
1374
+ - rois(Tensor)
1375
+
1376
+ - **features** (Tensor) - The input features, whose shape must be :math:`(N, C, H, W)`.
1377
+ - **rois** (Tensor) - The shape is :math:`(rois\_n, 5)`. With data type of float16 or float32.
1378
+ `rois_n` represents the number of RoI. The size of the second dimension must be `5` and the `5` colunms
1379
+ are :math:`(image\_index, top\_left\_x, top\_left\_y, bottom\_right\_x, bottom\_right\_y)`.
1380
+ `image_index` represents the index of image. `top_left_x` and `top_left_y` represent the `x, y`
1381
+ coordinates of the top left corner of corresponding RoI, respectively. `bottom_right_x` and `bottom_right_y`
1382
+ represent the `x, y` coordinates of the bottom right corner of corresponding RoI, respectively.
1383
+
1384
+ Outputs:
1385
+ - out shape(rois_num, out_channel, pool_height, pool_width), the result after pooling.
1386
+ - channel_map shape(rois_num, out_channel, pool_height, pool_width), use for back forward to compute grad
1387
+ Supported Platforms:
1388
+ ``GPU``
1389
+
1390
+ Examples:
1391
+ >>> import mindspore
1392
+ >>> import numpy as np
1393
+ >>> from mindspore import Tensor
1394
+ >>> from mindspore.ops.operations import _inner_ops as inner
1395
+ >>> features = np.random.randn(4, 21 * 7 * 7, 80, 48)
1396
+ >>> features = Tensor.from_numpy(features).astype(mindspore.float32)
1397
+ >>> rois = Tensor.from_numpy(
1398
+ >>> np.array([
1399
+ >>> [0.0000, 150.3563, 200.1320, 579.3563, 602.3452],
1400
+ >>> [1.0000, 657.1263, 302.8564, 762.4214, 567.9854],
1401
+ >>> [2.0000, 321.3122, 232.2410, 679.0281, 587.6346],
1402
+ >>> [3.0000, 664.1630, 387.4919, 778.7322, 562.7321],
1403
+ >>> ])).astype(mindspore.float32)
1404
+ >>> psRoIPooling = inner.PsROIPooling(pooled_height=7, pooled_width=7, num_rois=4,
1405
+ >>> spatial_scale=1.0/16, out_dim=21,
1406
+ >>> group_size=7)
1407
+ >>> out, channel_map = psRoIPooling(features, rois)
1408
+ >>> print(out.shape)
1409
+ [4, 21, 7, 7]
1410
+ >>> print(channel_map.shape)
1411
+ [4, 21, 7, 7]
1412
+ """
1413
+
1414
+ @prim_attr_register
1415
+ def __init__(self, pooled_height, pooled_width, num_rois, spatial_scale, out_dim, group_size):
1416
+ """Initialize PsROIPooling"""
1417
+ validator.check_value_type("pooled_height", pooled_height, [int], self.name)
1418
+ validator.check_value_type("pooled_width", pooled_width, [int], self.name)
1419
+ validator.check_value_type("num_rois", pooled_width, [int], self.name)
1420
+ validator.check_value_type("spatial_scale", spatial_scale, [float], self.name)
1421
+ validator.check_value_type("out_dim", out_dim, [int], self.name)
1422
+ validator.check_value_type("group_size", group_size, [int], self.name)
1423
+ self.pooled_height = pooled_height
1424
+ self.pooled_width = pooled_width
1425
+ self.num_rois = num_rois
1426
+ self.spatial_scale = spatial_scale
1427
+ self.out_dim = out_dim
1428
+ self.group_size = group_size
1429
+
1430
+ def infer_shape(self, inputs_shape, rois_shape):
1431
+ output_shape = [self.num_rois, self.out_dim, self.pooled_height, self.pooled_width]
1432
+ output_map_shape = [self.num_rois, self.out_dim, self.pooled_height, self.pooled_width]
1433
+ return output_shape, output_map_shape
1434
+
1435
+ def infer_dtype(self, inputs_type, rois_type):
1436
+ map_type = mstype.TensorType(mstype.int32)
1437
+ return inputs_type, map_type
1438
+
1439
+
1440
+ class ParallelResizeBilinear(PrimitiveWithInfer):
1441
+ """ParallelResizeBilinear ops"""
1442
+
1443
+ @prim_attr_register
1444
+ def __init__(self, ori_image_size, split_size, src_start_w, dst_start_w, align_corners):
1445
+ """Initialize ParallelResizeBilinear."""
1446
+ validator.check_value_type("ori_image_size", ori_image_size, [list, tuple], self.name)
1447
+ validator.check_value_type("split_size", split_size, [list, tuple], self.name)
1448
+ validator.check_int(len(split_size), 2, validator.EQ, "len of split_size", self.name)
1449
+ validator.check_value_type("src_start_w", src_start_w, [int], self.name)
1450
+ validator.check_value_type("dst_start_w", dst_start_w, [int], self.name)
1451
+ validator.check_value_type("align_corners", align_corners, [bool], self.name)
1452
+ self.ori_image_size = list(ori_image_size)
1453
+ self.split_size = list(split_size)
1454
+ self.src_start_w = src_start_w
1455
+ self.dst_start_w = dst_start_w
1456
+ self.align_corners = align_corners
1457
+ self.half_pixel_centers = False
1458
+ self.add_prim_attr('ori_image_size', self.ori_image_size)
1459
+ self.add_prim_attr('split_size', self.split_size)
1460
+ self.add_prim_attr('src_start_w', self.src_start_w)
1461
+ self.add_prim_attr('dst_start_w', self.dst_start_w)
1462
+ self.add_prim_attr('align_corners', self.align_corners)
1463
+ self.add_prim_attr('half_pixel_centers', self.half_pixel_centers)
1464
+
1465
+ def __infer__(self, x, size):
1466
+ size_val = size['value']
1467
+ x_shape = x['shape']
1468
+ x_dtype = x['dtype']
1469
+ validator.check_tensor_dtype_valid("x_dtype", x_dtype, [mstype.float16, mstype.float32], self.name)
1470
+ if size_val is None:
1471
+ raise ValueError("size must be const input")
1472
+ output_shape = [x_shape[0], x_shape[1], self.split_size[0], self.split_size[1]]
1473
+
1474
+ return {'shape': output_shape,
1475
+ 'dtype': x_dtype,
1476
+ 'value': None}
1477
+
1478
+
1479
+ class PartitionedCall(PrimitiveWithInfer):
1480
+ """
1481
+ Pass the input tensors to the subgraph and return the output tensors.
1482
+
1483
+ Inputs:
1484
+ - **inputs** (Tuple), the input tensors, which will be passed to subgraph.
1485
+
1486
+ Outputs:
1487
+ - outputs(Tuple), the output tensor returned by subgraph.
1488
+
1489
+ Supported Platforms:
1490
+ ``Ascend``
1491
+
1492
+ Examples:
1493
+ """
1494
+
1495
+ @prim_attr_register
1496
+ def __init__(self, graph, executor_type=""):
1497
+ super(PartitionedCall, self).__init__(self.__class__.__name__)
1498
+ self.add_prim_attr("executor_type", executor_type)
1499
+ self.graph = graph
1500
+
1501
+ def infer_shape(self, *inputs):
1502
+ return NotImplementedError
1503
+
1504
+ def infer_dtype(self, *inputs):
1505
+ return NotImplementedError
1506
+
1507
+
1508
+ class CellBackwardHook(PrimitiveWithInfer):
1509
+ r"""
1510
+ This operator is used to hook input gradient and output gradient of Cell object.
1511
+
1512
+ Note:
1513
+ This operator is only used in backward hook function of Cell object in pynative mode.
1514
+
1515
+ Args:
1516
+ cell_id (str): Used to identify which cell obj the hook function registered on. For example, 'nn.Add()' is a
1517
+ cell object.
1518
+
1519
+ Inputs:
1520
+ - **input** - The variable to hook.
1521
+
1522
+ Outputs:
1523
+ - **output** - Returns `input` directly. `CellBackwardHook` does not affect the forward result.
1524
+
1525
+ Supported Platforms:
1526
+ ``Ascend`` ``GPU`` ``CPU``
1527
+
1528
+ Examples:
1529
+ >>> import mindspore as ms
1530
+ >>> from mindspore import Tensor
1531
+ >>> from mindspore.ops import GradOperation
1532
+ >>> from mindspore.ops.operations import _inner_ops as inner
1533
+ >>> ms.set_context(mode=ms.PYNATIVE_MODE)
1534
+ >>> def hook_fn(grad):
1535
+ ... print(grad)
1536
+ ...
1537
+ >>> hook = inner.CellBackwardHook()
1538
+ >>> hook_fn_key = hook.register_backward_hook(hook_fn)
1539
+ >>> def hook_test(x, y):
1540
+ ... z = x * y
1541
+ ... z = hook(z)
1542
+ ... z = z * y
1543
+ ... return z
1544
+ ...
1545
+ >>> grad_all = GradOperation(get_all=True)
1546
+ >>> def backward(x, y):
1547
+ ... return grad_all(hook_test)(x, y)
1548
+ ...
1549
+ >>> output = backward(Tensor(1, mindspore.float32), Tensor(2, mindspore.float32))
1550
+ (Tensor(shape=[], dtype=Float32, value= 2),)
1551
+ >>> print(output)
1552
+ (Tensor(shape=[], dtype=Float32, value= 4), Tensor(shape=[], dtype=Float32, value= 4))
1553
+ >>> hook.remove_backward_hook(hook_fn_key)
1554
+ >>> output = backward(Tensor(1, mindspore.float32), Tensor(2, mindspore.float32))
1555
+ >>> print(output)
1556
+ (Tensor(shape=[], dtype=Float32, value= 4), Tensor(shape=[], dtype=Float32, value= 4))
1557
+ """
1558
+
1559
+ def __init__(self, cell_id=""):
1560
+ """Initialize CellBackwardHook"""
1561
+ super(CellBackwardHook, self).__init__(self.__class__.__name__)
1562
+ self.cell_id = cell_id
1563
+ self.add_prim_attr("cell_id", cell_id)
1564
+ self.init_attrs["cell_id"] = cell_id
1565
+
1566
+ def __call__(self, args):
1567
+ if not isinstance(args, tuple):
1568
+ args = (args,)
1569
+ return _run_op(self, self.name, args)
1570
+
1571
+ def infer_shape(self, *inputs_shape):
1572
+ if len(inputs_shape) == 1:
1573
+ return inputs_shape[0]
1574
+ return inputs_shape
1575
+
1576
+ def infer_dtype(self, *inputs_type):
1577
+ if len(inputs_type) == 1:
1578
+ return inputs_type[0]
1579
+ return inputs_type
1580
+
1581
+ def register_backward_hook(self, hook_fn):
1582
+ r"""
1583
+ This function is used to register backward hook function. Note that this function is only supported in pynative
1584
+ mode.
1585
+
1586
+ Note:
1587
+ The 'hook_fn' must be defined as the following code.
1588
+ `cell_id` is the information of registered cell. `grad_input` is the gradient passed to the cell.
1589
+ `grad_output` is the gradient computed and passed to the next cell or primitive, which may be modified by
1590
+ returning a new output gradient.
1591
+ The 'hook_fn' should have the following signature:
1592
+ hook_fn(cell_id, grad_input, grad_output) -> New output gradient or none.
1593
+ The 'hook_fn' is executed in the python environment.
1594
+
1595
+ Args:
1596
+ hook_fn (Function): Python function. Backward hook function.
1597
+
1598
+ Returns:
1599
+ - **key** (int) - The key of 'hook_fn'.
1600
+
1601
+ Raises:
1602
+ TypeError: If the `hook_fn` is not a function of python.
1603
+ """
1604
+ if not isinstance(hook_fn, (FunctionType, MethodType)):
1605
+ raise TypeError(f"When using 'register_backward_hook(hook_fn)', the type of 'hook_fn' must be python "
1606
+ f"function, but got {type(hook_fn)}.")
1607
+ key = self.add_backward_hook_fn(hook_fn)
1608
+ return key
1609
+
1610
+ def remove_backward_hook(self, key):
1611
+ r"""
1612
+ This function is used to remove backward hook function. Note that this operation is only supported in pynative
1613
+ mode.
1614
+
1615
+ Note:
1616
+ The 'key' is the object returned by 'register_backward_hook' function of the same CellBackwardHook
1617
+ operator.
1618
+
1619
+ Args:
1620
+ key (int): The key corresponding to the 'hook_fn'.
1621
+
1622
+ Returns:
1623
+ None.
1624
+ """
1625
+ self.remove_backward_hook_fn(key)
1626
+
1627
+
1628
+ class Format(PrimitiveWithInfer):
1629
+ r"""
1630
+ This operator is used to format a string.
1631
+
1632
+ Note:
1633
+ Current not supported to using by customer.
1634
+ Only support convert str.format() in user code and it will be converted to be Format
1635
+ operation by ME-Compiler automatically.
1636
+
1637
+
1638
+ Inputs:
1639
+ - **input** -
1640
+ string : the string to be formatted.
1641
+ args : the format args.
1642
+
1643
+ Outputs:
1644
+ - **output** - Returns formatted string.
1645
+
1646
+ Supported Platforms:
1647
+ ``Ascend`` ``GPU`` ``CPU``
1648
+ """
1649
+
1650
+ @prim_attr_register
1651
+ def __init__(self):
1652
+ self.init_prim_io_names(inputs=['string', 'args'], outputs=['string'])
1653
+
1654
+ def __infer__(self, str_, *var):
1655
+ def check_variable(str_, var):
1656
+ if _check_contains_variable(str_['dtype'], str_['value']):
1657
+ return True
1658
+
1659
+ for item in var:
1660
+ if _check_contains_variable(item['dtype'], item['value']):
1661
+ return True
1662
+ return False
1663
+
1664
+ if check_variable(str_, var):
1665
+ return {'dtype': mstype.string, 'shape': [], 'value': None}
1666
+
1667
+ str_value = str_['value']
1668
+ kwargs = dict()
1669
+ var_value = list()
1670
+
1671
+ for item in var:
1672
+ if isinstance(item["dtype"], typing.Keyword):
1673
+ kwargs.update(item["value"])
1674
+ var_value.append(item["value"])
1675
+
1676
+ value = str_value.format(*var_value, **kwargs)
1677
+ return {'dtype': mstype.string, 'shape': [], 'value': value}
1678
+
1679
+
1680
+ class FlattenConcat(Primitive):
1681
+ """
1682
+ Flatten input tensors and concatenate them into several chunk tensors grouped by data types.
1683
+
1684
+ Args:
1685
+ fusion_size (int): Maximum memory chunk size in bytes, 0 for unlimited. Default: 0.
1686
+
1687
+ Inputs:
1688
+ - **tensors** (tuple[Tensor], list[Tensor]) - The input Tensors to be flattened and concatenated.
1689
+
1690
+ Outputs:
1691
+ tuple[Tensor], result chunk tensors.
1692
+
1693
+ Supported Platforms:
1694
+ ``Ascend`` ``GPU`` ``CPU``
1695
+
1696
+ Examples:
1697
+ >>> from mindspore.ops.operations import _inner_ops as inner
1698
+ >>> t1 = Tensor(np.array([1]).astype(np.float32))
1699
+ >>> t2 = Tensor(np.array([2]).astype(np.float32))
1700
+ >>> t3 = Tensor(np.array([3]).astype(np.float64))
1701
+ >>> t4 = Tensor(np.array([4]).astype(np.float32))
1702
+ >>> t5 = Tensor(np.array([5]).astype(np.float64))
1703
+ >>> chunks = inner.FlattenConcat()([t1, t2, t2, t3, t4, t5])
1704
+ >>> print(chunks[0].asnumpy())
1705
+ >>> print(chunks[1].asnumpy())
1706
+ [1. 2. 4.]
1707
+ [3. 5.]
1708
+ """
1709
+
1710
+ @prim_attr_register
1711
+ def __init__(self, fusion_size=0):
1712
+ """Initialize FlattenConcat"""
1713
+ validator.check_non_negative_int(fusion_size, 'fusion_size', self.name)
1714
+ self.fusion_size = fusion_size
1715
+ self.add_prim_attr('fusion_size', fusion_size)
1716
+
1717
+
1718
+ class KMeansCentroids(PrimitiveWithInfer):
1719
+ """
1720
+ Calculate the segment_sum, segment_count, kmean_total_sum that are clustering results
1721
+
1722
+ Args:
1723
+ use_actual_distance (bool): A bool value to decide whether do complete calculation of distance.
1724
+
1725
+ Inputs:
1726
+ - **x** (Tensor(float32)) - Input data used for clustering
1727
+ - **y** (Tensor(float32)) - Initial centroids of clutering
1728
+ - **sum_square_y** (Tensor(float32)) - The result of preprocessing such as square, reduce and transpose of y
1729
+ - **sum_square_x** (Tensor(float32)) - The result of preprocessing such as square and reduce of x
1730
+
1731
+ Outputs:
1732
+ - **segment_sum** (Tensor(float32)) - Clustering result w.r.t. each centroid
1733
+ - **segment_count** (Tensor(float32)) - Clustering count w.r.t. each centroid
1734
+ - **kmean_total_sum** (Tensor(float32)) - The sum of the distances from all vectors to ther nearest centroid
1735
+
1736
+ Supported Platforms:
1737
+ ''Ascend''
1738
+
1739
+ Examples:
1740
+ >>> import numpy as np
1741
+ >>> import mindspore as ms
1742
+ >>> import mindspore.common.dtype as mstype
1743
+ >>> import mindspore.nn as nn
1744
+ >>> from mindspore import Tensor
1745
+ >>> from mindspore.ops import operations as P
1746
+ >>> ms.set_context(mode=ms.GRAPH_MODE, device_target="Ascend")
1747
+
1748
+ >>> class Net(nn.Cell):
1749
+ >>> def __init__(self):
1750
+ >>> super(Net, self).__init__()
1751
+ >>> self.reduce_sum = P.ReduceSUm(keep_dims=True)
1752
+ >>> self.square = P.Square()
1753
+ >>> self.transpose = P.Transpose()
1754
+ >>> self.k_means_centroids = P.KMeansCentroids(True)
1755
+
1756
+ >>> def construct(self, x, y):
1757
+ >>> p1 = self.reduce_sum(self.square(x), -1)
1758
+ >>> p2 = self.transpose(self.reduce_sum(self.square(y), -1), (1, 0))
1759
+ >>> return self.k_means_centroids(x, y, p2, p1)
1760
+
1761
+ >>> def test_net():
1762
+ >>> data_type = np.float32
1763
+ >>> x = Tensor(np.random.uniform(-10, 10, (65536, 128)).astype(data_type))
1764
+ >>> y = P.Ones()((1048576, 128), mstype.float32)
1765
+ >>> net = Net()
1766
+ >>> local_sum, local_count, local_avg_distance = net(x, y)
1767
+ """
1768
+
1769
+ @prim_attr_register
1770
+ def __init__(self, use_actual_distance):
1771
+ validator.check_value_type('use_actual_distance', use_actual_distance, [bool], self.name)
1772
+ self.init_prim_io_names(inputs=['x', 'y', 'sum_square_y', 'sum_square_x'],
1773
+ outputs=['segment_sum', 'segment_count', 'kmean_total_sum'])
1774
+
1775
+ def infer_shape(self, x_shape, y_shape, sum_square_y_shape, sum_square_x_shape):
1776
+ """infer shape of primitive"""
1777
+ expected_shape_size = 2
1778
+ validator.check_int(len(x_shape), expected_shape_size, validator.EQ, "dims of x", self.name)
1779
+ validator.check_int(len(y_shape), expected_shape_size, validator.EQ, "dims of y", self.name)
1780
+ validator.check_int(len(sum_square_y_shape), expected_shape_size, validator.EQ,
1781
+ "dims of sum_square_y", self.name)
1782
+ validator.check_int(len(sum_square_x_shape), expected_shape_size, validator.EQ,
1783
+ "dims of sum_square_x", self.name)
1784
+
1785
+ validator.check_int(x_shape[1], y_shape[1], validator.EQ,
1786
+ "the second dim of x and the second dim of y", self.name)
1787
+ validator.check_int(y_shape[0], sum_square_y_shape[1], validator.EQ,
1788
+ "the first dim of y and the second dim of sum_square_y", self.name)
1789
+ validator.check_int(x_shape[0], sum_square_x_shape[0], validator.EQ,
1790
+ "the first dim of x and the first dim of sum_square_x", self.name)
1791
+ validator.check_int(sum_square_y_shape[0], sum_square_x_shape[1], validator.EQ,
1792
+ "the first dim of sum_square_y and the first dim of sum_square_x",
1793
+ self.name)
1794
+ validator.check_int(sum_square_y_shape[0], 1, validator.EQ,
1795
+ "the first dim of sum_square_y", self.name)
1796
+
1797
+ k = y_shape[0]
1798
+ em_size = x_shape[1]
1799
+ return (k, em_size), (k, 1), (1)
1800
+
1801
+
1802
+ class ClipByNorm(PrimitiveWithInfer):
1803
+ r"""
1804
+ Clips tensor values to a maximum :math:`L_2`-norm.
1805
+
1806
+ Note:
1807
+ The output tensor of this operator remains the same with input tensor if the :math:`L_2`-norm of the input
1808
+ tensor is not greater than the argument `clip_norm`. Otherwise the output tensor will be normalized as:
1809
+
1810
+ .. math::
1811
+ \text{output}(X) = \frac{\text{clip_norm} * X}{L_2(X)},
1812
+
1813
+ where :math:`L_2(X)` is the :math:`L_2`-norm of :math:`X`.
1814
+
1815
+ Args:
1816
+ axis (Union[None, int, tuple(int), list(int)]): Compute the `L_2`-norm along the specific dimension.
1817
+ Default: ``None``, all dimensions to calculate.
1818
+
1819
+ Inputs:
1820
+ - **x** (Tensor) - Tensor of shape N-D. The type must be float16 or float32.
1821
+ - **clip_norm** (Tensor) - A scalar Tensor of shape :math:`()` or :math:`(1)`.
1822
+ Or a Tensor which shape can be broadcast to the shape of `x`. The type must be float16 or float32.
1823
+
1824
+ Outputs:
1825
+ Tensor, clipped Tensor with the same shape as the `x`, whose type is float32.
1826
+
1827
+ Raises:
1828
+ TypeError: If `axis` is not one of None, int, tuple(int) and list(int).
1829
+ TypeError: If dtype of `x` is neither float16 nor float32.
1830
+ TypeError: If dtype of `clip_norm` is neither float16 nor float32.
1831
+
1832
+ Supported Platforms:
1833
+ ``Ascend`` ``GPU`` ``CPU``
1834
+
1835
+ Examples:
1836
+ >>> import numpy as np
1837
+ >>> import mindspore
1838
+ >>> from mindspore import Tensor
1839
+ >>> from mindspore.ops.operations import _inner_ops as inner
1840
+ >>> clip_by_norm = inner.ClipByNorm()
1841
+ >>> x = Tensor(np.random.randint(0, 10, [4, 16]), mindspore.float32)
1842
+ >>> clip_norm = Tensor(np.array([100]).astype(np.float32))
1843
+ >>> output = clip_by_norm(x, clip_norm)
1844
+ >>> print(output.shape)
1845
+ (4, 16)
1846
+ """
1847
+
1848
+ @prim_attr_register
1849
+ def __init__(self, axis=None):
1850
+ """Initialize ClipByNorm"""
1851
+ self.axis = () if axis is None else axis
1852
+ validator.check_value_type('axis', self.axis, [int, tuple, list], self.name)
1853
+ axis_check = self.axis if isinstance(self.axis, Iterable) else (self.axis,)
1854
+ for i, value in enumerate(axis_check):
1855
+ validator.check_value_type('axis[%d]' % i, value, [int], self.name)
1856
+ self.init_attrs['axis'] = self.axis
1857
+ self.add_prim_attr('axis', self.axis)
1858
+ self.init_prim_io_names(inputs=['x', 'clip_norm'], outputs=['output'])
1859
+
1860
+ def infer_shape(self, x_shape, clip_norm_shape):
1861
+ """Infer shape for ClipByNorm"""
1862
+ x_dim = len(x_shape)
1863
+ axis = self.axis if isinstance(self.axis, Iterable) else (self.axis,)
1864
+ for _, value in enumerate(axis):
1865
+ validator.check_int_range(value, -x_dim, x_dim, validator.INC_LEFT, 'axis', self.name)
1866
+ return x_shape
1867
+
1868
+ def infer_dtype(self, x_type, clip_norm_type):
1869
+ """Infer data type for ClipByNorm"""
1870
+ validator.check_tensor_dtype_valid("x_type", x_type, [mstype.float16, mstype.float32], self.name)
1871
+ validator.check_tensor_dtype_valid("clip_norm_type", clip_norm_type,
1872
+ [mstype.float16, mstype.float32], self.name)
1873
+ return mstype.float32
1874
+
1875
+
1876
+ class TopTypeof(Primitive):
1877
+ """
1878
+ Internal primitive method, to speed up mindspore.ops.typeof.
1879
+
1880
+ Returns the top type of the input data.
1881
+
1882
+ In Pynative mode, returns the top type in cache.
1883
+
1884
+ Supported Platforms:
1885
+ ``Ascend`` ``GPU`` ``CPU``
1886
+ """
1887
+
1888
+ @prim_attr_register
1889
+ def __init__(self):
1890
+ self.prim = Primitive('TopTypeof')
1891
+ self.typeof_cache = {
1892
+ 'slice': mstype.Slice(),
1893
+ 'list': mstype.List(),
1894
+ 'tuple': mstype.Tuple(),
1895
+ 'Tensor': mstype.tensor_type,
1896
+ 'NoneType': mstype.NoneType(),
1897
+ 'int': mstype.Int(),
1898
+ 'bool': mstype.Bool(),
1899
+ 'ellipsis': mstype.Ellipsis_(),
1900
+ 'dict': mstype.Dict()
1901
+ }
1902
+
1903
+ def __call__(self, x):
1904
+ index_type = type(x).__name__
1905
+ if 'Tensor' in index_type:
1906
+ index_type = 'Tensor'
1907
+ if index_type in self.typeof_cache:
1908
+ return self.typeof_cache.get(index_type)
1909
+ return _pynative_executor.constant_folding(self.prim, x)
1910
+
1911
+
1912
+ class MixedPrecisionCast(Primitive):
1913
+ r"""
1914
+ Internal primitive method, to achieve mindspore.functional.mixed_precision_cast.
1915
+
1916
+ Note:
1917
+ This internal primitive method used to do mixed precision conversion.
1918
+ Only the input object with float dtype will be cast.
1919
+
1920
+ Inputs:
1921
+ - **dtype** (Union[Float16, Float32]) - The data type of the output object.
1922
+ - **input** (Union[Tensor, Tuple, Dictionary, KeywordArg]) - The object to be cast.
1923
+
1924
+ Outputs:
1925
+ Object, its dtype is the same as `dtype` and shape is the same as 'input'.
1926
+
1927
+ Supported Platforms:
1928
+ ``Ascend`` ``GPU`` ``CPU``
1929
+
1930
+ Examples:
1931
+ >>> import numpy as np
1932
+ >>> from mindspore import Tensor
1933
+ >>> from mindspore import dtype as mstype
1934
+ >>> from mindspore.ops.operations import _inner_ops as inner
1935
+ >>> x = Tensor(np.ones([2, 3], dtype=np.float32))
1936
+ >>> out = inner.MixedPrecisionCast(mstype.float16, x)
1937
+ >>> print(out.dtype)
1938
+ Float16
1939
+ """
1940
+
1941
+ @prim_attr_register
1942
+ def __init__(self):
1943
+ """Initialize MixedPrecisionCast"""
1944
+ self.init_prim_io_names(inputs=['dst_dtype', 'input_x'], outputs=['output'])
1945
+ self.cast = Cast()
1946
+ self.hyper_map = C.HyperMap()
1947
+
1948
+ def __call__(self, dst_dtype, x):
1949
+ def cast_inner(data):
1950
+ if isinstance(data, Tensor) and data.dtype in (mstype.float16, mstype.float32,
1951
+ mstype.float64, mstype.bfloat16):
1952
+ return self.cast(data, dst_dtype)
1953
+ return data
1954
+
1955
+ return self.hyper_map(cast_inner, x)
1956
+
1957
+
1958
+ class CheckBprop(PrimitiveWithInfer):
1959
+ """
1960
+ Checks whether the data type and the shape of corresponding elements from tuples x and y are the same.
1961
+
1962
+ Args:
1963
+ prim_to_check (str): The name of the primitive being checked. Default: ''.
1964
+
1965
+ Inputs:
1966
+ - **input_x** (tuple[Tensor]) - The `input_x` contains the outputs of bprop to be checked.
1967
+ - **input_y** (tuple[Tensor]) - The `input_y` contains the inputs of bprop to check against.
1968
+
1969
+ Outputs:
1970
+ Tuple[Tensor], the `input_x`,
1971
+ if data type and shape of corresponding elements from `input_x` and `input_y` are the same.
1972
+
1973
+ Raises:
1974
+ TypeError: If `input_x` or `input_y` is not a Tensor.
1975
+
1976
+ Supported Platforms:
1977
+ ``Ascend`` ``GPU`` ``CPU``
1978
+
1979
+ Examples:
1980
+ >>> class Net(nn.Cell):
1981
+ ... def __init__(self):
1982
+ ... super(Net, self).__init__()
1983
+ ... self.op = ops.CheckBprop()
1984
+ ... def construct(self, x, y):
1985
+ ... return self.op(x, y)
1986
+ ...
1987
+ >>> net = Net()
1988
+ >>> input_x = (Tensor(np.array([[2, 2], [2, 2]]), mindspore.float32),)
1989
+ >>> input_y = (Tensor(np.array([[2, 2], [2, 2]]), mindspore.float32),)
1990
+ >>> output = net(input_x, input_y)
1991
+ >>> print(output)
1992
+ (Tensor(shape=[2, 2], dtype=Float32, value=
1993
+ [[ 2.00000000e+00, 2.00000000e+00],
1994
+ [ 2.00000000e+00, 2.00000000e+00]]),)
1995
+ """
1996
+
1997
+ @prim_attr_register
1998
+ def __init__(self, prim_to_check=""):
1999
+ """Initialize CheckBprop"""
2000
+ self.prim_to_check = prim_to_check
2001
+
2002
+ def infer_shape(self, xshapes, yshapes):
2003
+ """infer shape"""
2004
+ tips = f"user defined method 'bprop'"
2005
+ validator.check_value_type('grads', xshapes, (tuple,), tips)
2006
+ validator.check_value_type('params', yshapes, (tuple,), tips)
2007
+ if not len(xshapes) == len(yshapes):
2008
+ raise ValueError(f"For {tips} the number of return values(gradients) must be equal to "
2009
+ f"the number of input arguments except 'out' and 'dout', "
2010
+ f"which is:{len(yshapes)} but got {len(xshapes)}.")
2011
+
2012
+ def shape_equal(shape1, shape2):
2013
+ if len(shape1) != len(shape2):
2014
+ return False
2015
+ for shape_axis1, shape_axis2 in zip(shape1, shape2):
2016
+ if shape_axis1 == -1 or shape_axis2 == -1:
2017
+ continue
2018
+ if shape_axis1 != shape_axis2:
2019
+ return False
2020
+ return True
2021
+
2022
+ for i, (xshape, yshape) in enumerate(zip(xshapes, yshapes)):
2023
+ if not xshape or not yshape:
2024
+ continue
2025
+
2026
+ if not shape_equal(xshape, yshape):
2027
+ raise ValueError(f"For {tips}, the {i}th return value(gradient of the {i}th argument) "
2028
+ f"should have the same shape as the {i}th argument, "
2029
+ f"which is:{yshape}, but got: {xshape}.")
2030
+ return xshapes
2031
+
2032
+ def infer_dtype(self, xdtypes, ydtypes):
2033
+ """infer dtype"""
2034
+ tips = f"user defined method 'bprop'"
2035
+ validator.check_value_type('grads', xdtypes, (tuple,), tips)
2036
+ validator.check_value_type('params', ydtypes, (tuple,), tips)
2037
+ if not len(xdtypes) == len(ydtypes):
2038
+ raise ValueError(f"For {tips}, the number of return values(gradients) must be equal to "
2039
+ f"the number of input arguments except 'out' and 'dout', "
2040
+ f"which is:{len(ydtypes)} but got {len(xdtypes)}.")
2041
+ checking_range = len(ydtypes)
2042
+ for i in range(checking_range):
2043
+ xdtype = xdtypes[i]
2044
+ ydtype = ydtypes[i]
2045
+ if isinstance(xdtype, mstype.AnythingType) or isinstance(ydtype, mstype.AnythingType):
2046
+ continue
2047
+ if isinstance(ydtype, mstype.FunctionType):
2048
+ if not isinstance(xdtype, mstype.EnvType):
2049
+ raise TypeError(f"For {tips}, the {i}th return value(gradient of the {i}th argument) type "
2050
+ f"should be {mstype.EnvType}, but got {xdtype}.")
2051
+ if xdtype != ydtype:
2052
+ raise TypeError(f"For {tips}, the {i}th return value(gradient of the {i}th argument) "
2053
+ f"should have the same dtype as the {i}th argument, "
2054
+ f"which is:{ydtype}, but got: {xdtype}.")
2055
+ return xdtypes
2056
+
2057
+
2058
+ check_bprop = CheckBprop()
2059
+
2060
+
2061
+ class SameTypeShape(PrimitiveWithInfer):
2062
+ """
2063
+ Checks whether the data type and shape of two tensors are the same.
2064
+
2065
+ Refer to :func:`mindspore.ops.same_type_shape` for more detail.
2066
+
2067
+ Supported Platforms:
2068
+ ``Ascend`` ``GPU`` ``CPU``
2069
+
2070
+ Examples:
2071
+ >>> input_x = Tensor(np.array([[2, 2], [2, 2]]), mindspore.float32)
2072
+ >>> input_y = Tensor(np.array([[2, 2], [2, 2]]), mindspore.float32)
2073
+ >>> output = ops.SameTypeShape()(input_x, input_y)
2074
+ >>> print(output)
2075
+ [[2. 2.]
2076
+ [2. 2.]]
2077
+ """
2078
+
2079
+ @prim_attr_register
2080
+ def __init__(self):
2081
+ """Initialize Same"""
2082
+
2083
+ def __call__(self, x, y):
2084
+ """run in PyNative mode"""
2085
+ validator.check_value_type('x', x, Tensor, self.name)
2086
+ validator.check_value_type('y', y, Tensor, self.name)
2087
+ validator.check('x dtype', x.dtype, 'y dtype', y.dtype, validator.EQ, self.name, TypeError)
2088
+ validator.check('x shape', x.shape, 'y shape', y.shape, validator.EQ, self.name)
2089
+ return x
2090
+
2091
+ def __infer__(self, x, y):
2092
+ validator.check_subclass('x', x['dtype'], mstype.tensor_type, self.name)
2093
+ validator.check_subclass('y', y['dtype'], mstype.tensor_type, self.name)
2094
+ validator.check('x dtype', x['dtype'], 'y dtype', y['dtype'], validator.EQ, self.name, TypeError)
2095
+ validator.check('x shape', x['shape'], 'y shape', y['shape'], validator.EQ, self.name)
2096
+ return x
2097
+
2098
+
2099
+ same_type_shape_ = SameTypeShape()
2100
+
2101
+
2102
+ def _is_subclass_(type_, dtype):
2103
+ if not isinstance(type_, typing.Type):
2104
+ return False
2105
+ return typing.is_subclass(type_, dtype)
2106
+
2107
+
2108
+ class IsSubClass(PrimitiveWithInfer):
2109
+ """
2110
+ Checks whether this type is a sub-class of another type.
2111
+
2112
+ Inputs:
2113
+ - **sub_type** (mindspore.dtype) - The type to be checked. Only constant value is allowed.
2114
+ - **type_** (mindspore.dtype) - The target type. Only constant value is allowed.
2115
+
2116
+ Outputs:
2117
+ bool, the check result.
2118
+
2119
+ Raises:
2120
+ TypeError: If `sub_type` or `type_` is not a Type.
2121
+
2122
+ Supported Platforms:
2123
+ ``Ascend`` ``GPU`` ``CPU``
2124
+
2125
+ Examples:
2126
+ >>> output = ops.IsSubClass()(mindspore.int32, mindspore.intc)
2127
+ >>> print(output)
2128
+ True
2129
+ """
2130
+
2131
+ @prim_attr_register
2132
+ def __init__(self):
2133
+ pass
2134
+
2135
+ def __infer__(self, sub_type, type_):
2136
+ sub_type_t = sub_type['value']
2137
+ type_v = type_['value']
2138
+
2139
+ validator.check_value_type("sub_type", sub_type_t, [mstype.Type], self.name)
2140
+ validator.check_value_type("type_", type_v, [mstype.Type], self.name)
2141
+
2142
+ value = _is_subclass_(sub_type_t, type_v)
2143
+
2144
+ out = {'shape': (),
2145
+ 'dtype': mstype.type_type,
2146
+ 'value': value}
2147
+ return out
2148
+
2149
+
2150
+ issubclass_ = IsSubClass()
2151
+
2152
+
2153
+ class IsInstance(PrimitiveWithInfer):
2154
+ """
2155
+ Checks whether an object is an instance of a target type.
2156
+
2157
+ Inputs:
2158
+ - **inst** (Any Object) - The instance to be checked. Only constant value is allowed.
2159
+ - **type_** (mindspore.dtype) - The target type. Only constant value is allowed.
2160
+
2161
+ Outputs:
2162
+ bool, the check result.
2163
+
2164
+ Raises:
2165
+ TypeError: If `type_` is not a Type.
2166
+
2167
+ Supported Platforms:
2168
+ ``Ascend`` ``GPU`` ``CPU``
2169
+
2170
+ Examples:
2171
+ >>> inst = 1
2172
+ >>> output = ops.IsInstance()(inst, mindspore.int32)
2173
+ >>> print(output)
2174
+ False
2175
+ """
2176
+
2177
+ @prim_attr_register
2178
+ def __init__(self):
2179
+ pass
2180
+
2181
+ def __infer__(self, inst, type_):
2182
+ sub_type_t = inst['dtype']
2183
+ type_v = type_['value']
2184
+
2185
+ validator.check_value_type("type_", type_v, [mstype.Type], self.name)
2186
+
2187
+ if type_v == mstype.list_:
2188
+ value = isinstance(sub_type_t, list)
2189
+ elif type_v == mstype.tuple_:
2190
+ value = isinstance(sub_type_t, tuple)
2191
+ else:
2192
+ value = _is_subclass_(sub_type_t, type_v)
2193
+
2194
+ out = {'shape': (),
2195
+ 'dtype': mstype.type_type,
2196
+ 'value': value}
2197
+ return out
2198
+
2199
+
2200
+ class ConvertToAdapterTensor(Primitive):
2201
+ """
2202
+ Convert a tensor from MindSpore's Tensor type to MSAdapter's Tensor type,
2203
+ where MSAdapter's Tensor is a subclass of MindSpore's Tensor.
2204
+
2205
+ Inputs:
2206
+ - **x** (Tensor) - The input tensor.
2207
+
2208
+ Outputs:
2209
+ A tensor, whose type is MSAdapter's Tensor.
2210
+
2211
+ Supported Platforms:
2212
+ ``Ascend`` ``GPU`` ``CPU``
2213
+
2214
+ Examples:
2215
+ >>> x = Tensor([1, 2 ,3])
2216
+ >>> x = ops.ConvertToAdapterTensor()(x)
2217
+ >>> print(x)
2218
+ [1 2 3]
2219
+ """
2220
+
2221
+ @prim_attr_register
2222
+ def __init__(self):
2223
+ """Initialize"""
2224
+
2225
+ def __call__(self, x):
2226
+ """Run in PyNative mode"""
2227
+ return ms_adapter_registry.tensor(x, cast_tensor=True)
2228
+
2229
+
2230
+ convert_to_adapter_tensor = ConvertToAdapterTensor()
2231
+
2232
+
2233
+ class ConvertToMsTensor(Primitive):
2234
+ """
2235
+ Convert a tensor from MSAdapter's Tensor type to MindSpore's Tensor type,
2236
+ where MSAdapter's Tensor is a subclass of MindSpore's Tensor.
2237
+
2238
+ Inputs:
2239
+ - **x** (Tensor) - The input tensor.
2240
+
2241
+ Outputs:
2242
+ A tensor, whose type is MindSpore's Tensor.
2243
+
2244
+ Supported Platforms:
2245
+ ``Ascend`` ``GPU`` ``CPU``
2246
+
2247
+ Examples:
2248
+ >>> x = Tensor([1, 2 ,3])
2249
+ >>> x = ops.ConvertToMsTensor()(x)
2250
+ >>> print(x)
2251
+ [1 2 3]
2252
+ """
2253
+
2254
+ @prim_attr_register
2255
+ def __init__(self):
2256
+ """Initialize"""
2257
+
2258
+ def __call__(self, x):
2259
+ """Run in PyNative mode"""
2260
+ if isinstance(x, StubTensor):
2261
+ return StubTensor(stub=x.stub, tensor=x.tensor)
2262
+ return ops.auto_generate.deepcopy(x)
2263
+
2264
+
2265
+ convert_to_ms_tensor = ConvertToMsTensor()
2266
+
2267
+
2268
+ class GetGrad(Primitive):
2269
+ """
2270
+ Use the position id or Parameter object to get the gradient from the output
2271
+ which returned by the :func:`mindspore.ops.grad`.
2272
+ """
2273
+
2274
+ @prim_attr_register
2275
+ def __init__(self):
2276
+ """Initialize ScatterElements"""
2277
+ self.init_prim_io_names(
2278
+ inputs=['gradients', 'x'], outputs=['gradient'])
2279
+
2280
+ def __call__(self, gradients, x):
2281
+ if not isinstance(x, int) and not isinstance(x, Parameter):
2282
+ raise TypeError(
2283
+ f"For `get_grad`, the `x` should be an integer or a Parameter, but got {x}")
2284
+ hash_id = x
2285
+ if isinstance(x, Parameter):
2286
+ hash_id = x.name
2287
+ output = None
2288
+
2289
+ def _get_grad(grads, identifier):
2290
+ if isinstance(grads, tuple):
2291
+ if len(grads) != 2 or identifier != grads[0]:
2292
+ for gradient in grads:
2293
+ _get_grad(gradient, identifier)
2294
+ else:
2295
+ nonlocal output
2296
+ output = grads[1]
2297
+ return
2298
+
2299
+ _get_grad(gradients, hash_id)
2300
+ if output is None:
2301
+ raise RuntimeError(
2302
+ f"Can not find the gradient for position or Parameter {x}")
2303
+ return output
2304
+
2305
+
2306
+ class IsParameter(PrimitiveWithInfer):
2307
+ """
2308
+ Check if input is `Parameter`
2309
+ """
2310
+
2311
+ @prim_attr_register
2312
+ def __init__(self):
2313
+ """Initialize IsParameter"""
2314
+
2315
+ def __call__(self, x):
2316
+ return isinstance(x, Parameter)
2317
+
2318
+ def __infer__(self, x):
2319
+ return {'shape': [],
2320
+ 'dtype': mstype.bool_,
2321
+ 'value': isinstance(x['dtype'], mstype.RefType)}
2322
+
2323
+
2324
+ class TileSize(Primitive):
2325
+ r"""
2326
+ Tile size for matmul
2327
+ """
2328
+
2329
+ @prim_attr_register
2330
+ def __init__(self):
2331
+ """Initialize TileSize"""
2332
+ self.init_prim_io_names(inputs=['shape', 'out_shape', 'ndim'], outputs=['output'])
2333
+
2334
+ def __call__(self, shape, out_shape, ndim):
2335
+ size = [1] * ndim
2336
+ for idx, (i, j) in enumerate(zip(shape, out_shape)):
2337
+ if i != j:
2338
+ size[idx] = j
2339
+ return tuple(size)
2340
+
2341
+
2342
+ class GetitemTensorIndexInfo(Primitive):
2343
+ r"""
2344
+ Get getitem tensor index info
2345
+ """
2346
+
2347
+ @prim_attr_register
2348
+ def __init__(self, is_ascend):
2349
+ """Initialize GetitemTensorIndexInfo"""
2350
+ self.init_prim_io_names(inputs=['data', 'index'],
2351
+ outputs=["new_index", "tensor_update_types", "tensor_update_args"])
2352
+ validator.check_value_type('is_ascend', is_ascend, [bool], self.name)
2353
+ self.is_ascend = is_ascend
2354
+
2355
+ def __call__(self, data, index):
2356
+ return Tensor_.getitem_index_info(data, index, self.is_ascend)
2357
+
2358
+
2359
+ class SetitemTensorIndexInfo(Primitive):
2360
+ r"""
2361
+ Get setitem tensor index info
2362
+ """
2363
+
2364
+ @prim_attr_register
2365
+ def __init__(self, is_ascend):
2366
+ """Initialize GetitemTensorIndexInfo"""
2367
+ self.init_prim_io_names(
2368
+ inputs=['data', 'index', 'value'], outputs=['new_index',
2369
+ 'v_transfer_types',
2370
+ 'v_transfer_args',
2371
+ 'tensor_update_types',
2372
+ 'tensor_update_args'])
2373
+ validator.check_value_type('is_ascend', is_ascend, [bool], self.name)
2374
+ self.is_ascend = is_ascend
2375
+
2376
+ def __call__(self, data, index, value):
2377
+ return Tensor_.setitem_index_info(data, index, value, self.is_ascend)
2378
+
2379
+
2380
+ class IsConstant(Primitive):
2381
+ r"""
2382
+ Check if the input is constant
2383
+ """
2384
+
2385
+ @prim_attr_register
2386
+ def __init__(self):
2387
+ """Initialize IsConstant"""
2388
+
2389
+ def __call__(self, x):
2390
+ return True
2391
+
2392
+
2393
+ class SelectView(Primitive):
2394
+ r"""
2395
+ Select tensor of view
2396
+ """
2397
+
2398
+ @prim_attr_register
2399
+ def __init__(self):
2400
+ self.init_prim_io_names(inputs=['input_tensor', 'input_indices', 'axis'], outputs=['output'])
2401
+
2402
+
2403
+ class CopyWithSlice(Primitive):
2404
+ r"""
2405
+ Copy data to discontinuous tensor
2406
+ """
2407
+
2408
+ @prim_attr_register
2409
+ def __init__(self):
2410
+ self.add_prim_attr('side_effect_mem', True)
2411
+ self.init_prim_io_names(inputs=['x', 'y'], outputs=['x'])
2412
+
2413
+
2414
+ class FFN(Primitive):
2415
+ r"""
2416
+ The FFN computation is similar to Feed-Forward Network, it contains matmul + gelu + matmul.
2417
+
2418
+ Args:
2419
+ activation (string): The activation type, set to 'fastgelu' or 'gelu'.
2420
+ Only support 'fastgelu' for now. Default: "fastgelu".
2421
+ inner_precise (int): The precise mode, set to 0 for high precision or 1 for high performance.
2422
+ Only support 1 for now. Default: 0.
2423
+
2424
+ Inputs:
2425
+ - **x** (Tensor) - The input tensor with data type of int8, float16.
2426
+ Input tensor of shape :math:`(batch\_size * seq\_length, hidden\_size)`.
2427
+ - **weight1** (Tensor) - The weight1 tensor with data type of float16.
2428
+ Weight1 tensor of shape :math:`(expert\_num, hidden\_size, ffn\_hidden\_size)`.
2429
+ - **weight2** (Tensor) - The weight2 tensor with data type of float16.
2430
+ Weight2 tensor of shape :math:`(expert\_num, ffn\_hidden\_size, hidden\_size)`.
2431
+ - **expert_tokens** (Tensor]) - The expert tokens tensor with data type of int64.
2432
+ Expert tokens tensor of shape :math:`(16,)`. For example, `(2, 1, 0, .., 9)`
2433
+ indicate that the 0th expert deals with 2 tokens, the 1th expert deals with 1 tokens,
2434
+ the 2th expert do noting and so on.
2435
+ - **bias1** (Tensor) - The bias1 tensor with data type of float16.
2436
+ Bias1 tensor of shape :math:`(expert\_num, ffn\_hidden\_size)`.
2437
+ - **bias2** (Tensor) - The bias2 tensor with data type of float16.
2438
+ Bias2 tensor of shape :math:`(expert\_num, hidden\_size)`.
2439
+ - **scale** (Tensor) - The scale tensor with data type of float16. Not enable now.
2440
+ - **offset** (Tensor) - The offset tensor with data type of float16. Not enable now.
2441
+ - **deq_scale1** (Tensor) - The deq_scale1 tensor with data type of float16. Not enable now.
2442
+ - **deq_scale2** (Tensor) - The deq_scale2 tensor with data type of float16. Not enable now.
2443
+
2444
+ Outputs:
2445
+ Tensor of shape :math:`(batch\_size * seq\_length, hidden\_size)`. With data type of float16.
2446
+
2447
+ Supported Platforms:
2448
+ ``Ascend``
2449
+
2450
+ Examples:
2451
+ >>> from mindspore.ops.operations import _inner_ops
2452
+ >>> b = 4
2453
+ >>> s = 128
2454
+ >>> h = 1024
2455
+ >>> h_f = 4 * h
2456
+ >>> e = 16
2457
+ >>> x = Tensor(np.random.randn(s, h).astype(np.float16))
2458
+ >>> w1 = Tensor(np.random.randn(e, h, h_f).astype(np.float16))
2459
+ >>> w2 = Tensor(np.random.randn(e, h_f, h).astype(np.float16))
2460
+ >>> expert_tokens = Tensor(np.full(e, 8))
2461
+ >>> bias1 = Tensor(np.random.randn(e, h_f).astype(np.float16))
2462
+ >>> bias2 = Tensor(np.random.randn(e, h).astype(np.float16))
2463
+ >>> ffn = _inner_ops.FFN("fastgelu", 1)
2464
+ >>> output = ffn(x, w1, w2, expert_tokens, bias1, bias2)
2465
+ >>> print(output)
2466
+ """
2467
+
2468
+ @prim_attr_register
2469
+ def __init__(self, activation, inner_precise):
2470
+ """Initialize FFN."""
2471
+ self.init_prim_io_names(inputs=["x", "weight1", "weight2", "expert_tokens", "bias1",
2472
+ "bias2", "scale", "offset", "deq_scale1", "deq_scale2",
2473
+ "antiquant_scale1", "antiquant_scale2",
2474
+ "antiquant_offset1", "antiquant_offset2"],
2475
+ outputs=["y"])
2476
+ cls_name = self.name
2477
+ validator.check_value_type("activation", activation, [str], cls_name)
2478
+ validator.check_value_type("inner_precise", inner_precise, [int], cls_name)
2479
+
2480
+
2481
+ class _MirrorSilentCheck(PrimitiveWithInfer):
2482
+ """
2483
+ The operator _MirrorSilentCheck implements accuracy-sensitive detection on the tensor input in backpropagator.
2484
+ Call _MirrorSilentCheck in method __call__ of derived class to implement accuracy-sensitive detection.
2485
+
2486
+ Inputs:
2487
+ - **input** (Tensor) : The tensor used for detection.
2488
+ Its data type must be mindspore.float16, mindspore.float32 or mindspore.bfloat16.
2489
+ - **pre_val** (Parameter(Tensor)) : Support parameter in accuracy-sensitive detection.
2490
+ Please only generated by method generate_params() of ASDBase.
2491
+ - **min_val** (Parameter(Tensor)) : Support parameter in accuracy-sensitive detection.
2492
+ Please only generated by method generate_params() of ASDBase.
2493
+ - **max_val** (Parameter(Tensor)) : Support parameter in accuracy-sensitive detection.
2494
+ Please only generated by method generate_params() of ASDBase.
2495
+ - **cnt** (Parameter(Tensor)) : Support parameter in accuracy-sensitive detection.
2496
+ Please only generated by method generate_params() of ASDBase.
2497
+ After each invocation of _MirrorSilentCheck, increment the value of cnt by one.
2498
+
2499
+ Outputs:
2500
+ - **output** (Tensor) - Same shape, type and value as `input`.
2501
+ """
2502
+ @prim_attr_register
2503
+ def __init__(self, min_steps=8):
2504
+ upper_thresh, sigma_thresh = self.get_thresh()
2505
+ self.min_steps = min_steps
2506
+ self.thresh_l1 = upper_thresh[0]
2507
+ self.coeff_l1 = sigma_thresh[0]
2508
+ self.thresh_l2 = upper_thresh[1]
2509
+ self.coeff_l2 = sigma_thresh[1]
2510
+ self.add_prim_attr('side_effect_mem', True)
2511
+
2512
+ def parse_thresh(self, env_var_name, default_value, min_value):
2513
+ env_var = os.environ.get(env_var_name, default=default_value)
2514
+ thresh = [value.strip() for value in env_var.split(",")]
2515
+ if len(thresh) != 2 or not all(value.isdigit() for value in thresh):
2516
+ thresh = default_value.split(",")
2517
+ thresh = [float(max(int(value), min_value)) for value in thresh]
2518
+ if thresh[0] <= thresh[1]:
2519
+ thresh = [float(value) for value in default_value.split(",")]
2520
+
2521
+ return thresh
2522
+
2523
+ def get_thresh(self):
2524
+ upper_thresh = self.parse_thresh("NPU_ASD_UPPER_THRESH", "1000000,10000", 3)
2525
+ sigma_thresh = self.parse_thresh("NPU_ASD_SIGMA_THRESH", "100000,5000", 3)
2526
+ return upper_thresh, sigma_thresh
2527
+
2528
+ def infer_shape(self, x_shape, pre_shape, min_shape, max_shape, n_step, loss_scale_shape):
2529
+ return x_shape
2530
+
2531
+ def infer_dtype(self, x_dtype, pre_dtype, min_dtype, max_dtype, n_dtype, loss_scale_dtype):
2532
+ return x_dtype
2533
+
2534
+
2535
+ class _VirtualConverterEnd(PrimitiveWithInfer):
2536
+ """
2537
+ Auto parallel virtual operator.
2538
+ """
2539
+
2540
+ @prim_attr_register
2541
+ def __init__(self, input_nums):
2542
+ """Initialize _VirtualConverterEnd."""
2543
+ self.input_nums = input_nums
2544
+
2545
+ def infer_shape(self, *args):
2546
+ return (args[0][0] * self.input_nums,) + tuple(args[0][1:])
2547
+
2548
+ def infer_dtype(self, *args):
2549
+ return args[0]
2550
+
2551
+
2552
+ class _VirtualConverterBegin(PrimitiveWithInfer):
2553
+ """
2554
+ Auto parallel virtual operator.
2555
+ """
2556
+
2557
+ @prim_attr_register
2558
+ def __init__(self, output_nums):
2559
+ """Initialize _VirtualConverterBegin."""
2560
+ self.output_nums = output_nums
2561
+
2562
+ def infer_shape(self, arg):
2563
+ new_arg = (arg[0] / self.output_nums,) + tuple(arg[1:])
2564
+ return (new_arg,) * self.output_nums
2565
+
2566
+ def infer_dtype(self, arg):
2567
+ return (arg,) * self.output_nums