mindspore 2.4.0__cp311-cp311-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (1387) hide show
  1. mindspore/.commit_id +1 -0
  2. mindspore/__init__.py +53 -0
  3. mindspore/_c_dataengine.cpython-311-darwin.so +0 -0
  4. mindspore/_c_expression.cpython-311-darwin.so +0 -0
  5. mindspore/_c_mindrecord.cpython-311-darwin.so +0 -0
  6. mindspore/_check_jit_forbidden_api.py +106 -0
  7. mindspore/_checkparam.py +1419 -0
  8. mindspore/_extends/__init__.py +23 -0
  9. mindspore/_extends/builtin_operations.py +224 -0
  10. mindspore/_extends/graph_kernel/__init__.py +17 -0
  11. mindspore/_extends/graph_kernel/model/__init__.py +19 -0
  12. mindspore/_extends/graph_kernel/model/graph_parallel.py +311 -0
  13. mindspore/_extends/graph_kernel/model/graph_split.py +1348 -0
  14. mindspore/_extends/graph_kernel/model/model.py +553 -0
  15. mindspore/_extends/graph_kernel/model/model_builder.py +216 -0
  16. mindspore/_extends/graph_kernel/parallel_estimate.py +60 -0
  17. mindspore/_extends/graph_kernel/splitter.py +140 -0
  18. mindspore/_extends/graph_kernel/utils.py +28 -0
  19. mindspore/_extends/parallel_compile/__init__.py +19 -0
  20. mindspore/_extends/parallel_compile/akg_compiler/__init__.py +19 -0
  21. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +269 -0
  22. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +529 -0
  23. mindspore/_extends/parallel_compile/akg_compiler/compiler.py +56 -0
  24. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +96 -0
  25. mindspore/_extends/parallel_compile/akg_compiler/get_file_path.py +36 -0
  26. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +556 -0
  27. mindspore/_extends/parallel_compile/akg_compiler/util.py +159 -0
  28. mindspore/_extends/parse/__init__.py +49 -0
  29. mindspore/_extends/parse/compile_config.py +299 -0
  30. mindspore/_extends/parse/namespace.py +136 -0
  31. mindspore/_extends/parse/parser.py +1448 -0
  32. mindspore/_extends/parse/resources.py +213 -0
  33. mindspore/_extends/parse/standard_method.py +4475 -0
  34. mindspore/_extends/parse/trope.py +97 -0
  35. mindspore/_extends/pijit/__init__.py +23 -0
  36. mindspore/_extends/pijit/pijit_func_white_list.py +669 -0
  37. mindspore/_extends/remote/__init__.py +19 -0
  38. mindspore/_extends/remote/kernel_build_server.py +199 -0
  39. mindspore/_extends/remote/kernel_build_server_akg.py +55 -0
  40. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  41. mindspore/_extends/remote/kernel_build_server_ascend.py +75 -0
  42. mindspore/_extends/utils.py +68 -0
  43. mindspore/_install_custom.py +43 -0
  44. mindspore/_profiler.py +30 -0
  45. mindspore/amp.py +433 -0
  46. mindspore/boost/__init__.py +42 -0
  47. mindspore/boost/adasum.py +319 -0
  48. mindspore/boost/base.py +535 -0
  49. mindspore/boost/boost.py +400 -0
  50. mindspore/boost/boost_cell_wrapper.py +790 -0
  51. mindspore/boost/dim_reduce.py +323 -0
  52. mindspore/boost/grad_accumulation.py +79 -0
  53. mindspore/boost/grad_freeze.py +382 -0
  54. mindspore/boost/group_loss_scale_manager.py +166 -0
  55. mindspore/boost/less_batch_normalization.py +174 -0
  56. mindspore/common/__init__.py +86 -0
  57. mindspore/common/_auto_dynamic.py +68 -0
  58. mindspore/common/_decorator.py +50 -0
  59. mindspore/common/_jit_fallback_utils.py +110 -0
  60. mindspore/common/_monad.py +25 -0
  61. mindspore/common/_pijit_context.py +190 -0
  62. mindspore/common/_register_for_adapter.py +74 -0
  63. mindspore/common/_register_for_recompute.py +48 -0
  64. mindspore/common/_register_for_tensor.py +46 -0
  65. mindspore/common/_stub_tensor.py +210 -0
  66. mindspore/common/_tensor_overload.py +139 -0
  67. mindspore/common/_utils.py +122 -0
  68. mindspore/common/api.py +2064 -0
  69. mindspore/common/auto_dynamic_shape.py +507 -0
  70. mindspore/common/dtype.py +422 -0
  71. mindspore/common/dump.py +130 -0
  72. mindspore/common/file_system.py +48 -0
  73. mindspore/common/generator.py +254 -0
  74. mindspore/common/hook_handle.py +143 -0
  75. mindspore/common/initializer.py +880 -0
  76. mindspore/common/jit_config.py +98 -0
  77. mindspore/common/lazy_inline.py +240 -0
  78. mindspore/common/mindir_util.py +111 -0
  79. mindspore/common/mutable.py +234 -0
  80. mindspore/common/no_inline.py +54 -0
  81. mindspore/common/np_dtype.py +25 -0
  82. mindspore/common/parameter.py +1081 -0
  83. mindspore/common/recompute.py +292 -0
  84. mindspore/common/seed.py +260 -0
  85. mindspore/common/sparse_tensor.py +1175 -0
  86. mindspore/common/symbol.py +122 -0
  87. mindspore/common/tensor.py +5039 -0
  88. mindspore/communication/__init__.py +37 -0
  89. mindspore/communication/_comm_helper.py +501 -0
  90. mindspore/communication/_hccl_management.py +297 -0
  91. mindspore/communication/comm_func.py +1395 -0
  92. mindspore/communication/management.py +673 -0
  93. mindspore/config/op_info.config +533 -0
  94. mindspore/context.py +2077 -0
  95. mindspore/dataset/__init__.py +90 -0
  96. mindspore/dataset/audio/__init__.py +61 -0
  97. mindspore/dataset/audio/transforms.py +3690 -0
  98. mindspore/dataset/audio/utils.py +386 -0
  99. mindspore/dataset/audio/validators.py +1172 -0
  100. mindspore/dataset/callback/__init__.py +20 -0
  101. mindspore/dataset/callback/ds_callback.py +368 -0
  102. mindspore/dataset/callback/validators.py +32 -0
  103. mindspore/dataset/core/__init__.py +13 -0
  104. mindspore/dataset/core/config.py +1095 -0
  105. mindspore/dataset/core/datatypes.py +101 -0
  106. mindspore/dataset/core/py_util_helpers.py +65 -0
  107. mindspore/dataset/core/validator_helpers.py +781 -0
  108. mindspore/dataset/debug/__init__.py +21 -0
  109. mindspore/dataset/debug/debug_hook.py +97 -0
  110. mindspore/dataset/debug/pre_defined_hook.py +67 -0
  111. mindspore/dataset/engine/__init__.py +124 -0
  112. mindspore/dataset/engine/cache_admin.py +47 -0
  113. mindspore/dataset/engine/cache_client.py +129 -0
  114. mindspore/dataset/engine/datasets.py +4582 -0
  115. mindspore/dataset/engine/datasets_audio.py +911 -0
  116. mindspore/dataset/engine/datasets_standard_format.py +543 -0
  117. mindspore/dataset/engine/datasets_text.py +2161 -0
  118. mindspore/dataset/engine/datasets_user_defined.py +1184 -0
  119. mindspore/dataset/engine/datasets_vision.py +4816 -0
  120. mindspore/dataset/engine/iterators.py +371 -0
  121. mindspore/dataset/engine/obs/__init__.py +23 -0
  122. mindspore/dataset/engine/obs/config_loader.py +68 -0
  123. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +508 -0
  124. mindspore/dataset/engine/obs/util.py +482 -0
  125. mindspore/dataset/engine/offload.py +596 -0
  126. mindspore/dataset/engine/queue.py +304 -0
  127. mindspore/dataset/engine/samplers.py +895 -0
  128. mindspore/dataset/engine/serializer_deserializer.py +159 -0
  129. mindspore/dataset/engine/validators.py +2895 -0
  130. mindspore/dataset/text/__init__.py +51 -0
  131. mindspore/dataset/text/transforms.py +1703 -0
  132. mindspore/dataset/text/utils.py +715 -0
  133. mindspore/dataset/text/validators.py +642 -0
  134. mindspore/dataset/transforms/__init__.py +45 -0
  135. mindspore/dataset/transforms/c_transforms.py +638 -0
  136. mindspore/dataset/transforms/py_transforms.py +393 -0
  137. mindspore/dataset/transforms/py_transforms_util.py +255 -0
  138. mindspore/dataset/transforms/transforms.py +1260 -0
  139. mindspore/dataset/transforms/validators.py +410 -0
  140. mindspore/dataset/utils/__init__.py +19 -0
  141. mindspore/dataset/utils/browse_dataset.py +190 -0
  142. mindspore/dataset/utils/line_reader.py +126 -0
  143. mindspore/dataset/vision/__init__.py +65 -0
  144. mindspore/dataset/vision/c_transforms.py +2641 -0
  145. mindspore/dataset/vision/py_transforms.py +2120 -0
  146. mindspore/dataset/vision/py_transforms_util.py +1660 -0
  147. mindspore/dataset/vision/transforms.py +7295 -0
  148. mindspore/dataset/vision/utils.py +863 -0
  149. mindspore/dataset/vision/validators.py +1483 -0
  150. mindspore/default_config.py +2 -0
  151. mindspore/experimental/__init__.py +20 -0
  152. mindspore/experimental/es/__init__.py +22 -0
  153. mindspore/experimental/es/embedding_service.py +883 -0
  154. mindspore/experimental/es/embedding_service_layer.py +581 -0
  155. mindspore/experimental/llm_boost/__init__.py +21 -0
  156. mindspore/experimental/llm_boost/atb/__init__.py +23 -0
  157. mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
  158. mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
  159. mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
  160. mindspore/experimental/llm_boost/register.py +129 -0
  161. mindspore/experimental/llm_boost/utils.py +31 -0
  162. mindspore/experimental/map_parameter.py +309 -0
  163. mindspore/experimental/optim/__init__.py +40 -0
  164. mindspore/experimental/optim/adadelta.py +161 -0
  165. mindspore/experimental/optim/adagrad.py +168 -0
  166. mindspore/experimental/optim/adam.py +193 -0
  167. mindspore/experimental/optim/adamax.py +170 -0
  168. mindspore/experimental/optim/adamw.py +290 -0
  169. mindspore/experimental/optim/asgd.py +153 -0
  170. mindspore/experimental/optim/lr_scheduler.py +1371 -0
  171. mindspore/experimental/optim/nadam.py +157 -0
  172. mindspore/experimental/optim/optimizer.py +262 -0
  173. mindspore/experimental/optim/radam.py +194 -0
  174. mindspore/experimental/optim/rmsprop.py +154 -0
  175. mindspore/experimental/optim/rprop.py +164 -0
  176. mindspore/experimental/optim/sgd.py +156 -0
  177. mindspore/hal/__init__.py +40 -0
  178. mindspore/hal/_ascend.py +57 -0
  179. mindspore/hal/_base.py +57 -0
  180. mindspore/hal/_cpu.py +56 -0
  181. mindspore/hal/_gpu.py +57 -0
  182. mindspore/hal/contiguous_tensors_handle.py +175 -0
  183. mindspore/hal/device.py +356 -0
  184. mindspore/hal/event.py +179 -0
  185. mindspore/hal/memory.py +326 -0
  186. mindspore/hal/stream.py +357 -0
  187. mindspore/include/OWNERS +7 -0
  188. mindspore/include/api/allocator.h +97 -0
  189. mindspore/include/api/callback/callback.h +93 -0
  190. mindspore/include/api/callback/ckpt_saver.h +41 -0
  191. mindspore/include/api/callback/loss_monitor.h +33 -0
  192. mindspore/include/api/callback/lr_scheduler.h +51 -0
  193. mindspore/include/api/callback/time_monitor.h +34 -0
  194. mindspore/include/api/callback/train_accuracy.h +37 -0
  195. mindspore/include/api/cell.h +90 -0
  196. mindspore/include/api/cfg.h +82 -0
  197. mindspore/include/api/context.h +602 -0
  198. mindspore/include/api/data_type.h +47 -0
  199. mindspore/include/api/delegate.h +178 -0
  200. mindspore/include/api/delegate_api.h +75 -0
  201. mindspore/include/api/dual_abi_helper.h +208 -0
  202. mindspore/include/api/format.h +28 -0
  203. mindspore/include/api/graph.h +46 -0
  204. mindspore/include/api/kernel.h +58 -0
  205. mindspore/include/api/kernel_api.h +168 -0
  206. mindspore/include/api/metrics/accuracy.h +36 -0
  207. mindspore/include/api/metrics/metrics.h +41 -0
  208. mindspore/include/api/model.h +438 -0
  209. mindspore/include/api/model_group.h +91 -0
  210. mindspore/include/api/model_parallel_runner.h +168 -0
  211. mindspore/include/api/serialization.h +185 -0
  212. mindspore/include/api/status.h +192 -0
  213. mindspore/include/api/types.h +431 -0
  214. mindspore/include/api/visible.h +41 -0
  215. mindspore/include/c_api/context_c.h +179 -0
  216. mindspore/include/c_api/data_type_c.h +52 -0
  217. mindspore/include/c_api/format_c.h +46 -0
  218. mindspore/include/c_api/model_c.h +347 -0
  219. mindspore/include/c_api/status_c.h +79 -0
  220. mindspore/include/c_api/tensor_c.h +146 -0
  221. mindspore/include/c_api/types_c.h +67 -0
  222. mindspore/include/dataset/config.h +163 -0
  223. mindspore/include/dataset/constants.h +363 -0
  224. mindspore/include/dataset/execute.h +196 -0
  225. mindspore/include/dataset/text.h +1092 -0
  226. mindspore/include/dataset/transforms.h +638 -0
  227. mindspore/include/dataset/vision.h +2129 -0
  228. mindspore/include/dataset/vision_ascend.h +206 -0
  229. mindspore/include/dataset/vision_lite.h +625 -0
  230. mindspore/lib/libavcodec.59.dylib +0 -0
  231. mindspore/lib/libavdevice.59.dylib +0 -0
  232. mindspore/lib/libavfilter.8.dylib +0 -0
  233. mindspore/lib/libavformat.59.dylib +0 -0
  234. mindspore/lib/libavutil.57.dylib +0 -0
  235. mindspore/lib/libdnnl.2.dylib +0 -0
  236. mindspore/lib/libicudata.69.dylib +0 -0
  237. mindspore/lib/libicui18n.69.dylib +0 -0
  238. mindspore/lib/libicuuc.69.dylib +0 -0
  239. mindspore/lib/libmindspore_address_sorting.15.dylib +0 -0
  240. mindspore/lib/libmindspore_backend.dylib +0 -0
  241. mindspore/lib/libmindspore_common.dylib +0 -0
  242. mindspore/lib/libmindspore_core.dylib +0 -0
  243. mindspore/lib/libmindspore_glog.0.dylib +0 -0
  244. mindspore/lib/libmindspore_gpr.15.dylib +0 -0
  245. mindspore/lib/libmindspore_grpc++.1.dylib +0 -0
  246. mindspore/lib/libmindspore_grpc.15.dylib +0 -0
  247. mindspore/lib/libmindspore_np_dtype.dylib +0 -0
  248. mindspore/lib/libmindspore_ops.dylib +0 -0
  249. mindspore/lib/libmindspore_upb.15.dylib +0 -0
  250. mindspore/lib/libnnacl.dylib +0 -0
  251. mindspore/lib/libopencv_core.4.5.dylib +0 -0
  252. mindspore/lib/libopencv_imgcodecs.4.5.dylib +0 -0
  253. mindspore/lib/libopencv_imgproc.4.5.dylib +0 -0
  254. mindspore/lib/libps_cache.dylib +0 -0
  255. mindspore/lib/libswresample.4.dylib +0 -0
  256. mindspore/lib/libswscale.6.dylib +0 -0
  257. mindspore/lib/libtinyxml2.8.dylib +0 -0
  258. mindspore/log.py +633 -0
  259. mindspore/mindrecord/__init__.py +43 -0
  260. mindspore/mindrecord/common/__init__.py +17 -0
  261. mindspore/mindrecord/common/constant.py +20 -0
  262. mindspore/mindrecord/common/enums.py +44 -0
  263. mindspore/mindrecord/common/exceptions.py +311 -0
  264. mindspore/mindrecord/config.py +809 -0
  265. mindspore/mindrecord/filereader.py +174 -0
  266. mindspore/mindrecord/filewriter.py +722 -0
  267. mindspore/mindrecord/mindpage.py +210 -0
  268. mindspore/mindrecord/shardheader.py +141 -0
  269. mindspore/mindrecord/shardindexgenerator.py +74 -0
  270. mindspore/mindrecord/shardreader.py +117 -0
  271. mindspore/mindrecord/shardsegment.py +128 -0
  272. mindspore/mindrecord/shardutils.py +185 -0
  273. mindspore/mindrecord/shardwriter.py +237 -0
  274. mindspore/mindrecord/tools/__init__.py +17 -0
  275. mindspore/mindrecord/tools/cifar10.py +140 -0
  276. mindspore/mindrecord/tools/cifar100.py +153 -0
  277. mindspore/mindrecord/tools/cifar100_to_mr.py +185 -0
  278. mindspore/mindrecord/tools/cifar10_to_mr.py +177 -0
  279. mindspore/mindrecord/tools/csv_to_mr.py +200 -0
  280. mindspore/mindrecord/tools/imagenet_to_mr.py +206 -0
  281. mindspore/mindrecord/tools/mnist_to_mr.py +259 -0
  282. mindspore/mindrecord/tools/tfrecord_to_mr.py +360 -0
  283. mindspore/mint/__init__.py +1586 -0
  284. mindspore/mint/distributed/__init__.py +31 -0
  285. mindspore/mint/distributed/distributed.py +254 -0
  286. mindspore/mint/linalg/__init__.py +22 -0
  287. mindspore/mint/nn/__init__.py +757 -0
  288. mindspore/mint/nn/functional.py +679 -0
  289. mindspore/mint/nn/layer/__init__.py +39 -0
  290. mindspore/mint/nn/layer/activation.py +133 -0
  291. mindspore/mint/nn/layer/normalization.py +477 -0
  292. mindspore/mint/nn/layer/pooling.py +110 -0
  293. mindspore/mint/optim/__init__.py +24 -0
  294. mindspore/mint/optim/adamw.py +206 -0
  295. mindspore/mint/special/__init__.py +63 -0
  296. mindspore/multiprocessing/__init__.py +73 -0
  297. mindspore/nn/__init__.py +47 -0
  298. mindspore/nn/cell.py +2787 -0
  299. mindspore/nn/dynamic_lr.py +482 -0
  300. mindspore/nn/grad/__init__.py +21 -0
  301. mindspore/nn/grad/cell_grad.py +196 -0
  302. mindspore/nn/layer/__init__.py +63 -0
  303. mindspore/nn/layer/activation.py +1822 -0
  304. mindspore/nn/layer/basic.py +1629 -0
  305. mindspore/nn/layer/channel_shuffle.py +90 -0
  306. mindspore/nn/layer/combined.py +248 -0
  307. mindspore/nn/layer/container.py +734 -0
  308. mindspore/nn/layer/conv.py +1505 -0
  309. mindspore/nn/layer/dense.py +204 -0
  310. mindspore/nn/layer/embedding.py +869 -0
  311. mindspore/nn/layer/image.py +661 -0
  312. mindspore/nn/layer/math.py +1069 -0
  313. mindspore/nn/layer/normalization.py +1273 -0
  314. mindspore/nn/layer/padding.py +880 -0
  315. mindspore/nn/layer/pooling.py +2302 -0
  316. mindspore/nn/layer/rnn_cells.py +388 -0
  317. mindspore/nn/layer/rnns.py +849 -0
  318. mindspore/nn/layer/thor_layer.py +963 -0
  319. mindspore/nn/layer/timedistributed.py +155 -0
  320. mindspore/nn/layer/transformer.py +823 -0
  321. mindspore/nn/learning_rate_schedule.py +512 -0
  322. mindspore/nn/loss/__init__.py +36 -0
  323. mindspore/nn/loss/loss.py +2924 -0
  324. mindspore/nn/metrics.py +53 -0
  325. mindspore/nn/optim/__init__.py +45 -0
  326. mindspore/nn/optim/_dist_optimizer_registry.py +111 -0
  327. mindspore/nn/optim/ada_grad.py +217 -0
  328. mindspore/nn/optim/adadelta.py +206 -0
  329. mindspore/nn/optim/adafactor.py +448 -0
  330. mindspore/nn/optim/adam.py +1297 -0
  331. mindspore/nn/optim/adamax.py +220 -0
  332. mindspore/nn/optim/adasum.py +548 -0
  333. mindspore/nn/optim/asgd.py +216 -0
  334. mindspore/nn/optim/ftrl.py +401 -0
  335. mindspore/nn/optim/lamb.py +296 -0
  336. mindspore/nn/optim/lars.py +202 -0
  337. mindspore/nn/optim/lazyadam.py +533 -0
  338. mindspore/nn/optim/momentum.py +239 -0
  339. mindspore/nn/optim/optimizer.py +1034 -0
  340. mindspore/nn/optim/proximal_ada_grad.py +242 -0
  341. mindspore/nn/optim/rmsprop.py +264 -0
  342. mindspore/nn/optim/rprop.py +251 -0
  343. mindspore/nn/optim/sgd.py +237 -0
  344. mindspore/nn/optim/tft_wrapper.py +127 -0
  345. mindspore/nn/optim/thor.py +1310 -0
  346. mindspore/nn/probability/__init__.py +22 -0
  347. mindspore/nn/probability/bijector/__init__.py +35 -0
  348. mindspore/nn/probability/bijector/bijector.py +337 -0
  349. mindspore/nn/probability/bijector/exp.py +65 -0
  350. mindspore/nn/probability/bijector/gumbel_cdf.py +144 -0
  351. mindspore/nn/probability/bijector/invert.py +126 -0
  352. mindspore/nn/probability/bijector/power_transform.py +196 -0
  353. mindspore/nn/probability/bijector/scalar_affine.py +167 -0
  354. mindspore/nn/probability/bijector/softplus.py +189 -0
  355. mindspore/nn/probability/bnn_layers/__init__.py +29 -0
  356. mindspore/nn/probability/bnn_layers/_util.py +46 -0
  357. mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +112 -0
  358. mindspore/nn/probability/bnn_layers/conv_variational.py +267 -0
  359. mindspore/nn/probability/bnn_layers/dense_variational.py +302 -0
  360. mindspore/nn/probability/bnn_layers/layer_distribution.py +123 -0
  361. mindspore/nn/probability/distribution/__init__.py +56 -0
  362. mindspore/nn/probability/distribution/_utils/__init__.py +34 -0
  363. mindspore/nn/probability/distribution/_utils/custom_ops.py +96 -0
  364. mindspore/nn/probability/distribution/_utils/utils.py +362 -0
  365. mindspore/nn/probability/distribution/bernoulli.py +334 -0
  366. mindspore/nn/probability/distribution/beta.py +391 -0
  367. mindspore/nn/probability/distribution/categorical.py +435 -0
  368. mindspore/nn/probability/distribution/cauchy.py +383 -0
  369. mindspore/nn/probability/distribution/distribution.py +827 -0
  370. mindspore/nn/probability/distribution/exponential.py +350 -0
  371. mindspore/nn/probability/distribution/gamma.py +391 -0
  372. mindspore/nn/probability/distribution/geometric.py +335 -0
  373. mindspore/nn/probability/distribution/gumbel.py +257 -0
  374. mindspore/nn/probability/distribution/half_normal.py +133 -0
  375. mindspore/nn/probability/distribution/laplace.py +128 -0
  376. mindspore/nn/probability/distribution/log_normal.py +272 -0
  377. mindspore/nn/probability/distribution/logistic.py +379 -0
  378. mindspore/nn/probability/distribution/normal.py +336 -0
  379. mindspore/nn/probability/distribution/poisson.py +288 -0
  380. mindspore/nn/probability/distribution/student_t.py +149 -0
  381. mindspore/nn/probability/distribution/transformed_distribution.py +235 -0
  382. mindspore/nn/probability/distribution/uniform.py +375 -0
  383. mindspore/nn/reinforcement/__init__.py +24 -0
  384. mindspore/nn/reinforcement/_batch_read_write.py +142 -0
  385. mindspore/nn/reinforcement/_tensors_queue.py +152 -0
  386. mindspore/nn/reinforcement/tensor_array.py +145 -0
  387. mindspore/nn/sparse/__init__.py +23 -0
  388. mindspore/nn/sparse/sparse.py +147 -0
  389. mindspore/nn/wrap/__init__.py +49 -0
  390. mindspore/nn/wrap/cell_wrapper.py +968 -0
  391. mindspore/nn/wrap/grad_reducer.py +608 -0
  392. mindspore/nn/wrap/loss_scale.py +694 -0
  393. mindspore/numpy/__init__.py +121 -0
  394. mindspore/numpy/array_creations.py +2731 -0
  395. mindspore/numpy/array_ops.py +2629 -0
  396. mindspore/numpy/dtypes.py +185 -0
  397. mindspore/numpy/fft.py +966 -0
  398. mindspore/numpy/logic_ops.py +936 -0
  399. mindspore/numpy/math_ops.py +5911 -0
  400. mindspore/numpy/utils.py +214 -0
  401. mindspore/numpy/utils_const.py +565 -0
  402. mindspore/ops/__init__.py +56 -0
  403. mindspore/ops/_constants.py +30 -0
  404. mindspore/ops/_grad_experimental/__init__.py +31 -0
  405. mindspore/ops/_grad_experimental/grad_array_ops.py +830 -0
  406. mindspore/ops/_grad_experimental/grad_base.py +143 -0
  407. mindspore/ops/_grad_experimental/grad_comm_ops.py +714 -0
  408. mindspore/ops/_grad_experimental/grad_debug_ops.py +31 -0
  409. mindspore/ops/_grad_experimental/grad_implementations.py +203 -0
  410. mindspore/ops/_grad_experimental/grad_inner_ops.py +79 -0
  411. mindspore/ops/_grad_experimental/grad_math_ops.py +802 -0
  412. mindspore/ops/_grad_experimental/grad_nn_ops.py +231 -0
  413. mindspore/ops/_grad_experimental/grad_quant_ops.py +238 -0
  414. mindspore/ops/_grad_experimental/grad_sparse.py +342 -0
  415. mindspore/ops/_grad_experimental/grad_sparse_ops.py +399 -0
  416. mindspore/ops/_grad_experimental/taylor_rule.py +220 -0
  417. mindspore/ops/_op_impl/__init__.py +23 -0
  418. mindspore/ops/_op_impl/_custom_op/__init__.py +39 -0
  419. mindspore/ops/_op_impl/_custom_op/_basic.py +158 -0
  420. mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +279 -0
  421. mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +156 -0
  422. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +109 -0
  423. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +125 -0
  424. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +105 -0
  425. mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +124 -0
  426. mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +116 -0
  427. mindspore/ops/_op_impl/_custom_op/correction_mul.py +89 -0
  428. mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +196 -0
  429. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +366 -0
  430. mindspore/ops/_op_impl/_custom_op/dsd_impl.py +162 -0
  431. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +136 -0
  432. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +206 -0
  433. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +88 -0
  434. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +128 -0
  435. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +199 -0
  436. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +88 -0
  437. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +156 -0
  438. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +184 -0
  439. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +143 -0
  440. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +169 -0
  441. mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +548 -0
  442. mindspore/ops/_op_impl/_custom_op/img2col_impl.py +881 -0
  443. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +278 -0
  444. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +200 -0
  445. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +334 -0
  446. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +255 -0
  447. mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +222 -0
  448. mindspore/ops/_op_impl/_custom_op/matmul_dds_grad_impl.py +644 -0
  449. mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +488 -0
  450. mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +87 -0
  451. mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +129 -0
  452. mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +121 -0
  453. mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +352 -0
  454. mindspore/ops/_op_impl/aicpu/__init__.py +441 -0
  455. mindspore/ops/_op_impl/aicpu/abs.py +36 -0
  456. mindspore/ops/_op_impl/aicpu/acos.py +32 -0
  457. mindspore/ops/_op_impl/aicpu/acos_grad.py +33 -0
  458. mindspore/ops/_op_impl/aicpu/acosh.py +34 -0
  459. mindspore/ops/_op_impl/aicpu/acosh_grad.py +35 -0
  460. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d.py +34 -0
  461. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
  462. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d.py +39 -0
  463. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d_grad.py +39 -0
  464. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
  465. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d_grad.py +37 -0
  466. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d.py +42 -0
  467. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d_grad.py +152 -0
  468. mindspore/ops/_op_impl/aicpu/add.py +43 -0
  469. mindspore/ops/_op_impl/aicpu/add_n.py +41 -0
  470. mindspore/ops/_op_impl/aicpu/add_v2.py +40 -0
  471. mindspore/ops/_op_impl/aicpu/addcdiv.py +41 -0
  472. mindspore/ops/_op_impl/aicpu/addcmul.py +47 -0
  473. mindspore/ops/_op_impl/aicpu/adjust_contrastv2.py +32 -0
  474. mindspore/ops/_op_impl/aicpu/adjust_hue.py +31 -0
  475. mindspore/ops/_op_impl/aicpu/adjust_saturation.py +32 -0
  476. mindspore/ops/_op_impl/aicpu/affine_grid.py +33 -0
  477. mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
  478. mindspore/ops/_op_impl/aicpu/angle.py +31 -0
  479. mindspore/ops/_op_impl/aicpu/arg_max.py +75 -0
  480. mindspore/ops/_op_impl/aicpu/arg_min.py +75 -0
  481. mindspore/ops/_op_impl/aicpu/argmax_with_value.py +43 -0
  482. mindspore/ops/_op_impl/aicpu/argmin_with_value.py +43 -0
  483. mindspore/ops/_op_impl/aicpu/asin.py +32 -0
  484. mindspore/ops/_op_impl/aicpu/asin_grad.py +33 -0
  485. mindspore/ops/_op_impl/aicpu/asinh.py +34 -0
  486. mindspore/ops/_op_impl/aicpu/asinh_grad.py +35 -0
  487. mindspore/ops/_op_impl/aicpu/atanh.py +34 -0
  488. mindspore/ops/_op_impl/aicpu/avgpool_grad_v1.py +37 -0
  489. mindspore/ops/_op_impl/aicpu/avgpool_v1.py +36 -0
  490. mindspore/ops/_op_impl/aicpu/bartlett_window.py +36 -0
  491. mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -0
  492. mindspore/ops/_op_impl/aicpu/batch_norm_grad_grad.py +49 -0
  493. mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
  494. mindspore/ops/_op_impl/aicpu/bessel_i0.py +31 -0
  495. mindspore/ops/_op_impl/aicpu/betainc.py +31 -0
  496. mindspore/ops/_op_impl/aicpu/bias_add.py +44 -0
  497. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +42 -0
  498. mindspore/ops/_op_impl/aicpu/bincount.py +33 -0
  499. mindspore/ops/_op_impl/aicpu/blackman_window.py +36 -0
  500. mindspore/ops/_op_impl/aicpu/broadcast_to.py +58 -0
  501. mindspore/ops/_op_impl/aicpu/bucketize.py +34 -0
  502. mindspore/ops/_op_impl/aicpu/cache_swap_table.py +102 -0
  503. mindspore/ops/_op_impl/aicpu/cast.py +225 -0
  504. mindspore/ops/_op_impl/aicpu/cauchy.py +33 -0
  505. mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
  506. mindspore/ops/_op_impl/aicpu/check_numerics.py +33 -0
  507. mindspore/ops/_op_impl/aicpu/cholesky.py +32 -0
  508. mindspore/ops/_op_impl/aicpu/cholesky_inverse.py +31 -0
  509. mindspore/ops/_op_impl/aicpu/cholesky_solve.py +33 -0
  510. mindspore/ops/_op_impl/aicpu/choleskygrad.py +32 -0
  511. mindspore/ops/_op_impl/aicpu/coalesce.py +37 -0
  512. mindspore/ops/_op_impl/aicpu/col2im.py +38 -0
  513. mindspore/ops/_op_impl/aicpu/combined_non_max_suppression.py +42 -0
  514. mindspore/ops/_op_impl/aicpu/compare_and_bitpack.py +37 -0
  515. mindspore/ops/_op_impl/aicpu/complex.py +32 -0
  516. mindspore/ops/_op_impl/aicpu/complex_abs.py +31 -0
  517. mindspore/ops/_op_impl/aicpu/compute_accidental_hits.py +44 -0
  518. mindspore/ops/_op_impl/aicpu/concat.py +57 -0
  519. mindspore/ops/_op_impl/aicpu/concat_offset.py +42 -0
  520. mindspore/ops/_op_impl/aicpu/concat_offset_v1.py +31 -0
  521. mindspore/ops/_op_impl/aicpu/conj.py +42 -0
  522. mindspore/ops/_op_impl/aicpu/conjugate_transpose.py +58 -0
  523. mindspore/ops/_op_impl/aicpu/cos.py +34 -0
  524. mindspore/ops/_op_impl/aicpu/cosh.py +34 -0
  525. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  526. mindspore/ops/_op_impl/aicpu/crop_and_resize.py +69 -0
  527. mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_boxes.py +68 -0
  528. mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_image.py +38 -0
  529. mindspore/ops/_op_impl/aicpu/cross.py +42 -0
  530. mindspore/ops/_op_impl/aicpu/csr_sparse_matrix_to_dense.py +48 -0
  531. mindspore/ops/_op_impl/aicpu/csr_sparse_matrix_to_sparse_tensor.py +51 -0
  532. mindspore/ops/_op_impl/aicpu/ctc_greedy_decoder.py +35 -0
  533. mindspore/ops/_op_impl/aicpu/ctc_loss_v2.py +43 -0
  534. mindspore/ops/_op_impl/aicpu/ctc_loss_v2_grad.py +45 -0
  535. mindspore/ops/_op_impl/aicpu/ctcloss.py +38 -0
  536. mindspore/ops/_op_impl/aicpu/cummax.py +41 -0
  537. mindspore/ops/_op_impl/aicpu/cumprod.py +58 -0
  538. mindspore/ops/_op_impl/aicpu/cumsum.py +58 -0
  539. mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +36 -0
  540. mindspore/ops/_op_impl/aicpu/data_format_vec_permute.py +32 -0
  541. mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
  542. mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +43 -0
  543. mindspore/ops/_op_impl/aicpu/dense_to_csr_sparse_matrix.py +49 -0
  544. mindspore/ops/_op_impl/aicpu/dense_to_dense_set_operation.py +45 -0
  545. mindspore/ops/_op_impl/aicpu/dense_to_sparse_set_operation.py +48 -0
  546. mindspore/ops/_op_impl/aicpu/depth_to_space.py +44 -0
  547. mindspore/ops/_op_impl/aicpu/diag.py +36 -0
  548. mindspore/ops/_op_impl/aicpu/diag_part.py +36 -0
  549. mindspore/ops/_op_impl/aicpu/diagonal.py +35 -0
  550. mindspore/ops/_op_impl/aicpu/digamma.py +31 -0
  551. mindspore/ops/_op_impl/aicpu/div.py +41 -0
  552. mindspore/ops/_op_impl/aicpu/div_no_nan.py +35 -0
  553. mindspore/ops/_op_impl/aicpu/dropout2d.py +42 -0
  554. mindspore/ops/_op_impl/aicpu/dropout3d.py +42 -0
  555. mindspore/ops/_op_impl/aicpu/dropout_genmask.py +41 -0
  556. mindspore/ops/_op_impl/aicpu/dropout_genmask_v3.py +32 -0
  557. mindspore/ops/_op_impl/aicpu/dynamic_stitch.py +42 -0
  558. mindspore/ops/_op_impl/aicpu/edit_distance.py +56 -0
  559. mindspore/ops/_op_impl/aicpu/eig.py +35 -0
  560. mindspore/ops/_op_impl/aicpu/embedding_lookup.py +102 -0
  561. mindspore/ops/_op_impl/aicpu/end_of_sequence.py +30 -0
  562. mindspore/ops/_op_impl/aicpu/environ_create.py +28 -0
  563. mindspore/ops/_op_impl/aicpu/environ_destroy_all.py +28 -0
  564. mindspore/ops/_op_impl/aicpu/environ_get.py +41 -0
  565. mindspore/ops/_op_impl/aicpu/environ_set.py +40 -0
  566. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  567. mindspore/ops/_op_impl/aicpu/equal.py +41 -0
  568. mindspore/ops/_op_impl/aicpu/exp.py +37 -0
  569. mindspore/ops/_op_impl/aicpu/expand.py +45 -0
  570. mindspore/ops/_op_impl/aicpu/expand_dims.py +42 -0
  571. mindspore/ops/_op_impl/aicpu/expm1.py +34 -0
  572. mindspore/ops/_op_impl/aicpu/extract_glimpse.py +35 -0
  573. mindspore/ops/_op_impl/aicpu/eye.py +44 -0
  574. mindspore/ops/_op_impl/aicpu/fft_with_size.py +47 -0
  575. mindspore/ops/_op_impl/aicpu/fill_diagonal.py +39 -0
  576. mindspore/ops/_op_impl/aicpu/fill_v2.py +58 -0
  577. mindspore/ops/_op_impl/aicpu/flatten.py +43 -0
  578. mindspore/ops/_op_impl/aicpu/floor_div.py +38 -0
  579. mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
  580. mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
  581. mindspore/ops/_op_impl/aicpu/fractional_avg_pool.py +41 -0
  582. mindspore/ops/_op_impl/aicpu/fractional_avg_pool_grad.py +41 -0
  583. mindspore/ops/_op_impl/aicpu/fractional_max_pool.py +41 -0
  584. mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_grad_with_fixed_ksize.py +43 -0
  585. mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +65 -0
  586. mindspore/ops/_op_impl/aicpu/fractional_max_pool_grad.py +42 -0
  587. mindspore/ops/_op_impl/aicpu/fractional_max_pool_grad_with_fixed_ksize.py +42 -0
  588. mindspore/ops/_op_impl/aicpu/fractional_max_pool_with_fixed_ksize.py +49 -0
  589. mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
  590. mindspore/ops/_op_impl/aicpu/fused_sparse_adam.py +46 -0
  591. mindspore/ops/_op_impl/aicpu/fused_sparse_ftrl.py +41 -0
  592. mindspore/ops/_op_impl/aicpu/fused_sparse_lazy_adam.py +46 -0
  593. mindspore/ops/_op_impl/aicpu/fused_sparse_proximal_adagrad.py +39 -0
  594. mindspore/ops/_op_impl/aicpu/gamma.py +38 -0
  595. mindspore/ops/_op_impl/aicpu/gather.py +46 -0
  596. mindspore/ops/_op_impl/aicpu/gather_d.py +79 -0
  597. mindspore/ops/_op_impl/aicpu/gather_d_grad_v2.py +79 -0
  598. mindspore/ops/_op_impl/aicpu/gather_grad.py +54 -0
  599. mindspore/ops/_op_impl/aicpu/gather_nd.py +56 -0
  600. mindspore/ops/_op_impl/aicpu/gcd.py +32 -0
  601. mindspore/ops/_op_impl/aicpu/generate_eod_mask.py +38 -0
  602. mindspore/ops/_op_impl/aicpu/geqrf.py +32 -0
  603. mindspore/ops/_op_impl/aicpu/get_next.py +39 -0
  604. mindspore/ops/_op_impl/aicpu/glu.py +33 -0
  605. mindspore/ops/_op_impl/aicpu/glu_grad.py +34 -0
  606. mindspore/ops/_op_impl/aicpu/greater.py +41 -0
  607. mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
  608. mindspore/ops/_op_impl/aicpu/grid_sampler_2d.py +35 -0
  609. mindspore/ops/_op_impl/aicpu/grid_sampler_2d_grad.py +38 -0
  610. mindspore/ops/_op_impl/aicpu/grid_sampler_3d.py +34 -0
  611. mindspore/ops/_op_impl/aicpu/grid_sampler_3d_grad.py +38 -0
  612. mindspore/ops/_op_impl/aicpu/hamming_window.py +57 -0
  613. mindspore/ops/_op_impl/aicpu/hard_sigmoid.py +32 -0
  614. mindspore/ops/_op_impl/aicpu/hard_sigmoid_grad.py +33 -0
  615. mindspore/ops/_op_impl/aicpu/heaviside.py +40 -0
  616. mindspore/ops/_op_impl/aicpu/histogram.py +35 -0
  617. mindspore/ops/_op_impl/aicpu/hsv_to_rgb.py +32 -0
  618. mindspore/ops/_op_impl/aicpu/hypot.py +32 -0
  619. mindspore/ops/_op_impl/aicpu/identity.py +42 -0
  620. mindspore/ops/_op_impl/aicpu/identity_n.py +41 -0
  621. mindspore/ops/_op_impl/aicpu/igamma.py +30 -0
  622. mindspore/ops/_op_impl/aicpu/igammac.py +30 -0
  623. mindspore/ops/_op_impl/aicpu/igammagrada.py +30 -0
  624. mindspore/ops/_op_impl/aicpu/im2col.py +43 -0
  625. mindspore/ops/_op_impl/aicpu/imag.py +31 -0
  626. mindspore/ops/_op_impl/aicpu/index_fill.py +54 -0
  627. mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
  628. mindspore/ops/_op_impl/aicpu/init_data_set_queue.py +27 -0
  629. mindspore/ops/_op_impl/aicpu/inplace_index_add.py +39 -0
  630. mindspore/ops/_op_impl/aicpu/instance_norm_v2.py +41 -0
  631. mindspore/ops/_op_impl/aicpu/instance_norm_v2_grad.py +44 -0
  632. mindspore/ops/_op_impl/aicpu/is_finite.py +40 -0
  633. mindspore/ops/_op_impl/aicpu/is_inf.py +31 -0
  634. mindspore/ops/_op_impl/aicpu/is_nan.py +31 -0
  635. mindspore/ops/_op_impl/aicpu/kldivloss.py +34 -0
  636. mindspore/ops/_op_impl/aicpu/kldivlossgrad.py +35 -0
  637. mindspore/ops/_op_impl/aicpu/layer_norm_grad_grad.py +47 -0
  638. mindspore/ops/_op_impl/aicpu/lcm.py +32 -0
  639. mindspore/ops/_op_impl/aicpu/left_shift.py +38 -0
  640. mindspore/ops/_op_impl/aicpu/less.py +41 -0
  641. mindspore/ops/_op_impl/aicpu/less_equal.py +41 -0
  642. mindspore/ops/_op_impl/aicpu/lgamma.py +33 -0
  643. mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +57 -0
  644. mindspore/ops/_op_impl/aicpu/linspace.py +33 -0
  645. mindspore/ops/_op_impl/aicpu/list_diff.py +50 -0
  646. mindspore/ops/_op_impl/aicpu/log.py +37 -0
  647. mindspore/ops/_op_impl/aicpu/log1p.py +34 -0
  648. mindspore/ops/_op_impl/aicpu/log_matrix_determinant.py +31 -0
  649. mindspore/ops/_op_impl/aicpu/log_normal_reverse.py +33 -0
  650. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +37 -0
  651. mindspore/ops/_op_impl/aicpu/logical_xor.py +30 -0
  652. mindspore/ops/_op_impl/aicpu/logit.py +33 -0
  653. mindspore/ops/_op_impl/aicpu/logit_grad.py +34 -0
  654. mindspore/ops/_op_impl/aicpu/logspace.py +36 -0
  655. mindspore/ops/_op_impl/aicpu/lower_bound.py +47 -0
  656. mindspore/ops/_op_impl/aicpu/lstsq.py +34 -0
  657. mindspore/ops/_op_impl/aicpu/lu.py +39 -0
  658. mindspore/ops/_op_impl/aicpu/lu_solve.py +32 -0
  659. mindspore/ops/_op_impl/aicpu/lu_unpack.py +114 -0
  660. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +49 -0
  661. mindspore/ops/_op_impl/aicpu/masked_fill.py +42 -0
  662. mindspore/ops/_op_impl/aicpu/masked_scatter.py +40 -0
  663. mindspore/ops/_op_impl/aicpu/masked_select.py +31 -0
  664. mindspore/ops/_op_impl/aicpu/masked_select_grad.py +35 -0
  665. mindspore/ops/_op_impl/aicpu/matmul.py +39 -0
  666. mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
  667. mindspore/ops/_op_impl/aicpu/matrix_determinant.py +30 -0
  668. mindspore/ops/_op_impl/aicpu/matrix_diag_part_v3.py +54 -0
  669. mindspore/ops/_op_impl/aicpu/matrix_diag_v3.py +56 -0
  670. mindspore/ops/_op_impl/aicpu/matrix_exp.py +34 -0
  671. mindspore/ops/_op_impl/aicpu/matrix_inverse.py +31 -0
  672. mindspore/ops/_op_impl/aicpu/matrix_logarithm.py +31 -0
  673. mindspore/ops/_op_impl/aicpu/matrix_power.py +37 -0
  674. mindspore/ops/_op_impl/aicpu/matrix_set_diag_v3.py +54 -0
  675. mindspore/ops/_op_impl/aicpu/matrix_solve.py +35 -0
  676. mindspore/ops/_op_impl/aicpu/matrix_solve_ls.py +36 -0
  677. mindspore/ops/_op_impl/aicpu/matrix_triangular_solve.py +36 -0
  678. mindspore/ops/_op_impl/aicpu/max_pool3d_grad_with_argmax.py +60 -0
  679. mindspore/ops/_op_impl/aicpu/max_pool3d_with_argmax.py +59 -0
  680. mindspore/ops/_op_impl/aicpu/max_unpool2d.py +57 -0
  681. mindspore/ops/_op_impl/aicpu/max_unpool2d_grad.py +58 -0
  682. mindspore/ops/_op_impl/aicpu/max_unpool3d.py +57 -0
  683. mindspore/ops/_op_impl/aicpu/max_unpool3d_grad.py +58 -0
  684. mindspore/ops/_op_impl/aicpu/maximum_grad_grad.py +40 -0
  685. mindspore/ops/_op_impl/aicpu/maxpool_grad_v1.py +46 -0
  686. mindspore/ops/_op_impl/aicpu/maxpool_v1.py +42 -0
  687. mindspore/ops/_op_impl/aicpu/median.py +39 -0
  688. mindspore/ops/_op_impl/aicpu/median_grad.py +45 -0
  689. mindspore/ops/_op_impl/aicpu/meshgrid.py +41 -0
  690. mindspore/ops/_op_impl/aicpu/minimum_grad_grad.py +40 -0
  691. mindspore/ops/_op_impl/aicpu/mirror_pad.py +50 -0
  692. mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +48 -0
  693. mindspore/ops/_op_impl/aicpu/mul.py +43 -0
  694. mindspore/ops/_op_impl/aicpu/mul_no_nan.py +42 -0
  695. mindspore/ops/_op_impl/aicpu/multi_margin_loss.py +37 -0
  696. mindspore/ops/_op_impl/aicpu/multi_margin_loss_grad.py +41 -0
  697. mindspore/ops/_op_impl/aicpu/multilabel_margin_loss_grad.py +37 -0
  698. mindspore/ops/_op_impl/aicpu/multinomial.py +47 -0
  699. mindspore/ops/_op_impl/aicpu/multinomial_with_replacement.py +35 -0
  700. mindspore/ops/_op_impl/aicpu/mvlgamma.py +32 -0
  701. mindspore/ops/_op_impl/aicpu/mvlgamma_grad.py +33 -0
  702. mindspore/ops/_op_impl/aicpu/nan_to_num.py +34 -0
  703. mindspore/ops/_op_impl/aicpu/neg.py +36 -0
  704. mindspore/ops/_op_impl/aicpu/nextafter.py +32 -0
  705. mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
  706. mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
  707. mindspore/ops/_op_impl/aicpu/no_repeat_ngram.py +34 -0
  708. mindspore/ops/_op_impl/aicpu/non_deterministic_ints.py +33 -0
  709. mindspore/ops/_op_impl/aicpu/non_max_suppression.py +36 -0
  710. mindspore/ops/_op_impl/aicpu/non_max_suppression_with_overlaps.py +35 -0
  711. mindspore/ops/_op_impl/aicpu/non_zero.py +43 -0
  712. mindspore/ops/_op_impl/aicpu/not_equal.py +39 -0
  713. mindspore/ops/_op_impl/aicpu/nth_element.py +39 -0
  714. mindspore/ops/_op_impl/aicpu/nuclear_norm.py +33 -0
  715. mindspore/ops/_op_impl/aicpu/one_hot.py +116 -0
  716. mindspore/ops/_op_impl/aicpu/ones_like.py +39 -0
  717. mindspore/ops/_op_impl/aicpu/orgqr.py +34 -0
  718. mindspore/ops/_op_impl/aicpu/pad_and_shift.py +33 -0
  719. mindspore/ops/_op_impl/aicpu/pad_v3.py +61 -0
  720. mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +59 -0
  721. mindspore/ops/_op_impl/aicpu/padding.py +41 -0
  722. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +54 -0
  723. mindspore/ops/_op_impl/aicpu/pdist_grad.py +33 -0
  724. mindspore/ops/_op_impl/aicpu/poisson.py +37 -0
  725. mindspore/ops/_op_impl/aicpu/polar.py +32 -0
  726. mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
  727. mindspore/ops/_op_impl/aicpu/pow.py +39 -0
  728. mindspore/ops/_op_impl/aicpu/print_tensor.py +39 -0
  729. mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +113 -0
  730. mindspore/ops/_op_impl/aicpu/qr.py +36 -0
  731. mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
  732. mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
  733. mindspore/ops/_op_impl/aicpu/ragged_range.py +49 -0
  734. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
  735. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_tensor.py +74 -0
  736. mindspore/ops/_op_impl/aicpu/random_categorical.py +68 -0
  737. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +36 -0
  738. mindspore/ops/_op_impl/aicpu/random_gamma.py +38 -0
  739. mindspore/ops/_op_impl/aicpu/random_poisson.py +134 -0
  740. mindspore/ops/_op_impl/aicpu/random_shuffle.py +47 -0
  741. mindspore/ops/_op_impl/aicpu/randperm.py +38 -0
  742. mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
  743. mindspore/ops/_op_impl/aicpu/range.py +36 -0
  744. mindspore/ops/_op_impl/aicpu/range_v2.py +35 -0
  745. mindspore/ops/_op_impl/aicpu/real.py +31 -0
  746. mindspore/ops/_op_impl/aicpu/real_div.py +40 -0
  747. mindspore/ops/_op_impl/aicpu/reciprocal.py +34 -0
  748. mindspore/ops/_op_impl/aicpu/reciprocal_grad.py +35 -0
  749. mindspore/ops/_op_impl/aicpu/reduce_mean.py +57 -0
  750. mindspore/ops/_op_impl/aicpu/reduce_prod.py +57 -0
  751. mindspore/ops/_op_impl/aicpu/reduce_sum.py +57 -0
  752. mindspore/ops/_op_impl/aicpu/relu_grad_v3.py +41 -0
  753. mindspore/ops/_op_impl/aicpu/relu_v3.py +38 -0
  754. mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +96 -0
  755. mindspore/ops/_op_impl/aicpu/reshape.py +42 -0
  756. mindspore/ops/_op_impl/aicpu/resize_area.py +40 -0
  757. mindspore/ops/_op_impl/aicpu/resize_bicubic.py +20 -0
  758. mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +19 -0
  759. mindspore/ops/_op_impl/aicpu/resize_bilinear.py +32 -0
  760. mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +32 -0
  761. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +36 -0
  762. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +35 -0
  763. mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
  764. mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
  765. mindspore/ops/_op_impl/aicpu/reverse_sequence.py +55 -0
  766. mindspore/ops/_op_impl/aicpu/reversev2.py +54 -0
  767. mindspore/ops/_op_impl/aicpu/rgb_to_hsv.py +32 -0
  768. mindspore/ops/_op_impl/aicpu/right_shift.py +38 -0
  769. mindspore/ops/_op_impl/aicpu/rnnt_loss.py +35 -0
  770. mindspore/ops/_op_impl/aicpu/round.py +34 -0
  771. mindspore/ops/_op_impl/aicpu/rsqrt.py +33 -0
  772. mindspore/ops/_op_impl/aicpu/rsqrt_grad.py +36 -0
  773. mindspore/ops/_op_impl/aicpu/sample_distorted_bounding_box_v2.py +49 -0
  774. mindspore/ops/_op_impl/aicpu/scale_and_translate.py +52 -0
  775. mindspore/ops/_op_impl/aicpu/scale_and_translate_grad.py +36 -0
  776. mindspore/ops/_op_impl/aicpu/scatter.py +79 -0
  777. mindspore/ops/_op_impl/aicpu/scatter_add_with_axis.py +53 -0
  778. mindspore/ops/_op_impl/aicpu/scatter_elements.py +39 -0
  779. mindspore/ops/_op_impl/aicpu/scatter_nd.py +59 -0
  780. mindspore/ops/_op_impl/aicpu/scatter_nd_max.py +54 -0
  781. mindspore/ops/_op_impl/aicpu/scatter_nd_min.py +54 -0
  782. mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +59 -0
  783. mindspore/ops/_op_impl/aicpu/search_sorted.py +44 -0
  784. mindspore/ops/_op_impl/aicpu/segment_max.py +52 -0
  785. mindspore/ops/_op_impl/aicpu/segment_mean.py +56 -0
  786. mindspore/ops/_op_impl/aicpu/segment_min.py +52 -0
  787. mindspore/ops/_op_impl/aicpu/segment_prod.py +56 -0
  788. mindspore/ops/_op_impl/aicpu/segment_sum.py +56 -0
  789. mindspore/ops/_op_impl/aicpu/select.py +45 -0
  790. mindspore/ops/_op_impl/aicpu/self_adjoint_eig.py +34 -0
  791. mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
  792. mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
  793. mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
  794. mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
  795. mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
  796. mindspore/ops/_op_impl/aicpu/set_size.py +38 -0
  797. mindspore/ops/_op_impl/aicpu/sign.py +36 -0
  798. mindspore/ops/_op_impl/aicpu/sin.py +34 -0
  799. mindspore/ops/_op_impl/aicpu/sinc.py +43 -0
  800. mindspore/ops/_op_impl/aicpu/sinh.py +34 -0
  801. mindspore/ops/_op_impl/aicpu/slice.py +59 -0
  802. mindspore/ops/_op_impl/aicpu/slice_grad.py +76 -0
  803. mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
  804. mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
  805. mindspore/ops/_op_impl/aicpu/sort.py +39 -0
  806. mindspore/ops/_op_impl/aicpu/space_to_depth.py +44 -0
  807. mindspore/ops/_op_impl/aicpu/sparse_addmm.py +87 -0
  808. mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +80 -0
  809. mindspore/ops/_op_impl/aicpu/sparse_apply_centered_rms_prop.py +105 -0
  810. mindspore/ops/_op_impl/aicpu/sparse_apply_momentum.py +80 -0
  811. mindspore/ops/_op_impl/aicpu/sparse_apply_proximal_gradient_descent.py +79 -0
  812. mindspore/ops/_op_impl/aicpu/sparse_concat.py +59 -0
  813. mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
  814. mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_add.py +58 -0
  815. mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_div.py +58 -0
  816. mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_mul.py +58 -0
  817. mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows.py +63 -0
  818. mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows_grad.py +45 -0
  819. mindspore/ops/_op_impl/aicpu/sparse_matrix_mat_mul.py +56 -0
  820. mindspore/ops/_op_impl/aicpu/sparse_matrix_nnz.py +81 -0
  821. mindspore/ops/_op_impl/aicpu/sparse_matrix_transpose.py +116 -0
  822. mindspore/ops/_op_impl/aicpu/sparse_reorder.py +56 -0
  823. mindspore/ops/_op_impl/aicpu/sparse_reshape.py +34 -0
  824. mindspore/ops/_op_impl/aicpu/sparse_segment_mean_grad.py +36 -0
  825. mindspore/ops/_op_impl/aicpu/sparse_segment_mean_with_num_segments.py +44 -0
  826. mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n.py +43 -0
  827. mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n_grad.py +38 -0
  828. mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n_with_num_segments.py +44 -0
  829. mindspore/ops/_op_impl/aicpu/sparse_segment_sum.py +49 -0
  830. mindspore/ops/_op_impl/aicpu/sparse_segment_sum_with_num_segments.py +68 -0
  831. mindspore/ops/_op_impl/aicpu/sparse_slice.py +63 -0
  832. mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +61 -0
  833. mindspore/ops/_op_impl/aicpu/sparse_softmax.py +33 -0
  834. mindspore/ops/_op_impl/aicpu/sparse_softmax_cross_entropy_with_logits_v2.py +35 -0
  835. mindspore/ops/_op_impl/aicpu/sparse_sparse_maximum.py +53 -0
  836. mindspore/ops/_op_impl/aicpu/sparse_sparse_minimum.py +53 -0
  837. mindspore/ops/_op_impl/aicpu/sparse_tensor_dense_add.py +84 -0
  838. mindspore/ops/_op_impl/aicpu/sparse_tensor_dense_mat_mul.py +190 -0
  839. mindspore/ops/_op_impl/aicpu/sparse_tensor_to_csr_sparse_matrix.py +51 -0
  840. mindspore/ops/_op_impl/aicpu/sparse_to_dense_v2.py +73 -0
  841. mindspore/ops/_op_impl/aicpu/split.py +45 -0
  842. mindspore/ops/_op_impl/aicpu/sqrt.py +34 -0
  843. mindspore/ops/_op_impl/aicpu/sqrt_grad.py +35 -0
  844. mindspore/ops/_op_impl/aicpu/square.py +35 -0
  845. mindspore/ops/_op_impl/aicpu/squared_difference.py +37 -0
  846. mindspore/ops/_op_impl/aicpu/squeeze.py +42 -0
  847. mindspore/ops/_op_impl/aicpu/sspaddmm.py +97 -0
  848. mindspore/ops/_op_impl/aicpu/stack.py +45 -0
  849. mindspore/ops/_op_impl/aicpu/stack_push_pop.py +87 -0
  850. mindspore/ops/_op_impl/aicpu/standard_laplace.py +34 -0
  851. mindspore/ops/_op_impl/aicpu/standard_normal.py +34 -0
  852. mindspore/ops/_op_impl/aicpu/stateless_dropout_genmask.py +37 -0
  853. mindspore/ops/_op_impl/aicpu/stft.py +70 -0
  854. mindspore/ops/_op_impl/aicpu/strided_slice.py +43 -0
  855. mindspore/ops/_op_impl/aicpu/strided_slice_grad.py +50 -0
  856. mindspore/ops/_op_impl/aicpu/sub.py +41 -0
  857. mindspore/ops/_op_impl/aicpu/sub_and_filter.py +36 -0
  858. mindspore/ops/_op_impl/aicpu/tan.py +34 -0
  859. mindspore/ops/_op_impl/aicpu/tanh.py +34 -0
  860. mindspore/ops/_op_impl/aicpu/tanh_grad.py +35 -0
  861. mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
  862. mindspore/ops/_op_impl/aicpu/tile.py +56 -0
  863. mindspore/ops/_op_impl/aicpu/topk.py +34 -0
  864. mindspore/ops/_op_impl/aicpu/trace.py +40 -0
  865. mindspore/ops/_op_impl/aicpu/tracegrad.py +41 -0
  866. mindspore/ops/_op_impl/aicpu/trans_data.py +35 -0
  867. mindspore/ops/_op_impl/aicpu/transpose.py +58 -0
  868. mindspore/ops/_op_impl/aicpu/tridiagonal_matmul.py +42 -0
  869. mindspore/ops/_op_impl/aicpu/tridiagonal_solve.py +35 -0
  870. mindspore/ops/_op_impl/aicpu/tril.py +42 -0
  871. mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
  872. mindspore/ops/_op_impl/aicpu/triplet_margin_loss.py +62 -0
  873. mindspore/ops/_op_impl/aicpu/triu.py +43 -0
  874. mindspore/ops/_op_impl/aicpu/triu_indices.py +34 -0
  875. mindspore/ops/_op_impl/aicpu/truncated_normal.py +39 -0
  876. mindspore/ops/_op_impl/aicpu/uniform.py +36 -0
  877. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +41 -0
  878. mindspore/ops/_op_impl/aicpu/uniform_int.py +36 -0
  879. mindspore/ops/_op_impl/aicpu/uniform_real.py +33 -0
  880. mindspore/ops/_op_impl/aicpu/unique.py +31 -0
  881. mindspore/ops/_op_impl/aicpu/unique_consecutive.py +47 -0
  882. mindspore/ops/_op_impl/aicpu/unique_with_pad.py +32 -0
  883. mindspore/ops/_op_impl/aicpu/unravel_index.py +32 -0
  884. mindspore/ops/_op_impl/aicpu/unsorted_segment_prod.py +53 -0
  885. mindspore/ops/_op_impl/aicpu/unsorted_segment_sum.py +57 -0
  886. mindspore/ops/_op_impl/aicpu/unstack.py +45 -0
  887. mindspore/ops/_op_impl/aicpu/update_cache.py +44 -0
  888. mindspore/ops/_op_impl/aicpu/upper_bound.py +47 -0
  889. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +42 -0
  890. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +49 -0
  891. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +40 -0
  892. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +50 -0
  893. mindspore/ops/_op_impl/aicpu/xdivy.py +35 -0
  894. mindspore/ops/_op_impl/aicpu/xlogy.py +33 -0
  895. mindspore/ops/_op_impl/aicpu/zeros_like.py +42 -0
  896. mindspore/ops/_op_impl/aicpu/zeta.py +31 -0
  897. mindspore/ops/_op_impl/akg/__init__.py +19 -0
  898. mindspore/ops/_op_impl/akg/ascend/__init__.py +48 -0
  899. mindspore/ops/_op_impl/akg/ascend/abs.py +35 -0
  900. mindspore/ops/_op_impl/akg/ascend/add.py +42 -0
  901. mindspore/ops/_op_impl/akg/ascend/add_n.py +37 -0
  902. mindspore/ops/_op_impl/akg/ascend/batchmatmul.py +33 -0
  903. mindspore/ops/_op_impl/akg/ascend/cast.py +46 -0
  904. mindspore/ops/_op_impl/akg/ascend/equal.py +35 -0
  905. mindspore/ops/_op_impl/akg/ascend/exp.py +35 -0
  906. mindspore/ops/_op_impl/akg/ascend/expand_dims.py +33 -0
  907. mindspore/ops/_op_impl/akg/ascend/greater.py +34 -0
  908. mindspore/ops/_op_impl/akg/ascend/greater_equal.py +35 -0
  909. mindspore/ops/_op_impl/akg/ascend/less.py +31 -0
  910. mindspore/ops/_op_impl/akg/ascend/less_equal.py +35 -0
  911. mindspore/ops/_op_impl/akg/ascend/load_im2col.py +33 -0
  912. mindspore/ops/_op_impl/akg/ascend/log.py +34 -0
  913. mindspore/ops/_op_impl/akg/ascend/maximum.py +36 -0
  914. mindspore/ops/_op_impl/akg/ascend/minimum.py +39 -0
  915. mindspore/ops/_op_impl/akg/ascend/mul.py +41 -0
  916. mindspore/ops/_op_impl/akg/ascend/neg.py +37 -0
  917. mindspore/ops/_op_impl/akg/ascend/pow.py +35 -0
  918. mindspore/ops/_op_impl/akg/ascend/prod_force_se_a.py +33 -0
  919. mindspore/ops/_op_impl/akg/ascend/real_div.py +36 -0
  920. mindspore/ops/_op_impl/akg/ascend/reciprocal.py +32 -0
  921. mindspore/ops/_op_impl/akg/ascend/reduce_max.py +32 -0
  922. mindspore/ops/_op_impl/akg/ascend/reduce_min.py +32 -0
  923. mindspore/ops/_op_impl/akg/ascend/reduce_sum.py +37 -0
  924. mindspore/ops/_op_impl/akg/ascend/rsqrt.py +35 -0
  925. mindspore/ops/_op_impl/akg/ascend/select.py +37 -0
  926. mindspore/ops/_op_impl/akg/ascend/sqrt.py +35 -0
  927. mindspore/ops/_op_impl/akg/ascend/square.py +35 -0
  928. mindspore/ops/_op_impl/akg/ascend/sub.py +42 -0
  929. mindspore/ops/_op_impl/akg/cpu/__init__.py +23 -0
  930. mindspore/ops/_op_impl/akg/cpu/coo2csr.py +29 -0
  931. mindspore/ops/_op_impl/akg/cpu/csr2coo.py +29 -0
  932. mindspore/ops/_op_impl/akg/cpu/csr_gather.py +33 -0
  933. mindspore/ops/_op_impl/akg/cpu/csr_mm.py +34 -0
  934. mindspore/ops/_op_impl/akg/cpu/csr_mul.py +33 -0
  935. mindspore/ops/_op_impl/akg/cpu/csr_mv.py +33 -0
  936. mindspore/ops/_op_impl/akg/cpu/csr_reduce_sum.py +31 -0
  937. mindspore/ops/_op_impl/akg/gpu/__init__.py +24 -0
  938. mindspore/ops/_op_impl/akg/gpu/coo2csr.py +29 -0
  939. mindspore/ops/_op_impl/akg/gpu/csr2coo.py +29 -0
  940. mindspore/ops/_op_impl/akg/gpu/csr_div.py +36 -0
  941. mindspore/ops/_op_impl/akg/gpu/csr_gather.py +33 -0
  942. mindspore/ops/_op_impl/akg/gpu/csr_mm.py +37 -0
  943. mindspore/ops/_op_impl/akg/gpu/csr_mul.py +36 -0
  944. mindspore/ops/_op_impl/akg/gpu/csr_mv.py +36 -0
  945. mindspore/ops/_op_impl/akg/gpu/csr_reduce_sum.py +33 -0
  946. mindspore/ops/_op_impl/cpu/__init__.py +78 -0
  947. mindspore/ops/_op_impl/cpu/adam.py +49 -0
  948. mindspore/ops/_op_impl/cpu/adam_weight_decay.py +47 -0
  949. mindspore/ops/_op_impl/cpu/arg_max.py +30 -0
  950. mindspore/ops/_op_impl/cpu/arg_max_with_value.py +31 -0
  951. mindspore/ops/_op_impl/cpu/arg_min_with_value.py +31 -0
  952. mindspore/ops/_op_impl/cpu/buffer_append.py +28 -0
  953. mindspore/ops/_op_impl/cpu/buffer_get.py +28 -0
  954. mindspore/ops/_op_impl/cpu/buffer_sample.py +28 -0
  955. mindspore/ops/_op_impl/cpu/cast.py +171 -0
  956. mindspore/ops/_op_impl/cpu/concat_offset.py +38 -0
  957. mindspore/ops/_op_impl/cpu/conv2d.py +30 -0
  958. mindspore/ops/_op_impl/cpu/conv3d.py +30 -0
  959. mindspore/ops/_op_impl/cpu/div.py +32 -0
  960. mindspore/ops/_op_impl/cpu/dropout.py +31 -0
  961. mindspore/ops/_op_impl/cpu/dropout_grad.py +30 -0
  962. mindspore/ops/_op_impl/cpu/dynamic_shape.py +42 -0
  963. mindspore/ops/_op_impl/cpu/dynamic_stitch.py +41 -0
  964. mindspore/ops/_op_impl/cpu/equal_count.py +30 -0
  965. mindspore/ops/_op_impl/cpu/gather_d.py +49 -0
  966. mindspore/ops/_op_impl/cpu/gather_d_grad.py +38 -0
  967. mindspore/ops/_op_impl/cpu/gather_d_grad_v2.py +40 -0
  968. mindspore/ops/_op_impl/cpu/gather_v2.py +40 -0
  969. mindspore/ops/_op_impl/cpu/hsigmoid.py +33 -0
  970. mindspore/ops/_op_impl/cpu/hsigmoid_grad.py +34 -0
  971. mindspore/ops/_op_impl/cpu/hswish.py +32 -0
  972. mindspore/ops/_op_impl/cpu/hswish_grad.py +33 -0
  973. mindspore/ops/_op_impl/cpu/identity_n.py +40 -0
  974. mindspore/ops/_op_impl/cpu/is_finite.py +39 -0
  975. mindspore/ops/_op_impl/cpu/l2loss.py +30 -0
  976. mindspore/ops/_op_impl/cpu/layer_norm.py +36 -0
  977. mindspore/ops/_op_impl/cpu/layer_norm_grad.py +38 -0
  978. mindspore/ops/_op_impl/cpu/maximum.py +35 -0
  979. mindspore/ops/_op_impl/cpu/maximum_grad.py +47 -0
  980. mindspore/ops/_op_impl/cpu/minimum.py +40 -0
  981. mindspore/ops/_op_impl/cpu/minimum_grad.py +51 -0
  982. mindspore/ops/_op_impl/cpu/mirror_pad.py +36 -0
  983. mindspore/ops/_op_impl/cpu/mirror_pad_grad.py +36 -0
  984. mindspore/ops/_op_impl/cpu/mul.py +32 -0
  985. mindspore/ops/_op_impl/cpu/one_hot.py +31 -0
  986. mindspore/ops/_op_impl/cpu/pad.py +32 -0
  987. mindspore/ops/_op_impl/cpu/pow.py +32 -0
  988. mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +42 -0
  989. mindspore/ops/_op_impl/cpu/pyexecute.py +29 -0
  990. mindspore/ops/_op_impl/cpu/pyfunc.py +29 -0
  991. mindspore/ops/_op_impl/cpu/range.py +34 -0
  992. mindspore/ops/_op_impl/cpu/real_div.py +33 -0
  993. mindspore/ops/_op_impl/cpu/reduce_all.py +29 -0
  994. mindspore/ops/_op_impl/cpu/reduce_any.py +29 -0
  995. mindspore/ops/_op_impl/cpu/reduce_max.py +32 -0
  996. mindspore/ops/_op_impl/cpu/reduce_mean.py +40 -0
  997. mindspore/ops/_op_impl/cpu/reduce_min.py +32 -0
  998. mindspore/ops/_op_impl/cpu/reduce_prod.py +40 -0
  999. mindspore/ops/_op_impl/cpu/reduce_std.py +31 -0
  1000. mindspore/ops/_op_impl/cpu/reduce_sum.py +41 -0
  1001. mindspore/ops/_op_impl/cpu/space_to_batch_nd.py +38 -0
  1002. mindspore/ops/_op_impl/cpu/sparse_slice.py +62 -0
  1003. mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +60 -0
  1004. mindspore/ops/_op_impl/cpu/split.py +34 -0
  1005. mindspore/ops/_op_impl/cpu/sspaddmm.py +95 -0
  1006. mindspore/ops/_op_impl/cpu/stack.py +38 -0
  1007. mindspore/ops/_op_impl/cpu/sub.py +32 -0
  1008. mindspore/ops/_op_impl/cpu/tensor_copy_slices.py +41 -0
  1009. mindspore/ops/_op_impl/cpu/tile.py +37 -0
  1010. mindspore/ops/_op_impl/cpu/top_k.py +31 -0
  1011. mindspore/ops/_op_impl/cpu/transpose.py +39 -0
  1012. mindspore/ops/_primitive_cache.py +90 -0
  1013. mindspore/ops/_register_for_op.py +73 -0
  1014. mindspore/ops/_utils/__init__.py +20 -0
  1015. mindspore/ops/_utils/utils.py +147 -0
  1016. mindspore/ops/_vmap/__init__.py +25 -0
  1017. mindspore/ops/_vmap/vmap_array_ops.py +2149 -0
  1018. mindspore/ops/_vmap/vmap_base.py +533 -0
  1019. mindspore/ops/_vmap/vmap_convolution_ops.py +441 -0
  1020. mindspore/ops/_vmap/vmap_debug_ops.py +50 -0
  1021. mindspore/ops/_vmap/vmap_grad_math_ops.py +274 -0
  1022. mindspore/ops/_vmap/vmap_grad_nn_ops.py +806 -0
  1023. mindspore/ops/_vmap/vmap_image_ops.py +194 -0
  1024. mindspore/ops/_vmap/vmap_math_ops.py +993 -0
  1025. mindspore/ops/_vmap/vmap_nn_ops.py +2250 -0
  1026. mindspore/ops/_vmap/vmap_other_ops.py +105 -0
  1027. mindspore/ops/_vmap/vmap_random_ops.py +122 -0
  1028. mindspore/ops/_vmap/vmap_sparse_ops.py +89 -0
  1029. mindspore/ops/auto_generate/__init__.py +31 -0
  1030. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +309 -0
  1031. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +252 -0
  1032. mindspore/ops/auto_generate/gen_arg_handler.py +197 -0
  1033. mindspore/ops/auto_generate/gen_extend_func.py +1701 -0
  1034. mindspore/ops/auto_generate/gen_ops_def.py +8482 -0
  1035. mindspore/ops/auto_generate/gen_ops_prim.py +16704 -0
  1036. mindspore/ops/auto_generate/pyboost_inner_prim.py +549 -0
  1037. mindspore/ops/composite/__init__.py +71 -0
  1038. mindspore/ops/composite/base.py +1318 -0
  1039. mindspore/ops/composite/env_ops.py +41 -0
  1040. mindspore/ops/composite/math_ops.py +125 -0
  1041. mindspore/ops/composite/multitype_ops/__init__.py +77 -0
  1042. mindspore/ops/composite/multitype_ops/_compile_utils.py +1459 -0
  1043. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +897 -0
  1044. mindspore/ops/composite/multitype_ops/add_impl.py +606 -0
  1045. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +56 -0
  1046. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +56 -0
  1047. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +56 -0
  1048. mindspore/ops/composite/multitype_ops/div_impl.py +189 -0
  1049. mindspore/ops/composite/multitype_ops/equal_impl.py +335 -0
  1050. mindspore/ops/composite/multitype_ops/floordiv_impl.py +88 -0
  1051. mindspore/ops/composite/multitype_ops/getitem_impl.py +400 -0
  1052. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +109 -0
  1053. mindspore/ops/composite/multitype_ops/greater_impl.py +110 -0
  1054. mindspore/ops/composite/multitype_ops/in_impl.py +196 -0
  1055. mindspore/ops/composite/multitype_ops/left_shift_impl.py +37 -0
  1056. mindspore/ops/composite/multitype_ops/less_equal_impl.py +111 -0
  1057. mindspore/ops/composite/multitype_ops/less_impl.py +112 -0
  1058. mindspore/ops/composite/multitype_ops/logic_not_impl.py +113 -0
  1059. mindspore/ops/composite/multitype_ops/logical_and_impl.py +60 -0
  1060. mindspore/ops/composite/multitype_ops/logical_or_impl.py +61 -0
  1061. mindspore/ops/composite/multitype_ops/mod_impl.py +86 -0
  1062. mindspore/ops/composite/multitype_ops/mul_impl.py +294 -0
  1063. mindspore/ops/composite/multitype_ops/negative_impl.py +79 -0
  1064. mindspore/ops/composite/multitype_ops/not_equal_impl.py +290 -0
  1065. mindspore/ops/composite/multitype_ops/not_in_impl.py +196 -0
  1066. mindspore/ops/composite/multitype_ops/ones_like_impl.py +96 -0
  1067. mindspore/ops/composite/multitype_ops/pow_impl.py +87 -0
  1068. mindspore/ops/composite/multitype_ops/right_shift_impl.py +37 -0
  1069. mindspore/ops/composite/multitype_ops/setitem_impl.py +884 -0
  1070. mindspore/ops/composite/multitype_ops/sub_impl.py +116 -0
  1071. mindspore/ops/composite/multitype_ops/uadd_impl.py +29 -0
  1072. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +228 -0
  1073. mindspore/ops/deprecated.py +315 -0
  1074. mindspore/ops/function/__init__.py +782 -0
  1075. mindspore/ops/function/array_func.py +7226 -0
  1076. mindspore/ops/function/clip_func.py +384 -0
  1077. mindspore/ops/function/debug_func.py +181 -0
  1078. mindspore/ops/function/fft_func.py +44 -0
  1079. mindspore/ops/function/grad/__init__.py +34 -0
  1080. mindspore/ops/function/grad/grad_func.py +1425 -0
  1081. mindspore/ops/function/image_func.py +292 -0
  1082. mindspore/ops/function/linalg_func.py +416 -0
  1083. mindspore/ops/function/math_func.py +12228 -0
  1084. mindspore/ops/function/nn_func.py +8609 -0
  1085. mindspore/ops/function/other_func.py +115 -0
  1086. mindspore/ops/function/parameter_func.py +134 -0
  1087. mindspore/ops/function/random_func.py +1715 -0
  1088. mindspore/ops/function/reshard_func.py +104 -0
  1089. mindspore/ops/function/sparse_func.py +884 -0
  1090. mindspore/ops/function/sparse_unary_func.py +2422 -0
  1091. mindspore/ops/function/spectral_func.py +150 -0
  1092. mindspore/ops/function/vmap_func.py +117 -0
  1093. mindspore/ops/functional.py +464 -0
  1094. mindspore/ops/op_info_register.py +1572 -0
  1095. mindspore/ops/operations/__init__.py +722 -0
  1096. mindspore/ops/operations/_csr_ops.py +403 -0
  1097. mindspore/ops/operations/_custom_grad.py +181 -0
  1098. mindspore/ops/operations/_embedding_cache_ops.py +307 -0
  1099. mindspore/ops/operations/_grad_ops.py +2978 -0
  1100. mindspore/ops/operations/_infer_ops.py +19 -0
  1101. mindspore/ops/operations/_inner_ops.py +2544 -0
  1102. mindspore/ops/operations/_map_tensor_ops.py +112 -0
  1103. mindspore/ops/operations/_ms_kernel.py +601 -0
  1104. mindspore/ops/operations/_ocr_ops.py +379 -0
  1105. mindspore/ops/operations/_opaque_predicate_registry.py +41 -0
  1106. mindspore/ops/operations/_pyfunc_registry.py +58 -0
  1107. mindspore/ops/operations/_quant_ops.py +1844 -0
  1108. mindspore/ops/operations/_rl_inner_ops.py +1231 -0
  1109. mindspore/ops/operations/_scalar_ops.py +106 -0
  1110. mindspore/ops/operations/_sequence_ops.py +1155 -0
  1111. mindspore/ops/operations/_sparse_grad_ops.py +56 -0
  1112. mindspore/ops/operations/_tensor_array.py +359 -0
  1113. mindspore/ops/operations/_thor_ops.py +807 -0
  1114. mindspore/ops/operations/array_ops.py +6124 -0
  1115. mindspore/ops/operations/comm_ops.py +1985 -0
  1116. mindspore/ops/operations/control_ops.py +127 -0
  1117. mindspore/ops/operations/custom_ops.py +1129 -0
  1118. mindspore/ops/operations/debug_ops.py +678 -0
  1119. mindspore/ops/operations/image_ops.py +1041 -0
  1120. mindspore/ops/operations/inner_ops.py +697 -0
  1121. mindspore/ops/operations/linalg_ops.py +95 -0
  1122. mindspore/ops/operations/manually_defined/__init__.py +24 -0
  1123. mindspore/ops/operations/manually_defined/_inner.py +73 -0
  1124. mindspore/ops/operations/manually_defined/ops_def.py +2271 -0
  1125. mindspore/ops/operations/math_ops.py +5095 -0
  1126. mindspore/ops/operations/nn_ops.py +9575 -0
  1127. mindspore/ops/operations/other_ops.py +874 -0
  1128. mindspore/ops/operations/random_ops.py +1288 -0
  1129. mindspore/ops/operations/reshard_ops.py +53 -0
  1130. mindspore/ops/operations/rl_ops.py +288 -0
  1131. mindspore/ops/operations/sparse_ops.py +2753 -0
  1132. mindspore/ops/operations/spectral_ops.py +111 -0
  1133. mindspore/ops/primitive.py +1046 -0
  1134. mindspore/ops/signature.py +54 -0
  1135. mindspore/ops/vm_impl_registry.py +91 -0
  1136. mindspore/ops_generate/__init__.py +27 -0
  1137. mindspore/ops_generate/arg_dtype_cast.py +252 -0
  1138. mindspore/ops_generate/arg_handler.py +197 -0
  1139. mindspore/ops_generate/gen_aclnn_implement.py +263 -0
  1140. mindspore/ops_generate/gen_constants.py +36 -0
  1141. mindspore/ops_generate/gen_ops.py +1099 -0
  1142. mindspore/ops_generate/gen_ops_inner_prim.py +131 -0
  1143. mindspore/ops_generate/gen_pyboost_func.py +1052 -0
  1144. mindspore/ops_generate/gen_utils.py +209 -0
  1145. mindspore/ops_generate/op_proto.py +145 -0
  1146. mindspore/ops_generate/pyboost_utils.py +367 -0
  1147. mindspore/ops_generate/template.py +261 -0
  1148. mindspore/parallel/__init__.py +30 -0
  1149. mindspore/parallel/_auto_parallel_context.py +1486 -0
  1150. mindspore/parallel/_cell_wrapper.py +174 -0
  1151. mindspore/parallel/_cost_model_context.py +700 -0
  1152. mindspore/parallel/_dp_allreduce_fusion.py +159 -0
  1153. mindspore/parallel/_offload_context.py +275 -0
  1154. mindspore/parallel/_parallel_serialization.py +561 -0
  1155. mindspore/parallel/_ps_context.py +242 -0
  1156. mindspore/parallel/_recovery_context.py +110 -0
  1157. mindspore/parallel/_tensor.py +730 -0
  1158. mindspore/parallel/_transformer/__init__.py +35 -0
  1159. mindspore/parallel/_transformer/layers.py +765 -0
  1160. mindspore/parallel/_transformer/loss.py +251 -0
  1161. mindspore/parallel/_transformer/moe.py +693 -0
  1162. mindspore/parallel/_transformer/op_parallel_config.py +222 -0
  1163. mindspore/parallel/_transformer/transformer.py +3119 -0
  1164. mindspore/parallel/_utils.py +612 -0
  1165. mindspore/parallel/algo_parameter_config.py +400 -0
  1166. mindspore/parallel/checkpoint_transform.py +650 -0
  1167. mindspore/parallel/cluster/__init__.py +15 -0
  1168. mindspore/parallel/cluster/process_entity/__init__.py +18 -0
  1169. mindspore/parallel/cluster/process_entity/_api.py +352 -0
  1170. mindspore/parallel/cluster/process_entity/_utils.py +101 -0
  1171. mindspore/parallel/cluster/run.py +136 -0
  1172. mindspore/parallel/mpi/__init__.py +14 -0
  1173. mindspore/parallel/mpi/_mpi_config.py +116 -0
  1174. mindspore/parallel/parameter_broadcast.py +151 -0
  1175. mindspore/parallel/shard.py +481 -0
  1176. mindspore/parallel/transform_safetensors.py +993 -0
  1177. mindspore/profiler/__init__.py +28 -0
  1178. mindspore/profiler/common/__init__.py +14 -0
  1179. mindspore/profiler/common/constant.py +29 -0
  1180. mindspore/profiler/common/exceptions/__init__.py +14 -0
  1181. mindspore/profiler/common/exceptions/error_code.py +83 -0
  1182. mindspore/profiler/common/exceptions/exceptions.py +286 -0
  1183. mindspore/profiler/common/process_pool.py +41 -0
  1184. mindspore/profiler/common/registry.py +47 -0
  1185. mindspore/profiler/common/singleton.py +28 -0
  1186. mindspore/profiler/common/struct_type.py +118 -0
  1187. mindspore/profiler/common/util.py +472 -0
  1188. mindspore/profiler/common/validator/__init__.py +14 -0
  1189. mindspore/profiler/common/validator/validate_path.py +84 -0
  1190. mindspore/profiler/dynamic_profiler.py +694 -0
  1191. mindspore/profiler/envprofiling.py +254 -0
  1192. mindspore/profiler/parser/__init__.py +14 -0
  1193. mindspore/profiler/parser/aicpu_data_parser.py +272 -0
  1194. mindspore/profiler/parser/ascend_analysis/__init__.py +14 -0
  1195. mindspore/profiler/parser/ascend_analysis/constant.py +71 -0
  1196. mindspore/profiler/parser/ascend_analysis/file_manager.py +180 -0
  1197. mindspore/profiler/parser/ascend_analysis/function_event.py +185 -0
  1198. mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +136 -0
  1199. mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +131 -0
  1200. mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +104 -0
  1201. mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
  1202. mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +123 -0
  1203. mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +86 -0
  1204. mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +75 -0
  1205. mindspore/profiler/parser/ascend_cluster_generator.py +116 -0
  1206. mindspore/profiler/parser/ascend_communicate_generator.py +314 -0
  1207. mindspore/profiler/parser/ascend_flops_generator.py +116 -0
  1208. mindspore/profiler/parser/ascend_fpbp_generator.py +82 -0
  1209. mindspore/profiler/parser/ascend_hccl_generator.py +271 -0
  1210. mindspore/profiler/parser/ascend_integrate_generator.py +42 -0
  1211. mindspore/profiler/parser/ascend_memory_generator.py +185 -0
  1212. mindspore/profiler/parser/ascend_msprof_exporter.py +282 -0
  1213. mindspore/profiler/parser/ascend_msprof_generator.py +187 -0
  1214. mindspore/profiler/parser/ascend_op_generator.py +334 -0
  1215. mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
  1216. mindspore/profiler/parser/ascend_timeline_generator.py +545 -0
  1217. mindspore/profiler/parser/base_timeline_generator.py +483 -0
  1218. mindspore/profiler/parser/container.py +229 -0
  1219. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +697 -0
  1220. mindspore/profiler/parser/flops_parser.py +531 -0
  1221. mindspore/profiler/parser/framework_enum.py +111 -0
  1222. mindspore/profiler/parser/framework_parser.py +464 -0
  1223. mindspore/profiler/parser/framework_struct.py +61 -0
  1224. mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
  1225. mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
  1226. mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
  1227. mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
  1228. mindspore/profiler/parser/hccl_parser.py +573 -0
  1229. mindspore/profiler/parser/hwts_log_parser.py +122 -0
  1230. mindspore/profiler/parser/integrator.py +526 -0
  1231. mindspore/profiler/parser/memory_usage_parser.py +277 -0
  1232. mindspore/profiler/parser/minddata_analyzer.py +800 -0
  1233. mindspore/profiler/parser/minddata_parser.py +186 -0
  1234. mindspore/profiler/parser/minddata_pipeline_parser.py +299 -0
  1235. mindspore/profiler/parser/op_intermediate_parser.py +149 -0
  1236. mindspore/profiler/parser/optime_parser.py +250 -0
  1237. mindspore/profiler/parser/profiler_info.py +213 -0
  1238. mindspore/profiler/parser/step_trace_parser.py +666 -0
  1239. mindspore/profiler/profiler.py +153 -0
  1240. mindspore/profiler/profiling.py +1922 -0
  1241. mindspore/rewrite/__init__.py +28 -0
  1242. mindspore/rewrite/api/__init__.py +17 -0
  1243. mindspore/rewrite/api/node.py +519 -0
  1244. mindspore/rewrite/api/node_type.py +53 -0
  1245. mindspore/rewrite/api/pattern_engine.py +490 -0
  1246. mindspore/rewrite/api/scoped_value.py +181 -0
  1247. mindspore/rewrite/api/symbol_tree.py +497 -0
  1248. mindspore/rewrite/ast_helpers/__init__.py +25 -0
  1249. mindspore/rewrite/ast_helpers/ast_converter.py +143 -0
  1250. mindspore/rewrite/ast_helpers/ast_finder.py +404 -0
  1251. mindspore/rewrite/ast_helpers/ast_flattener.py +268 -0
  1252. mindspore/rewrite/ast_helpers/ast_modifier.py +605 -0
  1253. mindspore/rewrite/ast_helpers/ast_replacer.py +79 -0
  1254. mindspore/rewrite/common/__init__.py +19 -0
  1255. mindspore/rewrite/common/config.py +24 -0
  1256. mindspore/rewrite/common/error_log.py +39 -0
  1257. mindspore/rewrite/common/event.py +28 -0
  1258. mindspore/rewrite/common/namer.py +271 -0
  1259. mindspore/rewrite/common/namespace.py +118 -0
  1260. mindspore/rewrite/common/observable.py +44 -0
  1261. mindspore/rewrite/common/observer.py +54 -0
  1262. mindspore/rewrite/node/__init__.py +22 -0
  1263. mindspore/rewrite/node/call_function.py +95 -0
  1264. mindspore/rewrite/node/cell_container.py +139 -0
  1265. mindspore/rewrite/node/control_flow.py +113 -0
  1266. mindspore/rewrite/node/node.py +1428 -0
  1267. mindspore/rewrite/node/node_manager.py +283 -0
  1268. mindspore/rewrite/node/node_topological_manager.py +223 -0
  1269. mindspore/rewrite/parsers/__init__.py +29 -0
  1270. mindspore/rewrite/parsers/arguments_parser.py +63 -0
  1271. mindspore/rewrite/parsers/assign_parser.py +852 -0
  1272. mindspore/rewrite/parsers/attribute_parser.py +57 -0
  1273. mindspore/rewrite/parsers/class_def_parser.py +289 -0
  1274. mindspore/rewrite/parsers/constant_parser.py +104 -0
  1275. mindspore/rewrite/parsers/container_parser.py +88 -0
  1276. mindspore/rewrite/parsers/expr_parser.py +55 -0
  1277. mindspore/rewrite/parsers/for_parser.py +61 -0
  1278. mindspore/rewrite/parsers/function_def_parser.py +84 -0
  1279. mindspore/rewrite/parsers/if_parser.py +85 -0
  1280. mindspore/rewrite/parsers/module_parser.py +117 -0
  1281. mindspore/rewrite/parsers/parser.py +43 -0
  1282. mindspore/rewrite/parsers/parser_register.py +86 -0
  1283. mindspore/rewrite/parsers/return_parser.py +37 -0
  1284. mindspore/rewrite/parsers/while_parser.py +59 -0
  1285. mindspore/rewrite/sparsify/__init__.py +0 -0
  1286. mindspore/rewrite/sparsify/sparse_transformer.py +457 -0
  1287. mindspore/rewrite/sparsify/sparsify.py +112 -0
  1288. mindspore/rewrite/sparsify/utils.py +179 -0
  1289. mindspore/rewrite/symbol_tree/__init__.py +20 -0
  1290. mindspore/rewrite/symbol_tree/symbol_tree.py +1819 -0
  1291. mindspore/rewrite/symbol_tree/symbol_tree_builder.py +76 -0
  1292. mindspore/rewrite/symbol_tree/symbol_tree_dumper.py +142 -0
  1293. mindspore/run_check/__init__.py +20 -0
  1294. mindspore/run_check/_check_version.py +507 -0
  1295. mindspore/run_check/run_check.py +66 -0
  1296. mindspore/safeguard/__init__.py +18 -0
  1297. mindspore/safeguard/rewrite_obfuscation.py +875 -0
  1298. mindspore/scipy/__init__.py +18 -0
  1299. mindspore/scipy/fft.py +264 -0
  1300. mindspore/scipy/linalg.py +919 -0
  1301. mindspore/scipy/ops.py +165 -0
  1302. mindspore/scipy/ops_grad.py +115 -0
  1303. mindspore/scipy/ops_wrapper.py +74 -0
  1304. mindspore/scipy/optimize/__init__.py +20 -0
  1305. mindspore/scipy/optimize/_bfgs.py +230 -0
  1306. mindspore/scipy/optimize/_lagrange.py +201 -0
  1307. mindspore/scipy/optimize/_lbfgs.py +146 -0
  1308. mindspore/scipy/optimize/gradient_optimization_algorithm.py +168 -0
  1309. mindspore/scipy/optimize/line_search.py +370 -0
  1310. mindspore/scipy/optimize/linear_sum_assignment.py +78 -0
  1311. mindspore/scipy/optimize/minimize.py +200 -0
  1312. mindspore/scipy/utils.py +156 -0
  1313. mindspore/scipy/utils_const.py +246 -0
  1314. mindspore/train/__init__.py +48 -0
  1315. mindspore/train/_utils.py +465 -0
  1316. mindspore/train/amp.py +935 -0
  1317. mindspore/train/anf_ir_pb2.py +1517 -0
  1318. mindspore/train/callback/__init__.py +44 -0
  1319. mindspore/train/callback/_backup_and_restore.py +117 -0
  1320. mindspore/train/callback/_callback.py +613 -0
  1321. mindspore/train/callback/_checkpoint.py +814 -0
  1322. mindspore/train/callback/_cluster_monitor.py +201 -0
  1323. mindspore/train/callback/_dataset_graph.py +150 -0
  1324. mindspore/train/callback/_early_stop.py +239 -0
  1325. mindspore/train/callback/_flops_collector.py +239 -0
  1326. mindspore/train/callback/_history.py +92 -0
  1327. mindspore/train/callback/_lambda_callback.py +80 -0
  1328. mindspore/train/callback/_landscape.py +1049 -0
  1329. mindspore/train/callback/_loss_monitor.py +107 -0
  1330. mindspore/train/callback/_lr_scheduler_callback.py +76 -0
  1331. mindspore/train/callback/_on_request_exit.py +298 -0
  1332. mindspore/train/callback/_reduce_lr_on_plateau.py +226 -0
  1333. mindspore/train/callback/_summary_collector.py +1184 -0
  1334. mindspore/train/callback/_tft_register.py +352 -0
  1335. mindspore/train/callback/_time_monitor.py +141 -0
  1336. mindspore/train/checkpoint_pb2.py +233 -0
  1337. mindspore/train/data_sink.py +219 -0
  1338. mindspore/train/dataset_helper.py +692 -0
  1339. mindspore/train/lineage_pb2.py +1260 -0
  1340. mindspore/train/loss_scale_manager.py +213 -0
  1341. mindspore/train/memory_profiling_pb2.py +298 -0
  1342. mindspore/train/metrics/__init__.py +175 -0
  1343. mindspore/train/metrics/accuracy.py +133 -0
  1344. mindspore/train/metrics/auc.py +129 -0
  1345. mindspore/train/metrics/bleu_score.py +170 -0
  1346. mindspore/train/metrics/confusion_matrix.py +700 -0
  1347. mindspore/train/metrics/cosine_similarity.py +109 -0
  1348. mindspore/train/metrics/dice.py +116 -0
  1349. mindspore/train/metrics/error.py +175 -0
  1350. mindspore/train/metrics/fbeta.py +167 -0
  1351. mindspore/train/metrics/hausdorff_distance.py +333 -0
  1352. mindspore/train/metrics/loss.py +97 -0
  1353. mindspore/train/metrics/mean_surface_distance.py +189 -0
  1354. mindspore/train/metrics/metric.py +373 -0
  1355. mindspore/train/metrics/occlusion_sensitivity.py +225 -0
  1356. mindspore/train/metrics/perplexity.py +133 -0
  1357. mindspore/train/metrics/precision.py +160 -0
  1358. mindspore/train/metrics/recall.py +159 -0
  1359. mindspore/train/metrics/roc.py +223 -0
  1360. mindspore/train/metrics/root_mean_square_surface_distance.py +191 -0
  1361. mindspore/train/metrics/topk.py +167 -0
  1362. mindspore/train/mind_ir_pb2.py +1908 -0
  1363. mindspore/train/model.py +2252 -0
  1364. mindspore/train/node_strategy_pb2.py +653 -0
  1365. mindspore/train/print_pb2.py +184 -0
  1366. mindspore/train/profiling_parallel_pb2.py +151 -0
  1367. mindspore/train/serialization.py +3325 -0
  1368. mindspore/train/summary/__init__.py +23 -0
  1369. mindspore/train/summary/_lineage_adapter.py +41 -0
  1370. mindspore/train/summary/_summary_adapter.py +496 -0
  1371. mindspore/train/summary/_writer_pool.py +207 -0
  1372. mindspore/train/summary/enums.py +56 -0
  1373. mindspore/train/summary/summary_record.py +581 -0
  1374. mindspore/train/summary/writer.py +167 -0
  1375. mindspore/train/summary_pb2.py +1165 -0
  1376. mindspore/train/train_thor/__init__.py +20 -0
  1377. mindspore/train/train_thor/convert_utils.py +268 -0
  1378. mindspore/train/train_thor/dataset_helper.py +192 -0
  1379. mindspore/train/train_thor/model_thor.py +257 -0
  1380. mindspore/utils/__init__.py +21 -0
  1381. mindspore/utils/utils.py +60 -0
  1382. mindspore/version.py +1 -0
  1383. mindspore-2.4.0.dist-info/METADATA +352 -0
  1384. mindspore-2.4.0.dist-info/RECORD +1387 -0
  1385. mindspore-2.4.0.dist-info/WHEEL +5 -0
  1386. mindspore-2.4.0.dist-info/entry_points.txt +3 -0
  1387. mindspore-2.4.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1318 @@
