mindspore 2.4.0__cp311-cp311-macosx_10_15_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (1387) hide show
  1. mindspore/.commit_id +1 -0
  2. mindspore/__init__.py +53 -0
  3. mindspore/_c_dataengine.cpython-311-darwin.so +0 -0
  4. mindspore/_c_expression.cpython-311-darwin.so +0 -0
  5. mindspore/_c_mindrecord.cpython-311-darwin.so +0 -0
  6. mindspore/_check_jit_forbidden_api.py +106 -0
  7. mindspore/_checkparam.py +1419 -0
  8. mindspore/_extends/__init__.py +23 -0
  9. mindspore/_extends/builtin_operations.py +224 -0
  10. mindspore/_extends/graph_kernel/__init__.py +17 -0
  11. mindspore/_extends/graph_kernel/model/__init__.py +19 -0
  12. mindspore/_extends/graph_kernel/model/graph_parallel.py +311 -0
  13. mindspore/_extends/graph_kernel/model/graph_split.py +1348 -0
  14. mindspore/_extends/graph_kernel/model/model.py +553 -0
  15. mindspore/_extends/graph_kernel/model/model_builder.py +216 -0
  16. mindspore/_extends/graph_kernel/parallel_estimate.py +60 -0
  17. mindspore/_extends/graph_kernel/splitter.py +140 -0
  18. mindspore/_extends/graph_kernel/utils.py +28 -0
  19. mindspore/_extends/parallel_compile/__init__.py +19 -0
  20. mindspore/_extends/parallel_compile/akg_compiler/__init__.py +19 -0
  21. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +269 -0
  22. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +529 -0
  23. mindspore/_extends/parallel_compile/akg_compiler/compiler.py +56 -0
  24. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +96 -0
  25. mindspore/_extends/parallel_compile/akg_compiler/get_file_path.py +36 -0
  26. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +556 -0
  27. mindspore/_extends/parallel_compile/akg_compiler/util.py +159 -0
  28. mindspore/_extends/parse/__init__.py +49 -0
  29. mindspore/_extends/parse/compile_config.py +299 -0
  30. mindspore/_extends/parse/namespace.py +136 -0
  31. mindspore/_extends/parse/parser.py +1448 -0
  32. mindspore/_extends/parse/resources.py +213 -0
  33. mindspore/_extends/parse/standard_method.py +4475 -0
  34. mindspore/_extends/parse/trope.py +97 -0
  35. mindspore/_extends/pijit/__init__.py +23 -0
  36. mindspore/_extends/pijit/pijit_func_white_list.py +669 -0
  37. mindspore/_extends/remote/__init__.py +19 -0
  38. mindspore/_extends/remote/kernel_build_server.py +199 -0
  39. mindspore/_extends/remote/kernel_build_server_akg.py +55 -0
  40. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  41. mindspore/_extends/remote/kernel_build_server_ascend.py +75 -0
  42. mindspore/_extends/utils.py +68 -0
  43. mindspore/_install_custom.py +43 -0
  44. mindspore/_profiler.py +30 -0
  45. mindspore/amp.py +433 -0
  46. mindspore/boost/__init__.py +42 -0
  47. mindspore/boost/adasum.py +319 -0
  48. mindspore/boost/base.py +535 -0
  49. mindspore/boost/boost.py +400 -0
  50. mindspore/boost/boost_cell_wrapper.py +790 -0
  51. mindspore/boost/dim_reduce.py +323 -0
  52. mindspore/boost/grad_accumulation.py +79 -0
  53. mindspore/boost/grad_freeze.py +382 -0
  54. mindspore/boost/group_loss_scale_manager.py +166 -0
  55. mindspore/boost/less_batch_normalization.py +174 -0
  56. mindspore/common/__init__.py +86 -0
  57. mindspore/common/_auto_dynamic.py +68 -0
  58. mindspore/common/_decorator.py +50 -0
  59. mindspore/common/_jit_fallback_utils.py +110 -0
  60. mindspore/common/_monad.py +25 -0
  61. mindspore/common/_pijit_context.py +190 -0
  62. mindspore/common/_register_for_adapter.py +74 -0
  63. mindspore/common/_register_for_recompute.py +48 -0
  64. mindspore/common/_register_for_tensor.py +46 -0
  65. mindspore/common/_stub_tensor.py +210 -0
  66. mindspore/common/_tensor_overload.py +139 -0
  67. mindspore/common/_utils.py +122 -0
  68. mindspore/common/api.py +2064 -0
  69. mindspore/common/auto_dynamic_shape.py +507 -0
  70. mindspore/common/dtype.py +422 -0
  71. mindspore/common/dump.py +130 -0
  72. mindspore/common/file_system.py +48 -0
  73. mindspore/common/generator.py +254 -0
  74. mindspore/common/hook_handle.py +143 -0
  75. mindspore/common/initializer.py +880 -0
  76. mindspore/common/jit_config.py +98 -0
  77. mindspore/common/lazy_inline.py +240 -0
  78. mindspore/common/mindir_util.py +111 -0
  79. mindspore/common/mutable.py +234 -0
  80. mindspore/common/no_inline.py +54 -0
  81. mindspore/common/np_dtype.py +25 -0
  82. mindspore/common/parameter.py +1081 -0
  83. mindspore/common/recompute.py +292 -0
  84. mindspore/common/seed.py +260 -0
  85. mindspore/common/sparse_tensor.py +1175 -0
  86. mindspore/common/symbol.py +122 -0
  87. mindspore/common/tensor.py +5039 -0
  88. mindspore/communication/__init__.py +37 -0
  89. mindspore/communication/_comm_helper.py +501 -0
  90. mindspore/communication/_hccl_management.py +297 -0
  91. mindspore/communication/comm_func.py +1395 -0
  92. mindspore/communication/management.py +673 -0
  93. mindspore/config/op_info.config +533 -0
  94. mindspore/context.py +2077 -0
  95. mindspore/dataset/__init__.py +90 -0
  96. mindspore/dataset/audio/__init__.py +61 -0
  97. mindspore/dataset/audio/transforms.py +3690 -0
  98. mindspore/dataset/audio/utils.py +386 -0
  99. mindspore/dataset/audio/validators.py +1172 -0
  100. mindspore/dataset/callback/__init__.py +20 -0
  101. mindspore/dataset/callback/ds_callback.py +368 -0
  102. mindspore/dataset/callback/validators.py +32 -0
  103. mindspore/dataset/core/__init__.py +13 -0
  104. mindspore/dataset/core/config.py +1095 -0
  105. mindspore/dataset/core/datatypes.py +101 -0
  106. mindspore/dataset/core/py_util_helpers.py +65 -0
  107. mindspore/dataset/core/validator_helpers.py +781 -0
  108. mindspore/dataset/debug/__init__.py +21 -0
  109. mindspore/dataset/debug/debug_hook.py +97 -0
  110. mindspore/dataset/debug/pre_defined_hook.py +67 -0
  111. mindspore/dataset/engine/__init__.py +124 -0
  112. mindspore/dataset/engine/cache_admin.py +47 -0
  113. mindspore/dataset/engine/cache_client.py +129 -0
  114. mindspore/dataset/engine/datasets.py +4582 -0
  115. mindspore/dataset/engine/datasets_audio.py +911 -0
  116. mindspore/dataset/engine/datasets_standard_format.py +543 -0
  117. mindspore/dataset/engine/datasets_text.py +2161 -0
  118. mindspore/dataset/engine/datasets_user_defined.py +1184 -0
  119. mindspore/dataset/engine/datasets_vision.py +4816 -0
  120. mindspore/dataset/engine/iterators.py +371 -0
  121. mindspore/dataset/engine/obs/__init__.py +23 -0
  122. mindspore/dataset/engine/obs/config_loader.py +68 -0
  123. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +508 -0
  124. mindspore/dataset/engine/obs/util.py +482 -0
  125. mindspore/dataset/engine/offload.py +596 -0
  126. mindspore/dataset/engine/queue.py +304 -0
  127. mindspore/dataset/engine/samplers.py +895 -0
  128. mindspore/dataset/engine/serializer_deserializer.py +159 -0
  129. mindspore/dataset/engine/validators.py +2895 -0
  130. mindspore/dataset/text/__init__.py +51 -0
  131. mindspore/dataset/text/transforms.py +1703 -0
  132. mindspore/dataset/text/utils.py +715 -0
  133. mindspore/dataset/text/validators.py +642 -0
  134. mindspore/dataset/transforms/__init__.py +45 -0
  135. mindspore/dataset/transforms/c_transforms.py +638 -0
  136. mindspore/dataset/transforms/py_transforms.py +393 -0
  137. mindspore/dataset/transforms/py_transforms_util.py +255 -0
  138. mindspore/dataset/transforms/transforms.py +1260 -0
  139. mindspore/dataset/transforms/validators.py +410 -0
  140. mindspore/dataset/utils/__init__.py +19 -0
  141. mindspore/dataset/utils/browse_dataset.py +190 -0
  142. mindspore/dataset/utils/line_reader.py +126 -0
  143. mindspore/dataset/vision/__init__.py +65 -0
  144. mindspore/dataset/vision/c_transforms.py +2641 -0
  145. mindspore/dataset/vision/py_transforms.py +2120 -0
  146. mindspore/dataset/vision/py_transforms_util.py +1660 -0
  147. mindspore/dataset/vision/transforms.py +7295 -0
  148. mindspore/dataset/vision/utils.py +863 -0
  149. mindspore/dataset/vision/validators.py +1483 -0
  150. mindspore/default_config.py +2 -0
  151. mindspore/experimental/__init__.py +20 -0
  152. mindspore/experimental/es/__init__.py +22 -0
  153. mindspore/experimental/es/embedding_service.py +883 -0
  154. mindspore/experimental/es/embedding_service_layer.py +581 -0
  155. mindspore/experimental/llm_boost/__init__.py +21 -0
  156. mindspore/experimental/llm_boost/atb/__init__.py +23 -0
  157. mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
  158. mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
  159. mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
  160. mindspore/experimental/llm_boost/register.py +129 -0
  161. mindspore/experimental/llm_boost/utils.py +31 -0
  162. mindspore/experimental/map_parameter.py +309 -0
  163. mindspore/experimental/optim/__init__.py +40 -0
  164. mindspore/experimental/optim/adadelta.py +161 -0
  165. mindspore/experimental/optim/adagrad.py +168 -0
  166. mindspore/experimental/optim/adam.py +193 -0
  167. mindspore/experimental/optim/adamax.py +170 -0
  168. mindspore/experimental/optim/adamw.py +290 -0
  169. mindspore/experimental/optim/asgd.py +153 -0
  170. mindspore/experimental/optim/lr_scheduler.py +1371 -0
  171. mindspore/experimental/optim/nadam.py +157 -0
  172. mindspore/experimental/optim/optimizer.py +262 -0
  173. mindspore/experimental/optim/radam.py +194 -0
  174. mindspore/experimental/optim/rmsprop.py +154 -0
  175. mindspore/experimental/optim/rprop.py +164 -0
  176. mindspore/experimental/optim/sgd.py +156 -0
  177. mindspore/hal/__init__.py +40 -0
  178. mindspore/hal/_ascend.py +57 -0
  179. mindspore/hal/_base.py +57 -0
  180. mindspore/hal/_cpu.py +56 -0
  181. mindspore/hal/_gpu.py +57 -0
  182. mindspore/hal/contiguous_tensors_handle.py +175 -0
  183. mindspore/hal/device.py +356 -0
  184. mindspore/hal/event.py +179 -0
  185. mindspore/hal/memory.py +326 -0
  186. mindspore/hal/stream.py +357 -0
  187. mindspore/include/OWNERS +7 -0
  188. mindspore/include/api/allocator.h +97 -0
  189. mindspore/include/api/callback/callback.h +93 -0
  190. mindspore/include/api/callback/ckpt_saver.h +41 -0
  191. mindspore/include/api/callback/loss_monitor.h +33 -0
  192. mindspore/include/api/callback/lr_scheduler.h +51 -0
  193. mindspore/include/api/callback/time_monitor.h +34 -0
  194. mindspore/include/api/callback/train_accuracy.h +37 -0
  195. mindspore/include/api/cell.h +90 -0
  196. mindspore/include/api/cfg.h +82 -0
  197. mindspore/include/api/context.h +602 -0
  198. mindspore/include/api/data_type.h +47 -0
  199. mindspore/include/api/delegate.h +178 -0
  200. mindspore/include/api/delegate_api.h +75 -0
  201. mindspore/include/api/dual_abi_helper.h +208 -0
  202. mindspore/include/api/format.h +28 -0
  203. mindspore/include/api/graph.h +46 -0
  204. mindspore/include/api/kernel.h +58 -0
  205. mindspore/include/api/kernel_api.h +168 -0
  206. mindspore/include/api/metrics/accuracy.h +36 -0
  207. mindspore/include/api/metrics/metrics.h +41 -0
  208. mindspore/include/api/model.h +438 -0
  209. mindspore/include/api/model_group.h +91 -0
  210. mindspore/include/api/model_parallel_runner.h +168 -0
  211. mindspore/include/api/serialization.h +185 -0
  212. mindspore/include/api/status.h +192 -0
  213. mindspore/include/api/types.h +431 -0
  214. mindspore/include/api/visible.h +41 -0
  215. mindspore/include/c_api/context_c.h +179 -0
  216. mindspore/include/c_api/data_type_c.h +52 -0
  217. mindspore/include/c_api/format_c.h +46 -0
  218. mindspore/include/c_api/model_c.h +347 -0
  219. mindspore/include/c_api/status_c.h +79 -0
  220. mindspore/include/c_api/tensor_c.h +146 -0
  221. mindspore/include/c_api/types_c.h +67 -0
  222. mindspore/include/dataset/config.h +163 -0
  223. mindspore/include/dataset/constants.h +363 -0
  224. mindspore/include/dataset/execute.h +196 -0
  225. mindspore/include/dataset/text.h +1092 -0
  226. mindspore/include/dataset/transforms.h +638 -0
  227. mindspore/include/dataset/vision.h +2129 -0
  228. mindspore/include/dataset/vision_ascend.h +206 -0
  229. mindspore/include/dataset/vision_lite.h +625 -0
  230. mindspore/lib/libavcodec.59.dylib +0 -0
  231. mindspore/lib/libavdevice.59.dylib +0 -0
  232. mindspore/lib/libavfilter.8.dylib +0 -0
  233. mindspore/lib/libavformat.59.dylib +0 -0
  234. mindspore/lib/libavutil.57.dylib +0 -0
  235. mindspore/lib/libdnnl.2.dylib +0 -0
  236. mindspore/lib/libicudata.69.dylib +0 -0
  237. mindspore/lib/libicui18n.69.dylib +0 -0
  238. mindspore/lib/libicuuc.69.dylib +0 -0
  239. mindspore/lib/libmindspore_address_sorting.15.dylib +0 -0
  240. mindspore/lib/libmindspore_backend.dylib +0 -0
  241. mindspore/lib/libmindspore_common.dylib +0 -0
  242. mindspore/lib/libmindspore_core.dylib +0 -0
  243. mindspore/lib/libmindspore_glog.0.dylib +0 -0
  244. mindspore/lib/libmindspore_gpr.15.dylib +0 -0
  245. mindspore/lib/libmindspore_grpc++.1.dylib +0 -0
  246. mindspore/lib/libmindspore_grpc.15.dylib +0 -0
  247. mindspore/lib/libmindspore_np_dtype.dylib +0 -0
  248. mindspore/lib/libmindspore_ops.dylib +0 -0
  249. mindspore/lib/libmindspore_upb.15.dylib +0 -0
  250. mindspore/lib/libnnacl.dylib +0 -0
  251. mindspore/lib/libopencv_core.4.5.dylib +0 -0
  252. mindspore/lib/libopencv_imgcodecs.4.5.dylib +0 -0
  253. mindspore/lib/libopencv_imgproc.4.5.dylib +0 -0
  254. mindspore/lib/libps_cache.dylib +0 -0
  255. mindspore/lib/libswresample.4.dylib +0 -0
  256. mindspore/lib/libswscale.6.dylib +0 -0
  257. mindspore/lib/libtinyxml2.8.dylib +0 -0
  258. mindspore/log.py +633 -0
  259. mindspore/mindrecord/__init__.py +43 -0
  260. mindspore/mindrecord/common/__init__.py +17 -0
  261. mindspore/mindrecord/common/constant.py +20 -0
  262. mindspore/mindrecord/common/enums.py +44 -0
  263. mindspore/mindrecord/common/exceptions.py +311 -0
  264. mindspore/mindrecord/config.py +809 -0
  265. mindspore/mindrecord/filereader.py +174 -0
  266. mindspore/mindrecord/filewriter.py +722 -0
  267. mindspore/mindrecord/mindpage.py +210 -0
  268. mindspore/mindrecord/shardheader.py +141 -0
  269. mindspore/mindrecord/shardindexgenerator.py +74 -0
  270. mindspore/mindrecord/shardreader.py +117 -0
  271. mindspore/mindrecord/shardsegment.py +128 -0
  272. mindspore/mindrecord/shardutils.py +185 -0
  273. mindspore/mindrecord/shardwriter.py +237 -0
  274. mindspore/mindrecord/tools/__init__.py +17 -0
  275. mindspore/mindrecord/tools/cifar10.py +140 -0
  276. mindspore/mindrecord/tools/cifar100.py +153 -0
  277. mindspore/mindrecord/tools/cifar100_to_mr.py +185 -0
  278. mindspore/mindrecord/tools/cifar10_to_mr.py +177 -0
  279. mindspore/mindrecord/tools/csv_to_mr.py +200 -0
  280. mindspore/mindrecord/tools/imagenet_to_mr.py +206 -0
  281. mindspore/mindrecord/tools/mnist_to_mr.py +259 -0
  282. mindspore/mindrecord/tools/tfrecord_to_mr.py +360 -0
  283. mindspore/mint/__init__.py +1586 -0
  284. mindspore/mint/distributed/__init__.py +31 -0
  285. mindspore/mint/distributed/distributed.py +254 -0
  286. mindspore/mint/linalg/__init__.py +22 -0
  287. mindspore/mint/nn/__init__.py +757 -0
  288. mindspore/mint/nn/functional.py +679 -0
  289. mindspore/mint/nn/layer/__init__.py +39 -0
  290. mindspore/mint/nn/layer/activation.py +133 -0
  291. mindspore/mint/nn/layer/normalization.py +477 -0
  292. mindspore/mint/nn/layer/pooling.py +110 -0
  293. mindspore/mint/optim/__init__.py +24 -0
  294. mindspore/mint/optim/adamw.py +206 -0
  295. mindspore/mint/special/__init__.py +63 -0
  296. mindspore/multiprocessing/__init__.py +73 -0
  297. mindspore/nn/__init__.py +47 -0
  298. mindspore/nn/cell.py +2787 -0
  299. mindspore/nn/dynamic_lr.py +482 -0
  300. mindspore/nn/grad/__init__.py +21 -0
  301. mindspore/nn/grad/cell_grad.py +196 -0
  302. mindspore/nn/layer/__init__.py +63 -0
  303. mindspore/nn/layer/activation.py +1822 -0
  304. mindspore/nn/layer/basic.py +1629 -0
  305. mindspore/nn/layer/channel_shuffle.py +90 -0
  306. mindspore/nn/layer/combined.py +248 -0
  307. mindspore/nn/layer/container.py +734 -0
  308. mindspore/nn/layer/conv.py +1505 -0
  309. mindspore/nn/layer/dense.py +204 -0
  310. mindspore/nn/layer/embedding.py +869 -0
  311. mindspore/nn/layer/image.py +661 -0
  312. mindspore/nn/layer/math.py +1069 -0
  313. mindspore/nn/layer/normalization.py +1273 -0
  314. mindspore/nn/layer/padding.py +880 -0
  315. mindspore/nn/layer/pooling.py +2302 -0
  316. mindspore/nn/layer/rnn_cells.py +388 -0
  317. mindspore/nn/layer/rnns.py +849 -0
  318. mindspore/nn/layer/thor_layer.py +963 -0
  319. mindspore/nn/layer/timedistributed.py +155 -0
  320. mindspore/nn/layer/transformer.py +823 -0
  321. mindspore/nn/learning_rate_schedule.py +512 -0
  322. mindspore/nn/loss/__init__.py +36 -0
  323. mindspore/nn/loss/loss.py +2924 -0
  324. mindspore/nn/metrics.py +53 -0
  325. mindspore/nn/optim/__init__.py +45 -0
  326. mindspore/nn/optim/_dist_optimizer_registry.py +111 -0
  327. mindspore/nn/optim/ada_grad.py +217 -0
  328. mindspore/nn/optim/adadelta.py +206 -0
  329. mindspore/nn/optim/adafactor.py +448 -0
  330. mindspore/nn/optim/adam.py +1297 -0
  331. mindspore/nn/optim/adamax.py +220 -0
  332. mindspore/nn/optim/adasum.py +548 -0
  333. mindspore/nn/optim/asgd.py +216 -0
  334. mindspore/nn/optim/ftrl.py +401 -0
  335. mindspore/nn/optim/lamb.py +296 -0
  336. mindspore/nn/optim/lars.py +202 -0
  337. mindspore/nn/optim/lazyadam.py +533 -0
  338. mindspore/nn/optim/momentum.py +239 -0
  339. mindspore/nn/optim/optimizer.py +1034 -0
  340. mindspore/nn/optim/proximal_ada_grad.py +242 -0
  341. mindspore/nn/optim/rmsprop.py +264 -0
  342. mindspore/nn/optim/rprop.py +251 -0
  343. mindspore/nn/optim/sgd.py +237 -0
  344. mindspore/nn/optim/tft_wrapper.py +127 -0
  345. mindspore/nn/optim/thor.py +1310 -0
  346. mindspore/nn/probability/__init__.py +22 -0
  347. mindspore/nn/probability/bijector/__init__.py +35 -0
  348. mindspore/nn/probability/bijector/bijector.py +337 -0
  349. mindspore/nn/probability/bijector/exp.py +65 -0
  350. mindspore/nn/probability/bijector/gumbel_cdf.py +144 -0
  351. mindspore/nn/probability/bijector/invert.py +126 -0
  352. mindspore/nn/probability/bijector/power_transform.py +196 -0
  353. mindspore/nn/probability/bijector/scalar_affine.py +167 -0
  354. mindspore/nn/probability/bijector/softplus.py +189 -0
  355. mindspore/nn/probability/bnn_layers/__init__.py +29 -0
  356. mindspore/nn/probability/bnn_layers/_util.py +46 -0
  357. mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +112 -0
  358. mindspore/nn/probability/bnn_layers/conv_variational.py +267 -0
  359. mindspore/nn/probability/bnn_layers/dense_variational.py +302 -0
  360. mindspore/nn/probability/bnn_layers/layer_distribution.py +123 -0
  361. mindspore/nn/probability/distribution/__init__.py +56 -0
  362. mindspore/nn/probability/distribution/_utils/__init__.py +34 -0
  363. mindspore/nn/probability/distribution/_utils/custom_ops.py +96 -0
  364. mindspore/nn/probability/distribution/_utils/utils.py +362 -0
  365. mindspore/nn/probability/distribution/bernoulli.py +334 -0
  366. mindspore/nn/probability/distribution/beta.py +391 -0
  367. mindspore/nn/probability/distribution/categorical.py +435 -0
  368. mindspore/nn/probability/distribution/cauchy.py +383 -0
  369. mindspore/nn/probability/distribution/distribution.py +827 -0
  370. mindspore/nn/probability/distribution/exponential.py +350 -0
  371. mindspore/nn/probability/distribution/gamma.py +391 -0
  372. mindspore/nn/probability/distribution/geometric.py +335 -0
  373. mindspore/nn/probability/distribution/gumbel.py +257 -0
  374. mindspore/nn/probability/distribution/half_normal.py +133 -0
  375. mindspore/nn/probability/distribution/laplace.py +128 -0
  376. mindspore/nn/probability/distribution/log_normal.py +272 -0
  377. mindspore/nn/probability/distribution/logistic.py +379 -0
  378. mindspore/nn/probability/distribution/normal.py +336 -0
  379. mindspore/nn/probability/distribution/poisson.py +288 -0
  380. mindspore/nn/probability/distribution/student_t.py +149 -0
  381. mindspore/nn/probability/distribution/transformed_distribution.py +235 -0
  382. mindspore/nn/probability/distribution/uniform.py +375 -0
  383. mindspore/nn/reinforcement/__init__.py +24 -0
  384. mindspore/nn/reinforcement/_batch_read_write.py +142 -0
  385. mindspore/nn/reinforcement/_tensors_queue.py +152 -0
  386. mindspore/nn/reinforcement/tensor_array.py +145 -0
  387. mindspore/nn/sparse/__init__.py +23 -0
  388. mindspore/nn/sparse/sparse.py +147 -0
  389. mindspore/nn/wrap/__init__.py +49 -0
  390. mindspore/nn/wrap/cell_wrapper.py +968 -0
  391. mindspore/nn/wrap/grad_reducer.py +608 -0
  392. mindspore/nn/wrap/loss_scale.py +694 -0
  393. mindspore/numpy/__init__.py +121 -0
  394. mindspore/numpy/array_creations.py +2731 -0
  395. mindspore/numpy/array_ops.py +2629 -0
  396. mindspore/numpy/dtypes.py +185 -0
  397. mindspore/numpy/fft.py +966 -0
  398. mindspore/numpy/logic_ops.py +936 -0
  399. mindspore/numpy/math_ops.py +5911 -0
  400. mindspore/numpy/utils.py +214 -0
  401. mindspore/numpy/utils_const.py +565 -0
  402. mindspore/ops/__init__.py +56 -0
  403. mindspore/ops/_constants.py +30 -0
  404. mindspore/ops/_grad_experimental/__init__.py +31 -0
  405. mindspore/ops/_grad_experimental/grad_array_ops.py +830 -0
  406. mindspore/ops/_grad_experimental/grad_base.py +143 -0
  407. mindspore/ops/_grad_experimental/grad_comm_ops.py +714 -0
  408. mindspore/ops/_grad_experimental/grad_debug_ops.py +31 -0
  409. mindspore/ops/_grad_experimental/grad_implementations.py +203 -0
  410. mindspore/ops/_grad_experimental/grad_inner_ops.py +79 -0
  411. mindspore/ops/_grad_experimental/grad_math_ops.py +802 -0
  412. mindspore/ops/_grad_experimental/grad_nn_ops.py +231 -0
  413. mindspore/ops/_grad_experimental/grad_quant_ops.py +238 -0
  414. mindspore/ops/_grad_experimental/grad_sparse.py +342 -0
  415. mindspore/ops/_grad_experimental/grad_sparse_ops.py +399 -0
  416. mindspore/ops/_grad_experimental/taylor_rule.py +220 -0
  417. mindspore/ops/_op_impl/__init__.py +23 -0
  418. mindspore/ops/_op_impl/_custom_op/__init__.py +39 -0
  419. mindspore/ops/_op_impl/_custom_op/_basic.py +158 -0
  420. mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +279 -0
  421. mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +156 -0
  422. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +109 -0
  423. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +125 -0
  424. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +105 -0
  425. mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +124 -0
  426. mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +116 -0
  427. mindspore/ops/_op_impl/_custom_op/correction_mul.py +89 -0
  428. mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +196 -0
  429. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +366 -0
  430. mindspore/ops/_op_impl/_custom_op/dsd_impl.py +162 -0
  431. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +136 -0
  432. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +206 -0
  433. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +88 -0
  434. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +128 -0
  435. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +199 -0
  436. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +88 -0
  437. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +156 -0
  438. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +184 -0
  439. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +143 -0
  440. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +169 -0
  441. mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +548 -0
  442. mindspore/ops/_op_impl/_custom_op/img2col_impl.py +881 -0
  443. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +278 -0
  444. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +200 -0
  445. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +334 -0
  446. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +255 -0
  447. mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +222 -0
  448. mindspore/ops/_op_impl/_custom_op/matmul_dds_grad_impl.py +644 -0
  449. mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +488 -0
  450. mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +87 -0
  451. mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +129 -0
  452. mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +121 -0
  453. mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +352 -0
  454. mindspore/ops/_op_impl/aicpu/__init__.py +441 -0
  455. mindspore/ops/_op_impl/aicpu/abs.py +36 -0
  456. mindspore/ops/_op_impl/aicpu/acos.py +32 -0
  457. mindspore/ops/_op_impl/aicpu/acos_grad.py +33 -0
  458. mindspore/ops/_op_impl/aicpu/acosh.py +34 -0
  459. mindspore/ops/_op_impl/aicpu/acosh_grad.py +35 -0
  460. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d.py +34 -0
  461. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
  462. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d.py +39 -0
  463. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d_grad.py +39 -0
  464. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
  465. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d_grad.py +37 -0
  466. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d.py +42 -0
  467. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d_grad.py +152 -0
  468. mindspore/ops/_op_impl/aicpu/add.py +43 -0
  469. mindspore/ops/_op_impl/aicpu/add_n.py +41 -0
  470. mindspore/ops/_op_impl/aicpu/add_v2.py +40 -0
  471. mindspore/ops/_op_impl/aicpu/addcdiv.py +41 -0
  472. mindspore/ops/_op_impl/aicpu/addcmul.py +47 -0
  473. mindspore/ops/_op_impl/aicpu/adjust_contrastv2.py +32 -0
  474. mindspore/ops/_op_impl/aicpu/adjust_hue.py +31 -0
  475. mindspore/ops/_op_impl/aicpu/adjust_saturation.py +32 -0
  476. mindspore/ops/_op_impl/aicpu/affine_grid.py +33 -0
  477. mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
  478. mindspore/ops/_op_impl/aicpu/angle.py +31 -0
  479. mindspore/ops/_op_impl/aicpu/arg_max.py +75 -0
  480. mindspore/ops/_op_impl/aicpu/arg_min.py +75 -0
  481. mindspore/ops/_op_impl/aicpu/argmax_with_value.py +43 -0
  482. mindspore/ops/_op_impl/aicpu/argmin_with_value.py +43 -0
  483. mindspore/ops/_op_impl/aicpu/asin.py +32 -0
  484. mindspore/ops/_op_impl/aicpu/asin_grad.py +33 -0
  485. mindspore/ops/_op_impl/aicpu/asinh.py +34 -0
  486. mindspore/ops/_op_impl/aicpu/asinh_grad.py +35 -0
  487. mindspore/ops/_op_impl/aicpu/atanh.py +34 -0
  488. mindspore/ops/_op_impl/aicpu/avgpool_grad_v1.py +37 -0
  489. mindspore/ops/_op_impl/aicpu/avgpool_v1.py +36 -0
  490. mindspore/ops/_op_impl/aicpu/bartlett_window.py +36 -0
  491. mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -0
  492. mindspore/ops/_op_impl/aicpu/batch_norm_grad_grad.py +49 -0
  493. mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
  494. mindspore/ops/_op_impl/aicpu/bessel_i0.py +31 -0
  495. mindspore/ops/_op_impl/aicpu/betainc.py +31 -0
  496. mindspore/ops/_op_impl/aicpu/bias_add.py +44 -0
  497. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +42 -0
  498. mindspore/ops/_op_impl/aicpu/bincount.py +33 -0
  499. mindspore/ops/_op_impl/aicpu/blackman_window.py +36 -0
  500. mindspore/ops/_op_impl/aicpu/broadcast_to.py +58 -0
  501. mindspore/ops/_op_impl/aicpu/bucketize.py +34 -0
  502. mindspore/ops/_op_impl/aicpu/cache_swap_table.py +102 -0
  503. mindspore/ops/_op_impl/aicpu/cast.py +225 -0
  504. mindspore/ops/_op_impl/aicpu/cauchy.py +33 -0
  505. mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
  506. mindspore/ops/_op_impl/aicpu/check_numerics.py +33 -0
  507. mindspore/ops/_op_impl/aicpu/cholesky.py +32 -0
  508. mindspore/ops/_op_impl/aicpu/cholesky_inverse.py +31 -0
  509. mindspore/ops/_op_impl/aicpu/cholesky_solve.py +33 -0
  510. mindspore/ops/_op_impl/aicpu/choleskygrad.py +32 -0
  511. mindspore/ops/_op_impl/aicpu/coalesce.py +37 -0
  512. mindspore/ops/_op_impl/aicpu/col2im.py +38 -0
  513. mindspore/ops/_op_impl/aicpu/combined_non_max_suppression.py +42 -0
  514. mindspore/ops/_op_impl/aicpu/compare_and_bitpack.py +37 -0
  515. mindspore/ops/_op_impl/aicpu/complex.py +32 -0
  516. mindspore/ops/_op_impl/aicpu/complex_abs.py +31 -0
  517. mindspore/ops/_op_impl/aicpu/compute_accidental_hits.py +44 -0
  518. mindspore/ops/_op_impl/aicpu/concat.py +57 -0
  519. mindspore/ops/_op_impl/aicpu/concat_offset.py +42 -0
  520. mindspore/ops/_op_impl/aicpu/concat_offset_v1.py +31 -0
  521. mindspore/ops/_op_impl/aicpu/conj.py +42 -0
  522. mindspore/ops/_op_impl/aicpu/conjugate_transpose.py +58 -0
  523. mindspore/ops/_op_impl/aicpu/cos.py +34 -0
  524. mindspore/ops/_op_impl/aicpu/cosh.py +34 -0
  525. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  526. mindspore/ops/_op_impl/aicpu/crop_and_resize.py +69 -0
  527. mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_boxes.py +68 -0
  528. mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_image.py +38 -0
  529. mindspore/ops/_op_impl/aicpu/cross.py +42 -0
  530. mindspore/ops/_op_impl/aicpu/csr_sparse_matrix_to_dense.py +48 -0
  531. mindspore/ops/_op_impl/aicpu/csr_sparse_matrix_to_sparse_tensor.py +51 -0
  532. mindspore/ops/_op_impl/aicpu/ctc_greedy_decoder.py +35 -0
  533. mindspore/ops/_op_impl/aicpu/ctc_loss_v2.py +43 -0
  534. mindspore/ops/_op_impl/aicpu/ctc_loss_v2_grad.py +45 -0
  535. mindspore/ops/_op_impl/aicpu/ctcloss.py +38 -0
  536. mindspore/ops/_op_impl/aicpu/cummax.py +41 -0
  537. mindspore/ops/_op_impl/aicpu/cumprod.py +58 -0
  538. mindspore/ops/_op_impl/aicpu/cumsum.py +58 -0
  539. mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +36 -0
  540. mindspore/ops/_op_impl/aicpu/data_format_vec_permute.py +32 -0
  541. mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
  542. mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +43 -0
  543. mindspore/ops/_op_impl/aicpu/dense_to_csr_sparse_matrix.py +49 -0
  544. mindspore/ops/_op_impl/aicpu/dense_to_dense_set_operation.py +45 -0
  545. mindspore/ops/_op_impl/aicpu/dense_to_sparse_set_operation.py +48 -0
  546. mindspore/ops/_op_impl/aicpu/depth_to_space.py +44 -0
  547. mindspore/ops/_op_impl/aicpu/diag.py +36 -0
  548. mindspore/ops/_op_impl/aicpu/diag_part.py +36 -0
  549. mindspore/ops/_op_impl/aicpu/diagonal.py +35 -0
  550. mindspore/ops/_op_impl/aicpu/digamma.py +31 -0
  551. mindspore/ops/_op_impl/aicpu/div.py +41 -0
  552. mindspore/ops/_op_impl/aicpu/div_no_nan.py +35 -0
  553. mindspore/ops/_op_impl/aicpu/dropout2d.py +42 -0
  554. mindspore/ops/_op_impl/aicpu/dropout3d.py +42 -0
  555. mindspore/ops/_op_impl/aicpu/dropout_genmask.py +41 -0
  556. mindspore/ops/_op_impl/aicpu/dropout_genmask_v3.py +32 -0
  557. mindspore/ops/_op_impl/aicpu/dynamic_stitch.py +42 -0
  558. mindspore/ops/_op_impl/aicpu/edit_distance.py +56 -0
  559. mindspore/ops/_op_impl/aicpu/eig.py +35 -0
  560. mindspore/ops/_op_impl/aicpu/embedding_lookup.py +102 -0
  561. mindspore/ops/_op_impl/aicpu/end_of_sequence.py +30 -0
  562. mindspore/ops/_op_impl/aicpu/environ_create.py +28 -0
  563. mindspore/ops/_op_impl/aicpu/environ_destroy_all.py +28 -0
  564. mindspore/ops/_op_impl/aicpu/environ_get.py +41 -0
  565. mindspore/ops/_op_impl/aicpu/environ_set.py +40 -0
  566. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  567. mindspore/ops/_op_impl/aicpu/equal.py +41 -0
  568. mindspore/ops/_op_impl/aicpu/exp.py +37 -0
  569. mindspore/ops/_op_impl/aicpu/expand.py +45 -0
  570. mindspore/ops/_op_impl/aicpu/expand_dims.py +42 -0
  571. mindspore/ops/_op_impl/aicpu/expm1.py +34 -0
  572. mindspore/ops/_op_impl/aicpu/extract_glimpse.py +35 -0
  573. mindspore/ops/_op_impl/aicpu/eye.py +44 -0
  574. mindspore/ops/_op_impl/aicpu/fft_with_size.py +47 -0
  575. mindspore/ops/_op_impl/aicpu/fill_diagonal.py +39 -0
  576. mindspore/ops/_op_impl/aicpu/fill_v2.py +58 -0
  577. mindspore/ops/_op_impl/aicpu/flatten.py +43 -0
  578. mindspore/ops/_op_impl/aicpu/floor_div.py +38 -0
  579. mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
  580. mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
  581. mindspore/ops/_op_impl/aicpu/fractional_avg_pool.py +41 -0
  582. mindspore/ops/_op_impl/aicpu/fractional_avg_pool_grad.py +41 -0
  583. mindspore/ops/_op_impl/aicpu/fractional_max_pool.py +41 -0
  584. mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_grad_with_fixed_ksize.py +43 -0
  585. mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +65 -0
  586. mindspore/ops/_op_impl/aicpu/fractional_max_pool_grad.py +42 -0
  587. mindspore/ops/_op_impl/aicpu/fractional_max_pool_grad_with_fixed_ksize.py +42 -0
  588. mindspore/ops/_op_impl/aicpu/fractional_max_pool_with_fixed_ksize.py +49 -0
  589. mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
  590. mindspore/ops/_op_impl/aicpu/fused_sparse_adam.py +46 -0
  591. mindspore/ops/_op_impl/aicpu/fused_sparse_ftrl.py +41 -0
  592. mindspore/ops/_op_impl/aicpu/fused_sparse_lazy_adam.py +46 -0
  593. mindspore/ops/_op_impl/aicpu/fused_sparse_proximal_adagrad.py +39 -0
  594. mindspore/ops/_op_impl/aicpu/gamma.py +38 -0
  595. mindspore/ops/_op_impl/aicpu/gather.py +46 -0
  596. mindspore/ops/_op_impl/aicpu/gather_d.py +79 -0
  597. mindspore/ops/_op_impl/aicpu/gather_d_grad_v2.py +79 -0
  598. mindspore/ops/_op_impl/aicpu/gather_grad.py +54 -0
  599. mindspore/ops/_op_impl/aicpu/gather_nd.py +56 -0
  600. mindspore/ops/_op_impl/aicpu/gcd.py +32 -0
  601. mindspore/ops/_op_impl/aicpu/generate_eod_mask.py +38 -0
  602. mindspore/ops/_op_impl/aicpu/geqrf.py +32 -0
  603. mindspore/ops/_op_impl/aicpu/get_next.py +39 -0
  604. mindspore/ops/_op_impl/aicpu/glu.py +33 -0
  605. mindspore/ops/_op_impl/aicpu/glu_grad.py +34 -0
  606. mindspore/ops/_op_impl/aicpu/greater.py +41 -0
  607. mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
  608. mindspore/ops/_op_impl/aicpu/grid_sampler_2d.py +35 -0
  609. mindspore/ops/_op_impl/aicpu/grid_sampler_2d_grad.py +38 -0
  610. mindspore/ops/_op_impl/aicpu/grid_sampler_3d.py +34 -0
  611. mindspore/ops/_op_impl/aicpu/grid_sampler_3d_grad.py +38 -0
  612. mindspore/ops/_op_impl/aicpu/hamming_window.py +57 -0
  613. mindspore/ops/_op_impl/aicpu/hard_sigmoid.py +32 -0
  614. mindspore/ops/_op_impl/aicpu/hard_sigmoid_grad.py +33 -0
  615. mindspore/ops/_op_impl/aicpu/heaviside.py +40 -0
  616. mindspore/ops/_op_impl/aicpu/histogram.py +35 -0
  617. mindspore/ops/_op_impl/aicpu/hsv_to_rgb.py +32 -0
  618. mindspore/ops/_op_impl/aicpu/hypot.py +32 -0
  619. mindspore/ops/_op_impl/aicpu/identity.py +42 -0
  620. mindspore/ops/_op_impl/aicpu/identity_n.py +41 -0
  621. mindspore/ops/_op_impl/aicpu/igamma.py +30 -0
  622. mindspore/ops/_op_impl/aicpu/igammac.py +30 -0
  623. mindspore/ops/_op_impl/aicpu/igammagrada.py +30 -0
  624. mindspore/ops/_op_impl/aicpu/im2col.py +43 -0
  625. mindspore/ops/_op_impl/aicpu/imag.py +31 -0
  626. mindspore/ops/_op_impl/aicpu/index_fill.py +54 -0
  627. mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
  628. mindspore/ops/_op_impl/aicpu/init_data_set_queue.py +27 -0
  629. mindspore/ops/_op_impl/aicpu/inplace_index_add.py +39 -0
  630. mindspore/ops/_op_impl/aicpu/instance_norm_v2.py +41 -0
  631. mindspore/ops/_op_impl/aicpu/instance_norm_v2_grad.py +44 -0
  632. mindspore/ops/_op_impl/aicpu/is_finite.py +40 -0
  633. mindspore/ops/_op_impl/aicpu/is_inf.py +31 -0
  634. mindspore/ops/_op_impl/aicpu/is_nan.py +31 -0
  635. mindspore/ops/_op_impl/aicpu/kldivloss.py +34 -0
  636. mindspore/ops/_op_impl/aicpu/kldivlossgrad.py +35 -0
  637. mindspore/ops/_op_impl/aicpu/layer_norm_grad_grad.py +47 -0
  638. mindspore/ops/_op_impl/aicpu/lcm.py +32 -0
  639. mindspore/ops/_op_impl/aicpu/left_shift.py +38 -0
  640. mindspore/ops/_op_impl/aicpu/less.py +41 -0
  641. mindspore/ops/_op_impl/aicpu/less_equal.py +41 -0
  642. mindspore/ops/_op_impl/aicpu/lgamma.py +33 -0
  643. mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +57 -0
  644. mindspore/ops/_op_impl/aicpu/linspace.py +33 -0
  645. mindspore/ops/_op_impl/aicpu/list_diff.py +50 -0
  646. mindspore/ops/_op_impl/aicpu/log.py +37 -0
  647. mindspore/ops/_op_impl/aicpu/log1p.py +34 -0
  648. mindspore/ops/_op_impl/aicpu/log_matrix_determinant.py +31 -0
  649. mindspore/ops/_op_impl/aicpu/log_normal_reverse.py +33 -0
  650. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +37 -0
  651. mindspore/ops/_op_impl/aicpu/logical_xor.py +30 -0
  652. mindspore/ops/_op_impl/aicpu/logit.py +33 -0
  653. mindspore/ops/_op_impl/aicpu/logit_grad.py +34 -0
  654. mindspore/ops/_op_impl/aicpu/logspace.py +36 -0
  655. mindspore/ops/_op_impl/aicpu/lower_bound.py +47 -0
  656. mindspore/ops/_op_impl/aicpu/lstsq.py +34 -0
  657. mindspore/ops/_op_impl/aicpu/lu.py +39 -0
  658. mindspore/ops/_op_impl/aicpu/lu_solve.py +32 -0
  659. mindspore/ops/_op_impl/aicpu/lu_unpack.py +114 -0
  660. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +49 -0
  661. mindspore/ops/_op_impl/aicpu/masked_fill.py +42 -0
  662. mindspore/ops/_op_impl/aicpu/masked_scatter.py +40 -0
  663. mindspore/ops/_op_impl/aicpu/masked_select.py +31 -0
  664. mindspore/ops/_op_impl/aicpu/masked_select_grad.py +35 -0
  665. mindspore/ops/_op_impl/aicpu/matmul.py +39 -0
  666. mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
  667. mindspore/ops/_op_impl/aicpu/matrix_determinant.py +30 -0
  668. mindspore/ops/_op_impl/aicpu/matrix_diag_part_v3.py +54 -0
  669. mindspore/ops/_op_impl/aicpu/matrix_diag_v3.py +56 -0
  670. mindspore/ops/_op_impl/aicpu/matrix_exp.py +34 -0
  671. mindspore/ops/_op_impl/aicpu/matrix_inverse.py +31 -0
  672. mindspore/ops/_op_impl/aicpu/matrix_logarithm.py +31 -0
  673. mindspore/ops/_op_impl/aicpu/matrix_power.py +37 -0
  674. mindspore/ops/_op_impl/aicpu/matrix_set_diag_v3.py +54 -0
  675. mindspore/ops/_op_impl/aicpu/matrix_solve.py +35 -0
  676. mindspore/ops/_op_impl/aicpu/matrix_solve_ls.py +36 -0
  677. mindspore/ops/_op_impl/aicpu/matrix_triangular_solve.py +36 -0
  678. mindspore/ops/_op_impl/aicpu/max_pool3d_grad_with_argmax.py +60 -0
  679. mindspore/ops/_op_impl/aicpu/max_pool3d_with_argmax.py +59 -0
  680. mindspore/ops/_op_impl/aicpu/max_unpool2d.py +57 -0
  681. mindspore/ops/_op_impl/aicpu/max_unpool2d_grad.py +58 -0
  682. mindspore/ops/_op_impl/aicpu/max_unpool3d.py +57 -0
  683. mindspore/ops/_op_impl/aicpu/max_unpool3d_grad.py +58 -0
  684. mindspore/ops/_op_impl/aicpu/maximum_grad_grad.py +40 -0
  685. mindspore/ops/_op_impl/aicpu/maxpool_grad_v1.py +46 -0
  686. mindspore/ops/_op_impl/aicpu/maxpool_v1.py +42 -0
  687. mindspore/ops/_op_impl/aicpu/median.py +39 -0
  688. mindspore/ops/_op_impl/aicpu/median_grad.py +45 -0
  689. mindspore/ops/_op_impl/aicpu/meshgrid.py +41 -0
  690. mindspore/ops/_op_impl/aicpu/minimum_grad_grad.py +40 -0
  691. mindspore/ops/_op_impl/aicpu/mirror_pad.py +50 -0
  692. mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +48 -0
  693. mindspore/ops/_op_impl/aicpu/mul.py +43 -0
  694. mindspore/ops/_op_impl/aicpu/mul_no_nan.py +42 -0
  695. mindspore/ops/_op_impl/aicpu/multi_margin_loss.py +37 -0
  696. mindspore/ops/_op_impl/aicpu/multi_margin_loss_grad.py +41 -0
  697. mindspore/ops/_op_impl/aicpu/multilabel_margin_loss_grad.py +37 -0
  698. mindspore/ops/_op_impl/aicpu/multinomial.py +47 -0
  699. mindspore/ops/_op_impl/aicpu/multinomial_with_replacement.py +35 -0
  700. mindspore/ops/_op_impl/aicpu/mvlgamma.py +32 -0
  701. mindspore/ops/_op_impl/aicpu/mvlgamma_grad.py +33 -0
  702. mindspore/ops/_op_impl/aicpu/nan_to_num.py +34 -0
  703. mindspore/ops/_op_impl/aicpu/neg.py +36 -0
  704. mindspore/ops/_op_impl/aicpu/nextafter.py +32 -0
  705. mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
  706. mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
  707. mindspore/ops/_op_impl/aicpu/no_repeat_ngram.py +34 -0
  708. mindspore/ops/_op_impl/aicpu/non_deterministic_ints.py +33 -0
  709. mindspore/ops/_op_impl/aicpu/non_max_suppression.py +36 -0
  710. mindspore/ops/_op_impl/aicpu/non_max_suppression_with_overlaps.py +35 -0
  711. mindspore/ops/_op_impl/aicpu/non_zero.py +43 -0
  712. mindspore/ops/_op_impl/aicpu/not_equal.py +39 -0
  713. mindspore/ops/_op_impl/aicpu/nth_element.py +39 -0
  714. mindspore/ops/_op_impl/aicpu/nuclear_norm.py +33 -0
  715. mindspore/ops/_op_impl/aicpu/one_hot.py +116 -0
  716. mindspore/ops/_op_impl/aicpu/ones_like.py +39 -0
  717. mindspore/ops/_op_impl/aicpu/orgqr.py +34 -0
  718. mindspore/ops/_op_impl/aicpu/pad_and_shift.py +33 -0
  719. mindspore/ops/_op_impl/aicpu/pad_v3.py +61 -0
  720. mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +59 -0
  721. mindspore/ops/_op_impl/aicpu/padding.py +41 -0
  722. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +54 -0
  723. mindspore/ops/_op_impl/aicpu/pdist_grad.py +33 -0
  724. mindspore/ops/_op_impl/aicpu/poisson.py +37 -0
  725. mindspore/ops/_op_impl/aicpu/polar.py +32 -0
  726. mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
  727. mindspore/ops/_op_impl/aicpu/pow.py +39 -0
  728. mindspore/ops/_op_impl/aicpu/print_tensor.py +39 -0
  729. mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +113 -0
  730. mindspore/ops/_op_impl/aicpu/qr.py +36 -0
  731. mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
  732. mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
  733. mindspore/ops/_op_impl/aicpu/ragged_range.py +49 -0
  734. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
  735. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_tensor.py +74 -0
  736. mindspore/ops/_op_impl/aicpu/random_categorical.py +68 -0
  737. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +36 -0
  738. mindspore/ops/_op_impl/aicpu/random_gamma.py +38 -0
  739. mindspore/ops/_op_impl/aicpu/random_poisson.py +134 -0
  740. mindspore/ops/_op_impl/aicpu/random_shuffle.py +47 -0
  741. mindspore/ops/_op_impl/aicpu/randperm.py +38 -0
  742. mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
  743. mindspore/ops/_op_impl/aicpu/range.py +36 -0
  744. mindspore/ops/_op_impl/aicpu/range_v2.py +35 -0
  745. mindspore/ops/_op_impl/aicpu/real.py +31 -0
  746. mindspore/ops/_op_impl/aicpu/real_div.py +40 -0
  747. mindspore/ops/_op_impl/aicpu/reciprocal.py +34 -0
  748. mindspore/ops/_op_impl/aicpu/reciprocal_grad.py +35 -0
  749. mindspore/ops/_op_impl/aicpu/reduce_mean.py +57 -0
  750. mindspore/ops/_op_impl/aicpu/reduce_prod.py +57 -0
  751. mindspore/ops/_op_impl/aicpu/reduce_sum.py +57 -0
  752. mindspore/ops/_op_impl/aicpu/relu_grad_v3.py +41 -0
  753. mindspore/ops/_op_impl/aicpu/relu_v3.py +38 -0
  754. mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +96 -0
  755. mindspore/ops/_op_impl/aicpu/reshape.py +42 -0
  756. mindspore/ops/_op_impl/aicpu/resize_area.py +40 -0
  757. mindspore/ops/_op_impl/aicpu/resize_bicubic.py +20 -0
  758. mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +19 -0
  759. mindspore/ops/_op_impl/aicpu/resize_bilinear.py +32 -0
  760. mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +32 -0
  761. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +36 -0
  762. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +35 -0
  763. mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
  764. mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
  765. mindspore/ops/_op_impl/aicpu/reverse_sequence.py +55 -0
  766. mindspore/ops/_op_impl/aicpu/reversev2.py +54 -0
  767. mindspore/ops/_op_impl/aicpu/rgb_to_hsv.py +32 -0
  768. mindspore/ops/_op_impl/aicpu/right_shift.py +38 -0
  769. mindspore/ops/_op_impl/aicpu/rnnt_loss.py +35 -0
  770. mindspore/ops/_op_impl/aicpu/round.py +34 -0
  771. mindspore/ops/_op_impl/aicpu/rsqrt.py +33 -0
  772. mindspore/ops/_op_impl/aicpu/rsqrt_grad.py +36 -0
  773. mindspore/ops/_op_impl/aicpu/sample_distorted_bounding_box_v2.py +49 -0
  774. mindspore/ops/_op_impl/aicpu/scale_and_translate.py +52 -0
  775. mindspore/ops/_op_impl/aicpu/scale_and_translate_grad.py +36 -0
  776. mindspore/ops/_op_impl/aicpu/scatter.py +79 -0
  777. mindspore/ops/_op_impl/aicpu/scatter_add_with_axis.py +53 -0
  778. mindspore/ops/_op_impl/aicpu/scatter_elements.py +39 -0
  779. mindspore/ops/_op_impl/aicpu/scatter_nd.py +59 -0
  780. mindspore/ops/_op_impl/aicpu/scatter_nd_max.py +54 -0
  781. mindspore/ops/_op_impl/aicpu/scatter_nd_min.py +54 -0
  782. mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +59 -0
  783. mindspore/ops/_op_impl/aicpu/search_sorted.py +44 -0
  784. mindspore/ops/_op_impl/aicpu/segment_max.py +52 -0
  785. mindspore/ops/_op_impl/aicpu/segment_mean.py +56 -0
  786. mindspore/ops/_op_impl/aicpu/segment_min.py +52 -0
  787. mindspore/ops/_op_impl/aicpu/segment_prod.py +56 -0
  788. mindspore/ops/_op_impl/aicpu/segment_sum.py +56 -0
  789. mindspore/ops/_op_impl/aicpu/select.py +45 -0
  790. mindspore/ops/_op_impl/aicpu/self_adjoint_eig.py +34 -0
  791. mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
  792. mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
  793. mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
  794. mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
  795. mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
  796. mindspore/ops/_op_impl/aicpu/set_size.py +38 -0
  797. mindspore/ops/_op_impl/aicpu/sign.py +36 -0
  798. mindspore/ops/_op_impl/aicpu/sin.py +34 -0
  799. mindspore/ops/_op_impl/aicpu/sinc.py +43 -0
  800. mindspore/ops/_op_impl/aicpu/sinh.py +34 -0
  801. mindspore/ops/_op_impl/aicpu/slice.py +59 -0
  802. mindspore/ops/_op_impl/aicpu/slice_grad.py +76 -0
  803. mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
  804. mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
  805. mindspore/ops/_op_impl/aicpu/sort.py +39 -0
  806. mindspore/ops/_op_impl/aicpu/space_to_depth.py +44 -0
  807. mindspore/ops/_op_impl/aicpu/sparse_addmm.py +87 -0
  808. mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +80 -0
  809. mindspore/ops/_op_impl/aicpu/sparse_apply_centered_rms_prop.py +105 -0
  810. mindspore/ops/_op_impl/aicpu/sparse_apply_momentum.py +80 -0
  811. mindspore/ops/_op_impl/aicpu/sparse_apply_proximal_gradient_descent.py +79 -0
  812. mindspore/ops/_op_impl/aicpu/sparse_concat.py +59 -0
  813. mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
  814. mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_add.py +58 -0
  815. mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_div.py +58 -0
  816. mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_mul.py +58 -0
  817. mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows.py +63 -0
  818. mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows_grad.py +45 -0
  819. mindspore/ops/_op_impl/aicpu/sparse_matrix_mat_mul.py +56 -0
  820. mindspore/ops/_op_impl/aicpu/sparse_matrix_nnz.py +81 -0
  821. mindspore/ops/_op_impl/aicpu/sparse_matrix_transpose.py +116 -0
  822. mindspore/ops/_op_impl/aicpu/sparse_reorder.py +56 -0
  823. mindspore/ops/_op_impl/aicpu/sparse_reshape.py +34 -0
  824. mindspore/ops/_op_impl/aicpu/sparse_segment_mean_grad.py +36 -0
  825. mindspore/ops/_op_impl/aicpu/sparse_segment_mean_with_num_segments.py +44 -0
  826. mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n.py +43 -0
  827. mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n_grad.py +38 -0
  828. mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n_with_num_segments.py +44 -0
  829. mindspore/ops/_op_impl/aicpu/sparse_segment_sum.py +49 -0
  830. mindspore/ops/_op_impl/aicpu/sparse_segment_sum_with_num_segments.py +68 -0
  831. mindspore/ops/_op_impl/aicpu/sparse_slice.py +63 -0
  832. mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +61 -0
  833. mindspore/ops/_op_impl/aicpu/sparse_softmax.py +33 -0
  834. mindspore/ops/_op_impl/aicpu/sparse_softmax_cross_entropy_with_logits_v2.py +35 -0
  835. mindspore/ops/_op_impl/aicpu/sparse_sparse_maximum.py +53 -0
  836. mindspore/ops/_op_impl/aicpu/sparse_sparse_minimum.py +53 -0
  837. mindspore/ops/_op_impl/aicpu/sparse_tensor_dense_add.py +84 -0
  838. mindspore/ops/_op_impl/aicpu/sparse_tensor_dense_mat_mul.py +190 -0
  839. mindspore/ops/_op_impl/aicpu/sparse_tensor_to_csr_sparse_matrix.py +51 -0
  840. mindspore/ops/_op_impl/aicpu/sparse_to_dense_v2.py +73 -0
  841. mindspore/ops/_op_impl/aicpu/split.py +45 -0
  842. mindspore/ops/_op_impl/aicpu/sqrt.py +34 -0
  843. mindspore/ops/_op_impl/aicpu/sqrt_grad.py +35 -0
  844. mindspore/ops/_op_impl/aicpu/square.py +35 -0
  845. mindspore/ops/_op_impl/aicpu/squared_difference.py +37 -0
  846. mindspore/ops/_op_impl/aicpu/squeeze.py +42 -0
  847. mindspore/ops/_op_impl/aicpu/sspaddmm.py +97 -0
  848. mindspore/ops/_op_impl/aicpu/stack.py +45 -0
  849. mindspore/ops/_op_impl/aicpu/stack_push_pop.py +87 -0
  850. mindspore/ops/_op_impl/aicpu/standard_laplace.py +34 -0
  851. mindspore/ops/_op_impl/aicpu/standard_normal.py +34 -0
  852. mindspore/ops/_op_impl/aicpu/stateless_dropout_genmask.py +37 -0
  853. mindspore/ops/_op_impl/aicpu/stft.py +70 -0
  854. mindspore/ops/_op_impl/aicpu/strided_slice.py +43 -0
  855. mindspore/ops/_op_impl/aicpu/strided_slice_grad.py +50 -0
  856. mindspore/ops/_op_impl/aicpu/sub.py +41 -0
  857. mindspore/ops/_op_impl/aicpu/sub_and_filter.py +36 -0
  858. mindspore/ops/_op_impl/aicpu/tan.py +34 -0
  859. mindspore/ops/_op_impl/aicpu/tanh.py +34 -0
  860. mindspore/ops/_op_impl/aicpu/tanh_grad.py +35 -0
  861. mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
  862. mindspore/ops/_op_impl/aicpu/tile.py +56 -0
  863. mindspore/ops/_op_impl/aicpu/topk.py +34 -0
  864. mindspore/ops/_op_impl/aicpu/trace.py +40 -0
  865. mindspore/ops/_op_impl/aicpu/tracegrad.py +41 -0
  866. mindspore/ops/_op_impl/aicpu/trans_data.py +35 -0
  867. mindspore/ops/_op_impl/aicpu/transpose.py +58 -0
  868. mindspore/ops/_op_impl/aicpu/tridiagonal_matmul.py +42 -0
  869. mindspore/ops/_op_impl/aicpu/tridiagonal_solve.py +35 -0
  870. mindspore/ops/_op_impl/aicpu/tril.py +42 -0
  871. mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
  872. mindspore/ops/_op_impl/aicpu/triplet_margin_loss.py +62 -0
  873. mindspore/ops/_op_impl/aicpu/triu.py +43 -0
  874. mindspore/ops/_op_impl/aicpu/triu_indices.py +34 -0
  875. mindspore/ops/_op_impl/aicpu/truncated_normal.py +39 -0
  876. mindspore/ops/_op_impl/aicpu/uniform.py +36 -0
  877. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +41 -0
  878. mindspore/ops/_op_impl/aicpu/uniform_int.py +36 -0
  879. mindspore/ops/_op_impl/aicpu/uniform_real.py +33 -0
  880. mindspore/ops/_op_impl/aicpu/unique.py +31 -0
  881. mindspore/ops/_op_impl/aicpu/unique_consecutive.py +47 -0
  882. mindspore/ops/_op_impl/aicpu/unique_with_pad.py +32 -0
  883. mindspore/ops/_op_impl/aicpu/unravel_index.py +32 -0
  884. mindspore/ops/_op_impl/aicpu/unsorted_segment_prod.py +53 -0
  885. mindspore/ops/_op_impl/aicpu/unsorted_segment_sum.py +57 -0
  886. mindspore/ops/_op_impl/aicpu/unstack.py +45 -0
  887. mindspore/ops/_op_impl/aicpu/update_cache.py +44 -0
  888. mindspore/ops/_op_impl/aicpu/upper_bound.py +47 -0
  889. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +42 -0
  890. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +49 -0
  891. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +40 -0
  892. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +50 -0
  893. mindspore/ops/_op_impl/aicpu/xdivy.py +35 -0
  894. mindspore/ops/_op_impl/aicpu/xlogy.py +33 -0
  895. mindspore/ops/_op_impl/aicpu/zeros_like.py +42 -0
  896. mindspore/ops/_op_impl/aicpu/zeta.py +31 -0
  897. mindspore/ops/_op_impl/akg/__init__.py +19 -0
  898. mindspore/ops/_op_impl/akg/ascend/__init__.py +48 -0
  899. mindspore/ops/_op_impl/akg/ascend/abs.py +35 -0
  900. mindspore/ops/_op_impl/akg/ascend/add.py +42 -0
  901. mindspore/ops/_op_impl/akg/ascend/add_n.py +37 -0
  902. mindspore/ops/_op_impl/akg/ascend/batchmatmul.py +33 -0
  903. mindspore/ops/_op_impl/akg/ascend/cast.py +46 -0
  904. mindspore/ops/_op_impl/akg/ascend/equal.py +35 -0
  905. mindspore/ops/_op_impl/akg/ascend/exp.py +35 -0
  906. mindspore/ops/_op_impl/akg/ascend/expand_dims.py +33 -0
  907. mindspore/ops/_op_impl/akg/ascend/greater.py +34 -0
  908. mindspore/ops/_op_impl/akg/ascend/greater_equal.py +35 -0
  909. mindspore/ops/_op_impl/akg/ascend/less.py +31 -0
  910. mindspore/ops/_op_impl/akg/ascend/less_equal.py +35 -0
  911. mindspore/ops/_op_impl/akg/ascend/load_im2col.py +33 -0
  912. mindspore/ops/_op_impl/akg/ascend/log.py +34 -0
  913. mindspore/ops/_op_impl/akg/ascend/maximum.py +36 -0
  914. mindspore/ops/_op_impl/akg/ascend/minimum.py +39 -0
  915. mindspore/ops/_op_impl/akg/ascend/mul.py +41 -0
  916. mindspore/ops/_op_impl/akg/ascend/neg.py +37 -0
  917. mindspore/ops/_op_impl/akg/ascend/pow.py +35 -0
  918. mindspore/ops/_op_impl/akg/ascend/prod_force_se_a.py +33 -0
  919. mindspore/ops/_op_impl/akg/ascend/real_div.py +36 -0
  920. mindspore/ops/_op_impl/akg/ascend/reciprocal.py +32 -0
  921. mindspore/ops/_op_impl/akg/ascend/reduce_max.py +32 -0
  922. mindspore/ops/_op_impl/akg/ascend/reduce_min.py +32 -0
  923. mindspore/ops/_op_impl/akg/ascend/reduce_sum.py +37 -0
  924. mindspore/ops/_op_impl/akg/ascend/rsqrt.py +35 -0
  925. mindspore/ops/_op_impl/akg/ascend/select.py +37 -0
  926. mindspore/ops/_op_impl/akg/ascend/sqrt.py +35 -0
  927. mindspore/ops/_op_impl/akg/ascend/square.py +35 -0
  928. mindspore/ops/_op_impl/akg/ascend/sub.py +42 -0
  929. mindspore/ops/_op_impl/akg/cpu/__init__.py +23 -0
  930. mindspore/ops/_op_impl/akg/cpu/coo2csr.py +29 -0
  931. mindspore/ops/_op_impl/akg/cpu/csr2coo.py +29 -0
  932. mindspore/ops/_op_impl/akg/cpu/csr_gather.py +33 -0
  933. mindspore/ops/_op_impl/akg/cpu/csr_mm.py +34 -0
  934. mindspore/ops/_op_impl/akg/cpu/csr_mul.py +33 -0
  935. mindspore/ops/_op_impl/akg/cpu/csr_mv.py +33 -0
  936. mindspore/ops/_op_impl/akg/cpu/csr_reduce_sum.py +31 -0
  937. mindspore/ops/_op_impl/akg/gpu/__init__.py +24 -0
  938. mindspore/ops/_op_impl/akg/gpu/coo2csr.py +29 -0
  939. mindspore/ops/_op_impl/akg/gpu/csr2coo.py +29 -0
  940. mindspore/ops/_op_impl/akg/gpu/csr_div.py +36 -0
  941. mindspore/ops/_op_impl/akg/gpu/csr_gather.py +33 -0
  942. mindspore/ops/_op_impl/akg/gpu/csr_mm.py +37 -0
  943. mindspore/ops/_op_impl/akg/gpu/csr_mul.py +36 -0
  944. mindspore/ops/_op_impl/akg/gpu/csr_mv.py +36 -0
  945. mindspore/ops/_op_impl/akg/gpu/csr_reduce_sum.py +33 -0
  946. mindspore/ops/_op_impl/cpu/__init__.py +78 -0
  947. mindspore/ops/_op_impl/cpu/adam.py +49 -0
  948. mindspore/ops/_op_impl/cpu/adam_weight_decay.py +47 -0
  949. mindspore/ops/_op_impl/cpu/arg_max.py +30 -0
  950. mindspore/ops/_op_impl/cpu/arg_max_with_value.py +31 -0
  951. mindspore/ops/_op_impl/cpu/arg_min_with_value.py +31 -0
  952. mindspore/ops/_op_impl/cpu/buffer_append.py +28 -0
  953. mindspore/ops/_op_impl/cpu/buffer_get.py +28 -0
  954. mindspore/ops/_op_impl/cpu/buffer_sample.py +28 -0
  955. mindspore/ops/_op_impl/cpu/cast.py +171 -0
  956. mindspore/ops/_op_impl/cpu/concat_offset.py +38 -0
  957. mindspore/ops/_op_impl/cpu/conv2d.py +30 -0
  958. mindspore/ops/_op_impl/cpu/conv3d.py +30 -0
  959. mindspore/ops/_op_impl/cpu/div.py +32 -0
  960. mindspore/ops/_op_impl/cpu/dropout.py +31 -0
  961. mindspore/ops/_op_impl/cpu/dropout_grad.py +30 -0
  962. mindspore/ops/_op_impl/cpu/dynamic_shape.py +42 -0
  963. mindspore/ops/_op_impl/cpu/dynamic_stitch.py +41 -0
  964. mindspore/ops/_op_impl/cpu/equal_count.py +30 -0
  965. mindspore/ops/_op_impl/cpu/gather_d.py +49 -0
  966. mindspore/ops/_op_impl/cpu/gather_d_grad.py +38 -0
  967. mindspore/ops/_op_impl/cpu/gather_d_grad_v2.py +40 -0
  968. mindspore/ops/_op_impl/cpu/gather_v2.py +40 -0
  969. mindspore/ops/_op_impl/cpu/hsigmoid.py +33 -0
  970. mindspore/ops/_op_impl/cpu/hsigmoid_grad.py +34 -0
  971. mindspore/ops/_op_impl/cpu/hswish.py +32 -0
  972. mindspore/ops/_op_impl/cpu/hswish_grad.py +33 -0
  973. mindspore/ops/_op_impl/cpu/identity_n.py +40 -0
  974. mindspore/ops/_op_impl/cpu/is_finite.py +39 -0
  975. mindspore/ops/_op_impl/cpu/l2loss.py +30 -0
  976. mindspore/ops/_op_impl/cpu/layer_norm.py +36 -0
  977. mindspore/ops/_op_impl/cpu/layer_norm_grad.py +38 -0
  978. mindspore/ops/_op_impl/cpu/maximum.py +35 -0
  979. mindspore/ops/_op_impl/cpu/maximum_grad.py +47 -0
  980. mindspore/ops/_op_impl/cpu/minimum.py +40 -0
  981. mindspore/ops/_op_impl/cpu/minimum_grad.py +51 -0
  982. mindspore/ops/_op_impl/cpu/mirror_pad.py +36 -0
  983. mindspore/ops/_op_impl/cpu/mirror_pad_grad.py +36 -0
  984. mindspore/ops/_op_impl/cpu/mul.py +32 -0
  985. mindspore/ops/_op_impl/cpu/one_hot.py +31 -0
  986. mindspore/ops/_op_impl/cpu/pad.py +32 -0
  987. mindspore/ops/_op_impl/cpu/pow.py +32 -0
  988. mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +42 -0
  989. mindspore/ops/_op_impl/cpu/pyexecute.py +29 -0
  990. mindspore/ops/_op_impl/cpu/pyfunc.py +29 -0
  991. mindspore/ops/_op_impl/cpu/range.py +34 -0
  992. mindspore/ops/_op_impl/cpu/real_div.py +33 -0
  993. mindspore/ops/_op_impl/cpu/reduce_all.py +29 -0
  994. mindspore/ops/_op_impl/cpu/reduce_any.py +29 -0
  995. mindspore/ops/_op_impl/cpu/reduce_max.py +32 -0
  996. mindspore/ops/_op_impl/cpu/reduce_mean.py +40 -0
  997. mindspore/ops/_op_impl/cpu/reduce_min.py +32 -0
  998. mindspore/ops/_op_impl/cpu/reduce_prod.py +40 -0
  999. mindspore/ops/_op_impl/cpu/reduce_std.py +31 -0
  1000. mindspore/ops/_op_impl/cpu/reduce_sum.py +41 -0
  1001. mindspore/ops/_op_impl/cpu/space_to_batch_nd.py +38 -0
  1002. mindspore/ops/_op_impl/cpu/sparse_slice.py +62 -0
  1003. mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +60 -0
  1004. mindspore/ops/_op_impl/cpu/split.py +34 -0
  1005. mindspore/ops/_op_impl/cpu/sspaddmm.py +95 -0
  1006. mindspore/ops/_op_impl/cpu/stack.py +38 -0
  1007. mindspore/ops/_op_impl/cpu/sub.py +32 -0
  1008. mindspore/ops/_op_impl/cpu/tensor_copy_slices.py +41 -0
  1009. mindspore/ops/_op_impl/cpu/tile.py +37 -0
  1010. mindspore/ops/_op_impl/cpu/top_k.py +31 -0
  1011. mindspore/ops/_op_impl/cpu/transpose.py +39 -0
  1012. mindspore/ops/_primitive_cache.py +90 -0
  1013. mindspore/ops/_register_for_op.py +73 -0
  1014. mindspore/ops/_utils/__init__.py +20 -0
  1015. mindspore/ops/_utils/utils.py +147 -0
  1016. mindspore/ops/_vmap/__init__.py +25 -0
  1017. mindspore/ops/_vmap/vmap_array_ops.py +2149 -0
  1018. mindspore/ops/_vmap/vmap_base.py +533 -0
  1019. mindspore/ops/_vmap/vmap_convolution_ops.py +441 -0
  1020. mindspore/ops/_vmap/vmap_debug_ops.py +50 -0
  1021. mindspore/ops/_vmap/vmap_grad_math_ops.py +274 -0
  1022. mindspore/ops/_vmap/vmap_grad_nn_ops.py +806 -0
  1023. mindspore/ops/_vmap/vmap_image_ops.py +194 -0
  1024. mindspore/ops/_vmap/vmap_math_ops.py +993 -0
  1025. mindspore/ops/_vmap/vmap_nn_ops.py +2250 -0
  1026. mindspore/ops/_vmap/vmap_other_ops.py +105 -0
  1027. mindspore/ops/_vmap/vmap_random_ops.py +122 -0
  1028. mindspore/ops/_vmap/vmap_sparse_ops.py +89 -0
  1029. mindspore/ops/auto_generate/__init__.py +31 -0
  1030. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +309 -0
  1031. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +252 -0
  1032. mindspore/ops/auto_generate/gen_arg_handler.py +197 -0
  1033. mindspore/ops/auto_generate/gen_extend_func.py +1701 -0
  1034. mindspore/ops/auto_generate/gen_ops_def.py +8482 -0
  1035. mindspore/ops/auto_generate/gen_ops_prim.py +16704 -0
  1036. mindspore/ops/auto_generate/pyboost_inner_prim.py +549 -0
  1037. mindspore/ops/composite/__init__.py +71 -0
  1038. mindspore/ops/composite/base.py +1318 -0
  1039. mindspore/ops/composite/env_ops.py +41 -0
  1040. mindspore/ops/composite/math_ops.py +125 -0
  1041. mindspore/ops/composite/multitype_ops/__init__.py +77 -0
  1042. mindspore/ops/composite/multitype_ops/_compile_utils.py +1459 -0
  1043. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +897 -0
  1044. mindspore/ops/composite/multitype_ops/add_impl.py +606 -0
  1045. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +56 -0
  1046. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +56 -0
  1047. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +56 -0
  1048. mindspore/ops/composite/multitype_ops/div_impl.py +189 -0
  1049. mindspore/ops/composite/multitype_ops/equal_impl.py +335 -0
  1050. mindspore/ops/composite/multitype_ops/floordiv_impl.py +88 -0
  1051. mindspore/ops/composite/multitype_ops/getitem_impl.py +400 -0
  1052. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +109 -0
  1053. mindspore/ops/composite/multitype_ops/greater_impl.py +110 -0
  1054. mindspore/ops/composite/multitype_ops/in_impl.py +196 -0
  1055. mindspore/ops/composite/multitype_ops/left_shift_impl.py +37 -0
  1056. mindspore/ops/composite/multitype_ops/less_equal_impl.py +111 -0
  1057. mindspore/ops/composite/multitype_ops/less_impl.py +112 -0
  1058. mindspore/ops/composite/multitype_ops/logic_not_impl.py +113 -0
  1059. mindspore/ops/composite/multitype_ops/logical_and_impl.py +60 -0
  1060. mindspore/ops/composite/multitype_ops/logical_or_impl.py +61 -0
  1061. mindspore/ops/composite/multitype_ops/mod_impl.py +86 -0
  1062. mindspore/ops/composite/multitype_ops/mul_impl.py +294 -0
  1063. mindspore/ops/composite/multitype_ops/negative_impl.py +79 -0
  1064. mindspore/ops/composite/multitype_ops/not_equal_impl.py +290 -0
  1065. mindspore/ops/composite/multitype_ops/not_in_impl.py +196 -0
  1066. mindspore/ops/composite/multitype_ops/ones_like_impl.py +96 -0
  1067. mindspore/ops/composite/multitype_ops/pow_impl.py +87 -0
  1068. mindspore/ops/composite/multitype_ops/right_shift_impl.py +37 -0
  1069. mindspore/ops/composite/multitype_ops/setitem_impl.py +884 -0
  1070. mindspore/ops/composite/multitype_ops/sub_impl.py +116 -0
  1071. mindspore/ops/composite/multitype_ops/uadd_impl.py +29 -0
  1072. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +228 -0
  1073. mindspore/ops/deprecated.py +315 -0
  1074. mindspore/ops/function/__init__.py +782 -0
  1075. mindspore/ops/function/array_func.py +7226 -0
  1076. mindspore/ops/function/clip_func.py +384 -0
  1077. mindspore/ops/function/debug_func.py +181 -0
  1078. mindspore/ops/function/fft_func.py +44 -0
  1079. mindspore/ops/function/grad/__init__.py +34 -0
  1080. mindspore/ops/function/grad/grad_func.py +1425 -0
  1081. mindspore/ops/function/image_func.py +292 -0
  1082. mindspore/ops/function/linalg_func.py +416 -0
  1083. mindspore/ops/function/math_func.py +12228 -0
  1084. mindspore/ops/function/nn_func.py +8609 -0
  1085. mindspore/ops/function/other_func.py +115 -0
  1086. mindspore/ops/function/parameter_func.py +134 -0
  1087. mindspore/ops/function/random_func.py +1715 -0
  1088. mindspore/ops/function/reshard_func.py +104 -0
  1089. mindspore/ops/function/sparse_func.py +884 -0
  1090. mindspore/ops/function/sparse_unary_func.py +2422 -0
  1091. mindspore/ops/function/spectral_func.py +150 -0
  1092. mindspore/ops/function/vmap_func.py +117 -0
  1093. mindspore/ops/functional.py +464 -0
  1094. mindspore/ops/op_info_register.py +1572 -0
  1095. mindspore/ops/operations/__init__.py +722 -0
  1096. mindspore/ops/operations/_csr_ops.py +403 -0
  1097. mindspore/ops/operations/_custom_grad.py +181 -0
  1098. mindspore/ops/operations/_embedding_cache_ops.py +307 -0
  1099. mindspore/ops/operations/_grad_ops.py +2978 -0
  1100. mindspore/ops/operations/_infer_ops.py +19 -0
  1101. mindspore/ops/operations/_inner_ops.py +2544 -0
  1102. mindspore/ops/operations/_map_tensor_ops.py +112 -0
  1103. mindspore/ops/operations/_ms_kernel.py +601 -0
  1104. mindspore/ops/operations/_ocr_ops.py +379 -0
  1105. mindspore/ops/operations/_opaque_predicate_registry.py +41 -0
  1106. mindspore/ops/operations/_pyfunc_registry.py +58 -0
  1107. mindspore/ops/operations/_quant_ops.py +1844 -0
  1108. mindspore/ops/operations/_rl_inner_ops.py +1231 -0
  1109. mindspore/ops/operations/_scalar_ops.py +106 -0
  1110. mindspore/ops/operations/_sequence_ops.py +1155 -0
  1111. mindspore/ops/operations/_sparse_grad_ops.py +56 -0
  1112. mindspore/ops/operations/_tensor_array.py +359 -0
  1113. mindspore/ops/operations/_thor_ops.py +807 -0
  1114. mindspore/ops/operations/array_ops.py +6124 -0
  1115. mindspore/ops/operations/comm_ops.py +1985 -0
  1116. mindspore/ops/operations/control_ops.py +127 -0
  1117. mindspore/ops/operations/custom_ops.py +1129 -0
  1118. mindspore/ops/operations/debug_ops.py +678 -0
  1119. mindspore/ops/operations/image_ops.py +1041 -0
  1120. mindspore/ops/operations/inner_ops.py +697 -0
  1121. mindspore/ops/operations/linalg_ops.py +95 -0
  1122. mindspore/ops/operations/manually_defined/__init__.py +24 -0
  1123. mindspore/ops/operations/manually_defined/_inner.py +73 -0
  1124. mindspore/ops/operations/manually_defined/ops_def.py +2271 -0
  1125. mindspore/ops/operations/math_ops.py +5095 -0
  1126. mindspore/ops/operations/nn_ops.py +9575 -0
  1127. mindspore/ops/operations/other_ops.py +874 -0
  1128. mindspore/ops/operations/random_ops.py +1288 -0
  1129. mindspore/ops/operations/reshard_ops.py +53 -0
  1130. mindspore/ops/operations/rl_ops.py +288 -0
  1131. mindspore/ops/operations/sparse_ops.py +2753 -0
  1132. mindspore/ops/operations/spectral_ops.py +111 -0
  1133. mindspore/ops/primitive.py +1046 -0
  1134. mindspore/ops/signature.py +54 -0
  1135. mindspore/ops/vm_impl_registry.py +91 -0
  1136. mindspore/ops_generate/__init__.py +27 -0
  1137. mindspore/ops_generate/arg_dtype_cast.py +252 -0
  1138. mindspore/ops_generate/arg_handler.py +197 -0
  1139. mindspore/ops_generate/gen_aclnn_implement.py +263 -0
  1140. mindspore/ops_generate/gen_constants.py +36 -0
  1141. mindspore/ops_generate/gen_ops.py +1099 -0
  1142. mindspore/ops_generate/gen_ops_inner_prim.py +131 -0
  1143. mindspore/ops_generate/gen_pyboost_func.py +1052 -0
  1144. mindspore/ops_generate/gen_utils.py +209 -0
  1145. mindspore/ops_generate/op_proto.py +145 -0
  1146. mindspore/ops_generate/pyboost_utils.py +367 -0
  1147. mindspore/ops_generate/template.py +261 -0
  1148. mindspore/parallel/__init__.py +30 -0
  1149. mindspore/parallel/_auto_parallel_context.py +1486 -0
  1150. mindspore/parallel/_cell_wrapper.py +174 -0
  1151. mindspore/parallel/_cost_model_context.py +700 -0
  1152. mindspore/parallel/_dp_allreduce_fusion.py +159 -0
  1153. mindspore/parallel/_offload_context.py +275 -0
  1154. mindspore/parallel/_parallel_serialization.py +561 -0
  1155. mindspore/parallel/_ps_context.py +242 -0
  1156. mindspore/parallel/_recovery_context.py +110 -0
  1157. mindspore/parallel/_tensor.py +730 -0
  1158. mindspore/parallel/_transformer/__init__.py +35 -0
  1159. mindspore/parallel/_transformer/layers.py +765 -0
  1160. mindspore/parallel/_transformer/loss.py +251 -0
  1161. mindspore/parallel/_transformer/moe.py +693 -0
  1162. mindspore/parallel/_transformer/op_parallel_config.py +222 -0
  1163. mindspore/parallel/_transformer/transformer.py +3119 -0
  1164. mindspore/parallel/_utils.py +612 -0
  1165. mindspore/parallel/algo_parameter_config.py +400 -0
  1166. mindspore/parallel/checkpoint_transform.py +650 -0
  1167. mindspore/parallel/cluster/__init__.py +15 -0
  1168. mindspore/parallel/cluster/process_entity/__init__.py +18 -0
  1169. mindspore/parallel/cluster/process_entity/_api.py +352 -0
  1170. mindspore/parallel/cluster/process_entity/_utils.py +101 -0
  1171. mindspore/parallel/cluster/run.py +136 -0
  1172. mindspore/parallel/mpi/__init__.py +14 -0
  1173. mindspore/parallel/mpi/_mpi_config.py +116 -0
  1174. mindspore/parallel/parameter_broadcast.py +151 -0
  1175. mindspore/parallel/shard.py +481 -0
  1176. mindspore/parallel/transform_safetensors.py +993 -0
  1177. mindspore/profiler/__init__.py +28 -0
  1178. mindspore/profiler/common/__init__.py +14 -0
  1179. mindspore/profiler/common/constant.py +29 -0
  1180. mindspore/profiler/common/exceptions/__init__.py +14 -0
  1181. mindspore/profiler/common/exceptions/error_code.py +83 -0
  1182. mindspore/profiler/common/exceptions/exceptions.py +286 -0
  1183. mindspore/profiler/common/process_pool.py +41 -0
  1184. mindspore/profiler/common/registry.py +47 -0
  1185. mindspore/profiler/common/singleton.py +28 -0
  1186. mindspore/profiler/common/struct_type.py +118 -0
  1187. mindspore/profiler/common/util.py +472 -0
  1188. mindspore/profiler/common/validator/__init__.py +14 -0
  1189. mindspore/profiler/common/validator/validate_path.py +84 -0
  1190. mindspore/profiler/dynamic_profiler.py +694 -0
  1191. mindspore/profiler/envprofiling.py +254 -0
  1192. mindspore/profiler/parser/__init__.py +14 -0
  1193. mindspore/profiler/parser/aicpu_data_parser.py +272 -0
  1194. mindspore/profiler/parser/ascend_analysis/__init__.py +14 -0
  1195. mindspore/profiler/parser/ascend_analysis/constant.py +71 -0
  1196. mindspore/profiler/parser/ascend_analysis/file_manager.py +180 -0
  1197. mindspore/profiler/parser/ascend_analysis/function_event.py +185 -0
  1198. mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +136 -0
  1199. mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +131 -0
  1200. mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +104 -0
  1201. mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
  1202. mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +123 -0
  1203. mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +86 -0
  1204. mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +75 -0
  1205. mindspore/profiler/parser/ascend_cluster_generator.py +116 -0
  1206. mindspore/profiler/parser/ascend_communicate_generator.py +314 -0
  1207. mindspore/profiler/parser/ascend_flops_generator.py +116 -0
  1208. mindspore/profiler/parser/ascend_fpbp_generator.py +82 -0
  1209. mindspore/profiler/parser/ascend_hccl_generator.py +271 -0
  1210. mindspore/profiler/parser/ascend_integrate_generator.py +42 -0
  1211. mindspore/profiler/parser/ascend_memory_generator.py +185 -0
  1212. mindspore/profiler/parser/ascend_msprof_exporter.py +282 -0
  1213. mindspore/profiler/parser/ascend_msprof_generator.py +187 -0
  1214. mindspore/profiler/parser/ascend_op_generator.py +334 -0
  1215. mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
  1216. mindspore/profiler/parser/ascend_timeline_generator.py +545 -0
  1217. mindspore/profiler/parser/base_timeline_generator.py +483 -0
  1218. mindspore/profiler/parser/container.py +229 -0
  1219. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +697 -0
  1220. mindspore/profiler/parser/flops_parser.py +531 -0
  1221. mindspore/profiler/parser/framework_enum.py +111 -0
  1222. mindspore/profiler/parser/framework_parser.py +464 -0
  1223. mindspore/profiler/parser/framework_struct.py +61 -0
  1224. mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
  1225. mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
  1226. mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
  1227. mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
  1228. mindspore/profiler/parser/hccl_parser.py +573 -0
  1229. mindspore/profiler/parser/hwts_log_parser.py +122 -0
  1230. mindspore/profiler/parser/integrator.py +526 -0
  1231. mindspore/profiler/parser/memory_usage_parser.py +277 -0
  1232. mindspore/profiler/parser/minddata_analyzer.py +800 -0
  1233. mindspore/profiler/parser/minddata_parser.py +186 -0
  1234. mindspore/profiler/parser/minddata_pipeline_parser.py +299 -0
  1235. mindspore/profiler/parser/op_intermediate_parser.py +149 -0
  1236. mindspore/profiler/parser/optime_parser.py +250 -0
  1237. mindspore/profiler/parser/profiler_info.py +213 -0
  1238. mindspore/profiler/parser/step_trace_parser.py +666 -0
  1239. mindspore/profiler/profiler.py +153 -0
  1240. mindspore/profiler/profiling.py +1922 -0
  1241. mindspore/rewrite/__init__.py +28 -0
  1242. mindspore/rewrite/api/__init__.py +17 -0
  1243. mindspore/rewrite/api/node.py +519 -0
  1244. mindspore/rewrite/api/node_type.py +53 -0
  1245. mindspore/rewrite/api/pattern_engine.py +490 -0
  1246. mindspore/rewrite/api/scoped_value.py +181 -0
  1247. mindspore/rewrite/api/symbol_tree.py +497 -0
  1248. mindspore/rewrite/ast_helpers/__init__.py +25 -0
  1249. mindspore/rewrite/ast_helpers/ast_converter.py +143 -0
  1250. mindspore/rewrite/ast_helpers/ast_finder.py +404 -0
  1251. mindspore/rewrite/ast_helpers/ast_flattener.py +268 -0
  1252. mindspore/rewrite/ast_helpers/ast_modifier.py +605 -0
  1253. mindspore/rewrite/ast_helpers/ast_replacer.py +79 -0
  1254. mindspore/rewrite/common/__init__.py +19 -0
  1255. mindspore/rewrite/common/config.py +24 -0
  1256. mindspore/rewrite/common/error_log.py +39 -0
  1257. mindspore/rewrite/common/event.py +28 -0
  1258. mindspore/rewrite/common/namer.py +271 -0
  1259. mindspore/rewrite/common/namespace.py +118 -0
  1260. mindspore/rewrite/common/observable.py +44 -0
  1261. mindspore/rewrite/common/observer.py +54 -0
  1262. mindspore/rewrite/node/__init__.py +22 -0
  1263. mindspore/rewrite/node/call_function.py +95 -0
  1264. mindspore/rewrite/node/cell_container.py +139 -0
  1265. mindspore/rewrite/node/control_flow.py +113 -0
  1266. mindspore/rewrite/node/node.py +1428 -0
  1267. mindspore/rewrite/node/node_manager.py +283 -0
  1268. mindspore/rewrite/node/node_topological_manager.py +223 -0
  1269. mindspore/rewrite/parsers/__init__.py +29 -0
  1270. mindspore/rewrite/parsers/arguments_parser.py +63 -0
  1271. mindspore/rewrite/parsers/assign_parser.py +852 -0
  1272. mindspore/rewrite/parsers/attribute_parser.py +57 -0
  1273. mindspore/rewrite/parsers/class_def_parser.py +289 -0
  1274. mindspore/rewrite/parsers/constant_parser.py +104 -0
  1275. mindspore/rewrite/parsers/container_parser.py +88 -0
  1276. mindspore/rewrite/parsers/expr_parser.py +55 -0
  1277. mindspore/rewrite/parsers/for_parser.py +61 -0
  1278. mindspore/rewrite/parsers/function_def_parser.py +84 -0
  1279. mindspore/rewrite/parsers/if_parser.py +85 -0
  1280. mindspore/rewrite/parsers/module_parser.py +117 -0
  1281. mindspore/rewrite/parsers/parser.py +43 -0
  1282. mindspore/rewrite/parsers/parser_register.py +86 -0
  1283. mindspore/rewrite/parsers/return_parser.py +37 -0
  1284. mindspore/rewrite/parsers/while_parser.py +59 -0
  1285. mindspore/rewrite/sparsify/__init__.py +0 -0
  1286. mindspore/rewrite/sparsify/sparse_transformer.py +457 -0
  1287. mindspore/rewrite/sparsify/sparsify.py +112 -0
  1288. mindspore/rewrite/sparsify/utils.py +179 -0
  1289. mindspore/rewrite/symbol_tree/__init__.py +20 -0
  1290. mindspore/rewrite/symbol_tree/symbol_tree.py +1819 -0
  1291. mindspore/rewrite/symbol_tree/symbol_tree_builder.py +76 -0
  1292. mindspore/rewrite/symbol_tree/symbol_tree_dumper.py +142 -0
  1293. mindspore/run_check/__init__.py +20 -0
  1294. mindspore/run_check/_check_version.py +507 -0
  1295. mindspore/run_check/run_check.py +66 -0
  1296. mindspore/safeguard/__init__.py +18 -0
  1297. mindspore/safeguard/rewrite_obfuscation.py +875 -0
  1298. mindspore/scipy/__init__.py +18 -0
  1299. mindspore/scipy/fft.py +264 -0
  1300. mindspore/scipy/linalg.py +919 -0
  1301. mindspore/scipy/ops.py +165 -0
  1302. mindspore/scipy/ops_grad.py +115 -0
  1303. mindspore/scipy/ops_wrapper.py +74 -0
  1304. mindspore/scipy/optimize/__init__.py +20 -0
  1305. mindspore/scipy/optimize/_bfgs.py +230 -0
  1306. mindspore/scipy/optimize/_lagrange.py +201 -0
  1307. mindspore/scipy/optimize/_lbfgs.py +146 -0
  1308. mindspore/scipy/optimize/gradient_optimization_algorithm.py +168 -0
  1309. mindspore/scipy/optimize/line_search.py +370 -0
  1310. mindspore/scipy/optimize/linear_sum_assignment.py +78 -0
  1311. mindspore/scipy/optimize/minimize.py +200 -0
  1312. mindspore/scipy/utils.py +156 -0
  1313. mindspore/scipy/utils_const.py +246 -0
  1314. mindspore/train/__init__.py +48 -0
  1315. mindspore/train/_utils.py +465 -0
  1316. mindspore/train/amp.py +935 -0
  1317. mindspore/train/anf_ir_pb2.py +1517 -0
  1318. mindspore/train/callback/__init__.py +44 -0
  1319. mindspore/train/callback/_backup_and_restore.py +117 -0
  1320. mindspore/train/callback/_callback.py +613 -0
  1321. mindspore/train/callback/_checkpoint.py +814 -0
  1322. mindspore/train/callback/_cluster_monitor.py +201 -0
  1323. mindspore/train/callback/_dataset_graph.py +150 -0
  1324. mindspore/train/callback/_early_stop.py +239 -0
  1325. mindspore/train/callback/_flops_collector.py +239 -0
  1326. mindspore/train/callback/_history.py +92 -0
  1327. mindspore/train/callback/_lambda_callback.py +80 -0
  1328. mindspore/train/callback/_landscape.py +1049 -0
  1329. mindspore/train/callback/_loss_monitor.py +107 -0
  1330. mindspore/train/callback/_lr_scheduler_callback.py +76 -0
  1331. mindspore/train/callback/_on_request_exit.py +298 -0
  1332. mindspore/train/callback/_reduce_lr_on_plateau.py +226 -0
  1333. mindspore/train/callback/_summary_collector.py +1184 -0
  1334. mindspore/train/callback/_tft_register.py +352 -0
  1335. mindspore/train/callback/_time_monitor.py +141 -0
  1336. mindspore/train/checkpoint_pb2.py +233 -0
  1337. mindspore/train/data_sink.py +219 -0
  1338. mindspore/train/dataset_helper.py +692 -0
  1339. mindspore/train/lineage_pb2.py +1260 -0
  1340. mindspore/train/loss_scale_manager.py +213 -0
  1341. mindspore/train/memory_profiling_pb2.py +298 -0
  1342. mindspore/train/metrics/__init__.py +175 -0
  1343. mindspore/train/metrics/accuracy.py +133 -0
  1344. mindspore/train/metrics/auc.py +129 -0
  1345. mindspore/train/metrics/bleu_score.py +170 -0
  1346. mindspore/train/metrics/confusion_matrix.py +700 -0
  1347. mindspore/train/metrics/cosine_similarity.py +109 -0
  1348. mindspore/train/metrics/dice.py +116 -0
  1349. mindspore/train/metrics/error.py +175 -0
  1350. mindspore/train/metrics/fbeta.py +167 -0
  1351. mindspore/train/metrics/hausdorff_distance.py +333 -0
  1352. mindspore/train/metrics/loss.py +97 -0
  1353. mindspore/train/metrics/mean_surface_distance.py +189 -0
  1354. mindspore/train/metrics/metric.py +373 -0
  1355. mindspore/train/metrics/occlusion_sensitivity.py +225 -0
  1356. mindspore/train/metrics/perplexity.py +133 -0
  1357. mindspore/train/metrics/precision.py +160 -0
  1358. mindspore/train/metrics/recall.py +159 -0
  1359. mindspore/train/metrics/roc.py +223 -0
  1360. mindspore/train/metrics/root_mean_square_surface_distance.py +191 -0
  1361. mindspore/train/metrics/topk.py +167 -0
  1362. mindspore/train/mind_ir_pb2.py +1908 -0
  1363. mindspore/train/model.py +2252 -0
  1364. mindspore/train/node_strategy_pb2.py +653 -0
  1365. mindspore/train/print_pb2.py +184 -0
  1366. mindspore/train/profiling_parallel_pb2.py +151 -0
  1367. mindspore/train/serialization.py +3325 -0
  1368. mindspore/train/summary/__init__.py +23 -0
  1369. mindspore/train/summary/_lineage_adapter.py +41 -0
  1370. mindspore/train/summary/_summary_adapter.py +496 -0
  1371. mindspore/train/summary/_writer_pool.py +207 -0
  1372. mindspore/train/summary/enums.py +56 -0
  1373. mindspore/train/summary/summary_record.py +581 -0
  1374. mindspore/train/summary/writer.py +167 -0
  1375. mindspore/train/summary_pb2.py +1165 -0
  1376. mindspore/train/train_thor/__init__.py +20 -0
  1377. mindspore/train/train_thor/convert_utils.py +268 -0
  1378. mindspore/train/train_thor/dataset_helper.py +192 -0
  1379. mindspore/train/train_thor/model_thor.py +257 -0
  1380. mindspore/utils/__init__.py +21 -0
  1381. mindspore/utils/utils.py +60 -0
  1382. mindspore/version.py +1 -0
  1383. mindspore-2.4.0.dist-info/METADATA +352 -0
  1384. mindspore-2.4.0.dist-info/RECORD +1387 -0
  1385. mindspore-2.4.0.dist-info/WHEEL +5 -0
  1386. mindspore-2.4.0.dist-info/entry_points.txt +3 -0
  1387. mindspore-2.4.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1348 @@