1
+ # This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
2
+ #
3
+ # Copyright 2020-2024 Huawei Technologies Co., Ltd
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ # ============================================================================
17
+
18
+ """Basic composite operations."""
19
+ from __future__ import absolute_import
20
+ from functools import partial
21
+
22
+ from types import FunctionType, MethodType
23
+ import numpy as np
24
+ import mindspore as ms
25
+ from mindspore import context
26
+ from mindspore.common.parameter import Parameter, ParameterTuple
27
+ from mindspore.parallel._utils import _grads_divided_by_device_num_if_recomputation
28
+ from mindspore._c_expression import GradOperation_, HyperMap_, Map_, MultitypeFuncGraph_, Tail_, \
29
+ TupleAdd_, UnpackCall_, ZipOperation_, ListAppend_, TupleGetItemTensor_, ListInsert_, \
30
+ SequenceSliceGetItem_, ListSliceSetItem_, VmapOperation_, TaylorOperation_, ListPop_, \
31
+ ListClear_, ListReverse_, ListExtend_, DictClear_, DictHasKey_, DictUpdate_, DictFromKeys_, \
32
+ ZerosLike_, TensorIndexGetitem_, TensorIndexSetitem_, ListAdd_, DictSetItem_, \
33
+ HandleBoolTensor_, PreSetitemByTuple_, StarredGetItem_, \
34
+ StarredUnpack_, StarredUnpackMerge_, IterConverter_, HasNext_, Next_, MSContext
35
+ from mindspore.common import dtype as mstype
36
+ from mindspore.common.api import jit, _pynative_executor, _wrap_func
37
+ from mindspore.common.api import _add_flags, _core
38
+ from mindspore.ops.primitive import Primitive
39
+ from mindspore.ops import signature as sig
40
+
41
+ __all__ = [TupleAdd_, ListAdd_, UnpackCall_, TupleGetItemTensor_, SequenceSliceGetItem_,
42
+ ListSliceSetItem_, ZerosLike_, TensorIndexGetitem_, TensorIndexSetitem_,
43
+ HandleBoolTensor_, PreSetitemByTuple_]
44
+
45
+
46
+ def add_flags(fn=None, **flags):
47
+ """
48
+ A decorator that adds a flag to the function.
49
+
50
+ Note:
51
+ Only supports bool value.
52
+
53
+ Args:
54
+ fn (Function): Function or cell to add flag. Default: ``None`` .
55
+ flags (dict): Flags use kwargs. Default: ``None`` .
56
+
57
+ Returns:
58
+ Function, the function with added flags.
59
+
60
+ Examples:
61
+ >>> net = Net();
62
+ >>> net = add_flags(net, predit=True)
63
+ >>> print(hasattr(net, '_func_graph_flags'))
64
+ True
65
+ """
66
+
67
+ return _add_flags(fn, **flags)
68
+
69
+
70
+ def core(fn=None, **flags):
71
+ """
72
+ A decorator that adds a flag to the function.
73
+
74
+ By default, the function is marked as True, enabling to use this decorator to
75
+ set flag to a graph.
76
+
77
+ Args:
78
+ fn (Function, optional): Function to add flag. Default: ``None`` .
79
+ flags (dict, optional): The following flags can be set core, which indicates that this is a core function or
80
+ other flag. Default: ``None`` .
81
+
82
+ Supported Platforms:
83
+ ``Ascend`` ``GPU`` ``CPU``
84
+
85
+ Examples:
86
+ >>> net = Net()
87
+ >>> net = core(net, predit=True)
88
+ >>> print(hasattr(net, '_func_graph_flags'))
89
+ True
90
+ """
91
+
92
+ return _core(fn, **flags)
93
+
94
+
95
+ def _get_grad_weights_id(weights=None):
96
+ """generate id of parameters"""
97
+ res = ""
98
+ if isinstance(weights, Parameter):
99
+ res = weights.name + str(weights.requires_grad)
100
+ if isinstance(weights, ParameterTuple):
101
+ res = ''.join(item.name + str(item.requires_grad) for item in weights)
102
+ if isinstance(weights, list):
103
+ res = ''.join(item.name + str(item.requires_grad) for item in weights if isinstance(item, Parameter))
104
+ return res
105
+
106
+
107
+ class GradOperation(GradOperation_):
108
+ """
109
+ A higher-order function which is used to generate the gradient function for the input function.
110
+
111
+ The gradient function generated by `GradOperation` higher-order function can be customized by
112
+ construction arguments.
113
+
114
+ For example, given an input function `net = Net()` that takes `x` and `y` as inputs, and has a parameter `z`,
115
+ see `Net` in Examples.
116
+
117
+ - Used to get the derivative of the input:
118
+
119
+ 1. Returns gradients with respect to the first input (see `GradNetWrtX` in Examples).
120
+
121
+ 1) Construct a `GradOperation` higher-order function with default arguments: `grad_op = GradOperation()`.
122
+
123
+ 2) Call it with input function as argument to get the gradient function: `gradient_function = grad_op(net)`.
124
+
125
+ 3) Call the gradient function with input function's inputs to get the gradients with respect to the first
126
+ input: `grad_op(net)(x, y)`.
127
+
128
+ 2. Returns gradients with respect to all inputs (see `GradNetWrtXY` in Examples).
129
+
130
+ 1) Construct a `GradOperation` higher-order function with `get_all=True` which indicates getting gradients
131
+ with respect to all inputs, they are `x` and `y` in example function `Net()`:
132
+ `grad_op = GradOperation(get_all=True)`.
133
+
134
+ 2) Call it with input function as argument to get the gradient function: `gradient_function = grad_op(net)`.
135
+
136
+ 3) Call the gradient function with input function's inputs to get the gradients with respect to all inputs:
137
+ `gradient_function(x, y)`.
138
+
139
+ - Used to get the derivative of the parameters:
140
+
141
+ Returns gradients with respect to given parameters (see `GradNetWithWrtParams` in Examples).
142
+
143
+ 1. Construct a `GradOperation` higher-order function with `get_by_list=True`:
144
+ `grad_op = GradOperation(get_by_list=True)`.
145
+
146
+ 2. Construct a `ParameterTuple` that will be passed to the input function when constructing
147
+ `GradOperation` higher-order function, it will be used as a parameter filter that determine
148
+ which gradient to return: `params = ParameterTuple(net.trainable_params())`.
149
+
150
+ 3. Call it with input function and `params` as arguments to get the gradient function:
151
+ `gradient_function = grad_op(net, params)`.
152
+
153
+ 4. Call the gradient function with input function's inputs to get the gradients with
154
+ respect to given parameters: `gradient_function(x, y)`.
155
+
156
+ - Used to get the derivative of the inputs and parameters at the same time:
157
+ Returns gradients with respect to all inputs and given parameters in the format of ((dx, dy), (dz))
158
+ (see `GradNetWrtInputsAndParams` in Examples).
159
+
160
+ 1. Construct a `GradOperation` higher-order function with `get_all=True` and `get_by_list=True`:
161
+ `grad_op = GradOperation(get_all=True, get_by_list=True)`.
162
+
163
+ 2. Construct a `ParameterTuple` that will be passed along input function when constructing
164
+ `GradOperation` higher-order function: `params = ParameterTuple(net.trainable_params())`.
165
+
166
+ 3. Call it with input function and `params` as arguments to get the gradient function:
167
+ `gradient_function = grad_op(net, params)`.
168
+
169
+ 4. Call the gradient function with input function's inputs to get the gradients with respect to
170
+ all inputs and given parameters: `gradient_function(x, y)`.
171
+
172
+ - We can configure the sensitivity(gradient with respect to output) by setting `sens_param` as True and
173
+ passing an extra sensitivity input to the gradient function, the sensitivity input should has the
174
+ same shape and type with input function's output(see `GradNetWrtXYWithSensParam` in Examples).
175
+
176
+ 1. Construct a `GradOperation` higher-order function with `get_all=True` and `sens_param=True`:
177
+ `grad_op = GradOperation(get_all=True, sens_param=True)`.
178
+
179
+ 2. Define `grad_wrt_output` as `sens_param` which works as the gradient with respect to output:
180
+ `grad_wrt_output = Tensor(np.ones([2, 2]).astype(np.float32))`.
181
+
182
+ 3. Call it with input function as argument to get the gradient function: `gradient_function = grad_op(net)`.
183
+
184
+ 4. Call the gradient function with input function's inputs and `sens_param` to
185
+ get the gradients with respect to all inputs: `gradient_function(x, y, grad_wrt_output)`.
186
+
187
+ Note:
188
+ For above gradient functions, the returned gradient result may vary for grad result element number:
189
+
190
+ - Return a single value if only one result.
191
+ - Return a tuple for multiple results.
192
+ - Return an empty tuple for no result.
193
+
194
+ Args:
195
+ get_all (bool): If ``True`` , get all the gradients with respect to inputs. Default: ``False`` .
196
+ get_by_list (bool): If ``True`` , get all the gradients with respect to Parameter free variables.
197
+ If get_all and get_by_list are both ``False`` , get the gradient with respect to first input.
198
+ If get_all and get_by_list are both ``True`` , get the gradients with respect to inputs and
199
+ Parameter free variables at the same time in the form of ("gradients with respect to inputs",
200
+ "gradients with respect to parameter free variables"). Default: ``False`` .
201
+ sens_param (bool): Whether to append sensitivity (gradient with respect to output) as input.
202
+ If sens_param is ``False`` , a 'ones_like(outputs)' sensitivity will be attached automatically.
203
+ Default: ``False`` .
204
+ If the sensor_param is ``True`` , a sensitivity (gradient with respect to output) needs to be transferred
205
+ through the positional parameter or key-value pair parameter. If the value is transferred through
206
+ the key-value pair parameter, the key must be sens.
207
+
208
+ Returns:
209
+ The higher-order function which takes a function as argument and returns gradient function for it.
210
+
211
+ Raises:
212
+ TypeError: If `get_all`, `get_by_list` or `sens_param` is not a bool.
213
+
214
+ Supported Platforms:
215
+ ``Ascend`` ``GPU`` ``CPU``
216
+
217
+ Examples:
218
+ >>> import mindspore
219
+ >>> import numpy as np
220
+ >>> from mindspore import dtype as mstype
221
+ >>> from mindspore import Tensor, ops, nn, Parameter
222
+ >>> class Net(nn.Cell):
223
+ ... def __init__(self):
224
+ ... super(Net, self).__init__()
225
+ ... self.matmul = ops.MatMul()
226
+ ... self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
227
+ ... def construct(self, x, y):
228
+ ... x = x * self.z
229
+ ... out = self.matmul(x, y)
230
+ ... return out
231
+ ...
232
+ >>> class GradNetWrtX(nn.Cell):
233
+ ... def __init__(self, net):
234
+ ... super(GradNetWrtX, self).__init__()
235
+ ... self.net = net
236
+ ... self.grad_op = ops.GradOperation()
237
+ ... def construct(self, x, y):
238
+ ... gradient_function = self.grad_op(self.net)
239
+ ... return gradient_function(x, y)
240
+ ...
241
+ >>> x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
242
+ >>> y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
243
+ >>> output = GradNetWrtX(Net())(x, y)
244
+ >>> print(output)
245
+ [[1.4100001 1.5999999 6.6 ]
246
+ [1.4100001 1.5999999 6.6 ]]
247
+ >>>
248
+ >>> class GradNetWrtXY(nn.Cell):
249
+ ... def __init__(self, net):
250
+ ... super(GradNetWrtXY, self).__init__()
251
+ ... self.net = net
252
+ ... self.grad_op = ops.GradOperation(get_all=True)
253
+ ... def construct(self, x, y):
254
+ ... gradient_function = self.grad_op(self.net)
255
+ ... return gradient_function(x, y)
256
+ >>>
257
+ >>> x = Tensor([[0.8, 0.6, 0.2], [1.8, 1.3, 1.1]], dtype=mstype.float32)
258
+ >>> y = Tensor([[0.1, 3.3, 1.1], [1.1, 0.2, 1.4], [1.1, 2.2, 0.3]], dtype=mstype.float32)
259
+ >>> output = GradNetWrtXY(Net())(x, y)
260
+ >>> print(output)
261
+ (Tensor(shape=[2, 3], dtype=Float32, value=
262
+ [[ 4.50000000e+00, 2.70000005e+00, 3.60000014e+00],
263
+ [ 4.50000000e+00, 2.70000005e+00, 3.60000014e+00]]), Tensor(shape=[3, 3], dtype=Float32, value=
264
+ [[ 2.59999990e+00, 2.59999990e+00, 2.59999990e+00],
265
+ [ 1.89999998e+00, 1.89999998e+00, 1.89999998e+00],
266
+ [ 1.30000007e+00, 1.30000007e+00, 1.30000007e+00]]))
267
+ >>>
268
+ >>> class GradNetWrtXYWithSensParam(nn.Cell):
269
+ ... def __init__(self, net):
270
+ ... super(GradNetWrtXYWithSensParam, self).__init__()
271
+ ... self.net = net
272
+ ... self.grad_op = ops.GradOperation(get_all=True, sens_param=True)
273
+ ... self.grad_wrt_output = Tensor([[0.1, 0.6, 0.2], [0.8, 1.3, 1.1]], dtype=mstype.float32)
274
+ ... def construct(self, x, y):
275
+ ... gradient_function = self.grad_op(self.net)
276
+ ... return gradient_function(x, y, self.grad_wrt_output)
277
+ >>>
278
+ >>> x = Tensor([[0.8, 0.6, 0.2], [1.8, 1.3, 1.1]], dtype=mstype.float32)
279
+ >>> y = Tensor([[0.11, 3.3, 1.1], [1.1, 0.2, 1.4], [1.1, 2.2, 0.3]], dtype=mstype.float32)
280
+ >>> output = GradNetWrtXYWithSensParam(Net())(x, y)
281
+ >>> print(output)
282
+ (Tensor(shape=[2, 3], dtype=Float32, value=
283
+ [[ 2.21099997e+00, 5.09999990e-01, 1.49000001e+00],
284
+ [ 5.58800030e+00, 2.68000007e+00, 4.07000017e+00]]), Tensor(shape=[3, 3], dtype=Float32, value=
285
+ [[ 1.51999998e+00, 2.81999993e+00, 2.14000010e+00],
286
+ [ 1.09999990e+00, 2.04999995e+00, 1.54999995e+00],
287
+ [ 9.00000036e-01, 1.54999995e+00, 1.25000000e+00]]))
288
+ >>>
289
+ >>> class GradNetWithWrtParams(nn.Cell):
290
+ ... def __init__(self, net):
291
+ ... super(GradNetWithWrtParams, self).__init__()
292
+ ... self.net = net
293
+ ... self.params = ParameterTuple(net.trainable_params())
294
+ ... self.grad_op = ops.GradOperation(get_by_list=True)
295
+ ... def construct(self, x, y):
296
+ ... gradient_function = self.grad_op(self.net, self.params)
297
+ ... return gradient_function(x, y)
298
+ >>>
299
+ >>> x = Tensor([[0.8, 0.6, 0.2], [1.8, 1.3, 1.1]], dtype=mstype.float32)
300
+ >>> y = Tensor([[0.11, 3.3, 1.1], [1.1, 0.2, 1.4], [1.1, 2.2, 0.3]], dtype=mstype.float32)
301
+ >>> output = GradNetWithWrtParams(Net())(x, y)
302
+ >>> print(output)
303
+ (Tensor(shape=[1], dtype=Float32, value= [ 2.15359993e+01]),)
304
+ >>>
305
+ >>> class GradNetWrtInputsAndParams(nn.Cell):
306
+ ... def __init__(self, net):
307
+ ... super(GradNetWrtInputsAndParams, self).__init__()
308
+ ... self.net = net
309
+ ... self.params = ParameterTuple(net.trainable_params())
310
+ ... self.grad_op = ops.GradOperation(get_all=True, get_by_list=True)
311
+ ... def construct(self, x, y):
312
+ ... gradient_function = self.grad_op(self.net, self.params)
313
+ ... return gradient_function(x, y)
314
+ >>>
315
+ >>> x = Tensor([[0.1, 0.6, 1.2], [0.5, 1.3, 0.1]], dtype=mstype.float32)
316
+ >>> y = Tensor([[0.12, 2.3, 1.1], [1.3, 0.2, 2.4], [0.1, 2.2, 0.3]], dtype=mstype.float32)
317
+ >>> output = GradNetWrtInputsAndParams(Net())(x, y)
318
+ >>> print(output)
319
+ ((Tensor(shape=[2, 3], dtype=Float32, value=
320
+ [[ 3.51999998e+00, 3.90000010e+00, 2.59999990e+00],
321
+ [ 3.51999998e+00, 3.90000010e+00, 2.59999990e+00]]), Tensor(shape=[3, 3], dtype=Float32, value=
322
+ [[ 6.00000024e-01, 6.00000024e-01, 6.00000024e-01],
323
+ [ 1.89999998e+00, 1.89999998e+00, 1.89999998e+00],
324
+ [ 1.30000007e+00, 1.30000007e+00, 1.30000007e+00]])), (Tensor(shape=[1], dtype=Float32, value=
325
+ [ 1.29020004e+01]),))
326
+ """
327
+
328
+ def __init__(self, get_all=False, get_by_list=False, sens_param=False):
329
+ """Initialize GradOperation."""
330
+ if not isinstance(get_all, bool):
331
+ raise TypeError(f"For 'GradOperation', the 'get_all' should be bool, but got {type(get_all).__name__}")
332
+ if not isinstance(get_by_list, bool):
333
+ raise TypeError(f"For 'GradOperation', the 'get_by_list' should be bool, "
334
+ f"but got {type(get_by_list).__name__}")
335
+ if not isinstance(sens_param, bool):
336
+ raise TypeError(f"For 'GradOperation', the 'sens_param' should be bool, "
337
+ f"but got {type(sens_param).__name__}")
338
+ self.get_all = get_all
339
+ self.get_by_list = get_by_list
340
+ self.sens_param = sens_param
341
+ GradOperation_.__init__(self, 'grad', get_all, get_by_list, sens_param, False, False, False, False, False)
342
+ self.grad_fn = None
343
+ self.fn = None
344
+ self.weights_id = None
345
+ self.pynative_ = False
346
+ self.grad_position = (0,)
347
+
348
+ def __call__(self, fn, weights=None):
349
+ weights_id = ''
350
+ if context.get_context("mode") == context.GRAPH_MODE:
351
+ weights_id = _get_grad_weights_id(weights)
352
+ if self.grad_fn is not None and self.fn == fn and self.weights_id == weights_id:
353
+ return self.grad_fn
354
+ grad_ = GradOperation(self.get_all, self.get_by_list, self.sens_param)
355
+ # If calling Grad in GRAPH_MODE or calling Grad in functions decorated with 'jit', do grad in GRAPH_MODE
356
+ # If calling Grad in pure PYNATIVE_MODE do grad in PYNATIVE_MODE
357
+ # In pure PYNATIVE_MODE the out layer after_grad just used to set pynative flag for inner GradOperation.
358
+ # In PYNATIVE_MODE calling Grad from functions decorated with 'jit', use the out layer after_grad do
359
+ # grad in GRAPH_MODE.
360
+ if context.get_context("mode") == context.GRAPH_MODE:
361
+ dynamic_shape_inputs = None
362
+ if isinstance(fn, ms.nn.Cell):
363
+ dynamic_shape_inputs = fn.get_inputs()
364
+ fn.grad_ops_label = True
365
+ if self.get_by_list:
366
+ @jit(input_signature=dynamic_shape_inputs)
367
+ def after_grad(*args, **kwargs):
368
+ return grad_(fn, weights)(*args, **kwargs)
369
+ else:
370
+ @jit(input_signature=dynamic_shape_inputs)
371
+ def after_grad(*args, **kwargs):
372
+ return grad_(fn)(*args, **kwargs)
373
+ elif self.pynative_:
374
+ if not _pynative_executor.enable_grad():
375
+ raise RuntimeError("In no_grad context, you can not calculate gradient")
376
+
377
+ @_wrap_func
378
+ def after_grad(*args, **kwargs):
379
+ run_args = self._pynative_forward_run(fn, grad_, weights, *args, **kwargs)
380
+ out = _pynative_executor.grad(fn, grad_, weights, self.grad_position, *run_args)
381
+ out = _grads_divided_by_device_num_if_recomputation(out)
382
+ return out
383
+ else:
384
+ MSContext.get_instance()._set_not_convert_jit(True)
385
+ grad_.pynative_ = True
386
+ if not _pynative_executor.enable_grad():
387
+ raise RuntimeError("In no_grad context, you can not calculate gradient")
388
+ # after_grad of this branch can't use @jit, just directly call grad_
389
+ if self.get_by_list:
390
+ def after_grad(*args, **kwargs):
391
+ return grad_(fn, weights)(*args, **kwargs)
392
+ else:
393
+ def after_grad(*args, **kwargs):
394
+ return grad_(fn)(*args, **kwargs)
395
+
396
+ self.grad_fn = after_grad
397
+ self.fn = fn
398
+ self.weights_id = weights_id
399
+ return self.grad_fn
400
+
401
+ def _pynative_forward_run(self, fn, grad, weights, *args, **kwargs):
402
+ """ PyNative forward run to build grad graph. """
403
+ sens = None
404
+ if self.sens_param:
405
+ if 'sens' in kwargs.keys():
406
+ sens = kwargs.pop('sens')
407
+ else:
408
+ # default use args last elem as sens
409
+ sens = args[-1]
410
+ args = args[:-1]
411
+ run_args = args
412
+ if kwargs:
413
+ run_args = args + tuple(kwargs.values())
414
+
415
+ # check run exclude sens
416
+ if isinstance(fn, (FunctionType, MethodType)):
417
+ if not _pynative_executor.check_run(grad, fn, weights, None, *run_args):
418
+ _pynative_executor.set_grad_flag(True)
419
+ _pynative_executor.new_graph(fn, *args, **kwargs)
420
+ output = fn(*args, **kwargs)
421
+ _pynative_executor.end_graph(fn, output, *args, **kwargs)
422
+ else:
423
+ # Check if fn has run already
424
+ if not _pynative_executor.check_run(grad, fn, weights, None, *run_args):
425
+ requires_grad = fn.requires_grad
426
+ fn.requires_grad = True
427
+ fn(*args, **kwargs)
428
+ fn.requires_grad = requires_grad
429
+
430
+ # If it has sens, keep sens as the last element
431
+ if sens is not None:
432
+ run_args += (sens,) if sens is not isinstance(run_args, tuple) else sens
433
+ return run_args
434
+
435
+
436
+ class _TaylorOperation(TaylorOperation_):
437
+ """
438
+ Generate the higher order derivatives function for the input function.
439
+ """
440
+
441
+ def __init__(self):
442
+ """Initialize TaylorOperation."""
443
+ TaylorOperation_.__init__(self, 'taylorgrad')
444
+ self.grad_fn = None
445
+ self.fn = None
446
+
447
+ def __call__(self, fn):
448
+ if self.grad_fn is not None and self.fn == fn:
449
+ return self.grad_fn
450
+ taylor_grad_ = _TaylorOperation()
451
+
452
+ # If calling Grad in GRAPH_MODE or calling Grad in functions decorated with 'jit', do grad in GRAPH_MODE
453
+
454
+ @jit
455
+ def after_taylor_grad(*args):
456
+ return taylor_grad_(fn)(*args)
457
+
458
+ self.grad_fn = after_taylor_grad
459
+ self.fn = fn
460
+ return self.grad_fn
461
+
462
+
463
+ def _combine_weight(grad_position, weights, out, out_with_ids):
464
+ """ Making resulting tuple for weight, when return_ids is set to True. """
465
+ weight_tuple = []
466
+ position = 0
467
+ if isinstance(weights, (list, ParameterTuple, tuple)) and grad_position:
468
+ for weight in weights:
469
+ weight_tuple.append((weight.name, out[1][position]))
470
+ position += 1
471
+ elif isinstance(weights, (list, ParameterTuple, tuple)):
472
+ for weight in weights:
473
+ weight_tuple.append((weight.name, out[position]))
474
+ position += 1
475
+ elif grad_position:
476
+ weight_tuple.append(weights.name)
477
+ weight_tuple.append(out[1])
478
+ else:
479
+ weight_tuple.append(weights.name)
480
+ weight_tuple.append(out)
481
+ if grad_position:
482
+ out_with_ids.append(tuple(weight_tuple))
483
+ else:
484
+ out_with_ids = weight_tuple
485
+ return out_with_ids
486
+
487
+
488
+ def _combine_position(grad_position, weights, out, out_with_ids):
489
+ """ Making resulting tuple for position, when return_ids is set to True. """
490
+ position_tuple = []
491
+ position = 0
492
+ if grad_position == (0,) and weights is not None:
493
+ position_tuple.append(0)
494
+ position_tuple.append(out[0])
495
+ elif grad_position == (0,):
496
+ position_tuple.append(0)
497
+ position_tuple.append(out)
498
+ elif weights is not None:
499
+ for index in grad_position:
500
+ position_tuple.append((index, out[0][position]))
501
+ position += 1
502
+ else:
503
+ for index in grad_position:
504
+ position_tuple.append((index, out[position]))
505
+ position += 1
506
+ if weights:
507
+ out_with_ids.append(tuple(position_tuple))
508
+ else:
509
+ out_with_ids = position_tuple
510
+ return out_with_ids
511
+
512
+
513
+ def _combine_with_ids(grad_position, weights, out):
514
+ """ Making resulting tuple, when return_ids is set to True. """
515
+ out_with_ids = []
516
+ if grad_position:
517
+ out_with_ids = _combine_position(
518
+ grad_position, weights, out, out_with_ids)
519
+ if weights is not None:
520
+ out_with_ids = _combine_weight(
521
+ grad_position, weights, out, out_with_ids)
522
+ if not out_with_ids:
523
+ raise ValueError(f"output tuple should not be a empty tuple.")
524
+ return tuple(out_with_ids)
525
+
526
+
527
+ class _Grad(GradOperation_):
528
+ """
529
+ A higher-order function which is used to generate the gradient function by position for the input function.
530
+ """
531
+
532
+ def __init__(self, get_all=False, get_by_list=False, sens_param=False, get_by_position=False, has_aux=False,
533
+ get_value=False, return_ids=False, merge_forward=False):
534
+ """Initialize _Grad."""
535
+ if not isinstance(get_by_position, bool):
536
+ raise TypeError(f"For '_Grad', the 'get_by_position' should be bool, "
537
+ f"but got {type(get_by_position).__name__}")
538
+ if not isinstance(get_by_list, bool):
539
+ raise TypeError(f"For '_Grad', the 'get_by_list' should be bool, "
540
+ f"but got {type(get_by_list).__name__}")
541
+ if not isinstance(sens_param, bool):
542
+ raise TypeError(f"For '_Grad', the 'sens_param' should be bool, "
543
+ f"but got {type(sens_param).__name__}")
544
+ if not isinstance(has_aux, bool):
545
+ raise TypeError(f"For '_Grad', the 'has_aux' should be bool, "
546
+ f"but got {type(has_aux).__name__}")
547
+ if not isinstance(get_value, bool):
548
+ raise TypeError(f"For '_Grad', the 'get_value' should be bool, "
549
+ f"but got {type(get_value).__name__}")
550
+ if not isinstance(return_ids, bool):
551
+ raise TypeError(f"For '_Grad', the 'return_ids' should be bool, "
552
+ f"but got {type(return_ids).__name__}")
553
+ self.get_all = get_all
554
+ self.get_by_position = get_by_position
555
+ self.get_by_list = get_by_list
556
+ self.sens_param = sens_param
557
+ self.has_aux = has_aux
558
+ self.get_value = get_value
559
+ self.return_ids = return_ids
560
+ self.merge_forward = merge_forward
561
+ GradOperation_.__init__(self, 'grad', get_all, get_by_list, sens_param, get_by_position, has_aux, get_value,
562
+ return_ids, merge_forward)
563
+ self.grad_fn = None
564
+ self.fn = None
565
+ self.pynative_ = False
566
+ self.grad_position = None
567
+ self.weights_id = None
568
+
569
+ def __call__(self, fn, weights=None, grad_position=0):
570
+ weights_id = ''
571
+ if context.get_context("mode") == context.GRAPH_MODE:
572
+ weights_id = _get_grad_weights_id(weights)
573
+ if self.grad_fn is not None and self.fn == fn and self.grad_position == grad_position and \
574
+ self.weights_id == weights_id:
575
+ return self.grad_fn
576
+
577
+ def aux_fn(*args, **kwargs):
578
+ outputs = fn(*args, **kwargs)
579
+ if not isinstance(outputs, tuple) or len(outputs) < 2:
580
+ raise ValueError("When has_aux is True, origin fn requires more than one outputs.")
581
+ res = (outputs[0],)
582
+ stop_gradient = Primitive("StopGradient")
583
+ for item in outputs[1:]:
584
+ res += (stop_gradient(item),)
585
+ return res
586
+
587
+ grad_ = _Grad(self.get_all, self.get_by_list, self.sens_param, self.get_by_position, self.has_aux,
588
+ self.get_value, self.return_ids, self.merge_forward)
589
+ # If calling Grad in GRAPH_MODE or calling Grad in functions decorated with 'jit', do grad in GRAPH_MODE
590
+ # If calling Grad in pure PYNATIVE_MODE do grad in PYNATIVE_MODE
591
+ # In pure PYNATIVE_MODE the out layer after_grad just used to set pynative flag for inner GradOperation.
592
+ # In PYNATIVE_MODE calling Grad from functions decorated with 'jit', use the out layer after_grad do
593
+ # grad in GRAPH_MODE.
594
+ if context.get_context("mode") == context.GRAPH_MODE:
595
+ dynamic_shape_inputs = None
596
+ if isinstance(fn, ms.nn.Cell):
597
+ dynamic_shape_inputs = fn.get_inputs()
598
+ if self.get_by_position:
599
+ @jit(input_signature=dynamic_shape_inputs)
600
+ def after_grad(*args):
601
+ return grad_(fn, weights, grad_position)(*args)
602
+ else:
603
+ if self.get_by_list:
604
+ @jit(input_signature=dynamic_shape_inputs)
605
+ def after_grad(*args):
606
+ return grad_(fn, weights)(*args)
607
+ else:
608
+ @jit(input_signature=dynamic_shape_inputs)
609
+ def after_grad(*args):
610
+ return grad_(fn)(*args)
611
+ elif self.pynative_:
612
+ if not _pynative_executor.enable_grad():
613
+ raise RuntimeError("In no_grad context, you can not calculate gradient")
614
+
615
+ @_wrap_func
616
+ def after_grad(*args, **kwargs):
617
+ run_args, res = self._pynative_forward_run(fn, grad_, weights, *args, **kwargs)
618
+ out = _pynative_executor.grad(fn, grad_, weights, grad_position, *run_args)
619
+ out = _grads_divided_by_device_num_if_recomputation(out)
620
+ if self.return_ids and out:
621
+ out = _combine_with_ids(grad_position, weights, out)
622
+ if self.get_value:
623
+ return res, out
624
+ if self.has_aux:
625
+ return out, res[1:]
626
+ return out
627
+ else:
628
+ MSContext.get_instance()._set_not_convert_jit(True)
629
+ if not _pynative_executor.enable_grad():
630
+ raise RuntimeError("In no_grad context, you can not calculate gradient")
631
+ grad_.pynative_ = True
632
+ fn_ = fn
633
+ if self.has_aux:
634
+ fn_ = aux_fn
635
+ # after_grad of this branch can't use @jit, just directly call grad_
636
+ if self.get_by_position:
637
+ def after_grad(*args, **kwargs):
638
+ return grad_(fn_, weights, grad_position)(*args, **kwargs)
639
+ else:
640
+ if self.get_by_list:
641
+ def after_grad(*args, **kwargs):
642
+ return grad_(fn_, weights)(*args, **kwargs)
643
+ else:
644
+ def after_grad(*args, **kwargs):
645
+ return grad_(fn_)(*args, **kwargs)
646
+
647
+ self.grad_fn = after_grad
648
+ self.fn = fn
649
+ self.grad_position = grad_position
650
+ self.weights_id = weights_id
651
+ return self.grad_fn
652
+
653
+ def _pynative_forward_run(self, fn, grad, weights, *args, **kwargs):
654
+ """ PyNative forward runs to build grad graph. """
655
+ sens = None
656
+ if self.sens_param:
657
+ if 'sens' in kwargs.keys():
658
+ sens = kwargs.pop('sens')
659
+ else:
660
+ # default use args last elem as sens
661
+ sens = args[-1]
662
+ args = args[:-1]
663
+ run_args = args
664
+ if kwargs:
665
+ run_args = args + tuple(kwargs.values())
666
+
667
+ # check run exclude sens
668
+ outputs = ()
669
+ run_forward = False
670
+ if isinstance(fn, (FunctionType, MethodType)):
671
+ if not _pynative_executor.check_run(grad, fn, weights, self.grad_position, *run_args):
672
+ _pynative_executor.set_grad_flag(True)
673
+ _pynative_executor.new_graph(fn, *args, **kwargs)
674
+ outputs = fn(*args, **kwargs)
675
+ _pynative_executor.end_graph(fn, outputs, *args, **kwargs)
676
+ run_forward = True
677
+ else:
678
+ # Check if fn has run already.
679
+ if not _pynative_executor.check_run(grad, fn, weights, self.grad_position, *run_args):
680
+ requires_grad = fn.requires_grad
681
+ fn.requires_grad = True
682
+ outputs = fn(*args, **kwargs)
683
+ fn.requires_grad = requires_grad
684
+ run_forward = True
685
+ # If it has sens, keep sens as the last element
686
+ if sens is not None:
687
+ run_args += (sens,) if sens is not isinstance(run_args, tuple) else sens
688
+
689
+ # Normal run grad
690
+ if run_forward:
691
+ return run_args, outputs
692
+
693
+ if (self.get_value or self.has_aux) and not outputs:
694
+ outputs = fn(*args, **kwargs)
695
+ return run_args, outputs
696
+
697
+
698
+ class _Vmap(VmapOperation_):
699
+ """
700
+ A higher-order function which is used to generate the vectorizing map function.
701
+ """
702
+
703
+ def __init__(self):
704
+ """Initialize _Vmap."""
705
+ VmapOperation_.__init__(self, 'vmap')
706
+ self.vmap_fn = None
707
+ self.fn = None
708
+ self.in_axes = None
709
+ self.out_axes = None
710
+
711
+ def __call__(self, fn, in_axes=0, out_axes=0):
712
+ if self.vmap_fn is not None and self.fn == fn and self.in_axes == in_axes and self.out_axes == out_axes:
713
+ return self.vmap_fn
714
+
715
+ vmap_ = self
716
+
717
+ @jit
718
+ def after_vmap(*args, **kwargs):
719
+ return vmap_(fn, in_axes, out_axes)(*args, **kwargs)
720
+
721
+ self.vmap_fn = after_vmap
722
+ self.fn = fn
723
+ self.in_axes = in_axes
724
+ self.out_axes = out_axes
725
+ return self.vmap_fn
726
+
727
+
728
+ class MultitypeFuncGraph(MultitypeFuncGraph_):
729
+ """
730
+ MultitypeFuncGraph is a class used to generate overloaded functions, considering different types as inputs.
731
+ Initialize an `MultitypeFuncGraph` object with name, and use `register` with input types as the decorator
732
+ for the function to be registered. And the object can be called with different types of inputs,
733
+ and work with `HyperMap` and `Map`.
734
+
735
+ Args:
736
+ name (str): Operator name.
737
+ read_value (bool, optional): If the registered function do not need to set value on Parameter,
738
+ and all inputs will pass by value, set `read_value` to ``True`` . Default: ``False`` .
739
+
740
+ Raises:
741
+ ValueError: If failed to find a matching function for the given arguments.
742
+
743
+ Supported Platforms:
744
+ ``Ascend`` ``GPU`` ``CPU``
745
+
746
+ Examples:
747
+ >>> # `add` is a metagraph object which will add two objects according to
748
+ >>> # input type using ".register" decorator.
749
+ >>> from mindspore import Tensor
750
+ >>> from mindspore import dtype as mstype
751
+ >>> from mindspore import ops
752
+ >>>
753
+ >>> tensor_add = ops.Add()
754
+ >>> add = ops.MultitypeFuncGraph('add')
755
+ >>> @add.register("Number", "Number")
756
+ ... def add_scala(x, y):
757
+ ... return x + y
758
+ >>> @add.register("Tensor", "Tensor")
759
+ ... def add_tensor(x, y):
760
+ ... return tensor_add(x, y)
761
+ >>> output = add(1, 2)
762
+ >>> print(output)
763
+ 3
764
+ >>> output = add(Tensor([0.1, 0.6, 1.2], dtype=mstype.float32), Tensor([0.1, 0.6, 1.2], dtype=mstype.float32))
765
+ >>> print(output)
766
+ [0.2 1.2 2.4]
767
+ """
768
+
769
+ def __init__(self, name, read_value=False):
770
+ """Initialize MultitypeFuncGraph."""
771
+ MultitypeFuncGraph_.__init__(self, name)
772
+ self.entries = list()
773
+ self.default_func = None
774
+ if read_value:
775
+ self.set_signatures((
776
+ sig.make_sig('args', sig.sig_rw.RW_READ, sig.sig_kind.KIND_VAR_POSITIONAL),))
777
+
778
+ def __call__(self, *args):
779
+ if callable(self.default_func):
780
+ return self.default_func(*args)
781
+ for arg in args:
782
+ if isinstance(arg, np.ndarray):
783
+ raise TypeError("For 'MultitypeFuncGraph', the input can not be numpy.ndarray")
784
+ if len(self.entries) == 1:
785
+ output = self.entries[0][1](*args)
786
+ return output
787
+ types = tuple(map(mstype.get_py_obj_dtype, args))
788
+ for sigs, fn in self.entries:
789
+ if len(sigs) != len(types):
790
+ continue
791
+ if any(not mstype._issubclass_(type_, sig) for sig, type_ in zip(sigs, types)): # pylint: disable=W0212
792
+ continue
793
+ output = fn(*args)
794
+ return output
795
+ raise ValueError(f"For 'MultitypeFuncGraph', cannot find fn match given args. Got (sigs, fn): {self.entries}, "
796
+ f"and (dtype, args): {types}.")
797
+
798
+ def register(self, *type_names):
799
+ """
800
+ Register a function for the given type string.
801
+
802
+ Args:
803
+ type_names (Union[str, :class:`mindspore.dtype`]): Inputs type names or types list.
804
+
805
+ Return:
806
+ decorator, a decorator to register the function to run, when called under the
807
+ types described in `type_names`.
808
+ """
809
+
810
+ def deco(fn):
811
+ def convert_type(type_input):
812
+ if isinstance(type_input, str):
813
+ return mstype.typing.str_to_type(type_input)
814
+ if not isinstance(type_input, mstype.Type):
815
+ raise TypeError(f"For 'MultitypeFuncGraph', register only support str or {mstype.Type}, but got "
816
+ f"'type_input': {type_input}.")
817
+ return type_input
818
+
819
+ types = tuple(map(convert_type, type_names))
820
+ self.register_fn(type_names, fn)
821
+ self.entries.append((types, fn))
822
+ return fn
823
+
824
+ return deco
825
+
826
+ def register_default(self):
827
+ def deco(fn):
828
+ self.default_func = fn
829
+ return fn
830
+
831
+ return deco
832
+
833
+ # pylint: disable=missing-docstring
834
+ def set_doc_url(self, doc_url):
835
+ self.set_doc_url_(doc_url)
836
+
837
+ def set_need_raise(self):
838
+ self.set_need_raise_()
839
+
840
+
841
+ class HyperMap(HyperMap_):
842
+ """
843
+ HyperMap will apply the set operation to input sequences.
844
+
845
+ Apply the operations to every element of the sequence or nested sequence. Different
846
+ from `mindspore.ops.Map`, the `HyperMap` supports to apply on nested structure. The
847
+ `HyperMap` also supports dynamic sequences as input, but it does not extend this
848
+ support to nested dynamic sequences.
849
+
850
+ Args:
851
+ ops (Union[MultitypeFuncGraph, None], optional): `ops` is the operation to apply. If `ops` is `None`,
852
+ the operations should be put in the first input of the instance. Default is None.
853
+ reverse (bool, optional): The optimizer needs to be inverted in some scenarios to improve parallel
854
+ performance, general users please ignore. `reverse` is the flag to decide if apply
855
+ the operation reversely. Only supported in graph mode. Default is False.
856
+
857
+ Inputs:
858
+ - **args** (Tuple[sequence]) -
859
+
860
+ - If `ops` is not `None`, all the inputs should be sequences with the same length.
861
+ And each row of the sequences will be the inputs of the operation.
862
+ - If `ops` is `None`, the first input is the operation, and the others are inputs.
863
+
864
+ Note:
865
+ Except for the operation input, the number of inputs should be equal to the number of inputs to `ops`.
866
+
867
+ Outputs:
868
+ Sequence or nested sequence, the sequence of output after applying the function.
869
+ e.g. `operation(args[0][i], args[1][i])`, `operation` is the function assigned by `ops`.
870
+
871
+ Raises:
872
+ TypeError: If `ops` is neither :class:`mindspore.ops.MultitypeFuncGraph` nor None.
873
+ TypeError: If `args` is not a Tuple.
874
+
875
+ Supported Platforms:
876
+ ``Ascend`` ``GPU`` ``CPU``
877
+
878
+ Examples:
879
+ >>> from mindspore import Tensor, ops
880
+ >>> from mindspore import dtype as mstype
881
+ >>> nest_tensor_list = ((Tensor(1, mstype.float32), Tensor(2, mstype.float32)),
882
+ ... (Tensor(3, mstype.float32), Tensor(4, mstype.float32)))
883
+ >>> # square all the tensor in the nested list
884
+ >>>
885
+ >>> square = ops.MultitypeFuncGraph('square')
886
+ >>> @square.register("Tensor")
887
+ ... def square_tensor(x):
888
+ ... return ops.square(x)
889
+ >>>
890
+ >>> common_map = ops.HyperMap()
891
+ >>> output = common_map(square, nest_tensor_list)
892
+ >>> print(output)
893
+ ((Tensor(shape=[], dtype=Float32, value= 1), Tensor(shape=[], dtype=Float32, value= 4)),
894
+ (Tensor(shape=[], dtype=Float32, value= 9), Tensor(shape=[], dtype=Float32, value= 16)))
895
+ >>> square_map = ops.HyperMap(square, False)
896
+ >>> output = square_map(nest_tensor_list)
897
+ >>> print(output)
898
+ ((Tensor(shape=[], dtype=Float32, value= 1), Tensor(shape=[], dtype=Float32, value= 4)),
899
+ (Tensor(shape=[], dtype=Float32, value= 9), Tensor(shape=[], dtype=Float32, value= 16)))
900
+ """
901
+
902
+ def __init__(self, ops=None, reverse=False):
903
+ """Initialize HyperMap."""
904
+ self.ops = ops
905
+ if ops:
906
+ HyperMap_.__init__(self, reverse, ops)
907
+ else:
908
+ HyperMap_.__init__(self, reverse)
909
+
910
+ def __call__(self, *args):
911
+ func = self.ops
912
+ args_list = args
913
+ hypermap = self
914
+ if self.ops is None:
915
+ func = args[0]
916
+ args_list = args[1:]
917
+ hypermap = partial(self, func)
918
+ # is leaf
919
+ if not isinstance(args_list[0], (tuple, list)):
920
+ return func(*args_list)
921
+ return tuple(map(hypermap, *args_list))
922
+
923
+
924
+ class Map(Map_):
925
+ """
926
+ Map will apply the set operation on input sequences.
927
+
928
+ Apply the operations to every element of the sequence.
929
+
930
+ Args:
931
+ ops (Union[MultitypeFuncGraph, None]): `ops` is the operation to apply. If `ops` is `None`,
932
+ the operations should be put in the first input of the instance. Default: ``None`` .
933
+ reverse (bool): The optimizer needs to be inverted in some scenarios to improve parallel performance,
934
+ general users please ignore. `Reverse` is the flag to decide if apply the operation reversely.
935
+ Only supported in graph mode. Default is ``False`` .
936
+
937
+ Inputs:
938
+ - **args** (Tuple[sequence]) - If `ops` is not `None`, all the inputs should be the same length sequences,
939
+ and each row of the sequences. e.g. If the length of args is 2, and for `i` in length of each sequence
940
+ `(args[0][i], args[1][i])` will be the input of the operation.
941
+
942
+ If `ops` is `None`, the first input is the operation, and the other is the sequence.
943
+
944
+ Outputs:
945
+ Sequence, the sequence of output after applying the ops function. e.g. `ops(args[0][i], args[1][i])`.
946
+
947
+ Supported Platforms:
948
+ ``Ascend`` ``GPU`` ``CPU``
949
+
950
+ Examples:
951
+ >>> from mindspore import dtype as mstype
952
+ >>> from mindspore import Tensor, ops
953
+ >>> from mindspore.ops import MultitypeFuncGraph, Map
954
+ >>> tensor_list = (Tensor(1, mstype.float32), Tensor(2, mstype.float32), Tensor(3, mstype.float32))
955
+ >>> # square all the tensor in the list
956
+ >>>
957
+ >>> square = MultitypeFuncGraph('square')
958
+ >>> @square.register("Tensor")
959
+ ... def square_tensor(x):
960
+ ... return ops.square(x)
961
+ >>>
962
+ >>> common_map = Map()
963
+ >>> output = common_map(square, tensor_list)
964
+ >>> print(output)
965
+ (Tensor(shape=[], dtype=Float32, value= 1), Tensor(shape=[], dtype=Float32, value= 4),
966
+ Tensor(shape=[], dtype=Float32, value= 9))
967
+ >>> square_map = Map(square, False)
968
+ >>> output = square_map(tensor_list)
969
+ >>> print(output)
970
+ (Tensor(shape=[], dtype=Float32, value= 1), Tensor(shape=[], dtype=Float32, value= 4),
971
+ Tensor(shape=[], dtype=Float32, value= 9))
972
+ """
973
+
974
+ def __init__(self, ops=None, reverse=False):
975
+ """Initialize Map."""
976
+ self.ops = ops
977
+ if ops:
978
+ Map_.__init__(self, reverse, ops)
979
+ else:
980
+ Map_.__init__(self, reverse)
981
+
982
+ def __call__(self, *args):
983
+ func = self.ops
984
+ args_list = args
985
+ if self.ops is None:
986
+ func = args[0]
987
+ args_list = args[1:]
988
+ return tuple(map(func, *args_list))
989
+
990
+
991
+ class _ListAppend(ListAppend_):
992
+ """
993
+ A metafuncgraph class that append one element to list.
994
+
995
+ Args:
996
+ name (str): The name of the metafuncgraph object.
997
+ """
998
+
999
+ # `__init__` method removed entirely
1000
+ def __call__(self, *args):
1001
+ pass
1002
+
1003
+
1004
+ _append = _ListAppend("append")
1005
+
1006
+
1007
+ class _ListInsert(ListInsert_):
1008
+ """
1009
+ A metafuncgraph class that insert one element to list.
1010
+
1011
+ Args:
1012
+ name (str): The name of the metafuncgraph object.
1013
+ """
1014
+
1015
+ def __init__(self, name):
1016
+ """Initialize _ListInsert."""
1017
+ ListInsert_.__init__(self, name)
1018
+
1019
+ def __call__(self, *args):
1020
+ pass
1021
+
1022
+
1023
+ _insert = _ListInsert("insert")
1024
+
1025
+
1026
+ class _ListPop(ListPop_):
1027
+ """
1028
+ A metafuncgraph class that pop one element from list.
1029
+
1030
+ Args:
1031
+ name (str): The name of the metafuncgraph object.
1032
+ """
1033
+
1034
+ def __init__(self, name):
1035
+ """Initialize _ListPop."""
1036
+ ListPop_.__init__(self, name)
1037
+
1038
+ def __call__(self, *args):
1039
+ pass
1040
+
1041
+
1042
+ _pop = _ListPop("pop")
1043
+
1044
+
1045
+ class _ListClear(ListClear_):
1046
+ """
1047
+ A metafuncgraph class that clear the list.
1048
+
1049
+ Args:
1050
+ name (str): The name of the metafuncgraph object.
1051
+ """
1052
+
1053
+ def __init__(self, name):
1054
+ """Initialize _ListClear."""
1055
+ ListClear_.__init__(self, name)
1056
+
1057
+ def __call__(self, *args):
1058
+ pass
1059
+
1060
+
1061
+ _list_clear = _ListClear("clear")
1062
+
1063
+
1064
+ class _ListReverse(ListReverse_):
1065
+ """
1066
+ A metafuncgraph class that reverse the list.
1067
+
1068
+ Args:
1069
+ name (str): The name of the metafuncgraph object.
1070
+ """
1071
+
1072
+ def __init__(self, name):
1073
+ """Initialize _ListReverse."""
1074
+ ListReverse_.__init__(self, name)
1075
+
1076
+ def __call__(self, *args):
1077
+ pass
1078
+
1079
+
1080
+ _reverse = _ListReverse("reverse")
1081
+
1082
+
1083
+ class _ListExtend(ListExtend_):
1084
+ """
1085
+ A metafuncgraph class that append another list to the end of the list.
1086
+
1087
+ Args:
1088
+ name (str): The name of the metafuncgraph object.
1089
+ """
1090
+
1091
+ def __init__(self, name):
1092
+ """Initialize _ListExtend."""
1093
+ ListExtend_.__init__(self, name)
1094
+
1095
+ def __call__(self, *args):
1096
+ pass
1097
+
1098
+
1099
+ _extend = _ListExtend("extend")
1100
+
1101
+
1102
+ class _DictSetItem(DictSetItem_):
1103
+ """
1104
+ A metafuncgraph class that setitem for the dict.
1105
+
1106
+ Args:
1107
+ name (str): The name of the metafuncgraph object.
1108
+ """
1109
+
1110
+ def __init__(self, name):
1111
+ """Initialize _DictClear."""
1112
+ DictSetItem_.__init__(self, name)
1113
+
1114
+ def __call__(self, *args):
1115
+ pass
1116
+
1117
+
1118
+ _dict_setitem = _DictSetItem("setitem")
1119
+
1120
+
1121
+ class _DictClear(DictClear_):
1122
+ """
1123
+ A metafuncgraph class that clear the dict.
1124
+
1125
+ Args:
1126
+ name (str): The name of the metafuncgraph object.
1127
+ """
1128
+
1129
+ def __init__(self, name):
1130
+ """Initialize _DictClear."""
1131
+ DictClear_.__init__(self, name)
1132
+
1133
+ def __call__(self, *args):
1134
+ pass
1135
+
1136
+
1137
+ _dict_clear = _DictClear("clear")
1138
+
1139
+
1140
+ class _DictHasKey(DictHasKey_):
1141
+ """
1142
+ A metafuncgraph class that Check if key is in dict.
1143
+
1144
+ Args:
1145
+ name (str): The name of the metafuncgraph object.
1146
+ """
1147
+
1148
+ def __init__(self, name):
1149
+ """Initialize _DictHasKey."""
1150
+ DictHasKey_.__init__(self, name)
1151
+
1152
+ def __call__(self, *args):
1153
+ pass
1154
+
1155
+
1156
+ _haskey = _DictHasKey("has_key")
1157
+
1158
+
1159
+ class _DictUpdate(DictUpdate_):
1160
+ """
1161
+ A metafuncgraph class that append another dict to the end of the dict.
1162
+
1163
+ Args:
1164
+ name (str): The name of the metafuncgraph object.
1165
+ """
1166
+
1167
+ def __init__(self, name):
1168
+ """Initialize _DictUpdate."""
1169
+ DictUpdate_.__init__(self, name)
1170
+
1171
+ def __call__(self, *args):
1172
+ pass
1173
+
1174
+
1175
+ _update = _DictUpdate("update")
1176
+
1177
+
1178
+ class _DictFromKeys(DictFromKeys_):
1179
+ """
1180
+ A metafuncgraph class that creates a new dict from the given sequence and value.
1181
+
1182
+ Args:
1183
+ name (str): The name of the metafuncgraph object.
1184
+ """
1185
+
1186
+ def __init__(self, name):
1187
+ """Initialize _DictFromKeys."""
1188
+ DictFromKeys_.__init__(self, name)
1189
+
1190
+ def __call__(self, *args):
1191
+ pass
1192
+
1193
+
1194
+ _fromkeys = _DictFromKeys("fromkeys")
1195
+
1196
+
1197
+ class _Tail(Tail_):
1198
+ """
1199
+ A metafuncgraph class that generates tail elements of the tuple.
1200
+
1201
+ Args:
1202
+ name (str): The name of the metafuncgraph object.
1203
+ """
1204
+
1205
+ def __init__(self, name):
1206
+ """Initialize _Tail."""
1207
+ Tail_.__init__(self, name)
1208
+
1209
+ def __call__(self, *args):
1210
+ pass
1211
+
1212
+
1213
+ tail = _Tail('tail')
1214
+
1215
+
1216
+ class _ZipOperation(ZipOperation_):
1217
+ """Generates a tuple of zip iterations for inputs."""
1218
+
1219
+ def __init__(self, name):
1220
+ """Initialize _ZipOperation."""
1221
+ ZipOperation_.__init__(self, name)
1222
+
1223
+ def __call__(self, *args):
1224
+ pass
1225
+
1226
+
1227
+ zip_operation = _ZipOperation('zip_operation')
1228
+ """`zip_operation` will generate a tuple of zip iterations of inputs."""
1229
+
1230
+
1231
+ class _StarredGetItem(StarredGetItem_):
1232
+ """Generates a list of starred get_item for inputs."""
1233
+
1234
+ def __init__(self, name):
1235
+ """Initialize _StarredGetItem."""
1236
+ StarredGetItem_.__init__(self, name)
1237
+
1238
+ def __call__(self, *args):
1239
+ pass
1240
+
1241
+
1242
+ starred_get_item = _StarredGetItem('starred_get_item')
1243
+ """`starred_get_item` will generate a list of starred get_item for inputs."""
1244
+
1245
+
1246
+ class _StarredUnpack(StarredUnpack_):
1247
+ """Generates a tuple of starred unpack for inputs."""
1248
+
1249
+ def __init__(self, name):
1250
+ """Initialize _StarredUnpack."""
1251
+ StarredUnpack_.__init__(self, name)
1252
+
1253
+ def __call__(self, *args):
1254
+ pass
1255
+
1256
+
1257
+ starred_unpack = _StarredUnpack('starred_unpack')
1258
+ """`starred_unpack` will generate a tuple of starred unpack for inputs."""
1259
+
1260
+
1261
+ class _StarredUnpackMerge(StarredUnpackMerge_):
1262
+ """Generates a tuple of starred unpack merge for inputs."""
1263
+
1264
+ def __init__(self, name):
1265
+ """Initialize _StarredUnpackMerge."""
1266
+ StarredUnpackMerge_.__init__(self, name)
1267
+
1268
+ def __call__(self, *args):
1269
+ pass
1270
+
1271
+
1272
+ starred_unpack_merge = _StarredUnpackMerge('starred_unpack_merge')
1273
+ """`starred_unpack_merge` will generate a tuple of starred unpack merge for inputs."""
1274
+
1275
+
1276
+ class _IterConverter(IterConverter_):
1277
+ """Convert input to interable object"""
1278
+
1279
+ def __init__(self, name):
1280
+ """Initialize _IterConverter."""
1281
+ IterConverter_.__init__(self, name)
1282
+
1283
+ def __call__(self, *args):
1284
+ pass
1285
+
1286
+
1287
+ iter_converter = _IterConverter('iter_converter')
1288
+ """`iter_converter` will convert input to ietrable object"""
1289
+
1290
+
1291
+ class _HasNext(HasNext_):
1292
+ """Check whether the input has next value"""
1293
+
1294
+ def __init__(self, name):
1295
+ """Initialize _HasNext."""
1296
+ HasNext_.__init__(self, name)
1297
+
1298
+ def __call__(self, *args):
1299
+ pass
1300
+
1301
+
1302
+ ms_hasnext = _HasNext('has_next')
1303
+ """`ms_hasnext` will check whether the input has next value"""
1304
+
1305
+
1306
+ class _Next(Next_):
1307
+ """Get next element and res elements for input"""
1308
+
1309
+ def __init__(self, name):
1310
+ """Initialize _Next."""
1311
+ Next_.__init__(self, name)
1312
+
1313
+ def __call__(self, *args):
1314
+ pass
1315
+
1316
+
1317
+ ms_next = _Next('next')
1318
+ """`ms_next` will get next element and res elements for input"""