1
+ # Copyright 2020-2022 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ===========================================================================
15
+ """Cost model splitter"""
16
+ from functools import reduce as prod_reduce
17
+ from functools import partial
18
+ from .model import PrimLib, Graph, Tensor, Operator
19
+
20
+
21
+ def tensor_size(tensor):
22
+ """get tensor size"""
23
+ size = 1
24
+ for i in tensor.shape:
25
+ size *= i
26
+ return size
27
+
28
+
29
+ def reduce_nums(ops):
30
+ """get reduce nums"""
31
+ count = 0
32
+ for op in ops:
33
+ if op.prim.startswith('Reduce'):
34
+ count += 1
35
+ return count
36
+
37
+
38
+ def may_stitch(dom, a, r, stitch_axis_size, stitch_buffer_size):
39
+ """check if can stitch"""
40
+
41
+ def _same_stitch_axis(stitch_tensors, final_outs, stitch_axis_size):
42
+ """does a and b have same stitch axis"""
43
+
44
+ def _stitch_axis(shape, stitch_axis_size):
45
+ """get stitch axis"""
46
+ stitchaxis = []
47
+ size = 1
48
+ for i in shape:
49
+ size = size * i
50
+ stitchaxis.append(i)
51
+ if size >= stitch_axis_size:
52
+ return stitchaxis
53
+ return []
54
+
55
+ x = []
56
+ x.extend(stitch_tensors)
57
+ x.extend(final_outs)
58
+ stitch_axis_0 = _stitch_axis(x[0].shape, stitch_axis_size)
59
+ for item in x:
60
+ i_stitch_axis = _stitch_axis(item.shape, stitch_axis_size)
61
+ if not i_stitch_axis or i_stitch_axis != stitch_axis_0:
62
+ return False
63
+ return True
64
+
65
+ if a.pattern <= PrimLib.REDUCE and r <= PrimLib.BROADCAST and dom.check_acyclic(a):
66
+ if reduce_nums(a.ops) >= 2:
67
+ return False
68
+ dom_outs = set(op.output for op in dom.ops)
69
+ a_ins = set(op_input for op in a.ops for op_input in op.inputs)
70
+ a_outs = set(op.output for op in a.ops)
71
+ a_final_outs = list(tensor for tensor in a_outs if tensor not in a_ins)
72
+ stitch_tensors = list(tensor for tensor in dom_outs if tensor in a_ins)
73
+ if not _same_stitch_axis(stitch_tensors, a_final_outs, stitch_axis_size):
74
+ return False
75
+ return any((tensor_size(tensor) >= stitch_buffer_size for tensor in stitch_tensors))
76
+ return False
77
+
78
+
79
+ class CommonPattern:
80
+ """common fuse strategies across various devices"""
81
+
82
+ @staticmethod
83
+ def reshape(dom):
84
+ """fuse strategy for reshape dom"""
85
+ if dom.pattern != PrimLib.RESHAPE:
86
+ return [], False
87
+ min_area, forward_fuse = None, False
88
+ for a, _ in dom.out_relations.items():
89
+ if a.pattern <= PrimLib.BROADCAST and dom.check_acyclic(a):
90
+ if min_area is None or a.pattern < min_area.pattern:
91
+ min_area = a
92
+ for a, _ in dom.in_relations.items():
93
+ if a.pattern <= PrimLib.BROADCAST and a.check_acyclic(dom):
94
+ if min_area is None or a.pattern < min_area.pattern:
95
+ min_area, forward_fuse = a, True
96
+ return ([min_area], forward_fuse) if min_area else ([], False)
97
+
98
+ @staticmethod
99
+ def isolate_reshape(dom):
100
+ """fuse strategy for isolate reshape dom"""
101
+ if dom.pattern != PrimLib.RESHAPE or len(dom.ops) != 1:
102
+ return [], False
103
+ for a, _ in dom.out_relations.items():
104
+ if a.mode == GraphSplitByPattern.Area.MODE_COMPOSITE and dom.check_acyclic(a):
105
+ return [a], False
106
+ for a, _ in dom.in_relations.items():
107
+ if a.mode == GraphSplitByPattern.Area.MODE_COMPOSITE and a.pattern <= PrimLib.BROADCAST and \
108
+ a.check_acyclic(dom):
109
+ return [a], True
110
+ return [], False
111
+
112
+ @staticmethod
113
+ def elemwise_depth(dom):
114
+ """fuse strategy in depth for elemwise dom"""
115
+ if dom.pattern != PrimLib.ELEMWISE or len(dom.in_relations) != 1:
116
+ return [], False
117
+ a, r = list(dom.in_relations.items())[0]
118
+ if a.pattern > PrimLib.ELEMWISE or len(a.out_relations) != 1 or r > PrimLib.ELEMWISE:
119
+ return [], False
120
+ if tensor_size(a.dom_op().output) != tensor_size(dom.dom_op().output):
121
+ return [], False
122
+ return [a], True
123
+
124
+ @staticmethod
125
+ def elemwise_width(dom):
126
+ """fuse strategy in width for elemwise dom"""
127
+ if dom.pattern != PrimLib.ELEMWISE:
128
+ return [], False
129
+ fused = []
130
+ for a, r in dom.in_relations.items():
131
+ if a.pattern <= PrimLib.ELEMWISE and r <= PrimLib.ELEMWISE and a.check_acyclic(dom):
132
+ if tensor_size(a.dom_op().output) == tensor_size(dom.dom_op().output):
133
+ fused.append(a)
134
+ return fused, True
135
+
136
+ @staticmethod
137
+ def broadcast_depth(dom):
138
+ """fuse strategy in depth for broadcast dom"""
139
+ if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or len(dom.in_relations) != 1:
140
+ return [], False
141
+ a, r = list(dom.in_relations.items())[0]
142
+ if a.pattern > PrimLib.BROADCAST or len(a.out_relations) != 1 or r > PrimLib.ELEMWISE:
143
+ return [], False
144
+ if tensor_size(a.dom_op().output) != tensor_size(dom.dom_op().output):
145
+ return [], False
146
+ return [a], True
147
+
148
+ @staticmethod
149
+ def broadcast_width(dom):
150
+ """fuse strategy in width for broadcast dom"""
151
+ if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST):
152
+ return [], False
153
+ fused = []
154
+ for a, r in dom.in_relations.items():
155
+ if a.pattern <= PrimLib.BROADCAST and r <= PrimLib.ELEMWISE and a.check_acyclic(dom):
156
+ if tensor_size(a.dom_op().output) == tensor_size(dom.dom_op().output):
157
+ fused.append(a)
158
+ return fused, True
159
+
160
+ @staticmethod
161
+ def assign(dom):
162
+ """fuse strategy for assign dom"""
163
+ if len(dom.ops) != 1 or dom.dom_op().prim != "Assign":
164
+ return [], False
165
+ fused = []
166
+ for a, _ in dom.in_relations.items():
167
+ fused.append(a)
168
+ return fused, True
169
+
170
+
171
+ class ReshapeElimChecker:
172
+ """ check reshape elim """
173
+
174
+ def __init__(self, reshape):
175
+ def _get_remap_axis(in_shape, out_shape):
176
+ rin, rout = [], []
177
+ in_prod, out_prod, out_idx = 1, out_shape[-1], -1
178
+ in_ext, out_ext = -len(in_shape) - 1, -len(out_shape) - 1
179
+ for in_idx in range(-1, in_ext, -1):
180
+ in_prod = in_prod * in_shape[in_idx]
181
+ while out_prod < in_prod:
182
+ rout.append(out_idx)
183
+ out_idx = out_idx - 1
184
+ out_prod = out_prod * out_shape[out_idx]
185
+ if out_prod == in_prod and out_idx > out_ext and out_shape[out_idx] == in_shape[in_idx]:
186
+ out_idx = out_idx - 1
187
+ if out_idx > out_ext:
188
+ out_prod = out_prod * out_shape[out_idx]
189
+ else:
190
+ rin.append(in_idx)
191
+ if out_idx > out_ext:
192
+ rout.extend([i for i in range(out_idx, out_ext, -1)])
193
+ return rin, rout
194
+
195
+ remap_in, remap_out = _get_remap_axis(reshape.inputs[0].shape, reshape.output.shape)
196
+ self.exc_fwd = self._collect_exc_ops(reshape, remap_in, True)
197
+ self.exc_bwd = self._collect_exc_ops(reshape, remap_out, False)
198
+
199
+ @staticmethod
200
+ def _collect_exc_ops(reshape, remap_axis, is_fwd):
201
+ """collect exclude ops of reshape"""
202
+
203
+ def _propagate(remap, src, des):
204
+ out_remap = []
205
+ src_prod, des_prod, des_idx = 1, 1, 0
206
+ for src_idx in range(-1, -len(src) - 1, -1):
207
+ src_prod = src_prod * src[src_idx]
208
+ if src_idx in remap:
209
+ while des_prod < src_prod:
210
+ des_idx = des_idx - 1
211
+ des_prod = des_prod * des[des_idx]
212
+ out_remap.append(des_idx)
213
+ else:
214
+ while des_prod < src_prod:
215
+ prod = des_prod * des[des_idx - 1]
216
+ if prod > src_prod:
217
+ break
218
+ des_idx, des_prod = des_idx - 1, prod
219
+ return out_remap
220
+
221
+ def _remap_check(op, remap, iter_type):
222
+ if iter_type not in (PrimLib.ELEMWISE, PrimLib.BROADCAST):
223
+ return False
224
+ for t in op.inputs:
225
+ for i in remap:
226
+ if -i <= len(t.shape) and t.shape[i] != op.output.shape[i]:
227
+ return False
228
+ return True
229
+
230
+ def push_stack(op, remap):
231
+ stack.append((op, remap))
232
+ visited.add(op)
233
+
234
+ def _visit_fwd(op, remap):
235
+ for t in op.inputs:
236
+ if t.op is None:
237
+ _visit_bwd(t, remap)
238
+ elif tensor_size(t) > 1 and t.op not in visited: # all broadcast
239
+ iter_type = PrimLib.iter_type(t.op)
240
+ if iter_type == PrimLib.RESHAPE:
241
+ new_remap = _propagate(remap, t.shape, t.op.inputs[0].shape)
242
+ push_stack(t.op, new_remap)
243
+ elif _remap_check(t.op, remap, iter_type):
244
+ push_stack(t.op, remap)
245
+ else:
246
+ exc_ops.add(t.op)
247
+
248
+ def _visit_bwd(t, remap):
249
+ for op in t.to_ops:
250
+ if op not in visited:
251
+ iter_type = PrimLib.iter_type(op)
252
+ if iter_type == PrimLib.REDUCE and tensor_size(op.output) == 1: # all reduce
253
+ continue
254
+ if iter_type == PrimLib.RESHAPE:
255
+ new_remap = _propagate(remap, t.shape, op.output.shape)
256
+ push_stack(op, new_remap)
257
+ elif _remap_check(op, remap, iter_type):
258
+ push_stack(op, remap)
259
+ else:
260
+ exc_ops.add(op)
261
+
262
+ exc_ops, stack, visited = set(), [], {reshape}
263
+ if is_fwd:
264
+ _visit_fwd(reshape, remap_axis)
265
+ else:
266
+ _visit_bwd(reshape.output, remap_axis)
267
+ while stack:
268
+ top, remap = stack.pop()
269
+ _visit_bwd(top.output, remap)
270
+ _visit_fwd(top, remap)
271
+ return exc_ops
272
+
273
+ def check(self, ops, is_fwd):
274
+ """ fuse check """
275
+ if is_fwd:
276
+ fwd_res = all([op not in self.exc_fwd for op in ops]) if self.exc_fwd is not None else False
277
+ bwd_res = self.exc_bwd is not None
278
+ else:
279
+ fwd_res = self.exc_fwd is not None
280
+ bwd_res = all([op not in self.exc_bwd for op in ops]) if self.exc_bwd is not None else False
281
+ return [fwd_res, bwd_res] if fwd_res or bwd_res else False
282
+
283
+ def commit(self, res):
284
+ """ commit fuse result """
285
+ if not res[0] and self.exc_fwd is not None:
286
+ self.exc_fwd = None
287
+ if not res[1] and self.exc_bwd is not None:
288
+ self.exc_bwd = None
289
+
290
+
291
+ class ReduceOutFuseChecker:
292
+ """Reduce output fuse checker """
293
+
294
+ def __init__(self, red_op):
295
+ self.output_excluded = set()
296
+ recursion_stack = [red_op]
297
+ while recursion_stack:
298
+ op = recursion_stack.pop()
299
+ for to in op.output.to_ops:
300
+ idx = to.inputs.index(op.output)
301
+ if PrimLib.iter_type(to) > PrimLib.ELEMWISE or \
302
+ tensor_size(to.inputs[idx]) != tensor_size(to.output):
303
+ self.output_excluded.add(to)
304
+ else:
305
+ recursion_stack.append(to)
306
+
307
+ def check(self, ops, is_fwd):
308
+ """ fuse check """
309
+ if not is_fwd and self.output_excluded:
310
+ for op in self.output_excluded:
311
+ if op in ops:
312
+ return False
313
+ return True
314
+
315
+ def commit(self, res):
316
+ """ commit fuse result """
317
+ del res
318
+ return self.output_excluded # I'm not static
319
+
320
+
321
+ class GraphSplitByPattern:
322
+ """Graph splitter"""
323
+
324
+ class ReachTable:
325
+ """Reachable table"""
326
+
327
+ def __init__(self, size):
328
+ self.map = []
329
+ self.alive = set(range(size))
330
+ for i in range(0, size):
331
+ self.map.append([False] * size)
332
+ self.map[i][i] = True
333
+
334
+ def reachable(self, x, y):
335
+ """reachable from x to y"""
336
+ return self.map[x][y]
337
+
338
+ def sync(self, x, y):
339
+ """sync from y to x"""
340
+ for i in self.alive:
341
+ self._link(self.map[y][i], x, i)
342
+
343
+ def _link(self, cond, f, t):
344
+ """link from `f` to `t`"""
345
+ if cond:
346
+ self.map[f][t] = True
347
+
348
+ def fuse(self, x, y):
349
+ """fuse y to x"""
350
+ for i in self.alive:
351
+ # i is the succeeding node of y, links the x's previous nodes to i
352
+ if self.map[y][i] and not self.map[x][i]:
353
+ for pre in self.alive:
354
+ self._link(self.map[pre][x], pre, i)
355
+ # i is the previous node of y, link i to x's succeeding nodes
356
+ if self.map[i][y] and not self.map[i][x]:
357
+ for suc in self.alive:
358
+ self._link(self.map[x][suc], i, suc)
359
+ self.alive.remove(y)
360
+
361
+ class Area:
362
+ """Area"""
363
+ MODE_BASIC = 1
364
+ MODE_COMPOSITE = 2
365
+
366
+ class StitchInfo:
367
+ """StitchInfo"""
368
+
369
+ def __init__(self):
370
+ self.stitch_ops = set()
371
+ self.stitch_atomic_ops = set()
372
+
373
+ def has_stitch_op(self):
374
+ """check stitch_op exists"""
375
+ return self.stitch_ops or self.stitch_atomic_ops
376
+
377
+ def __init__(self, init_op, is_output, unique_id, reach_tab):
378
+ self.pattern = PrimLib.iter_type(init_op) if init_op is not None else PrimLib.UNKNOWN
379
+ self.ops = [] if init_op is None else [init_op]
380
+ self.in_relations = dict() # {area1: relation1, area2: relation2, ...}
381
+ self.out_relations = dict() # {area1: relation1, area2: relation2, ...}
382
+ self.mode = None
383
+ self.stitch_info = self.StitchInfo()
384
+ self.recompute_ops = []
385
+ self.is_output = is_output
386
+ self.output_excluded = set()
387
+ self.unique_id = unique_id
388
+ self.reach_tab = reach_tab
389
+ self.checkers = []
390
+ if self.pattern == PrimLib.RESHAPE and init_op.inputs: # reshape's input may be empty (const value)
391
+ self.checkers.append(ReshapeElimChecker(init_op))
392
+ elif self.pattern == PrimLib.REDUCE:
393
+ self.checkers.append(ReduceOutFuseChecker(init_op))
394
+
395
+ def __str__(self):
396
+ return '<' + '-'.join((op.output.name for op in self.ops)) + '>'
397
+
398
+ def __repr__(self):
399
+ return str(self)
400
+
401
+ @staticmethod
402
+ def get_relation(op, i):
403
+ """Get op relation"""
404
+ relation = PrimLib.UNKNOWN
405
+ _, elem_relation = PrimLib.input_relation(op, i)
406
+ for r in elem_relation:
407
+ if r is None:
408
+ relation = max(relation, PrimLib.BROADCAST)
409
+ elif r > relation:
410
+ relation = r
411
+ return relation
412
+
413
+ def link_input(self, area_map):
414
+ """Link inputs"""
415
+ for i, t in enumerate(self.ops[0].inputs):
416
+ if t.op is not None:
417
+ area, relation = area_map[t.op], self.get_relation(self.ops[0], i)
418
+ self.in_relations[area] = relation
419
+
420
+ def link_output(self):
421
+ """Link outputs"""
422
+ for input_area, r in self.in_relations.items():
423
+ input_area.out_relations[self] = r
424
+ for out, _ in self.out_relations.items():
425
+ self.reach_tab.sync(self.unique_id, out.unique_id)
426
+
427
+ def update_stitch_info(self, stitch_info):
428
+ """Update stitch info"""
429
+ if stitch_info.stitch_ops:
430
+ self.stitch_info.stitch_ops.update(stitch_info.stitch_ops)
431
+ if stitch_info.stitch_atomic_ops:
432
+ self.stitch_info.stitch_atomic_ops.update(stitch_info.stitch_atomic_ops)
433
+
434
+ def fuse_confirm(self, area):
435
+ """confirm if area can be fused"""
436
+
437
+ def _check(a, b, res, fwd):
438
+ for checker in a.checkers:
439
+ r = checker.check(b.ops, fwd)
440
+ if not r:
441
+ return False
442
+ res.append(r)
443
+ return True
444
+
445
+ def _commit(a, res):
446
+ for i, checker in enumerate(a.checkers):
447
+ checker.commit(res[i])
448
+
449
+ res1, res2 = [], []
450
+ if not _check(self, area, res1, True) or not _check(area, self, res2, False):
451
+ return False
452
+ _commit(self, res1)
453
+ _commit(area, res2)
454
+ return True
455
+
456
+ def fuse_prepare(self, dom):
457
+ """do some prepare before fused to dom"""
458
+ del dom
459
+ return self.unique_id # I'm not static method
460
+
461
+ def fuse_done(self, dom):
462
+ """do some thing after fused to dom"""
463
+ dom.reach_tab.fuse(dom.unique_id, self.unique_id)
464
+
465
+ def fuse(self, area):
466
+ """Fuse `area` to `self`"""
467
+
468
+ def _update_relation(relations, a, r):
469
+ relations[a] = max(r, relations[a]) if a in relations else r
470
+
471
+ def _update_pattern():
472
+ if area.pattern > self.pattern:
473
+ self.pattern = area.pattern
474
+ if area in self.in_relations and self.in_relations.get(area) > self.pattern:
475
+ self.pattern = self.in_relations.get(area)
476
+
477
+ def _fuse_relation(self_relations, new_relations):
478
+ for a, r in new_relations.items():
479
+ if a != self:
480
+ _update_relation(self_relations, a, r)
481
+ if area in self_relations:
482
+ self_relations.pop(area)
483
+
484
+ def _redirect_relation(rels):
485
+ """Replace `area` with `self` in relations"""
486
+ if area in rels:
487
+ r = rels.pop(area)
488
+ _update_relation(rels, self, r)
489
+
490
+ area.fuse_prepare(self)
491
+ if self.pattern >= area.pattern:
492
+ self.ops.extend(area.ops)
493
+ else:
494
+ self.ops = area.ops + self.ops
495
+ _update_pattern()
496
+ _fuse_relation(self.in_relations, area.in_relations)
497
+ _fuse_relation(self.out_relations, area.out_relations)
498
+ for a, _ in area.in_relations.items():
499
+ _redirect_relation(a.out_relations)
500
+ for a, _ in area.out_relations.items():
501
+ _redirect_relation(a.in_relations)
502
+ if self.pattern > PrimLib.RESHAPE:
503
+ self.mode = self.MODE_COMPOSITE
504
+ if area.is_output and not self.is_output:
505
+ self.is_output = True
506
+ self.update_stitch_info(area.stitch_info)
507
+ self.recompute_ops.extend(area.recompute_ops)
508
+ self.checkers.extend(area.checkers)
509
+ area.fuse_done(self)
510
+
511
+ def check_acyclic(self, to):
512
+ """Check circle. It returns false if circle exists"""
513
+ for out, _ in self.out_relations.items():
514
+ if out != to and self.reach_tab.reachable(out.unique_id, to.unique_id):
515
+ return False
516
+ return True
517
+
518
+ def dom_op(self):
519
+ """Get dom op"""
520
+ return self.ops[0]
521
+
522
+ class RecomputeArea(Area):
523
+ """RecomputeArea"""
524
+
525
+ def __init__(self, unique_id, reach_tab):
526
+ super().__init__(None, False, unique_id, reach_tab)
527
+ self.recom_pre = None
528
+ self.recom_user = None
529
+ self.recom_dom = None
530
+ self.dom_user_r = PrimLib.UNKNOWN
531
+ self.ori_op_map = {}
532
+ self.recom_map = {}
533
+ self.fuse_success = False
534
+
535
+ def fuse_prepare(self, dom):
536
+ """copy recompute_ops in area to ops, self is area's user"""
537
+ tail_tensor = self.recompute_ops[-1].output
538
+ # copy tensors, all copied are Tensor.PARA_NONE
539
+ tensor_map = {}
540
+ if self.recompute_ops[0].inputs:
541
+ tensor_map[self.recompute_ops[0].inputs[0]] = self.recompute_ops[0].inputs[0]
542
+ for op in self.recompute_ops:
543
+ orig_tensor = op.output
544
+ cp_tensor = Tensor(orig_tensor.name, orig_tensor.shape, orig_tensor.dtype, orig_tensor.data_format)
545
+ tensor_map[orig_tensor] = cp_tensor
546
+ # copy ops
547
+ cp_ops = []
548
+ for op in self.recompute_ops:
549
+ inputs = [tensor_map.get(op.inputs[0])] if op.inputs else []
550
+ cp_op = Operator(op.prim, inputs, tensor_map.get(op.output), op.attrs)
551
+ cp_op.all_inputs = cp_op.inputs
552
+ cp_ops.append(cp_op)
553
+ self.ori_op_map[cp_op] = op
554
+ # connect copied ops
555
+ for op in dom.ops:
556
+ if tail_tensor in op.inputs:
557
+ op.inputs.remove(tail_tensor)
558
+ op.inputs.append(tensor_map.get(tail_tensor))
559
+ tail_tensor.to_ops.remove(op)
560
+ tensor_map.get(tail_tensor).to_ops.append(op)
561
+ # fill cp_ops in self.recompute_area
562
+ cp_dom_op = None
563
+ for cp, ori in self.ori_op_map.items():
564
+ if ori == self.dom_op():
565
+ cp_dom_op = cp
566
+ self.ops.clear()
567
+ self.ops.append(cp_dom_op)
568
+ self.ops.extend((op for op in cp_ops if op != cp_dom_op))
569
+
570
+ def fuse_done(self, dom):
571
+ """do some thing after fused to dom"""
572
+ del dom
573
+ self.fuse_success = True
574
+
575
+ def reset(self, dom_area, ops, user_area, pre_area):
576
+ """set the recompute area and connect with other areas"""
577
+ self.recompute_ops.extend(ops)
578
+ # recom_area: set dom_op and correct ops length
579
+ patterns = list(PrimLib.iter_type(op) for op in ops)
580
+ self.pattern = max(patterns)
581
+ for i, pat in enumerate(patterns):
582
+ if pat == self.pattern:
583
+ self.ops = [ops[i]] * len(ops)
584
+ break
585
+ # disconnect dom_area and user_area
586
+ self.dom_user_r = dom_area.out_relations[user_area]
587
+ dom_area.out_relations.pop(user_area)
588
+ user_area.in_relations.pop(dom_area)
589
+ # connect recom_area and user_area
590
+ user_area.in_relations[self] = self.dom_user_r
591
+ self.out_relations[user_area] = self.dom_user_r
592
+ # connect recom_pre and recom_area
593
+ self.recom_pre = pre_area
594
+ if self.recom_pre is not None:
595
+ self.in_relations[self.recom_pre] = dom_area.in_relations[self.recom_pre]
596
+ self.recom_pre.out_relations[self] = dom_area.in_relations[self.recom_pre]
597
+ # set related areas
598
+ self.recom_user = user_area
599
+ self.recom_dom = dom_area
600
+ self.fuse_success = False
601
+
602
+ def clear(self):
603
+ """disconnect recom_area from other areas, and clear recom_area"""
604
+ self.out_relations.clear()
605
+ self.in_relations.clear()
606
+ if not self.fuse_success:
607
+ self.recom_user.in_relations.pop(self)
608
+ self.recom_user.in_relations[self.recom_dom] = self.dom_user_r
609
+ self.recom_dom.out_relations[self.recom_user] = self.dom_user_r
610
+ if self.recom_pre:
611
+ self.recom_pre.out_relations.pop(self)
612
+ self.ops.clear()
613
+ self.recompute_ops.clear()
614
+ self.recom_map.update(self.ori_op_map)
615
+ self.ori_op_map.clear()
616
+
617
+ def __init__(self, graph, flags):
618
+ self.graph = graph
619
+ self.areas = []
620
+ self.flags = flags
621
+ self.enable_recompute = self.flags.get("enable_recompute_fusion", False)
622
+ self.enable_stitch_fusion = self.flags.get("enable_stitch_fusion", False)
623
+ self.enable_horizontal_fusion = self.flags.get("enable_horizontal_fusion", False)
624
+ self.reduce_fuse_depth = self.flags.get("reduce_fuse_depth", -1)
625
+ self.reach_tab = self.ReachTable(len(graph.ops) + 1 if self.enable_recompute else len(graph.ops))
626
+ self.area_map = {}
627
+ _, outputs = graph.deduce_parameters()
628
+ idx = 0
629
+ for op in graph.ops:
630
+ is_output = op.output in outputs
631
+ a = self.Area(op, is_output, idx, self.reach_tab)
632
+ idx += 1
633
+ self.set_default_mode(a)
634
+ self.areas.append(a)
635
+ self.set_area_map([op], a)
636
+ for a in self.areas:
637
+ a.link_input(self.area_map)
638
+ for i in range(len(self.areas) - 1, -1, -1):
639
+ self.areas[i].link_output()
640
+ if self.enable_recompute:
641
+ self.recom_area = self.RecomputeArea(idx, self.reach_tab)
642
+
643
+ def set_area_map(self, ops, area):
644
+ """update area_map after op fused to area"""
645
+ for op in ops:
646
+ self.area_map[op] = area
647
+
648
+ def set_default_mode(self, area):
649
+ """Set default mode"""
650
+ area.mode = self.get_default_mode(area.ops[0])
651
+
652
+ @staticmethod
653
+ def limit_area_size(dominant, fuse_areas, limit_size):
654
+ """Remove some areas if the size is too large"""
655
+ area_sizes = map(lambda area: len(area.ops), fuse_areas)
656
+ dom_size = len(dominant.ops)
657
+ if dom_size + prod_reduce(lambda x, y: x + y, area_sizes) <= limit_size:
658
+ return fuse_areas
659
+ # fuse the smaller area in priority
660
+ fuse_areas.sort(key=lambda area: len(area.ops))
661
+ new_fuse_areas = []
662
+ for area in fuse_areas:
663
+ if dom_size + len(area.ops) > limit_size:
664
+ break
665
+ dom_size += len(area.ops)
666
+ new_fuse_areas.append(area)
667
+ return new_fuse_areas
668
+
669
+ def fuse(self, selector, is_stitch=False):
670
+ """Fuse areas"""
671
+
672
+ def _fuse_area():
673
+ for dominant in self.areas:
674
+ result = selector(dominant)
675
+ if not result or not result[0]:
676
+ continue
677
+ fuse_areas, is_forward = result
678
+ if not is_stitch:
679
+ fuse_areas = self.limit_area_size(dominant, fuse_areas, self.flags['composite_op_limit_size'])
680
+ if not fuse_areas:
681
+ continue
682
+ changed = False
683
+ if is_forward:
684
+ for area in fuse_areas:
685
+ if is_stitch or dominant.fuse_confirm(area):
686
+ dominant.fuse(area)
687
+ self.set_area_map(area.ops, dominant)
688
+ self.areas.remove(area)
689
+ changed = True
690
+ else:
691
+ forward_area = dominant
692
+ for area in fuse_areas:
693
+ if is_stitch or area.fuse_confirm(forward_area):
694
+ area.fuse(forward_area)
695
+ self.set_area_map(forward_area.ops, area)
696
+ self.areas.remove(forward_area)
697
+ forward_area = area
698
+ changed = True
699
+ if changed:
700
+ return True
701
+ return False
702
+
703
+ changed, do_again = False, True
704
+ while do_again:
705
+ do_again = _fuse_area()
706
+ changed = changed or do_again
707
+ return changed
708
+
709
+ def hfuse(self, selector):
710
+ """Fuse horizontal areas with same input tensor"""
711
+
712
+ def _do_fuse(areas):
713
+ for i in range(len(areas) - 1):
714
+ dom = areas[i]
715
+ for a in areas[i + 1:]:
716
+ can_fuse = dom.check_acyclic(a) and a.check_acyclic(dom) and selector(dom, a) \
717
+ and self.limit_area_size(dom, [a], 64) and dom.fuse_confirm(a)
718
+ if can_fuse:
719
+ dom.fuse(a)
720
+ self.set_area_map(a.ops, dom)
721
+ self.areas.remove(a)
722
+ return True
723
+ return False
724
+
725
+ def _update_areas(areas, from_op):
726
+ for op in from_op.to_ops:
727
+ a = self.area_map.get(op)
728
+ if a in self.areas and a not in areas:
729
+ areas.append(a)
730
+
731
+ changed = False
732
+ while True:
733
+ for dom in self.areas:
734
+ if len(dom.out_relations) > 1 and _do_fuse(list(dom.out_relations.keys())):
735
+ changed = True
736
+ break
737
+ else:
738
+ break
739
+ inputs, _ = self.graph.deduce_parameters()
740
+ while True:
741
+ for t in inputs:
742
+ areas = []
743
+ _update_areas(areas, t)
744
+ if len(areas) > 1 and _do_fuse(areas):
745
+ changed = True
746
+ break
747
+ else:
748
+ break
749
+ return changed
750
+
751
+ def fuse_recom(self, selector):
752
+ """Fuse recompute area to its user"""
753
+ user = self.recom_area.recom_user
754
+ for dominant in [self.recom_area, user]:
755
+ result = selector(dominant)
756
+ if result and result[0]:
757
+ fuse_areas, _ = result
758
+ fuse_areas = self.limit_area_size(dominant, fuse_areas, self.flags['composite_op_limit_size'])
759
+ if not fuse_areas:
760
+ continue
761
+ if fuse_areas[0] in [self.recom_area, user] and user.fuse_confirm(self.recom_area):
762
+ user.fuse(self.recom_area)
763
+ self.set_area_map(self.recom_area.ops, user)
764
+ return True
765
+ return False
766
+
767
+ def index_op(self):
768
+ """index op by order, the copied op share id with original op, for topo-sort"""
769
+ ids = {}
770
+ for i, op in enumerate(self.graph.ops):
771
+ ids[op] = i
772
+ if self.enable_recompute:
773
+ for k, v in self.recom_area.recom_map.items():
774
+ ids[k] = ids.get(v)
775
+ return ids
776
+
777
+ def to_subgraphs(self):
778
+ """Transform op groups to subgraphs"""
779
+ ids = self.index_op()
780
+ subgraphs = []
781
+ graphmodes = []
782
+ for i, area in enumerate(self.areas):
783
+ area.ops.sort(key=ids.get)
784
+ subgraphs.append(Graph('{}_{}'.format(self.graph.name, i), area.ops, area.stitch_info, area.recompute_ops))
785
+ graphmodes.append("basic" if area.mode == self.Area.MODE_BASIC else "composite")
786
+ return subgraphs, graphmodes
787
+
788
+ def pattern_fuse(self, fuse_func=None):
789
+ """fuse Areas by pattern repeatedly"""
790
+ del fuse_func
791
+ raise Exception("pattern_fuse() is not implemented in {}".format(self.__class__.__name__))
792
+
793
+ def split(self):
794
+ """Split graph by pattern"""
795
+ self.pattern_fuse()
796
+ if self.enable_recompute:
797
+ self.recompute_fuse()
798
+ # The reshape should not be output node
799
+ # Note: after this function, the input output relation is not maintained.
800
+ self.split_output_reshapes()
801
+ subgraphs, graphmodes = self.to_subgraphs()
802
+ return subgraphs, graphmodes
803
+
804
+ def split_output_reshapes(self):
805
+ """Force split the output Reshapes into other new area"""
806
+
807
+ def _remove_output_reshape(reshape_ops, other_ops):
808
+ def _run():
809
+ for op in reshape_ops:
810
+ if any((to_op in other_ops for to_op in op.output.to_ops)):
811
+ reshape_ops.remove(op)
812
+ other_ops.append(op)
813
+ return True
814
+ return False
815
+
816
+ while _run():
817
+ pass
818
+
819
+ new_areas = []
820
+ for area in self.areas:
821
+ reshape_ops = list(op for op in area.ops if PrimLib.iter_type(op) == PrimLib.RESHAPE)
822
+ other_ops = list(op for op in area.ops if op not in reshape_ops)
823
+ if not other_ops or not reshape_ops:
824
+ continue
825
+ # remove the output reshape from "reshape_ops" and add it into "other_ops"
826
+ _remove_output_reshape(reshape_ops, other_ops)
827
+ if not reshape_ops:
828
+ continue
829
+ for op in reshape_ops:
830
+ a = self.Area(op, False, 0, self.reach_tab)
831
+ self.set_default_mode(a)
832
+ new_areas.append(a)
833
+ area.ops = other_ops
834
+ if len(other_ops) == 1:
835
+ self.set_default_mode(area)
836
+ if new_areas:
837
+ self.areas += new_areas
838
+
839
+ def recompute_fuse(self):
840
+ """find recompute regions and copy them out to new Areas"""
841
+
842
+ def _get_prods(area, border):
843
+ """get producer region of border op"""
844
+ max_weight = 10
845
+ stack = [border]
846
+ ops, inputs = [], []
847
+ while stack:
848
+ op = stack.pop()
849
+ if len(op.inputs) > 1 or PrimLib.iter_type(op) > PrimLib.BROADCAST or len(ops) > max_weight:
850
+ return [], []
851
+ ops.append(op)
852
+ for t in op.inputs:
853
+ if t.op in area.ops:
854
+ stack.append(t.op)
855
+ else:
856
+ inputs.append(t)
857
+ return ops, inputs
858
+
859
+ def _get_border_info(area):
860
+ """get border information"""
861
+ prods, users = {}, {}
862
+ for op in area.ops:
863
+ if len(op.output.to_ops) <= 1 and op.output.para_type != Tensor.PARA_OUTPUT:
864
+ continue
865
+ for to in op.output.to_ops:
866
+ if to in area.ops:
867
+ continue
868
+ user = self.area_map.get(to)
869
+ if user.pattern > PrimLib.RESHAPE:
870
+ if user in users:
871
+ users.get(user).append(op)
872
+ else:
873
+ users[user] = [op]
874
+ if op not in prods:
875
+ prods[op] = _get_prods(area, op)
876
+ return prods, users
877
+
878
+ def _get_cheap_region(prods, borders):
879
+ """get cheap region of border ops"""
880
+ if len(borders) > 1:
881
+ return []
882
+ result = []
883
+ for op in borders:
884
+ prod_ops, inputs = prods[op]
885
+ if prod_ops:
886
+ if sum([t.get_size() for t in inputs]) <= op.output.get_size():
887
+ pred = self.area_map.get(inputs[0].op) if inputs and inputs[0].op else None
888
+ result.append([pred, prod_ops[::-1]])
889
+ return result
890
+
891
+ def _do_recompute(area):
892
+ """split the unfusing pattern by add recompute area"""
893
+ prods, users = _get_border_info(area)
894
+ for user, borders in users.items():
895
+ result = _get_cheap_region(prods, borders)
896
+ for pred, region in result:
897
+ self.recom_area.reset(area, region, user, pred)
898
+ self.pattern_fuse(self.fuse_recom)
899
+ self.recom_area.clear()
900
+ if self.recom_area.fuse_success:
901
+ return True
902
+ return False
903
+
904
+ changed = True
905
+ while changed:
906
+ changed = False
907
+ orig_areas = []
908
+ orig_areas.extend(self.areas)
909
+ for area in orig_areas:
910
+ if area in self.areas and area.out_relations:
911
+ changed = _do_recompute(area) or changed
912
+ if changed:
913
+ self.pattern_fuse()
914
+
915
+
916
+ class GraphSplitGpu(GraphSplitByPattern):
917
+ """Graph splitter"""
918
+ BROADCAST_FUSE_DEPTH = 20
919
+ TRANSPOSE_FUSE_DEPTH = 6
920
+
921
+ def __init__(self, graph, flags):
922
+ super().__init__(graph, flags)
923
+ self.reduce_fuse_depth = 20 if self.reduce_fuse_depth < 0 else self.reduce_fuse_depth
924
+
925
+ def get_default_mode(self, op):
926
+ """Get default mode in GPU"""
927
+ if op.prim == "MatMul":
928
+ return self.Area.MODE_COMPOSITE if op.inputs[0].dtype == "float16" and op.attrs['Akg'] else \
929
+ self.Area.MODE_BASIC
930
+ if op.prim == "Assign":
931
+ return self.Area.MODE_BASIC
932
+ pattern = PrimLib.iter_type(op)
933
+ return self.Area.MODE_BASIC if pattern == PrimLib.RESHAPE else self.Area.MODE_COMPOSITE
934
+
935
+ def pattern_fuse(self, fuse_func=None):
936
+ """fuse Areas by pattern"""
937
+
938
+ def _broadcast_pat_exclude(dom, a, r):
939
+ if a.pattern == PrimLib.REDUCE:
940
+ return dom.pattern > PrimLib.ELEMWISE or r > PrimLib.ELEMWISE
941
+ return a.pattern > PrimLib.REDUCE or r > PrimLib.BROADCAST
942
+
943
+ def _broadcast_bwd_depth(dom):
944
+ if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or len(dom.out_relations) != 1:
945
+ return [], False
946
+ if dom.is_output or len(dom.ops) > self.BROADCAST_FUSE_DEPTH:
947
+ return [], False
948
+ a, r = list(dom.out_relations.items())[0]
949
+ if _broadcast_pat_exclude(dom, a, r) or len(a.in_relations) != 1:
950
+ return [], False
951
+ return [a], False
952
+
953
+ def _broadcast_bwd_width(dom):
954
+ if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or \
955
+ dom.is_output or len(dom.ops) > self.BROADCAST_FUSE_DEPTH:
956
+ return [], False
957
+ fused = []
958
+ for a, r in dom.out_relations.items():
959
+ if _broadcast_pat_exclude(dom, a, r) or not dom.check_acyclic(a):
960
+ return [], False
961
+ if fused and tensor_size(fused[0].dom_op().output) != tensor_size(a.dom_op().output):
962
+ return [], False
963
+ fused.append(a)
964
+ return fused, False
965
+
966
+ def _reduce_pat_exclude(_, a, r):
967
+ if len(a.ops) > self.reduce_fuse_depth:
968
+ return True
969
+ return a.pattern > PrimLib.ELEMWISE or r > PrimLib.REDUCE or r == PrimLib.BROADCAST
970
+
971
+ def _reduce_depth(dom):
972
+ if dom.pattern != PrimLib.REDUCE or len(dom.in_relations) != 1:
973
+ return [], False
974
+ a, r = list(dom.in_relations.items())[0]
975
+ if dom.ops[0].inputs[0].dtype == "float16" and a.is_output:
976
+ if len(a.ops) >= 10 and _is_atomic_add_available(dom):
977
+ # to evade the precision problem.
978
+ return [], False
979
+ if _reduce_pat_exclude(dom, a, r) or len(a.out_relations) != 1:
980
+ return []
981
+ return [a], True
982
+
983
+ def _reduce_width(dom):
984
+ if dom.pattern != PrimLib.REDUCE:
985
+ return [], False
986
+ fused = []
987
+ for a, r in dom.in_relations.items():
988
+ if dom.ops[0].inputs[0].dtype == "float16" and a.is_output:
989
+ if len(a.ops) >= 10 and _is_atomic_add_available(dom):
990
+ # to evade the precision problem.
991
+ continue
992
+ if not _reduce_pat_exclude(dom, a, r) and a.check_acyclic(dom):
993
+ fused.append(a)
994
+ return fused, True
995
+
996
+ def _is_atomic_add_available(dom):
997
+ if any(("Reduce" in x.prim for x in dom.ops[1:])):
998
+ return False
999
+ op = dom.ops[0]
1000
+ if "reduce_axis" in op.attrs:
1001
+ reduce_axis = op.attrs["reduce_axis"]
1002
+ elif "axis" in op.attrs:
1003
+ reduce_axis = [op.attrs["axis"]]
1004
+ else:
1005
+ raise Exception("For '{}', can not find the attr 'reduce_axis' or 'axis'".format(op.prim))
1006
+ if op.inputs and len(op.inputs[0].shape) - 1 in reduce_axis:
1007
+ reduce_size = prod_reduce(lambda x, y: x * y, (op.inputs[0].shape[i] for i in reduce_axis))
1008
+ return reduce_size >= 1024
1009
+ return True
1010
+
1011
+ def _may_multi_filter(dom_ops):
1012
+ count = 1
1013
+ stack = [dom_ops[0]]
1014
+ while stack:
1015
+ op = stack.pop()
1016
+ for t in op.inputs:
1017
+ if t.op and t.op in dom_ops:
1018
+ count = count + 1
1019
+ stack.append(t.op)
1020
+ return count < len(dom_ops)
1021
+
1022
+ def _reduce_output(dom):
1023
+ if dom.pattern != PrimLib.REDUCE:
1024
+ return [], False
1025
+ if _may_multi_filter(dom.ops):
1026
+ return [], False
1027
+ if _is_atomic_add_available(dom):
1028
+ return [], False
1029
+ is_all_reduce = tensor_size(dom.ops[0].output) == 1
1030
+ # excluded large size all reduce
1031
+ if is_all_reduce and dom.ops[0].inputs and tensor_size(dom.ops[0].inputs[0]) > 1024 * 12:
1032
+ return [], False
1033
+
1034
+ fused = []
1035
+ for a, r in dom.out_relations.items():
1036
+ if a.pattern <= PrimLib.BROADCAST and r <= PrimLib.BROADCAST and dom.check_acyclic(a):
1037
+ fused.append(a)
1038
+ return fused, False
1039
+
1040
+ def _reduce_stitch(dom):
1041
+ if dom.pattern != PrimLib.REDUCE:
1042
+ return [], False
1043
+ if tensor_size(dom.ops[0].output) == 1:
1044
+ return [], False
1045
+ if tensor_size(dom.ops[0].inputs[0]) < 1024 * 12:
1046
+ return [], False
1047
+
1048
+ fused = []
1049
+ for a, r in dom.out_relations.items():
1050
+ if not may_stitch(dom, a, r, 1024 * 8, 1024 * 1024):
1051
+ continue
1052
+ if a.pattern == PrimLib.REDUCE:
1053
+ if a.ops[0].attrs['reduce_axis'] == dom.ops[0].attrs['reduce_axis']:
1054
+ dom.stitch_info.stitch_ops.add(dom.ops[0].output.name)
1055
+ fused.append(a)
1056
+ elif a.pattern == PrimLib.BROADCAST:
1057
+ dom.stitch_info.stitch_ops.add(dom.ops[0].output.name)
1058
+ fused.append(a)
1059
+ return fused, False
1060
+
1061
+ def _transpose(dom):
1062
+ if len(dom.ops) != 1 or dom.ops[0].prim != "Transpose":
1063
+ return [], False
1064
+ fused = []
1065
+ for a, _ in dom.in_relations.items():
1066
+ if a.pattern <= PrimLib.BROADCAST and a.check_acyclic(dom) and len(a.ops) <= self.TRANSPOSE_FUSE_DEPTH:
1067
+ fused.append(a)
1068
+ return fused, True
1069
+
1070
+ def _strided_slice(dom):
1071
+ if dom.dom_op().prim != "StridedSlice":
1072
+ return [], False
1073
+ fused = []
1074
+ for a, _ in dom.in_relations.items():
1075
+ if a.pattern <= PrimLib.BROADCAST and a.check_acyclic(dom) and \
1076
+ len(a.out_relations) == 1 and not a.is_output:
1077
+ fused.append(a)
1078
+ return fused, True
1079
+
1080
+ def _gather_output(dom, reduce_fusion=False):
1081
+ gather_prims = ("Gather", "GatherNd", "CSRGather")
1082
+ if not dom.dom_op().prim in gather_prims:
1083
+ return [], False
1084
+
1085
+ def _reduce_exclude(op, axis_list):
1086
+ """ Whether this operator should be excluded.
1087
+ Excluding condition:
1088
+ 1. There are at least one same axis between reduce axes and axis_list.
1089
+
1090
+ Args:
1091
+ op (Operator): Target reduce operator.
1092
+ axis_list (list): List to check whether it is intersected by reduce axis.
1093
+ Returns:
1094
+ Boolean. Whether this operator should be excluded.
1095
+ """
1096
+ axis = op.attrs["reduce_axis"]
1097
+ if isinstance(axis, int):
1098
+ axis = [axis]
1099
+ in_shape_len = len(op.inputs[0].shape)
1100
+ for i, dim in enumerate(axis):
1101
+ axis[i] = in_shape_len + dim if dim < 0 else dim
1102
+ fix_axis = []
1103
+ for ax in axis:
1104
+ if op.inputs[0].shape[ax] == 1:
1105
+ continue
1106
+ fix_axis.append(ax)
1107
+ return bool(set(fix_axis) & set(axis_list))
1108
+
1109
+ def _bfs_visit(start_op, start_prims, total_ops, end_ops, gather_axis):
1110
+ consisten_shape = start_op.output.shape
1111
+ visited = []
1112
+ op_queue = [start_op]
1113
+
1114
+ def _early_stop(cur_op):
1115
+ if cur_op in end_ops:
1116
+ # If reduce the gather axis, stop early for not fusion.
1117
+ if cur_op.prim == "ReduceSum" and _reduce_exclude(cur_op, gather_axis):
1118
+ return True
1119
+ else:
1120
+ if (cur_op.prim in start_prims and cur_op != start_op) or \
1121
+ consisten_shape != cur_op.output.shape:
1122
+ return True
1123
+ return False
1124
+
1125
+ while op_queue:
1126
+ tmp_queue = []
1127
+ for op in op_queue:
1128
+ if op in visited or op not in total_ops:
1129
+ continue
1130
+ if _early_stop(op):
1131
+ return False
1132
+ if op in end_ops:
1133
+ continue
1134
+ for to_op in op.output.to_ops:
1135
+ tmp_queue.append(to_op)
1136
+ visited.append(op)
1137
+ op_queue = tmp_queue
1138
+ return True
1139
+
1140
+ def _shape_consistent(start_prims, end_prims, source, target):
1141
+ """
1142
+ Check whether it is always shape consistent from source nodes to target nodes.
1143
+ Excluding condition:
1144
+ When fusing ReduceSum, first check if TensorScatterAdd and/or UnsortedSegmentSum
1145
+ has already been fused, if so, stop ReduceSum fusion.
1146
+ """
1147
+ total_ops = source.ops + target.ops
1148
+ op_prims_set = {op.prim for op in total_ops}
1149
+ if reduce_fusion and (len({"TensorScatterAdd", "UnsortedSegmentSum"} & op_prims_set) >= 1):
1150
+ return False
1151
+ start_ops = []
1152
+ for op in source.ops:
1153
+ if op.prim in start_prims:
1154
+ start_ops.append(op)
1155
+ end_ops = []
1156
+ for op in total_ops:
1157
+ if op.prim in end_prims and not any((to_op in total_ops for to_op in op.output.to_ops)):
1158
+ end_ops.append(op)
1159
+
1160
+ for start_op in start_ops:
1161
+ gather_axis = start_op.attrs.get("axis", None)
1162
+ if gather_axis is None:
1163
+ # For GatherNd
1164
+ gather_axis = list(range(len(start_op.inputs[1].shape)))
1165
+ elif isinstance(gather_axis, int):
1166
+ gather_axis = [gather_axis]
1167
+
1168
+ is_consistent = _bfs_visit(start_op, start_prims, total_ops, end_ops, gather_axis)
1169
+ if not is_consistent:
1170
+ return False
1171
+ return True
1172
+
1173
+ if reduce_fusion:
1174
+ appected_areas = {"ReduceSum", "CSRReduceSum"}
1175
+ else:
1176
+ appected_areas = {"TensorScatterAdd", "UnsortedSegmentSum"}
1177
+
1178
+ for a, _ in dom.out_relations.items():
1179
+ if _shape_consistent(gather_prims, appected_areas, dom, a) and dom.check_acyclic(a):
1180
+ return [a], False
1181
+ return [], False
1182
+
1183
+ def _broadcast_tot(dom):
1184
+ """Fuse rule for TensorScatterAdd and UnsortedSegmentSum."""
1185
+
1186
+ def _same_input(op1, op2):
1187
+ return bool(set(op1.inputs) & set(op2.inputs))
1188
+
1189
+ if len(dom.ops) != 1:
1190
+ return [], False
1191
+
1192
+ # Only fuse the first input for `TensorScatterAdd`` and the first and second input for `UnsortedSegmentSum`.
1193
+ fuse_arg = {"TensorScatterAdd": slice(1, None), "UnsortedSegmentSum": slice(0, 2)}
1194
+ arg_idx = fuse_arg.get(dom.dom_op().prim, -1)
1195
+ if arg_idx == -1:
1196
+ return [], False
1197
+ fuse_tensor = dom.dom_op().inputs[arg_idx]
1198
+
1199
+ for a, _ in dom.in_relations.items():
1200
+ if not a.check_acyclic(dom):
1201
+ continue
1202
+ # Rule 1: Same type with at lease one same input.
1203
+ if a.dom_op().prim == dom.dom_op().prim and _same_input(dom.dom_op(), a.dom_op()):
1204
+ return [a], True
1205
+ # Rule 2: Fuse op(reshape/elementwise/broadcast) in specified position inputs.
1206
+ if a.pattern <= PrimLib.BROADCAST and any((op.output in fuse_tensor for op in a.ops)):
1207
+ return [a], True
1208
+ return [], False
1209
+
1210
+ def _broadcast_onehot(dom, fwd=True):
1211
+ """Fuse rule for OneHot."""
1212
+ if dom.dom_op().prim != "OneHot":
1213
+ return [], False
1214
+
1215
+ fused = []
1216
+ neighbours = dom.in_relations.items() if fwd else dom.out_relations.items()
1217
+ for a, _ in neighbours:
1218
+ if a.pattern <= PrimLib.BROADCAST:
1219
+ if fwd:
1220
+ if a.check_acyclic(dom) and len(a.out_relations) == 1 and not a.is_output:
1221
+ fused.append(a)
1222
+ else:
1223
+ if dom.check_acyclic(a):
1224
+ fused.append(a)
1225
+
1226
+ return fused, fwd
1227
+
1228
+ def _elemwise_elemany(dom):
1229
+ """Fuse rule for elemany."""
1230
+ if dom.dom_op().prim != "ElemAny":
1231
+ return [], False
1232
+
1233
+ fused = []
1234
+ for a, r in dom.in_relations.items():
1235
+ if a.pattern < PrimLib.BROADCAST and r <= PrimLib.ELEMWISE and a.check_acyclic(dom):
1236
+ fused.append(a)
1237
+
1238
+ return fused, True
1239
+
1240
+ def _injective_output(dom):
1241
+ """Fuse rule for injective """
1242
+ injective_ops = {"Transpose", "StridedSlice"}
1243
+ if dom.dom_op().prim not in injective_ops:
1244
+ return [], False
1245
+ to_ops = dom.dom_op().output.to_ops
1246
+ if dom.is_output or len(to_ops) != 1 or len(dom.out_relations) != 1:
1247
+ return [], False
1248
+ to_area = list(dom.out_relations.keys())[0]
1249
+ if (to_area.pattern >= PrimLib.REDUCE and to_area.dom_op().prim not in injective_ops) or \
1250
+ to_ops[0] not in to_area.ops:
1251
+ return [], False
1252
+ if len(to_area.ops) > self.TRANSPOSE_FUSE_DEPTH:
1253
+ return [], False
1254
+ return [to_area], False
1255
+
1256
+ def _h_broadcast(dom, a):
1257
+ if dom.pattern > PrimLib.BROADCAST:
1258
+ return [], False
1259
+ return a.pattern <= PrimLib.BROADCAST and dom.ops[0].output.shape == a.ops[0].output.shape
1260
+
1261
+ def _h_reduce(dom, a):
1262
+ if dom.pattern != PrimLib.REDUCE or dom.stitch_info.stitch_ops:
1263
+ return []
1264
+ dom_op = dom.ops[0]
1265
+ if not PrimLib.is_reduce(dom_op) or _is_atomic_add_available(dom):
1266
+ return []
1267
+ op = a.ops[0]
1268
+ return a.pattern == PrimLib.REDUCE and not a.stitch_info.stitch_ops and \
1269
+ PrimLib.is_reduce(op) and dom_op.inputs[0].shape == op.inputs[0].shape and \
1270
+ dom_op.attrs.get("reduce_axis") == op.attrs.get("reduce_axis")
1271
+
1272
+ def _h_opaque(dom, a):
1273
+ if dom.ops[0].prim not in {"StridedSlice"}:
1274
+ return []
1275
+ return a.ops[0].prim == dom.ops[0].prim and dom.ops[0].output.shape == a.ops[0].output.shape and \
1276
+ dom.ops[0].inputs[0].shape == a.ops[0].inputs[0].shape
1277
+
1278
+ def _link_csr(dom):
1279
+ def _same_input(op1, op2):
1280
+ return bool(set(op1.inputs.copy()) & set(op2.inputs.copy()))
1281
+
1282
+ fuse_arg = {"CSRReduceSum": slice(1, 3), "CSRGather": slice(2, 3)}
1283
+ arg_idx = fuse_arg.get(dom.dom_op().prim, -1)
1284
+ if arg_idx == -1:
1285
+ return [], False
1286
+ fuse_tensor = dom.dom_op().inputs[arg_idx]
1287
+ for a, _ in dom.in_relations.items():
1288
+ if (a.dom_op().prim == "CSRGather" and a.dom_op().prim == dom.dom_op().prim and
1289
+ _same_input(dom.dom_op(), a.dom_op())):
1290
+ return [a], True
1291
+ if a.pattern <= PrimLib.BROADCAST and dom.check_acyclic(a) and \
1292
+ any([op.output in fuse_tensor for op in a.ops]):
1293
+ return [a], True
1294
+ return [], False
1295
+
1296
+ def _fuse_loop():
1297
+ self.fuse(CommonPattern.reshape)
1298
+ self.fuse(CommonPattern.assign)
1299
+ self.fuse(CommonPattern.elemwise_depth)
1300
+ self.fuse(CommonPattern.elemwise_width)
1301
+ self.fuse(_broadcast_tot)
1302
+ self.fuse(_link_csr)
1303
+ self.fuse(CommonPattern.broadcast_depth)
1304
+ self.fuse(CommonPattern.broadcast_width)
1305
+ self.fuse(_reduce_depth)
1306
+ self.fuse(_reduce_width)
1307
+ self.fuse(_broadcast_bwd_depth)
1308
+ self.fuse(_broadcast_bwd_width)
1309
+ self.fuse(_strided_slice)
1310
+ self.fuse(partial(_broadcast_onehot, fwd=True))
1311
+ self.fuse(partial(_broadcast_onehot, fwd=False))
1312
+ self.fuse(partial(_gather_output, reduce_fusion=False))
1313
+ self.fuse(partial(_gather_output, reduce_fusion=True))
1314
+ self.fuse(_reduce_output)
1315
+ if self.enable_stitch_fusion:
1316
+ self.fuse(_reduce_stitch, True)
1317
+ self.fuse(_transpose)
1318
+ self.fuse(_injective_output)
1319
+ self.fuse(CommonPattern.isolate_reshape)
1320
+ if self.enable_horizontal_fusion:
1321
+ self.hfuse(_h_broadcast)
1322
+ self.hfuse(_h_reduce)
1323
+ self.hfuse(_h_opaque)
1324
+ self.fuse(_elemwise_elemany)
1325
+
1326
+ def _fuse_once(fuse_func):
1327
+ if fuse_func(CommonPattern.reshape) or \
1328
+ fuse_func(CommonPattern.elemwise_depth) or fuse_func(CommonPattern.elemwise_width) or \
1329
+ fuse_func(CommonPattern.broadcast_depth) or fuse_func(CommonPattern.broadcast_width) or \
1330
+ fuse_func(_reduce_depth) or fuse_func(_reduce_width) or \
1331
+ fuse_func(_broadcast_bwd_depth) or fuse_func(_broadcast_bwd_width):
1332
+ return
1333
+ if fuse_func(_reduce_output):
1334
+ return
1335
+ fuse_func(_transpose)
1336
+
1337
+ if fuse_func is None:
1338
+ _fuse_loop()
1339
+ else:
1340
+ _fuse_once(fuse_func)
1341
+
1342
+
1343
+ def split(graph, target, flags):
1344
+ """Split graph"""
1345
+ result = None
1346
+ if target == "cuda":
1347
+ result = GraphSplitGpu(graph, flags).split()
1348
+ return result