mindspore 2.4.0__cp311-cp311-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (1387) hide show
  1. mindspore/.commit_id +1 -0
  2. mindspore/__init__.py +53 -0
  3. mindspore/_c_dataengine.cpython-311-darwin.so +0 -0
  4. mindspore/_c_expression.cpython-311-darwin.so +0 -0
  5. mindspore/_c_mindrecord.cpython-311-darwin.so +0 -0
  6. mindspore/_check_jit_forbidden_api.py +106 -0
  7. mindspore/_checkparam.py +1419 -0
  8. mindspore/_extends/__init__.py +23 -0
  9. mindspore/_extends/builtin_operations.py +224 -0
  10. mindspore/_extends/graph_kernel/__init__.py +17 -0
  11. mindspore/_extends/graph_kernel/model/__init__.py +19 -0
  12. mindspore/_extends/graph_kernel/model/graph_parallel.py +311 -0
  13. mindspore/_extends/graph_kernel/model/graph_split.py +1348 -0
  14. mindspore/_extends/graph_kernel/model/model.py +553 -0
  15. mindspore/_extends/graph_kernel/model/model_builder.py +216 -0
  16. mindspore/_extends/graph_kernel/parallel_estimate.py +60 -0
  17. mindspore/_extends/graph_kernel/splitter.py +140 -0
  18. mindspore/_extends/graph_kernel/utils.py +28 -0
  19. mindspore/_extends/parallel_compile/__init__.py +19 -0
  20. mindspore/_extends/parallel_compile/akg_compiler/__init__.py +19 -0
  21. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +269 -0
  22. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +529 -0
  23. mindspore/_extends/parallel_compile/akg_compiler/compiler.py +56 -0
  24. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +96 -0
  25. mindspore/_extends/parallel_compile/akg_compiler/get_file_path.py +36 -0
  26. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +556 -0
  27. mindspore/_extends/parallel_compile/akg_compiler/util.py +159 -0
  28. mindspore/_extends/parse/__init__.py +49 -0
  29. mindspore/_extends/parse/compile_config.py +299 -0
  30. mindspore/_extends/parse/namespace.py +136 -0
  31. mindspore/_extends/parse/parser.py +1448 -0
  32. mindspore/_extends/parse/resources.py +213 -0
  33. mindspore/_extends/parse/standard_method.py +4475 -0
  34. mindspore/_extends/parse/trope.py +97 -0
  35. mindspore/_extends/pijit/__init__.py +23 -0
  36. mindspore/_extends/pijit/pijit_func_white_list.py +669 -0
  37. mindspore/_extends/remote/__init__.py +19 -0
  38. mindspore/_extends/remote/kernel_build_server.py +199 -0
  39. mindspore/_extends/remote/kernel_build_server_akg.py +55 -0
  40. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  41. mindspore/_extends/remote/kernel_build_server_ascend.py +75 -0
  42. mindspore/_extends/utils.py +68 -0
  43. mindspore/_install_custom.py +43 -0
  44. mindspore/_profiler.py +30 -0
  45. mindspore/amp.py +433 -0
  46. mindspore/boost/__init__.py +42 -0
  47. mindspore/boost/adasum.py +319 -0
  48. mindspore/boost/base.py +535 -0
  49. mindspore/boost/boost.py +400 -0
  50. mindspore/boost/boost_cell_wrapper.py +790 -0
  51. mindspore/boost/dim_reduce.py +323 -0
  52. mindspore/boost/grad_accumulation.py +79 -0
  53. mindspore/boost/grad_freeze.py +382 -0
  54. mindspore/boost/group_loss_scale_manager.py +166 -0
  55. mindspore/boost/less_batch_normalization.py +174 -0
  56. mindspore/common/__init__.py +86 -0
  57. mindspore/common/_auto_dynamic.py +68 -0
  58. mindspore/common/_decorator.py +50 -0
  59. mindspore/common/_jit_fallback_utils.py +110 -0
  60. mindspore/common/_monad.py +25 -0
  61. mindspore/common/_pijit_context.py +190 -0
  62. mindspore/common/_register_for_adapter.py +74 -0
  63. mindspore/common/_register_for_recompute.py +48 -0
  64. mindspore/common/_register_for_tensor.py +46 -0
  65. mindspore/common/_stub_tensor.py +210 -0
  66. mindspore/common/_tensor_overload.py +139 -0
  67. mindspore/common/_utils.py +122 -0
  68. mindspore/common/api.py +2064 -0
  69. mindspore/common/auto_dynamic_shape.py +507 -0
  70. mindspore/common/dtype.py +422 -0
  71. mindspore/common/dump.py +130 -0
  72. mindspore/common/file_system.py +48 -0
  73. mindspore/common/generator.py +254 -0
  74. mindspore/common/hook_handle.py +143 -0
  75. mindspore/common/initializer.py +880 -0
  76. mindspore/common/jit_config.py +98 -0
  77. mindspore/common/lazy_inline.py +240 -0
  78. mindspore/common/mindir_util.py +111 -0
  79. mindspore/common/mutable.py +234 -0
  80. mindspore/common/no_inline.py +54 -0
  81. mindspore/common/np_dtype.py +25 -0
  82. mindspore/common/parameter.py +1081 -0
  83. mindspore/common/recompute.py +292 -0
  84. mindspore/common/seed.py +260 -0
  85. mindspore/common/sparse_tensor.py +1175 -0
  86. mindspore/common/symbol.py +122 -0
  87. mindspore/common/tensor.py +5039 -0
  88. mindspore/communication/__init__.py +37 -0
  89. mindspore/communication/_comm_helper.py +501 -0
  90. mindspore/communication/_hccl_management.py +297 -0
  91. mindspore/communication/comm_func.py +1395 -0
  92. mindspore/communication/management.py +673 -0
  93. mindspore/config/op_info.config +533 -0
  94. mindspore/context.py +2077 -0
  95. mindspore/dataset/__init__.py +90 -0
  96. mindspore/dataset/audio/__init__.py +61 -0
  97. mindspore/dataset/audio/transforms.py +3690 -0
  98. mindspore/dataset/audio/utils.py +386 -0
  99. mindspore/dataset/audio/validators.py +1172 -0
  100. mindspore/dataset/callback/__init__.py +20 -0
  101. mindspore/dataset/callback/ds_callback.py +368 -0
  102. mindspore/dataset/callback/validators.py +32 -0
  103. mindspore/dataset/core/__init__.py +13 -0
  104. mindspore/dataset/core/config.py +1095 -0
  105. mindspore/dataset/core/datatypes.py +101 -0
  106. mindspore/dataset/core/py_util_helpers.py +65 -0
  107. mindspore/dataset/core/validator_helpers.py +781 -0
  108. mindspore/dataset/debug/__init__.py +21 -0
  109. mindspore/dataset/debug/debug_hook.py +97 -0
  110. mindspore/dataset/debug/pre_defined_hook.py +67 -0
  111. mindspore/dataset/engine/__init__.py +124 -0
  112. mindspore/dataset/engine/cache_admin.py +47 -0
  113. mindspore/dataset/engine/cache_client.py +129 -0
  114. mindspore/dataset/engine/datasets.py +4582 -0
  115. mindspore/dataset/engine/datasets_audio.py +911 -0
  116. mindspore/dataset/engine/datasets_standard_format.py +543 -0
  117. mindspore/dataset/engine/datasets_text.py +2161 -0
  118. mindspore/dataset/engine/datasets_user_defined.py +1184 -0
  119. mindspore/dataset/engine/datasets_vision.py +4816 -0
  120. mindspore/dataset/engine/iterators.py +371 -0
  121. mindspore/dataset/engine/obs/__init__.py +23 -0
  122. mindspore/dataset/engine/obs/config_loader.py +68 -0
  123. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +508 -0
  124. mindspore/dataset/engine/obs/util.py +482 -0
  125. mindspore/dataset/engine/offload.py +596 -0
  126. mindspore/dataset/engine/queue.py +304 -0
  127. mindspore/dataset/engine/samplers.py +895 -0
  128. mindspore/dataset/engine/serializer_deserializer.py +159 -0
  129. mindspore/dataset/engine/validators.py +2895 -0
  130. mindspore/dataset/text/__init__.py +51 -0
  131. mindspore/dataset/text/transforms.py +1703 -0
  132. mindspore/dataset/text/utils.py +715 -0
  133. mindspore/dataset/text/validators.py +642 -0
  134. mindspore/dataset/transforms/__init__.py +45 -0
  135. mindspore/dataset/transforms/c_transforms.py +638 -0
  136. mindspore/dataset/transforms/py_transforms.py +393 -0
  137. mindspore/dataset/transforms/py_transforms_util.py +255 -0
  138. mindspore/dataset/transforms/transforms.py +1260 -0
  139. mindspore/dataset/transforms/validators.py +410 -0
  140. mindspore/dataset/utils/__init__.py +19 -0
  141. mindspore/dataset/utils/browse_dataset.py +190 -0
  142. mindspore/dataset/utils/line_reader.py +126 -0
  143. mindspore/dataset/vision/__init__.py +65 -0
  144. mindspore/dataset/vision/c_transforms.py +2641 -0
  145. mindspore/dataset/vision/py_transforms.py +2120 -0
  146. mindspore/dataset/vision/py_transforms_util.py +1660 -0
  147. mindspore/dataset/vision/transforms.py +7295 -0
  148. mindspore/dataset/vision/utils.py +863 -0
  149. mindspore/dataset/vision/validators.py +1483 -0
  150. mindspore/default_config.py +2 -0
  151. mindspore/experimental/__init__.py +20 -0
  152. mindspore/experimental/es/__init__.py +22 -0
  153. mindspore/experimental/es/embedding_service.py +883 -0
  154. mindspore/experimental/es/embedding_service_layer.py +581 -0
  155. mindspore/experimental/llm_boost/__init__.py +21 -0
  156. mindspore/experimental/llm_boost/atb/__init__.py +23 -0
  157. mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
  158. mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
  159. mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
  160. mindspore/experimental/llm_boost/register.py +129 -0
  161. mindspore/experimental/llm_boost/utils.py +31 -0
  162. mindspore/experimental/map_parameter.py +309 -0
  163. mindspore/experimental/optim/__init__.py +40 -0
  164. mindspore/experimental/optim/adadelta.py +161 -0
  165. mindspore/experimental/optim/adagrad.py +168 -0
  166. mindspore/experimental/optim/adam.py +193 -0
  167. mindspore/experimental/optim/adamax.py +170 -0
  168. mindspore/experimental/optim/adamw.py +290 -0
  169. mindspore/experimental/optim/asgd.py +153 -0
  170. mindspore/experimental/optim/lr_scheduler.py +1371 -0
  171. mindspore/experimental/optim/nadam.py +157 -0
  172. mindspore/experimental/optim/optimizer.py +262 -0
  173. mindspore/experimental/optim/radam.py +194 -0
  174. mindspore/experimental/optim/rmsprop.py +154 -0
  175. mindspore/experimental/optim/rprop.py +164 -0
  176. mindspore/experimental/optim/sgd.py +156 -0
  177. mindspore/hal/__init__.py +40 -0
  178. mindspore/hal/_ascend.py +57 -0
  179. mindspore/hal/_base.py +57 -0
  180. mindspore/hal/_cpu.py +56 -0
  181. mindspore/hal/_gpu.py +57 -0
  182. mindspore/hal/contiguous_tensors_handle.py +175 -0
  183. mindspore/hal/device.py +356 -0
  184. mindspore/hal/event.py +179 -0
  185. mindspore/hal/memory.py +326 -0
  186. mindspore/hal/stream.py +357 -0
  187. mindspore/include/OWNERS +7 -0
  188. mindspore/include/api/allocator.h +97 -0
  189. mindspore/include/api/callback/callback.h +93 -0
  190. mindspore/include/api/callback/ckpt_saver.h +41 -0
  191. mindspore/include/api/callback/loss_monitor.h +33 -0
  192. mindspore/include/api/callback/lr_scheduler.h +51 -0
  193. mindspore/include/api/callback/time_monitor.h +34 -0
  194. mindspore/include/api/callback/train_accuracy.h +37 -0
  195. mindspore/include/api/cell.h +90 -0
  196. mindspore/include/api/cfg.h +82 -0
  197. mindspore/include/api/context.h +602 -0
  198. mindspore/include/api/data_type.h +47 -0
  199. mindspore/include/api/delegate.h +178 -0
  200. mindspore/include/api/delegate_api.h +75 -0
  201. mindspore/include/api/dual_abi_helper.h +208 -0
  202. mindspore/include/api/format.h +28 -0
  203. mindspore/include/api/graph.h +46 -0
  204. mindspore/include/api/kernel.h +58 -0
  205. mindspore/include/api/kernel_api.h +168 -0
  206. mindspore/include/api/metrics/accuracy.h +36 -0
  207. mindspore/include/api/metrics/metrics.h +41 -0
  208. mindspore/include/api/model.h +438 -0
  209. mindspore/include/api/model_group.h +91 -0
  210. mindspore/include/api/model_parallel_runner.h +168 -0
  211. mindspore/include/api/serialization.h +185 -0
  212. mindspore/include/api/status.h +192 -0
  213. mindspore/include/api/types.h +431 -0
  214. mindspore/include/api/visible.h +41 -0
  215. mindspore/include/c_api/context_c.h +179 -0
  216. mindspore/include/c_api/data_type_c.h +52 -0
  217. mindspore/include/c_api/format_c.h +46 -0
  218. mindspore/include/c_api/model_c.h +347 -0
  219. mindspore/include/c_api/status_c.h +79 -0
  220. mindspore/include/c_api/tensor_c.h +146 -0
  221. mindspore/include/c_api/types_c.h +67 -0
  222. mindspore/include/dataset/config.h +163 -0
  223. mindspore/include/dataset/constants.h +363 -0
  224. mindspore/include/dataset/execute.h +196 -0
  225. mindspore/include/dataset/text.h +1092 -0
  226. mindspore/include/dataset/transforms.h +638 -0
  227. mindspore/include/dataset/vision.h +2129 -0
  228. mindspore/include/dataset/vision_ascend.h +206 -0
  229. mindspore/include/dataset/vision_lite.h +625 -0
  230. mindspore/lib/libavcodec.59.dylib +0 -0
  231. mindspore/lib/libavdevice.59.dylib +0 -0
  232. mindspore/lib/libavfilter.8.dylib +0 -0
  233. mindspore/lib/libavformat.59.dylib +0 -0
  234. mindspore/lib/libavutil.57.dylib +0 -0
  235. mindspore/lib/libdnnl.2.dylib +0 -0
  236. mindspore/lib/libicudata.69.dylib +0 -0
  237. mindspore/lib/libicui18n.69.dylib +0 -0
  238. mindspore/lib/libicuuc.69.dylib +0 -0
  239. mindspore/lib/libmindspore_address_sorting.15.dylib +0 -0
  240. mindspore/lib/libmindspore_backend.dylib +0 -0
  241. mindspore/lib/libmindspore_common.dylib +0 -0
  242. mindspore/lib/libmindspore_core.dylib +0 -0
  243. mindspore/lib/libmindspore_glog.0.dylib +0 -0
  244. mindspore/lib/libmindspore_gpr.15.dylib +0 -0
  245. mindspore/lib/libmindspore_grpc++.1.dylib +0 -0
  246. mindspore/lib/libmindspore_grpc.15.dylib +0 -0
  247. mindspore/lib/libmindspore_np_dtype.dylib +0 -0
  248. mindspore/lib/libmindspore_ops.dylib +0 -0
  249. mindspore/lib/libmindspore_upb.15.dylib +0 -0
  250. mindspore/lib/libnnacl.dylib +0 -0
  251. mindspore/lib/libopencv_core.4.5.dylib +0 -0
  252. mindspore/lib/libopencv_imgcodecs.4.5.dylib +0 -0
  253. mindspore/lib/libopencv_imgproc.4.5.dylib +0 -0
  254. mindspore/lib/libps_cache.dylib +0 -0
  255. mindspore/lib/libswresample.4.dylib +0 -0
  256. mindspore/lib/libswscale.6.dylib +0 -0
  257. mindspore/lib/libtinyxml2.8.dylib +0 -0
  258. mindspore/log.py +633 -0
  259. mindspore/mindrecord/__init__.py +43 -0
  260. mindspore/mindrecord/common/__init__.py +17 -0
  261. mindspore/mindrecord/common/constant.py +20 -0
  262. mindspore/mindrecord/common/enums.py +44 -0
  263. mindspore/mindrecord/common/exceptions.py +311 -0
  264. mindspore/mindrecord/config.py +809 -0
  265. mindspore/mindrecord/filereader.py +174 -0
  266. mindspore/mindrecord/filewriter.py +722 -0
  267. mindspore/mindrecord/mindpage.py +210 -0
  268. mindspore/mindrecord/shardheader.py +141 -0
  269. mindspore/mindrecord/shardindexgenerator.py +74 -0
  270. mindspore/mindrecord/shardreader.py +117 -0
  271. mindspore/mindrecord/shardsegment.py +128 -0
  272. mindspore/mindrecord/shardutils.py +185 -0
  273. mindspore/mindrecord/shardwriter.py +237 -0
  274. mindspore/mindrecord/tools/__init__.py +17 -0
  275. mindspore/mindrecord/tools/cifar10.py +140 -0
  276. mindspore/mindrecord/tools/cifar100.py +153 -0
  277. mindspore/mindrecord/tools/cifar100_to_mr.py +185 -0
  278. mindspore/mindrecord/tools/cifar10_to_mr.py +177 -0
  279. mindspore/mindrecord/tools/csv_to_mr.py +200 -0
  280. mindspore/mindrecord/tools/imagenet_to_mr.py +206 -0
  281. mindspore/mindrecord/tools/mnist_to_mr.py +259 -0
  282. mindspore/mindrecord/tools/tfrecord_to_mr.py +360 -0
  283. mindspore/mint/__init__.py +1586 -0
  284. mindspore/mint/distributed/__init__.py +31 -0
  285. mindspore/mint/distributed/distributed.py +254 -0
  286. mindspore/mint/linalg/__init__.py +22 -0
  287. mindspore/mint/nn/__init__.py +757 -0
  288. mindspore/mint/nn/functional.py +679 -0
  289. mindspore/mint/nn/layer/__init__.py +39 -0
  290. mindspore/mint/nn/layer/activation.py +133 -0
  291. mindspore/mint/nn/layer/normalization.py +477 -0
  292. mindspore/mint/nn/layer/pooling.py +110 -0
  293. mindspore/mint/optim/__init__.py +24 -0
  294. mindspore/mint/optim/adamw.py +206 -0
  295. mindspore/mint/special/__init__.py +63 -0
  296. mindspore/multiprocessing/__init__.py +73 -0
  297. mindspore/nn/__init__.py +47 -0
  298. mindspore/nn/cell.py +2787 -0
  299. mindspore/nn/dynamic_lr.py +482 -0
  300. mindspore/nn/grad/__init__.py +21 -0
  301. mindspore/nn/grad/cell_grad.py +196 -0
  302. mindspore/nn/layer/__init__.py +63 -0
  303. mindspore/nn/layer/activation.py +1822 -0
  304. mindspore/nn/layer/basic.py +1629 -0
  305. mindspore/nn/layer/channel_shuffle.py +90 -0
  306. mindspore/nn/layer/combined.py +248 -0
  307. mindspore/nn/layer/container.py +734 -0
  308. mindspore/nn/layer/conv.py +1505 -0
  309. mindspore/nn/layer/dense.py +204 -0
  310. mindspore/nn/layer/embedding.py +869 -0
  311. mindspore/nn/layer/image.py +661 -0
  312. mindspore/nn/layer/math.py +1069 -0
  313. mindspore/nn/layer/normalization.py +1273 -0
  314. mindspore/nn/layer/padding.py +880 -0
  315. mindspore/nn/layer/pooling.py +2302 -0
  316. mindspore/nn/layer/rnn_cells.py +388 -0
  317. mindspore/nn/layer/rnns.py +849 -0
  318. mindspore/nn/layer/thor_layer.py +963 -0
  319. mindspore/nn/layer/timedistributed.py +155 -0
  320. mindspore/nn/layer/transformer.py +823 -0
  321. mindspore/nn/learning_rate_schedule.py +512 -0
  322. mindspore/nn/loss/__init__.py +36 -0
  323. mindspore/nn/loss/loss.py +2924 -0
  324. mindspore/nn/metrics.py +53 -0
  325. mindspore/nn/optim/__init__.py +45 -0
  326. mindspore/nn/optim/_dist_optimizer_registry.py +111 -0
  327. mindspore/nn/optim/ada_grad.py +217 -0
  328. mindspore/nn/optim/adadelta.py +206 -0
  329. mindspore/nn/optim/adafactor.py +448 -0
  330. mindspore/nn/optim/adam.py +1297 -0
  331. mindspore/nn/optim/adamax.py +220 -0
  332. mindspore/nn/optim/adasum.py +548 -0
  333. mindspore/nn/optim/asgd.py +216 -0
  334. mindspore/nn/optim/ftrl.py +401 -0
  335. mindspore/nn/optim/lamb.py +296 -0
  336. mindspore/nn/optim/lars.py +202 -0
  337. mindspore/nn/optim/lazyadam.py +533 -0
  338. mindspore/nn/optim/momentum.py +239 -0
  339. mindspore/nn/optim/optimizer.py +1034 -0
  340. mindspore/nn/optim/proximal_ada_grad.py +242 -0
  341. mindspore/nn/optim/rmsprop.py +264 -0
  342. mindspore/nn/optim/rprop.py +251 -0
  343. mindspore/nn/optim/sgd.py +237 -0
  344. mindspore/nn/optim/tft_wrapper.py +127 -0
  345. mindspore/nn/optim/thor.py +1310 -0
  346. mindspore/nn/probability/__init__.py +22 -0
  347. mindspore/nn/probability/bijector/__init__.py +35 -0
  348. mindspore/nn/probability/bijector/bijector.py +337 -0
  349. mindspore/nn/probability/bijector/exp.py +65 -0
  350. mindspore/nn/probability/bijector/gumbel_cdf.py +144 -0
  351. mindspore/nn/probability/bijector/invert.py +126 -0
  352. mindspore/nn/probability/bijector/power_transform.py +196 -0
  353. mindspore/nn/probability/bijector/scalar_affine.py +167 -0
  354. mindspore/nn/probability/bijector/softplus.py +189 -0
  355. mindspore/nn/probability/bnn_layers/__init__.py +29 -0
  356. mindspore/nn/probability/bnn_layers/_util.py +46 -0
  357. mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +112 -0
  358. mindspore/nn/probability/bnn_layers/conv_variational.py +267 -0
  359. mindspore/nn/probability/bnn_layers/dense_variational.py +302 -0
  360. mindspore/nn/probability/bnn_layers/layer_distribution.py +123 -0
  361. mindspore/nn/probability/distribution/__init__.py +56 -0
  362. mindspore/nn/probability/distribution/_utils/__init__.py +34 -0
  363. mindspore/nn/probability/distribution/_utils/custom_ops.py +96 -0
  364. mindspore/nn/probability/distribution/_utils/utils.py +362 -0
  365. mindspore/nn/probability/distribution/bernoulli.py +334 -0
  366. mindspore/nn/probability/distribution/beta.py +391 -0
  367. mindspore/nn/probability/distribution/categorical.py +435 -0
  368. mindspore/nn/probability/distribution/cauchy.py +383 -0
  369. mindspore/nn/probability/distribution/distribution.py +827 -0
  370. mindspore/nn/probability/distribution/exponential.py +350 -0
  371. mindspore/nn/probability/distribution/gamma.py +391 -0
  372. mindspore/nn/probability/distribution/geometric.py +335 -0
  373. mindspore/nn/probability/distribution/gumbel.py +257 -0
  374. mindspore/nn/probability/distribution/half_normal.py +133 -0
  375. mindspore/nn/probability/distribution/laplace.py +128 -0
  376. mindspore/nn/probability/distribution/log_normal.py +272 -0
  377. mindspore/nn/probability/distribution/logistic.py +379 -0
  378. mindspore/nn/probability/distribution/normal.py +336 -0
  379. mindspore/nn/probability/distribution/poisson.py +288 -0
  380. mindspore/nn/probability/distribution/student_t.py +149 -0
  381. mindspore/nn/probability/distribution/transformed_distribution.py +235 -0
  382. mindspore/nn/probability/distribution/uniform.py +375 -0
  383. mindspore/nn/reinforcement/__init__.py +24 -0
  384. mindspore/nn/reinforcement/_batch_read_write.py +142 -0
  385. mindspore/nn/reinforcement/_tensors_queue.py +152 -0
  386. mindspore/nn/reinforcement/tensor_array.py +145 -0
  387. mindspore/nn/sparse/__init__.py +23 -0
  388. mindspore/nn/sparse/sparse.py +147 -0
  389. mindspore/nn/wrap/__init__.py +49 -0
  390. mindspore/nn/wrap/cell_wrapper.py +968 -0
  391. mindspore/nn/wrap/grad_reducer.py +608 -0
  392. mindspore/nn/wrap/loss_scale.py +694 -0
  393. mindspore/numpy/__init__.py +121 -0
  394. mindspore/numpy/array_creations.py +2731 -0
  395. mindspore/numpy/array_ops.py +2629 -0
  396. mindspore/numpy/dtypes.py +185 -0
  397. mindspore/numpy/fft.py +966 -0
  398. mindspore/numpy/logic_ops.py +936 -0
  399. mindspore/numpy/math_ops.py +5911 -0
  400. mindspore/numpy/utils.py +214 -0
  401. mindspore/numpy/utils_const.py +565 -0
  402. mindspore/ops/__init__.py +56 -0
  403. mindspore/ops/_constants.py +30 -0
  404. mindspore/ops/_grad_experimental/__init__.py +31 -0
  405. mindspore/ops/_grad_experimental/grad_array_ops.py +830 -0
  406. mindspore/ops/_grad_experimental/grad_base.py +143 -0
  407. mindspore/ops/_grad_experimental/grad_comm_ops.py +714 -0
  408. mindspore/ops/_grad_experimental/grad_debug_ops.py +31 -0
  409. mindspore/ops/_grad_experimental/grad_implementations.py +203 -0
  410. mindspore/ops/_grad_experimental/grad_inner_ops.py +79 -0
  411. mindspore/ops/_grad_experimental/grad_math_ops.py +802 -0
  412. mindspore/ops/_grad_experimental/grad_nn_ops.py +231 -0
  413. mindspore/ops/_grad_experimental/grad_quant_ops.py +238 -0
  414. mindspore/ops/_grad_experimental/grad_sparse.py +342 -0
  415. mindspore/ops/_grad_experimental/grad_sparse_ops.py +399 -0
  416. mindspore/ops/_grad_experimental/taylor_rule.py +220 -0
  417. mindspore/ops/_op_impl/__init__.py +23 -0
  418. mindspore/ops/_op_impl/_custom_op/__init__.py +39 -0
  419. mindspore/ops/_op_impl/_custom_op/_basic.py +158 -0
  420. mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +279 -0
  421. mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +156 -0
  422. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +109 -0
  423. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +125 -0
  424. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +105 -0
  425. mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +124 -0
  426. mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +116 -0
  427. mindspore/ops/_op_impl/_custom_op/correction_mul.py +89 -0
  428. mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +196 -0
  429. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +366 -0
  430. mindspore/ops/_op_impl/_custom_op/dsd_impl.py +162 -0
  431. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +136 -0
  432. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +206 -0
  433. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +88 -0
  434. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +128 -0
  435. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +199 -0
  436. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +88 -0
  437. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +156 -0
  438. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +184 -0
  439. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +143 -0
  440. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +169 -0
  441. mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +548 -0
  442. mindspore/ops/_op_impl/_custom_op/img2col_impl.py +881 -0
  443. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +278 -0
  444. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +200 -0
  445. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +334 -0
  446. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +255 -0
  447. mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +222 -0
  448. mindspore/ops/_op_impl/_custom_op/matmul_dds_grad_impl.py +644 -0
  449. mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +488 -0
  450. mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +87 -0
  451. mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +129 -0
  452. mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +121 -0
  453. mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +352 -0
  454. mindspore/ops/_op_impl/aicpu/__init__.py +441 -0
  455. mindspore/ops/_op_impl/aicpu/abs.py +36 -0
  456. mindspore/ops/_op_impl/aicpu/acos.py +32 -0
  457. mindspore/ops/_op_impl/aicpu/acos_grad.py +33 -0
  458. mindspore/ops/_op_impl/aicpu/acosh.py +34 -0
  459. mindspore/ops/_op_impl/aicpu/acosh_grad.py +35 -0
  460. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d.py +34 -0
  461. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
  462. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d.py +39 -0
  463. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d_grad.py +39 -0
  464. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
  465. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d_grad.py +37 -0
  466. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d.py +42 -0
  467. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d_grad.py +152 -0
  468. mindspore/ops/_op_impl/aicpu/add.py +43 -0
  469. mindspore/ops/_op_impl/aicpu/add_n.py +41 -0
  470. mindspore/ops/_op_impl/aicpu/add_v2.py +40 -0
  471. mindspore/ops/_op_impl/aicpu/addcdiv.py +41 -0
  472. mindspore/ops/_op_impl/aicpu/addcmul.py +47 -0
  473. mindspore/ops/_op_impl/aicpu/adjust_contrastv2.py +32 -0
  474. mindspore/ops/_op_impl/aicpu/adjust_hue.py +31 -0
  475. mindspore/ops/_op_impl/aicpu/adjust_saturation.py +32 -0
  476. mindspore/ops/_op_impl/aicpu/affine_grid.py +33 -0
  477. mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
  478. mindspore/ops/_op_impl/aicpu/angle.py +31 -0
  479. mindspore/ops/_op_impl/aicpu/arg_max.py +75 -0
  480. mindspore/ops/_op_impl/aicpu/arg_min.py +75 -0
  481. mindspore/ops/_op_impl/aicpu/argmax_with_value.py +43 -0
  482. mindspore/ops/_op_impl/aicpu/argmin_with_value.py +43 -0
  483. mindspore/ops/_op_impl/aicpu/asin.py +32 -0
  484. mindspore/ops/_op_impl/aicpu/asin_grad.py +33 -0
  485. mindspore/ops/_op_impl/aicpu/asinh.py +34 -0
  486. mindspore/ops/_op_impl/aicpu/asinh_grad.py +35 -0
  487. mindspore/ops/_op_impl/aicpu/atanh.py +34 -0
  488. mindspore/ops/_op_impl/aicpu/avgpool_grad_v1.py +37 -0
  489. mindspore/ops/_op_impl/aicpu/avgpool_v1.py +36 -0
  490. mindspore/ops/_op_impl/aicpu/bartlett_window.py +36 -0
  491. mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -0
  492. mindspore/ops/_op_impl/aicpu/batch_norm_grad_grad.py +49 -0
  493. mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
  494. mindspore/ops/_op_impl/aicpu/bessel_i0.py +31 -0
  495. mindspore/ops/_op_impl/aicpu/betainc.py +31 -0
  496. mindspore/ops/_op_impl/aicpu/bias_add.py +44 -0
  497. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +42 -0
  498. mindspore/ops/_op_impl/aicpu/bincount.py +33 -0
  499. mindspore/ops/_op_impl/aicpu/blackman_window.py +36 -0
  500. mindspore/ops/_op_impl/aicpu/broadcast_to.py +58 -0
  501. mindspore/ops/_op_impl/aicpu/bucketize.py +34 -0
  502. mindspore/ops/_op_impl/aicpu/cache_swap_table.py +102 -0
  503. mindspore/ops/_op_impl/aicpu/cast.py +225 -0
  504. mindspore/ops/_op_impl/aicpu/cauchy.py +33 -0
  505. mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
  506. mindspore/ops/_op_impl/aicpu/check_numerics.py +33 -0
  507. mindspore/ops/_op_impl/aicpu/cholesky.py +32 -0
  508. mindspore/ops/_op_impl/aicpu/cholesky_inverse.py +31 -0
  509. mindspore/ops/_op_impl/aicpu/cholesky_solve.py +33 -0
  510. mindspore/ops/_op_impl/aicpu/choleskygrad.py +32 -0
  511. mindspore/ops/_op_impl/aicpu/coalesce.py +37 -0
  512. mindspore/ops/_op_impl/aicpu/col2im.py +38 -0
  513. mindspore/ops/_op_impl/aicpu/combined_non_max_suppression.py +42 -0
  514. mindspore/ops/_op_impl/aicpu/compare_and_bitpack.py +37 -0
  515. mindspore/ops/_op_impl/aicpu/complex.py +32 -0
  516. mindspore/ops/_op_impl/aicpu/complex_abs.py +31 -0
  517. mindspore/ops/_op_impl/aicpu/compute_accidental_hits.py +44 -0
  518. mindspore/ops/_op_impl/aicpu/concat.py +57 -0
  519. mindspore/ops/_op_impl/aicpu/concat_offset.py +42 -0
  520. mindspore/ops/_op_impl/aicpu/concat_offset_v1.py +31 -0
  521. mindspore/ops/_op_impl/aicpu/conj.py +42 -0
  522. mindspore/ops/_op_impl/aicpu/conjugate_transpose.py +58 -0
  523. mindspore/ops/_op_impl/aicpu/cos.py +34 -0
  524. mindspore/ops/_op_impl/aicpu/cosh.py +34 -0
  525. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  526. mindspore/ops/_op_impl/aicpu/crop_and_resize.py +69 -0
  527. mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_boxes.py +68 -0
  528. mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_image.py +38 -0
  529. mindspore/ops/_op_impl/aicpu/cross.py +42 -0
  530. mindspore/ops/_op_impl/aicpu/csr_sparse_matrix_to_dense.py +48 -0
  531. mindspore/ops/_op_impl/aicpu/csr_sparse_matrix_to_sparse_tensor.py +51 -0
  532. mindspore/ops/_op_impl/aicpu/ctc_greedy_decoder.py +35 -0
  533. mindspore/ops/_op_impl/aicpu/ctc_loss_v2.py +43 -0
  534. mindspore/ops/_op_impl/aicpu/ctc_loss_v2_grad.py +45 -0
  535. mindspore/ops/_op_impl/aicpu/ctcloss.py +38 -0
  536. mindspore/ops/_op_impl/aicpu/cummax.py +41 -0
  537. mindspore/ops/_op_impl/aicpu/cumprod.py +58 -0
  538. mindspore/ops/_op_impl/aicpu/cumsum.py +58 -0
  539. mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +36 -0
  540. mindspore/ops/_op_impl/aicpu/data_format_vec_permute.py +32 -0
  541. mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
  542. mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +43 -0
  543. mindspore/ops/_op_impl/aicpu/dense_to_csr_sparse_matrix.py +49 -0
  544. mindspore/ops/_op_impl/aicpu/dense_to_dense_set_operation.py +45 -0
  545. mindspore/ops/_op_impl/aicpu/dense_to_sparse_set_operation.py +48 -0
  546. mindspore/ops/_op_impl/aicpu/depth_to_space.py +44 -0
  547. mindspore/ops/_op_impl/aicpu/diag.py +36 -0
  548. mindspore/ops/_op_impl/aicpu/diag_part.py +36 -0
  549. mindspore/ops/_op_impl/aicpu/diagonal.py +35 -0
  550. mindspore/ops/_op_impl/aicpu/digamma.py +31 -0
  551. mindspore/ops/_op_impl/aicpu/div.py +41 -0
  552. mindspore/ops/_op_impl/aicpu/div_no_nan.py +35 -0
  553. mindspore/ops/_op_impl/aicpu/dropout2d.py +42 -0
  554. mindspore/ops/_op_impl/aicpu/dropout3d.py +42 -0
  555. mindspore/ops/_op_impl/aicpu/dropout_genmask.py +41 -0
  556. mindspore/ops/_op_impl/aicpu/dropout_genmask_v3.py +32 -0
  557. mindspore/ops/_op_impl/aicpu/dynamic_stitch.py +42 -0
  558. mindspore/ops/_op_impl/aicpu/edit_distance.py +56 -0
  559. mindspore/ops/_op_impl/aicpu/eig.py +35 -0
  560. mindspore/ops/_op_impl/aicpu/embedding_lookup.py +102 -0
  561. mindspore/ops/_op_impl/aicpu/end_of_sequence.py +30 -0
  562. mindspore/ops/_op_impl/aicpu/environ_create.py +28 -0
  563. mindspore/ops/_op_impl/aicpu/environ_destroy_all.py +28 -0
  564. mindspore/ops/_op_impl/aicpu/environ_get.py +41 -0
  565. mindspore/ops/_op_impl/aicpu/environ_set.py +40 -0
  566. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  567. mindspore/ops/_op_impl/aicpu/equal.py +41 -0
  568. mindspore/ops/_op_impl/aicpu/exp.py +37 -0
  569. mindspore/ops/_op_impl/aicpu/expand.py +45 -0
  570. mindspore/ops/_op_impl/aicpu/expand_dims.py +42 -0
  571. mindspore/ops/_op_impl/aicpu/expm1.py +34 -0
  572. mindspore/ops/_op_impl/aicpu/extract_glimpse.py +35 -0
  573. mindspore/ops/_op_impl/aicpu/eye.py +44 -0
  574. mindspore/ops/_op_impl/aicpu/fft_with_size.py +47 -0
  575. mindspore/ops/_op_impl/aicpu/fill_diagonal.py +39 -0
  576. mindspore/ops/_op_impl/aicpu/fill_v2.py +58 -0
  577. mindspore/ops/_op_impl/aicpu/flatten.py +43 -0
  578. mindspore/ops/_op_impl/aicpu/floor_div.py +38 -0
  579. mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
  580. mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
  581. mindspore/ops/_op_impl/aicpu/fractional_avg_pool.py +41 -0
  582. mindspore/ops/_op_impl/aicpu/fractional_avg_pool_grad.py +41 -0
  583. mindspore/ops/_op_impl/aicpu/fractional_max_pool.py +41 -0
  584. mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_grad_with_fixed_ksize.py +43 -0
  585. mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +65 -0
  586. mindspore/ops/_op_impl/aicpu/fractional_max_pool_grad.py +42 -0
  587. mindspore/ops/_op_impl/aicpu/fractional_max_pool_grad_with_fixed_ksize.py +42 -0
  588. mindspore/ops/_op_impl/aicpu/fractional_max_pool_with_fixed_ksize.py +49 -0
  589. mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
  590. mindspore/ops/_op_impl/aicpu/fused_sparse_adam.py +46 -0
  591. mindspore/ops/_op_impl/aicpu/fused_sparse_ftrl.py +41 -0
  592. mindspore/ops/_op_impl/aicpu/fused_sparse_lazy_adam.py +46 -0
  593. mindspore/ops/_op_impl/aicpu/fused_sparse_proximal_adagrad.py +39 -0
  594. mindspore/ops/_op_impl/aicpu/gamma.py +38 -0
  595. mindspore/ops/_op_impl/aicpu/gather.py +46 -0
  596. mindspore/ops/_op_impl/aicpu/gather_d.py +79 -0
  597. mindspore/ops/_op_impl/aicpu/gather_d_grad_v2.py +79 -0
  598. mindspore/ops/_op_impl/aicpu/gather_grad.py +54 -0
  599. mindspore/ops/_op_impl/aicpu/gather_nd.py +56 -0
  600. mindspore/ops/_op_impl/aicpu/gcd.py +32 -0
  601. mindspore/ops/_op_impl/aicpu/generate_eod_mask.py +38 -0
  602. mindspore/ops/_op_impl/aicpu/geqrf.py +32 -0
  603. mindspore/ops/_op_impl/aicpu/get_next.py +39 -0
  604. mindspore/ops/_op_impl/aicpu/glu.py +33 -0
  605. mindspore/ops/_op_impl/aicpu/glu_grad.py +34 -0
  606. mindspore/ops/_op_impl/aicpu/greater.py +41 -0
  607. mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
  608. mindspore/ops/_op_impl/aicpu/grid_sampler_2d.py +35 -0
  609. mindspore/ops/_op_impl/aicpu/grid_sampler_2d_grad.py +38 -0
  610. mindspore/ops/_op_impl/aicpu/grid_sampler_3d.py +34 -0
  611. mindspore/ops/_op_impl/aicpu/grid_sampler_3d_grad.py +38 -0
  612. mindspore/ops/_op_impl/aicpu/hamming_window.py +57 -0
  613. mindspore/ops/_op_impl/aicpu/hard_sigmoid.py +32 -0
  614. mindspore/ops/_op_impl/aicpu/hard_sigmoid_grad.py +33 -0
  615. mindspore/ops/_op_impl/aicpu/heaviside.py +40 -0
  616. mindspore/ops/_op_impl/aicpu/histogram.py +35 -0
  617. mindspore/ops/_op_impl/aicpu/hsv_to_rgb.py +32 -0
  618. mindspore/ops/_op_impl/aicpu/hypot.py +32 -0
  619. mindspore/ops/_op_impl/aicpu/identity.py +42 -0
  620. mindspore/ops/_op_impl/aicpu/identity_n.py +41 -0
  621. mindspore/ops/_op_impl/aicpu/igamma.py +30 -0
  622. mindspore/ops/_op_impl/aicpu/igammac.py +30 -0
  623. mindspore/ops/_op_impl/aicpu/igammagrada.py +30 -0
  624. mindspore/ops/_op_impl/aicpu/im2col.py +43 -0
  625. mindspore/ops/_op_impl/aicpu/imag.py +31 -0
  626. mindspore/ops/_op_impl/aicpu/index_fill.py +54 -0
  627. mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
  628. mindspore/ops/_op_impl/aicpu/init_data_set_queue.py +27 -0
  629. mindspore/ops/_op_impl/aicpu/inplace_index_add.py +39 -0
  630. mindspore/ops/_op_impl/aicpu/instance_norm_v2.py +41 -0
  631. mindspore/ops/_op_impl/aicpu/instance_norm_v2_grad.py +44 -0
  632. mindspore/ops/_op_impl/aicpu/is_finite.py +40 -0
  633. mindspore/ops/_op_impl/aicpu/is_inf.py +31 -0
  634. mindspore/ops/_op_impl/aicpu/is_nan.py +31 -0
  635. mindspore/ops/_op_impl/aicpu/kldivloss.py +34 -0
  636. mindspore/ops/_op_impl/aicpu/kldivlossgrad.py +35 -0
  637. mindspore/ops/_op_impl/aicpu/layer_norm_grad_grad.py +47 -0
  638. mindspore/ops/_op_impl/aicpu/lcm.py +32 -0
  639. mindspore/ops/_op_impl/aicpu/left_shift.py +38 -0
  640. mindspore/ops/_op_impl/aicpu/less.py +41 -0
  641. mindspore/ops/_op_impl/aicpu/less_equal.py +41 -0
  642. mindspore/ops/_op_impl/aicpu/lgamma.py +33 -0
  643. mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +57 -0
  644. mindspore/ops/_op_impl/aicpu/linspace.py +33 -0
  645. mindspore/ops/_op_impl/aicpu/list_diff.py +50 -0
  646. mindspore/ops/_op_impl/aicpu/log.py +37 -0
  647. mindspore/ops/_op_impl/aicpu/log1p.py +34 -0
  648. mindspore/ops/_op_impl/aicpu/log_matrix_determinant.py +31 -0
  649. mindspore/ops/_op_impl/aicpu/log_normal_reverse.py +33 -0
  650. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +37 -0
  651. mindspore/ops/_op_impl/aicpu/logical_xor.py +30 -0
  652. mindspore/ops/_op_impl/aicpu/logit.py +33 -0
  653. mindspore/ops/_op_impl/aicpu/logit_grad.py +34 -0
  654. mindspore/ops/_op_impl/aicpu/logspace.py +36 -0
  655. mindspore/ops/_op_impl/aicpu/lower_bound.py +47 -0
  656. mindspore/ops/_op_impl/aicpu/lstsq.py +34 -0
  657. mindspore/ops/_op_impl/aicpu/lu.py +39 -0
  658. mindspore/ops/_op_impl/aicpu/lu_solve.py +32 -0
  659. mindspore/ops/_op_impl/aicpu/lu_unpack.py +114 -0
  660. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +49 -0
  661. mindspore/ops/_op_impl/aicpu/masked_fill.py +42 -0
  662. mindspore/ops/_op_impl/aicpu/masked_scatter.py +40 -0
  663. mindspore/ops/_op_impl/aicpu/masked_select.py +31 -0
  664. mindspore/ops/_op_impl/aicpu/masked_select_grad.py +35 -0
  665. mindspore/ops/_op_impl/aicpu/matmul.py +39 -0
  666. mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
  667. mindspore/ops/_op_impl/aicpu/matrix_determinant.py +30 -0
  668. mindspore/ops/_op_impl/aicpu/matrix_diag_part_v3.py +54 -0
  669. mindspore/ops/_op_impl/aicpu/matrix_diag_v3.py +56 -0
  670. mindspore/ops/_op_impl/aicpu/matrix_exp.py +34 -0
  671. mindspore/ops/_op_impl/aicpu/matrix_inverse.py +31 -0
  672. mindspore/ops/_op_impl/aicpu/matrix_logarithm.py +31 -0
  673. mindspore/ops/_op_impl/aicpu/matrix_power.py +37 -0
  674. mindspore/ops/_op_impl/aicpu/matrix_set_diag_v3.py +54 -0
  675. mindspore/ops/_op_impl/aicpu/matrix_solve.py +35 -0
  676. mindspore/ops/_op_impl/aicpu/matrix_solve_ls.py +36 -0
  677. mindspore/ops/_op_impl/aicpu/matrix_triangular_solve.py +36 -0
  678. mindspore/ops/_op_impl/aicpu/max_pool3d_grad_with_argmax.py +60 -0
  679. mindspore/ops/_op_impl/aicpu/max_pool3d_with_argmax.py +59 -0
  680. mindspore/ops/_op_impl/aicpu/max_unpool2d.py +57 -0
  681. mindspore/ops/_op_impl/aicpu/max_unpool2d_grad.py +58 -0
  682. mindspore/ops/_op_impl/aicpu/max_unpool3d.py +57 -0
  683. mindspore/ops/_op_impl/aicpu/max_unpool3d_grad.py +58 -0
  684. mindspore/ops/_op_impl/aicpu/maximum_grad_grad.py +40 -0
  685. mindspore/ops/_op_impl/aicpu/maxpool_grad_v1.py +46 -0
  686. mindspore/ops/_op_impl/aicpu/maxpool_v1.py +42 -0
  687. mindspore/ops/_op_impl/aicpu/median.py +39 -0
  688. mindspore/ops/_op_impl/aicpu/median_grad.py +45 -0
  689. mindspore/ops/_op_impl/aicpu/meshgrid.py +41 -0
  690. mindspore/ops/_op_impl/aicpu/minimum_grad_grad.py +40 -0
  691. mindspore/ops/_op_impl/aicpu/mirror_pad.py +50 -0
  692. mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +48 -0
  693. mindspore/ops/_op_impl/aicpu/mul.py +43 -0
  694. mindspore/ops/_op_impl/aicpu/mul_no_nan.py +42 -0
  695. mindspore/ops/_op_impl/aicpu/multi_margin_loss.py +37 -0
  696. mindspore/ops/_op_impl/aicpu/multi_margin_loss_grad.py +41 -0
  697. mindspore/ops/_op_impl/aicpu/multilabel_margin_loss_grad.py +37 -0
  698. mindspore/ops/_op_impl/aicpu/multinomial.py +47 -0
  699. mindspore/ops/_op_impl/aicpu/multinomial_with_replacement.py +35 -0
  700. mindspore/ops/_op_impl/aicpu/mvlgamma.py +32 -0
  701. mindspore/ops/_op_impl/aicpu/mvlgamma_grad.py +33 -0
  702. mindspore/ops/_op_impl/aicpu/nan_to_num.py +34 -0
  703. mindspore/ops/_op_impl/aicpu/neg.py +36 -0
  704. mindspore/ops/_op_impl/aicpu/nextafter.py +32 -0
  705. mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
  706. mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
  707. mindspore/ops/_op_impl/aicpu/no_repeat_ngram.py +34 -0
  708. mindspore/ops/_op_impl/aicpu/non_deterministic_ints.py +33 -0
  709. mindspore/ops/_op_impl/aicpu/non_max_suppression.py +36 -0
  710. mindspore/ops/_op_impl/aicpu/non_max_suppression_with_overlaps.py +35 -0
  711. mindspore/ops/_op_impl/aicpu/non_zero.py +43 -0
  712. mindspore/ops/_op_impl/aicpu/not_equal.py +39 -0
  713. mindspore/ops/_op_impl/aicpu/nth_element.py +39 -0
  714. mindspore/ops/_op_impl/aicpu/nuclear_norm.py +33 -0
  715. mindspore/ops/_op_impl/aicpu/one_hot.py +116 -0
  716. mindspore/ops/_op_impl/aicpu/ones_like.py +39 -0
  717. mindspore/ops/_op_impl/aicpu/orgqr.py +34 -0
  718. mindspore/ops/_op_impl/aicpu/pad_and_shift.py +33 -0
  719. mindspore/ops/_op_impl/aicpu/pad_v3.py +61 -0
  720. mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +59 -0
  721. mindspore/ops/_op_impl/aicpu/padding.py +41 -0
  722. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +54 -0
  723. mindspore/ops/_op_impl/aicpu/pdist_grad.py +33 -0
  724. mindspore/ops/_op_impl/aicpu/poisson.py +37 -0
  725. mindspore/ops/_op_impl/aicpu/polar.py +32 -0
  726. mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
  727. mindspore/ops/_op_impl/aicpu/pow.py +39 -0
  728. mindspore/ops/_op_impl/aicpu/print_tensor.py +39 -0
  729. mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +113 -0
  730. mindspore/ops/_op_impl/aicpu/qr.py +36 -0
  731. mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
  732. mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
  733. mindspore/ops/_op_impl/aicpu/ragged_range.py +49 -0
  734. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
  735. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_tensor.py +74 -0
  736. mindspore/ops/_op_impl/aicpu/random_categorical.py +68 -0
  737. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +36 -0
  738. mindspore/ops/_op_impl/aicpu/random_gamma.py +38 -0
  739. mindspore/ops/_op_impl/aicpu/random_poisson.py +134 -0
  740. mindspore/ops/_op_impl/aicpu/random_shuffle.py +47 -0
  741. mindspore/ops/_op_impl/aicpu/randperm.py +38 -0
  742. mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
  743. mindspore/ops/_op_impl/aicpu/range.py +36 -0
  744. mindspore/ops/_op_impl/aicpu/range_v2.py +35 -0
  745. mindspore/ops/_op_impl/aicpu/real.py +31 -0
  746. mindspore/ops/_op_impl/aicpu/real_div.py +40 -0
  747. mindspore/ops/_op_impl/aicpu/reciprocal.py +34 -0
  748. mindspore/ops/_op_impl/aicpu/reciprocal_grad.py +35 -0
  749. mindspore/ops/_op_impl/aicpu/reduce_mean.py +57 -0
  750. mindspore/ops/_op_impl/aicpu/reduce_prod.py +57 -0
  751. mindspore/ops/_op_impl/aicpu/reduce_sum.py +57 -0
  752. mindspore/ops/_op_impl/aicpu/relu_grad_v3.py +41 -0
  753. mindspore/ops/_op_impl/aicpu/relu_v3.py +38 -0
  754. mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +96 -0
  755. mindspore/ops/_op_impl/aicpu/reshape.py +42 -0
  756. mindspore/ops/_op_impl/aicpu/resize_area.py +40 -0
  757. mindspore/ops/_op_impl/aicpu/resize_bicubic.py +20 -0
  758. mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +19 -0
  759. mindspore/ops/_op_impl/aicpu/resize_bilinear.py +32 -0
  760. mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +32 -0
  761. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +36 -0
  762. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +35 -0
  763. mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
  764. mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
  765. mindspore/ops/_op_impl/aicpu/reverse_sequence.py +55 -0
  766. mindspore/ops/_op_impl/aicpu/reversev2.py +54 -0
  767. mindspore/ops/_op_impl/aicpu/rgb_to_hsv.py +32 -0
  768. mindspore/ops/_op_impl/aicpu/right_shift.py +38 -0
  769. mindspore/ops/_op_impl/aicpu/rnnt_loss.py +35 -0
  770. mindspore/ops/_op_impl/aicpu/round.py +34 -0
  771. mindspore/ops/_op_impl/aicpu/rsqrt.py +33 -0
  772. mindspore/ops/_op_impl/aicpu/rsqrt_grad.py +36 -0
  773. mindspore/ops/_op_impl/aicpu/sample_distorted_bounding_box_v2.py +49 -0
  774. mindspore/ops/_op_impl/aicpu/scale_and_translate.py +52 -0
  775. mindspore/ops/_op_impl/aicpu/scale_and_translate_grad.py +36 -0
  776. mindspore/ops/_op_impl/aicpu/scatter.py +79 -0
  777. mindspore/ops/_op_impl/aicpu/scatter_add_with_axis.py +53 -0
  778. mindspore/ops/_op_impl/aicpu/scatter_elements.py +39 -0
  779. mindspore/ops/_op_impl/aicpu/scatter_nd.py +59 -0
  780. mindspore/ops/_op_impl/aicpu/scatter_nd_max.py +54 -0
  781. mindspore/ops/_op_impl/aicpu/scatter_nd_min.py +54 -0
  782. mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +59 -0
  783. mindspore/ops/_op_impl/aicpu/search_sorted.py +44 -0
  784. mindspore/ops/_op_impl/aicpu/segment_max.py +52 -0
  785. mindspore/ops/_op_impl/aicpu/segment_mean.py +56 -0
  786. mindspore/ops/_op_impl/aicpu/segment_min.py +52 -0
  787. mindspore/ops/_op_impl/aicpu/segment_prod.py +56 -0
  788. mindspore/ops/_op_impl/aicpu/segment_sum.py +56 -0
  789. mindspore/ops/_op_impl/aicpu/select.py +45 -0
  790. mindspore/ops/_op_impl/aicpu/self_adjoint_eig.py +34 -0
  791. mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
  792. mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
  793. mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
  794. mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
  795. mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
  796. mindspore/ops/_op_impl/aicpu/set_size.py +38 -0
  797. mindspore/ops/_op_impl/aicpu/sign.py +36 -0
  798. mindspore/ops/_op_impl/aicpu/sin.py +34 -0
  799. mindspore/ops/_op_impl/aicpu/sinc.py +43 -0
  800. mindspore/ops/_op_impl/aicpu/sinh.py +34 -0
  801. mindspore/ops/_op_impl/aicpu/slice.py +59 -0
  802. mindspore/ops/_op_impl/aicpu/slice_grad.py +76 -0
  803. mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
  804. mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
  805. mindspore/ops/_op_impl/aicpu/sort.py +39 -0
  806. mindspore/ops/_op_impl/aicpu/space_to_depth.py +44 -0
  807. mindspore/ops/_op_impl/aicpu/sparse_addmm.py +87 -0
  808. mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +80 -0
  809. mindspore/ops/_op_impl/aicpu/sparse_apply_centered_rms_prop.py +105 -0
  810. mindspore/ops/_op_impl/aicpu/sparse_apply_momentum.py +80 -0
  811. mindspore/ops/_op_impl/aicpu/sparse_apply_proximal_gradient_descent.py +79 -0
  812. mindspore/ops/_op_impl/aicpu/sparse_concat.py +59 -0
  813. mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
  814. mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_add.py +58 -0
  815. mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_div.py +58 -0
  816. mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_mul.py +58 -0
  817. mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows.py +63 -0
  818. mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows_grad.py +45 -0
  819. mindspore/ops/_op_impl/aicpu/sparse_matrix_mat_mul.py +56 -0
  820. mindspore/ops/_op_impl/aicpu/sparse_matrix_nnz.py +81 -0
  821. mindspore/ops/_op_impl/aicpu/sparse_matrix_transpose.py +116 -0
  822. mindspore/ops/_op_impl/aicpu/sparse_reorder.py +56 -0
  823. mindspore/ops/_op_impl/aicpu/sparse_reshape.py +34 -0
  824. mindspore/ops/_op_impl/aicpu/sparse_segment_mean_grad.py +36 -0
  825. mindspore/ops/_op_impl/aicpu/sparse_segment_mean_with_num_segments.py +44 -0
  826. mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n.py +43 -0
  827. mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n_grad.py +38 -0
  828. mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n_with_num_segments.py +44 -0
  829. mindspore/ops/_op_impl/aicpu/sparse_segment_sum.py +49 -0
  830. mindspore/ops/_op_impl/aicpu/sparse_segment_sum_with_num_segments.py +68 -0
  831. mindspore/ops/_op_impl/aicpu/sparse_slice.py +63 -0
  832. mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +61 -0
  833. mindspore/ops/_op_impl/aicpu/sparse_softmax.py +33 -0
  834. mindspore/ops/_op_impl/aicpu/sparse_softmax_cross_entropy_with_logits_v2.py +35 -0
  835. mindspore/ops/_op_impl/aicpu/sparse_sparse_maximum.py +53 -0
  836. mindspore/ops/_op_impl/aicpu/sparse_sparse_minimum.py +53 -0
  837. mindspore/ops/_op_impl/aicpu/sparse_tensor_dense_add.py +84 -0
  838. mindspore/ops/_op_impl/aicpu/sparse_tensor_dense_mat_mul.py +190 -0
  839. mindspore/ops/_op_impl/aicpu/sparse_tensor_to_csr_sparse_matrix.py +51 -0
  840. mindspore/ops/_op_impl/aicpu/sparse_to_dense_v2.py +73 -0
  841. mindspore/ops/_op_impl/aicpu/split.py +45 -0
  842. mindspore/ops/_op_impl/aicpu/sqrt.py +34 -0
  843. mindspore/ops/_op_impl/aicpu/sqrt_grad.py +35 -0
  844. mindspore/ops/_op_impl/aicpu/square.py +35 -0
  845. mindspore/ops/_op_impl/aicpu/squared_difference.py +37 -0
  846. mindspore/ops/_op_impl/aicpu/squeeze.py +42 -0
  847. mindspore/ops/_op_impl/aicpu/sspaddmm.py +97 -0
  848. mindspore/ops/_op_impl/aicpu/stack.py +45 -0
  849. mindspore/ops/_op_impl/aicpu/stack_push_pop.py +87 -0
  850. mindspore/ops/_op_impl/aicpu/standard_laplace.py +34 -0
  851. mindspore/ops/_op_impl/aicpu/standard_normal.py +34 -0
  852. mindspore/ops/_op_impl/aicpu/stateless_dropout_genmask.py +37 -0
  853. mindspore/ops/_op_impl/aicpu/stft.py +70 -0
  854. mindspore/ops/_op_impl/aicpu/strided_slice.py +43 -0
  855. mindspore/ops/_op_impl/aicpu/strided_slice_grad.py +50 -0
  856. mindspore/ops/_op_impl/aicpu/sub.py +41 -0
  857. mindspore/ops/_op_impl/aicpu/sub_and_filter.py +36 -0
  858. mindspore/ops/_op_impl/aicpu/tan.py +34 -0
  859. mindspore/ops/_op_impl/aicpu/tanh.py +34 -0
  860. mindspore/ops/_op_impl/aicpu/tanh_grad.py +35 -0
  861. mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
  862. mindspore/ops/_op_impl/aicpu/tile.py +56 -0
  863. mindspore/ops/_op_impl/aicpu/topk.py +34 -0
  864. mindspore/ops/_op_impl/aicpu/trace.py +40 -0
  865. mindspore/ops/_op_impl/aicpu/tracegrad.py +41 -0
  866. mindspore/ops/_op_impl/aicpu/trans_data.py +35 -0
  867. mindspore/ops/_op_impl/aicpu/transpose.py +58 -0
  868. mindspore/ops/_op_impl/aicpu/tridiagonal_matmul.py +42 -0
  869. mindspore/ops/_op_impl/aicpu/tridiagonal_solve.py +35 -0
  870. mindspore/ops/_op_impl/aicpu/tril.py +42 -0
  871. mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
  872. mindspore/ops/_op_impl/aicpu/triplet_margin_loss.py +62 -0
  873. mindspore/ops/_op_impl/aicpu/triu.py +43 -0
  874. mindspore/ops/_op_impl/aicpu/triu_indices.py +34 -0
  875. mindspore/ops/_op_impl/aicpu/truncated_normal.py +39 -0
  876. mindspore/ops/_op_impl/aicpu/uniform.py +36 -0
  877. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +41 -0
  878. mindspore/ops/_op_impl/aicpu/uniform_int.py +36 -0
  879. mindspore/ops/_op_impl/aicpu/uniform_real.py +33 -0
  880. mindspore/ops/_op_impl/aicpu/unique.py +31 -0
  881. mindspore/ops/_op_impl/aicpu/unique_consecutive.py +47 -0
  882. mindspore/ops/_op_impl/aicpu/unique_with_pad.py +32 -0
  883. mindspore/ops/_op_impl/aicpu/unravel_index.py +32 -0
  884. mindspore/ops/_op_impl/aicpu/unsorted_segment_prod.py +53 -0
  885. mindspore/ops/_op_impl/aicpu/unsorted_segment_sum.py +57 -0
  886. mindspore/ops/_op_impl/aicpu/unstack.py +45 -0
  887. mindspore/ops/_op_impl/aicpu/update_cache.py +44 -0
  888. mindspore/ops/_op_impl/aicpu/upper_bound.py +47 -0
  889. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +42 -0
  890. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +49 -0
  891. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +40 -0
  892. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +50 -0
  893. mindspore/ops/_op_impl/aicpu/xdivy.py +35 -0
  894. mindspore/ops/_op_impl/aicpu/xlogy.py +33 -0
  895. mindspore/ops/_op_impl/aicpu/zeros_like.py +42 -0
  896. mindspore/ops/_op_impl/aicpu/zeta.py +31 -0
  897. mindspore/ops/_op_impl/akg/__init__.py +19 -0
  898. mindspore/ops/_op_impl/akg/ascend/__init__.py +48 -0
  899. mindspore/ops/_op_impl/akg/ascend/abs.py +35 -0
  900. mindspore/ops/_op_impl/akg/ascend/add.py +42 -0
  901. mindspore/ops/_op_impl/akg/ascend/add_n.py +37 -0
  902. mindspore/ops/_op_impl/akg/ascend/batchmatmul.py +33 -0
  903. mindspore/ops/_op_impl/akg/ascend/cast.py +46 -0
  904. mindspore/ops/_op_impl/akg/ascend/equal.py +35 -0
  905. mindspore/ops/_op_impl/akg/ascend/exp.py +35 -0
  906. mindspore/ops/_op_impl/akg/ascend/expand_dims.py +33 -0
  907. mindspore/ops/_op_impl/akg/ascend/greater.py +34 -0
  908. mindspore/ops/_op_impl/akg/ascend/greater_equal.py +35 -0
  909. mindspore/ops/_op_impl/akg/ascend/less.py +31 -0
  910. mindspore/ops/_op_impl/akg/ascend/less_equal.py +35 -0
  911. mindspore/ops/_op_impl/akg/ascend/load_im2col.py +33 -0
  912. mindspore/ops/_op_impl/akg/ascend/log.py +34 -0
  913. mindspore/ops/_op_impl/akg/ascend/maximum.py +36 -0
  914. mindspore/ops/_op_impl/akg/ascend/minimum.py +39 -0
  915. mindspore/ops/_op_impl/akg/ascend/mul.py +41 -0
  916. mindspore/ops/_op_impl/akg/ascend/neg.py +37 -0
  917. mindspore/ops/_op_impl/akg/ascend/pow.py +35 -0
  918. mindspore/ops/_op_impl/akg/ascend/prod_force_se_a.py +33 -0
  919. mindspore/ops/_op_impl/akg/ascend/real_div.py +36 -0
  920. mindspore/ops/_op_impl/akg/ascend/reciprocal.py +32 -0
  921. mindspore/ops/_op_impl/akg/ascend/reduce_max.py +32 -0
  922. mindspore/ops/_op_impl/akg/ascend/reduce_min.py +32 -0
  923. mindspore/ops/_op_impl/akg/ascend/reduce_sum.py +37 -0
  924. mindspore/ops/_op_impl/akg/ascend/rsqrt.py +35 -0
  925. mindspore/ops/_op_impl/akg/ascend/select.py +37 -0
  926. mindspore/ops/_op_impl/akg/ascend/sqrt.py +35 -0
  927. mindspore/ops/_op_impl/akg/ascend/square.py +35 -0
  928. mindspore/ops/_op_impl/akg/ascend/sub.py +42 -0
  929. mindspore/ops/_op_impl/akg/cpu/__init__.py +23 -0
  930. mindspore/ops/_op_impl/akg/cpu/coo2csr.py +29 -0
  931. mindspore/ops/_op_impl/akg/cpu/csr2coo.py +29 -0
  932. mindspore/ops/_op_impl/akg/cpu/csr_gather.py +33 -0
  933. mindspore/ops/_op_impl/akg/cpu/csr_mm.py +34 -0
  934. mindspore/ops/_op_impl/akg/cpu/csr_mul.py +33 -0
  935. mindspore/ops/_op_impl/akg/cpu/csr_mv.py +33 -0
  936. mindspore/ops/_op_impl/akg/cpu/csr_reduce_sum.py +31 -0
  937. mindspore/ops/_op_impl/akg/gpu/__init__.py +24 -0
  938. mindspore/ops/_op_impl/akg/gpu/coo2csr.py +29 -0
  939. mindspore/ops/_op_impl/akg/gpu/csr2coo.py +29 -0
  940. mindspore/ops/_op_impl/akg/gpu/csr_div.py +36 -0
  941. mindspore/ops/_op_impl/akg/gpu/csr_gather.py +33 -0
  942. mindspore/ops/_op_impl/akg/gpu/csr_mm.py +37 -0
  943. mindspore/ops/_op_impl/akg/gpu/csr_mul.py +36 -0
  944. mindspore/ops/_op_impl/akg/gpu/csr_mv.py +36 -0
  945. mindspore/ops/_op_impl/akg/gpu/csr_reduce_sum.py +33 -0
  946. mindspore/ops/_op_impl/cpu/__init__.py +78 -0
  947. mindspore/ops/_op_impl/cpu/adam.py +49 -0
  948. mindspore/ops/_op_impl/cpu/adam_weight_decay.py +47 -0
  949. mindspore/ops/_op_impl/cpu/arg_max.py +30 -0
  950. mindspore/ops/_op_impl/cpu/arg_max_with_value.py +31 -0
  951. mindspore/ops/_op_impl/cpu/arg_min_with_value.py +31 -0
  952. mindspore/ops/_op_impl/cpu/buffer_append.py +28 -0
  953. mindspore/ops/_op_impl/cpu/buffer_get.py +28 -0
  954. mindspore/ops/_op_impl/cpu/buffer_sample.py +28 -0
  955. mindspore/ops/_op_impl/cpu/cast.py +171 -0
  956. mindspore/ops/_op_impl/cpu/concat_offset.py +38 -0
  957. mindspore/ops/_op_impl/cpu/conv2d.py +30 -0
  958. mindspore/ops/_op_impl/cpu/conv3d.py +30 -0
  959. mindspore/ops/_op_impl/cpu/div.py +32 -0
  960. mindspore/ops/_op_impl/cpu/dropout.py +31 -0
  961. mindspore/ops/_op_impl/cpu/dropout_grad.py +30 -0
  962. mindspore/ops/_op_impl/cpu/dynamic_shape.py +42 -0
  963. mindspore/ops/_op_impl/cpu/dynamic_stitch.py +41 -0
  964. mindspore/ops/_op_impl/cpu/equal_count.py +30 -0
  965. mindspore/ops/_op_impl/cpu/gather_d.py +49 -0
  966. mindspore/ops/_op_impl/cpu/gather_d_grad.py +38 -0
  967. mindspore/ops/_op_impl/cpu/gather_d_grad_v2.py +40 -0
  968. mindspore/ops/_op_impl/cpu/gather_v2.py +40 -0
  969. mindspore/ops/_op_impl/cpu/hsigmoid.py +33 -0
  970. mindspore/ops/_op_impl/cpu/hsigmoid_grad.py +34 -0
  971. mindspore/ops/_op_impl/cpu/hswish.py +32 -0
  972. mindspore/ops/_op_impl/cpu/hswish_grad.py +33 -0
  973. mindspore/ops/_op_impl/cpu/identity_n.py +40 -0
  974. mindspore/ops/_op_impl/cpu/is_finite.py +39 -0
  975. mindspore/ops/_op_impl/cpu/l2loss.py +30 -0
  976. mindspore/ops/_op_impl/cpu/layer_norm.py +36 -0
  977. mindspore/ops/_op_impl/cpu/layer_norm_grad.py +38 -0
  978. mindspore/ops/_op_impl/cpu/maximum.py +35 -0
  979. mindspore/ops/_op_impl/cpu/maximum_grad.py +47 -0
  980. mindspore/ops/_op_impl/cpu/minimum.py +40 -0
  981. mindspore/ops/_op_impl/cpu/minimum_grad.py +51 -0
  982. mindspore/ops/_op_impl/cpu/mirror_pad.py +36 -0
  983. mindspore/ops/_op_impl/cpu/mirror_pad_grad.py +36 -0
  984. mindspore/ops/_op_impl/cpu/mul.py +32 -0
  985. mindspore/ops/_op_impl/cpu/one_hot.py +31 -0
  986. mindspore/ops/_op_impl/cpu/pad.py +32 -0
  987. mindspore/ops/_op_impl/cpu/pow.py +32 -0
  988. mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +42 -0
  989. mindspore/ops/_op_impl/cpu/pyexecute.py +29 -0
  990. mindspore/ops/_op_impl/cpu/pyfunc.py +29 -0
  991. mindspore/ops/_op_impl/cpu/range.py +34 -0
  992. mindspore/ops/_op_impl/cpu/real_div.py +33 -0
  993. mindspore/ops/_op_impl/cpu/reduce_all.py +29 -0
  994. mindspore/ops/_op_impl/cpu/reduce_any.py +29 -0
  995. mindspore/ops/_op_impl/cpu/reduce_max.py +32 -0
  996. mindspore/ops/_op_impl/cpu/reduce_mean.py +40 -0
  997. mindspore/ops/_op_impl/cpu/reduce_min.py +32 -0
  998. mindspore/ops/_op_impl/cpu/reduce_prod.py +40 -0
  999. mindspore/ops/_op_impl/cpu/reduce_std.py +31 -0
  1000. mindspore/ops/_op_impl/cpu/reduce_sum.py +41 -0
  1001. mindspore/ops/_op_impl/cpu/space_to_batch_nd.py +38 -0
  1002. mindspore/ops/_op_impl/cpu/sparse_slice.py +62 -0
  1003. mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +60 -0
  1004. mindspore/ops/_op_impl/cpu/split.py +34 -0
  1005. mindspore/ops/_op_impl/cpu/sspaddmm.py +95 -0
  1006. mindspore/ops/_op_impl/cpu/stack.py +38 -0
  1007. mindspore/ops/_op_impl/cpu/sub.py +32 -0
  1008. mindspore/ops/_op_impl/cpu/tensor_copy_slices.py +41 -0
  1009. mindspore/ops/_op_impl/cpu/tile.py +37 -0
  1010. mindspore/ops/_op_impl/cpu/top_k.py +31 -0
  1011. mindspore/ops/_op_impl/cpu/transpose.py +39 -0
  1012. mindspore/ops/_primitive_cache.py +90 -0
  1013. mindspore/ops/_register_for_op.py +73 -0
  1014. mindspore/ops/_utils/__init__.py +20 -0
  1015. mindspore/ops/_utils/utils.py +147 -0
  1016. mindspore/ops/_vmap/__init__.py +25 -0
  1017. mindspore/ops/_vmap/vmap_array_ops.py +2149 -0
  1018. mindspore/ops/_vmap/vmap_base.py +533 -0
  1019. mindspore/ops/_vmap/vmap_convolution_ops.py +441 -0
  1020. mindspore/ops/_vmap/vmap_debug_ops.py +50 -0
  1021. mindspore/ops/_vmap/vmap_grad_math_ops.py +274 -0
  1022. mindspore/ops/_vmap/vmap_grad_nn_ops.py +806 -0
  1023. mindspore/ops/_vmap/vmap_image_ops.py +194 -0
  1024. mindspore/ops/_vmap/vmap_math_ops.py +993 -0
  1025. mindspore/ops/_vmap/vmap_nn_ops.py +2250 -0
  1026. mindspore/ops/_vmap/vmap_other_ops.py +105 -0
  1027. mindspore/ops/_vmap/vmap_random_ops.py +122 -0
  1028. mindspore/ops/_vmap/vmap_sparse_ops.py +89 -0
  1029. mindspore/ops/auto_generate/__init__.py +31 -0
  1030. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +309 -0
  1031. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +252 -0
  1032. mindspore/ops/auto_generate/gen_arg_handler.py +197 -0
  1033. mindspore/ops/auto_generate/gen_extend_func.py +1701 -0
  1034. mindspore/ops/auto_generate/gen_ops_def.py +8482 -0
  1035. mindspore/ops/auto_generate/gen_ops_prim.py +16704 -0
  1036. mindspore/ops/auto_generate/pyboost_inner_prim.py +549 -0
  1037. mindspore/ops/composite/__init__.py +71 -0
  1038. mindspore/ops/composite/base.py +1318 -0
  1039. mindspore/ops/composite/env_ops.py +41 -0
  1040. mindspore/ops/composite/math_ops.py +125 -0
  1041. mindspore/ops/composite/multitype_ops/__init__.py +77 -0
  1042. mindspore/ops/composite/multitype_ops/_compile_utils.py +1459 -0
  1043. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +897 -0
  1044. mindspore/ops/composite/multitype_ops/add_impl.py +606 -0
  1045. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +56 -0
  1046. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +56 -0
  1047. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +56 -0
  1048. mindspore/ops/composite/multitype_ops/div_impl.py +189 -0
  1049. mindspore/ops/composite/multitype_ops/equal_impl.py +335 -0
  1050. mindspore/ops/composite/multitype_ops/floordiv_impl.py +88 -0
  1051. mindspore/ops/composite/multitype_ops/getitem_impl.py +400 -0
  1052. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +109 -0
  1053. mindspore/ops/composite/multitype_ops/greater_impl.py +110 -0
  1054. mindspore/ops/composite/multitype_ops/in_impl.py +196 -0
  1055. mindspore/ops/composite/multitype_ops/left_shift_impl.py +37 -0
  1056. mindspore/ops/composite/multitype_ops/less_equal_impl.py +111 -0
  1057. mindspore/ops/composite/multitype_ops/less_impl.py +112 -0
  1058. mindspore/ops/composite/multitype_ops/logic_not_impl.py +113 -0
  1059. mindspore/ops/composite/multitype_ops/logical_and_impl.py +60 -0
  1060. mindspore/ops/composite/multitype_ops/logical_or_impl.py +61 -0
  1061. mindspore/ops/composite/multitype_ops/mod_impl.py +86 -0
  1062. mindspore/ops/composite/multitype_ops/mul_impl.py +294 -0
  1063. mindspore/ops/composite/multitype_ops/negative_impl.py +79 -0
  1064. mindspore/ops/composite/multitype_ops/not_equal_impl.py +290 -0
  1065. mindspore/ops/composite/multitype_ops/not_in_impl.py +196 -0
  1066. mindspore/ops/composite/multitype_ops/ones_like_impl.py +96 -0
  1067. mindspore/ops/composite/multitype_ops/pow_impl.py +87 -0
  1068. mindspore/ops/composite/multitype_ops/right_shift_impl.py +37 -0
  1069. mindspore/ops/composite/multitype_ops/setitem_impl.py +884 -0
  1070. mindspore/ops/composite/multitype_ops/sub_impl.py +116 -0
  1071. mindspore/ops/composite/multitype_ops/uadd_impl.py +29 -0
  1072. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +228 -0
  1073. mindspore/ops/deprecated.py +315 -0
  1074. mindspore/ops/function/__init__.py +782 -0
  1075. mindspore/ops/function/array_func.py +7226 -0
  1076. mindspore/ops/function/clip_func.py +384 -0
  1077. mindspore/ops/function/debug_func.py +181 -0
  1078. mindspore/ops/function/fft_func.py +44 -0
  1079. mindspore/ops/function/grad/__init__.py +34 -0
  1080. mindspore/ops/function/grad/grad_func.py +1425 -0
  1081. mindspore/ops/function/image_func.py +292 -0
  1082. mindspore/ops/function/linalg_func.py +416 -0
  1083. mindspore/ops/function/math_func.py +12228 -0
  1084. mindspore/ops/function/nn_func.py +8609 -0
  1085. mindspore/ops/function/other_func.py +115 -0
  1086. mindspore/ops/function/parameter_func.py +134 -0
  1087. mindspore/ops/function/random_func.py +1715 -0
  1088. mindspore/ops/function/reshard_func.py +104 -0
  1089. mindspore/ops/function/sparse_func.py +884 -0
  1090. mindspore/ops/function/sparse_unary_func.py +2422 -0
  1091. mindspore/ops/function/spectral_func.py +150 -0
  1092. mindspore/ops/function/vmap_func.py +117 -0
  1093. mindspore/ops/functional.py +464 -0
  1094. mindspore/ops/op_info_register.py +1572 -0
  1095. mindspore/ops/operations/__init__.py +722 -0
  1096. mindspore/ops/operations/_csr_ops.py +403 -0
  1097. mindspore/ops/operations/_custom_grad.py +181 -0
  1098. mindspore/ops/operations/_embedding_cache_ops.py +307 -0
  1099. mindspore/ops/operations/_grad_ops.py +2978 -0
  1100. mindspore/ops/operations/_infer_ops.py +19 -0
  1101. mindspore/ops/operations/_inner_ops.py +2544 -0
  1102. mindspore/ops/operations/_map_tensor_ops.py +112 -0
  1103. mindspore/ops/operations/_ms_kernel.py +601 -0
  1104. mindspore/ops/operations/_ocr_ops.py +379 -0
  1105. mindspore/ops/operations/_opaque_predicate_registry.py +41 -0
  1106. mindspore/ops/operations/_pyfunc_registry.py +58 -0
  1107. mindspore/ops/operations/_quant_ops.py +1844 -0
  1108. mindspore/ops/operations/_rl_inner_ops.py +1231 -0
  1109. mindspore/ops/operations/_scalar_ops.py +106 -0
  1110. mindspore/ops/operations/_sequence_ops.py +1155 -0
  1111. mindspore/ops/operations/_sparse_grad_ops.py +56 -0
  1112. mindspore/ops/operations/_tensor_array.py +359 -0
  1113. mindspore/ops/operations/_thor_ops.py +807 -0
  1114. mindspore/ops/operations/array_ops.py +6124 -0
  1115. mindspore/ops/operations/comm_ops.py +1985 -0
  1116. mindspore/ops/operations/control_ops.py +127 -0
  1117. mindspore/ops/operations/custom_ops.py +1129 -0
  1118. mindspore/ops/operations/debug_ops.py +678 -0
  1119. mindspore/ops/operations/image_ops.py +1041 -0
  1120. mindspore/ops/operations/inner_ops.py +697 -0
  1121. mindspore/ops/operations/linalg_ops.py +95 -0
  1122. mindspore/ops/operations/manually_defined/__init__.py +24 -0
  1123. mindspore/ops/operations/manually_defined/_inner.py +73 -0
  1124. mindspore/ops/operations/manually_defined/ops_def.py +2271 -0
  1125. mindspore/ops/operations/math_ops.py +5095 -0
  1126. mindspore/ops/operations/nn_ops.py +9575 -0
  1127. mindspore/ops/operations/other_ops.py +874 -0
  1128. mindspore/ops/operations/random_ops.py +1288 -0
  1129. mindspore/ops/operations/reshard_ops.py +53 -0
  1130. mindspore/ops/operations/rl_ops.py +288 -0
  1131. mindspore/ops/operations/sparse_ops.py +2753 -0
  1132. mindspore/ops/operations/spectral_ops.py +111 -0
  1133. mindspore/ops/primitive.py +1046 -0
  1134. mindspore/ops/signature.py +54 -0
  1135. mindspore/ops/vm_impl_registry.py +91 -0
  1136. mindspore/ops_generate/__init__.py +27 -0
  1137. mindspore/ops_generate/arg_dtype_cast.py +252 -0
  1138. mindspore/ops_generate/arg_handler.py +197 -0
  1139. mindspore/ops_generate/gen_aclnn_implement.py +263 -0
  1140. mindspore/ops_generate/gen_constants.py +36 -0
  1141. mindspore/ops_generate/gen_ops.py +1099 -0
  1142. mindspore/ops_generate/gen_ops_inner_prim.py +131 -0
  1143. mindspore/ops_generate/gen_pyboost_func.py +1052 -0
  1144. mindspore/ops_generate/gen_utils.py +209 -0
  1145. mindspore/ops_generate/op_proto.py +145 -0
  1146. mindspore/ops_generate/pyboost_utils.py +367 -0
  1147. mindspore/ops_generate/template.py +261 -0
  1148. mindspore/parallel/__init__.py +30 -0
  1149. mindspore/parallel/_auto_parallel_context.py +1486 -0
  1150. mindspore/parallel/_cell_wrapper.py +174 -0
  1151. mindspore/parallel/_cost_model_context.py +700 -0
  1152. mindspore/parallel/_dp_allreduce_fusion.py +159 -0
  1153. mindspore/parallel/_offload_context.py +275 -0
  1154. mindspore/parallel/_parallel_serialization.py +561 -0
  1155. mindspore/parallel/_ps_context.py +242 -0
  1156. mindspore/parallel/_recovery_context.py +110 -0
  1157. mindspore/parallel/_tensor.py +730 -0
  1158. mindspore/parallel/_transformer/__init__.py +35 -0
  1159. mindspore/parallel/_transformer/layers.py +765 -0
  1160. mindspore/parallel/_transformer/loss.py +251 -0
  1161. mindspore/parallel/_transformer/moe.py +693 -0
  1162. mindspore/parallel/_transformer/op_parallel_config.py +222 -0
  1163. mindspore/parallel/_transformer/transformer.py +3119 -0
  1164. mindspore/parallel/_utils.py +612 -0
  1165. mindspore/parallel/algo_parameter_config.py +400 -0
  1166. mindspore/parallel/checkpoint_transform.py +650 -0
  1167. mindspore/parallel/cluster/__init__.py +15 -0
  1168. mindspore/parallel/cluster/process_entity/__init__.py +18 -0
  1169. mindspore/parallel/cluster/process_entity/_api.py +352 -0
  1170. mindspore/parallel/cluster/process_entity/_utils.py +101 -0
  1171. mindspore/parallel/cluster/run.py +136 -0
  1172. mindspore/parallel/mpi/__init__.py +14 -0
  1173. mindspore/parallel/mpi/_mpi_config.py +116 -0
  1174. mindspore/parallel/parameter_broadcast.py +151 -0
  1175. mindspore/parallel/shard.py +481 -0
  1176. mindspore/parallel/transform_safetensors.py +993 -0
  1177. mindspore/profiler/__init__.py +28 -0
  1178. mindspore/profiler/common/__init__.py +14 -0
  1179. mindspore/profiler/common/constant.py +29 -0
  1180. mindspore/profiler/common/exceptions/__init__.py +14 -0
  1181. mindspore/profiler/common/exceptions/error_code.py +83 -0
  1182. mindspore/profiler/common/exceptions/exceptions.py +286 -0
  1183. mindspore/profiler/common/process_pool.py +41 -0
  1184. mindspore/profiler/common/registry.py +47 -0
  1185. mindspore/profiler/common/singleton.py +28 -0
  1186. mindspore/profiler/common/struct_type.py +118 -0
  1187. mindspore/profiler/common/util.py +472 -0
  1188. mindspore/profiler/common/validator/__init__.py +14 -0
  1189. mindspore/profiler/common/validator/validate_path.py +84 -0
  1190. mindspore/profiler/dynamic_profiler.py +694 -0
  1191. mindspore/profiler/envprofiling.py +254 -0
  1192. mindspore/profiler/parser/__init__.py +14 -0
  1193. mindspore/profiler/parser/aicpu_data_parser.py +272 -0
  1194. mindspore/profiler/parser/ascend_analysis/__init__.py +14 -0
  1195. mindspore/profiler/parser/ascend_analysis/constant.py +71 -0
  1196. mindspore/profiler/parser/ascend_analysis/file_manager.py +180 -0
  1197. mindspore/profiler/parser/ascend_analysis/function_event.py +185 -0
  1198. mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +136 -0
  1199. mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +131 -0
  1200. mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +104 -0
  1201. mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
  1202. mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +123 -0
  1203. mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +86 -0
  1204. mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +75 -0
  1205. mindspore/profiler/parser/ascend_cluster_generator.py +116 -0
  1206. mindspore/profiler/parser/ascend_communicate_generator.py +314 -0
  1207. mindspore/profiler/parser/ascend_flops_generator.py +116 -0
  1208. mindspore/profiler/parser/ascend_fpbp_generator.py +82 -0
  1209. mindspore/profiler/parser/ascend_hccl_generator.py +271 -0
  1210. mindspore/profiler/parser/ascend_integrate_generator.py +42 -0
  1211. mindspore/profiler/parser/ascend_memory_generator.py +185 -0
  1212. mindspore/profiler/parser/ascend_msprof_exporter.py +282 -0
  1213. mindspore/profiler/parser/ascend_msprof_generator.py +187 -0
  1214. mindspore/profiler/parser/ascend_op_generator.py +334 -0
  1215. mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
  1216. mindspore/profiler/parser/ascend_timeline_generator.py +545 -0
  1217. mindspore/profiler/parser/base_timeline_generator.py +483 -0
  1218. mindspore/profiler/parser/container.py +229 -0
  1219. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +697 -0
  1220. mindspore/profiler/parser/flops_parser.py +531 -0
  1221. mindspore/profiler/parser/framework_enum.py +111 -0
  1222. mindspore/profiler/parser/framework_parser.py +464 -0
  1223. mindspore/profiler/parser/framework_struct.py +61 -0
  1224. mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
  1225. mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
  1226. mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
  1227. mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
  1228. mindspore/profiler/parser/hccl_parser.py +573 -0
  1229. mindspore/profiler/parser/hwts_log_parser.py +122 -0
  1230. mindspore/profiler/parser/integrator.py +526 -0
  1231. mindspore/profiler/parser/memory_usage_parser.py +277 -0
  1232. mindspore/profiler/parser/minddata_analyzer.py +800 -0
  1233. mindspore/profiler/parser/minddata_parser.py +186 -0
  1234. mindspore/profiler/parser/minddata_pipeline_parser.py +299 -0
  1235. mindspore/profiler/parser/op_intermediate_parser.py +149 -0
  1236. mindspore/profiler/parser/optime_parser.py +250 -0
  1237. mindspore/profiler/parser/profiler_info.py +213 -0
  1238. mindspore/profiler/parser/step_trace_parser.py +666 -0
  1239. mindspore/profiler/profiler.py +153 -0
  1240. mindspore/profiler/profiling.py +1922 -0
  1241. mindspore/rewrite/__init__.py +28 -0
  1242. mindspore/rewrite/api/__init__.py +17 -0
  1243. mindspore/rewrite/api/node.py +519 -0
  1244. mindspore/rewrite/api/node_type.py +53 -0
  1245. mindspore/rewrite/api/pattern_engine.py +490 -0
  1246. mindspore/rewrite/api/scoped_value.py +181 -0
  1247. mindspore/rewrite/api/symbol_tree.py +497 -0
  1248. mindspore/rewrite/ast_helpers/__init__.py +25 -0
  1249. mindspore/rewrite/ast_helpers/ast_converter.py +143 -0
  1250. mindspore/rewrite/ast_helpers/ast_finder.py +404 -0
  1251. mindspore/rewrite/ast_helpers/ast_flattener.py +268 -0
  1252. mindspore/rewrite/ast_helpers/ast_modifier.py +605 -0
  1253. mindspore/rewrite/ast_helpers/ast_replacer.py +79 -0
  1254. mindspore/rewrite/common/__init__.py +19 -0
  1255. mindspore/rewrite/common/config.py +24 -0
  1256. mindspore/rewrite/common/error_log.py +39 -0
  1257. mindspore/rewrite/common/event.py +28 -0
  1258. mindspore/rewrite/common/namer.py +271 -0
  1259. mindspore/rewrite/common/namespace.py +118 -0
  1260. mindspore/rewrite/common/observable.py +44 -0
  1261. mindspore/rewrite/common/observer.py +54 -0
  1262. mindspore/rewrite/node/__init__.py +22 -0
  1263. mindspore/rewrite/node/call_function.py +95 -0
  1264. mindspore/rewrite/node/cell_container.py +139 -0
  1265. mindspore/rewrite/node/control_flow.py +113 -0
  1266. mindspore/rewrite/node/node.py +1428 -0
  1267. mindspore/rewrite/node/node_manager.py +283 -0
  1268. mindspore/rewrite/node/node_topological_manager.py +223 -0
  1269. mindspore/rewrite/parsers/__init__.py +29 -0
  1270. mindspore/rewrite/parsers/arguments_parser.py +63 -0
  1271. mindspore/rewrite/parsers/assign_parser.py +852 -0
  1272. mindspore/rewrite/parsers/attribute_parser.py +57 -0
  1273. mindspore/rewrite/parsers/class_def_parser.py +289 -0
  1274. mindspore/rewrite/parsers/constant_parser.py +104 -0
  1275. mindspore/rewrite/parsers/container_parser.py +88 -0
  1276. mindspore/rewrite/parsers/expr_parser.py +55 -0
  1277. mindspore/rewrite/parsers/for_parser.py +61 -0
  1278. mindspore/rewrite/parsers/function_def_parser.py +84 -0
  1279. mindspore/rewrite/parsers/if_parser.py +85 -0
  1280. mindspore/rewrite/parsers/module_parser.py +117 -0
  1281. mindspore/rewrite/parsers/parser.py +43 -0
  1282. mindspore/rewrite/parsers/parser_register.py +86 -0
  1283. mindspore/rewrite/parsers/return_parser.py +37 -0
  1284. mindspore/rewrite/parsers/while_parser.py +59 -0
  1285. mindspore/rewrite/sparsify/__init__.py +0 -0
  1286. mindspore/rewrite/sparsify/sparse_transformer.py +457 -0
  1287. mindspore/rewrite/sparsify/sparsify.py +112 -0
  1288. mindspore/rewrite/sparsify/utils.py +179 -0
  1289. mindspore/rewrite/symbol_tree/__init__.py +20 -0
  1290. mindspore/rewrite/symbol_tree/symbol_tree.py +1819 -0
  1291. mindspore/rewrite/symbol_tree/symbol_tree_builder.py +76 -0
  1292. mindspore/rewrite/symbol_tree/symbol_tree_dumper.py +142 -0
  1293. mindspore/run_check/__init__.py +20 -0
  1294. mindspore/run_check/_check_version.py +507 -0
  1295. mindspore/run_check/run_check.py +66 -0
  1296. mindspore/safeguard/__init__.py +18 -0
  1297. mindspore/safeguard/rewrite_obfuscation.py +875 -0
  1298. mindspore/scipy/__init__.py +18 -0
  1299. mindspore/scipy/fft.py +264 -0
  1300. mindspore/scipy/linalg.py +919 -0
  1301. mindspore/scipy/ops.py +165 -0
  1302. mindspore/scipy/ops_grad.py +115 -0
  1303. mindspore/scipy/ops_wrapper.py +74 -0
  1304. mindspore/scipy/optimize/__init__.py +20 -0
  1305. mindspore/scipy/optimize/_bfgs.py +230 -0
  1306. mindspore/scipy/optimize/_lagrange.py +201 -0
  1307. mindspore/scipy/optimize/_lbfgs.py +146 -0
  1308. mindspore/scipy/optimize/gradient_optimization_algorithm.py +168 -0
  1309. mindspore/scipy/optimize/line_search.py +370 -0
  1310. mindspore/scipy/optimize/linear_sum_assignment.py +78 -0
  1311. mindspore/scipy/optimize/minimize.py +200 -0
  1312. mindspore/scipy/utils.py +156 -0
  1313. mindspore/scipy/utils_const.py +246 -0
  1314. mindspore/train/__init__.py +48 -0
  1315. mindspore/train/_utils.py +465 -0
  1316. mindspore/train/amp.py +935 -0
  1317. mindspore/train/anf_ir_pb2.py +1517 -0
  1318. mindspore/train/callback/__init__.py +44 -0
  1319. mindspore/train/callback/_backup_and_restore.py +117 -0
  1320. mindspore/train/callback/_callback.py +613 -0
  1321. mindspore/train/callback/_checkpoint.py +814 -0
  1322. mindspore/train/callback/_cluster_monitor.py +201 -0
  1323. mindspore/train/callback/_dataset_graph.py +150 -0
  1324. mindspore/train/callback/_early_stop.py +239 -0
  1325. mindspore/train/callback/_flops_collector.py +239 -0
  1326. mindspore/train/callback/_history.py +92 -0
  1327. mindspore/train/callback/_lambda_callback.py +80 -0
  1328. mindspore/train/callback/_landscape.py +1049 -0
  1329. mindspore/train/callback/_loss_monitor.py +107 -0
  1330. mindspore/train/callback/_lr_scheduler_callback.py +76 -0
  1331. mindspore/train/callback/_on_request_exit.py +298 -0
  1332. mindspore/train/callback/_reduce_lr_on_plateau.py +226 -0
  1333. mindspore/train/callback/_summary_collector.py +1184 -0
  1334. mindspore/train/callback/_tft_register.py +352 -0
  1335. mindspore/train/callback/_time_monitor.py +141 -0
  1336. mindspore/train/checkpoint_pb2.py +233 -0
  1337. mindspore/train/data_sink.py +219 -0
  1338. mindspore/train/dataset_helper.py +692 -0
  1339. mindspore/train/lineage_pb2.py +1260 -0
  1340. mindspore/train/loss_scale_manager.py +213 -0
  1341. mindspore/train/memory_profiling_pb2.py +298 -0
  1342. mindspore/train/metrics/__init__.py +175 -0
  1343. mindspore/train/metrics/accuracy.py +133 -0
  1344. mindspore/train/metrics/auc.py +129 -0
  1345. mindspore/train/metrics/bleu_score.py +170 -0
  1346. mindspore/train/metrics/confusion_matrix.py +700 -0
  1347. mindspore/train/metrics/cosine_similarity.py +109 -0
  1348. mindspore/train/metrics/dice.py +116 -0
  1349. mindspore/train/metrics/error.py +175 -0
  1350. mindspore/train/metrics/fbeta.py +167 -0
  1351. mindspore/train/metrics/hausdorff_distance.py +333 -0
  1352. mindspore/train/metrics/loss.py +97 -0
  1353. mindspore/train/metrics/mean_surface_distance.py +189 -0
  1354. mindspore/train/metrics/metric.py +373 -0
  1355. mindspore/train/metrics/occlusion_sensitivity.py +225 -0
  1356. mindspore/train/metrics/perplexity.py +133 -0
  1357. mindspore/train/metrics/precision.py +160 -0
  1358. mindspore/train/metrics/recall.py +159 -0
  1359. mindspore/train/metrics/roc.py +223 -0
  1360. mindspore/train/metrics/root_mean_square_surface_distance.py +191 -0
  1361. mindspore/train/metrics/topk.py +167 -0
  1362. mindspore/train/mind_ir_pb2.py +1908 -0
  1363. mindspore/train/model.py +2252 -0
  1364. mindspore/train/node_strategy_pb2.py +653 -0
  1365. mindspore/train/print_pb2.py +184 -0
  1366. mindspore/train/profiling_parallel_pb2.py +151 -0
  1367. mindspore/train/serialization.py +3325 -0
  1368. mindspore/train/summary/__init__.py +23 -0
  1369. mindspore/train/summary/_lineage_adapter.py +41 -0
  1370. mindspore/train/summary/_summary_adapter.py +496 -0
  1371. mindspore/train/summary/_writer_pool.py +207 -0
  1372. mindspore/train/summary/enums.py +56 -0
  1373. mindspore/train/summary/summary_record.py +581 -0
  1374. mindspore/train/summary/writer.py +167 -0
  1375. mindspore/train/summary_pb2.py +1165 -0
  1376. mindspore/train/train_thor/__init__.py +20 -0
  1377. mindspore/train/train_thor/convert_utils.py +268 -0
  1378. mindspore/train/train_thor/dataset_helper.py +192 -0
  1379. mindspore/train/train_thor/model_thor.py +257 -0
  1380. mindspore/utils/__init__.py +21 -0
  1381. mindspore/utils/utils.py +60 -0
  1382. mindspore/version.py +1 -0
  1383. mindspore-2.4.0.dist-info/METADATA +352 -0
  1384. mindspore-2.4.0.dist-info/RECORD +1387 -0
  1385. mindspore-2.4.0.dist-info/WHEEL +5 -0
  1386. mindspore-2.4.0.dist-info/entry_points.txt +3 -0
  1387. mindspore-2.4.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1819 @@
1
+ # Copyright 2022 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+ """SymbolTree class define of Rewrite according to forward function of a network."""
16
+ import stat
17
+ from typing import Optional, Union, Tuple, Any, Dict, List
18
+ import types
19
+ import os
20
+ import sys
21
+ import ast
22
+ import importlib.util
23
+ import time
24
+ import inspect
25
+ from textwrap import dedent
26
+ from collections import OrderedDict
27
+
28
+ from mindspore.nn import Cell
29
+ from mindspore import log as logger
30
+ from .symbol_tree_dumper import SymbolTreeDumper
31
+ from ..node import Node, TreeNode, ControlFlow, CallFunction, NodeManager
32
+ from ..api.node_type import NodeType
33
+ from ..api.scoped_value import ScopedValue, ValueType
34
+ from ..ast_helpers import AstModifier, AstReplacer, StrChecker, AstFinder, AstClassFinder, AstFunctionFinder, \
35
+ AstImportFinder
36
+ from ..common.namer import TargetNamer, NodeNamer, ClassNamer
37
+ from ..common.observer import Observer
38
+ from ..common.observable import Observable
39
+ from ..common.event import Event
40
+
41
+ if sys.version_info >= (3, 9):
42
+ import ast as astunparse # pylint: disable=reimported, ungrouped-imports
43
+ else:
44
+ import astunparse
45
+
46
+ class Position:
47
+ """
48
+ Position indicates a source code position in one network.
49
+
50
+ Rewrite recommend using class method `create()` of position rather than constructor of Position.
51
+
52
+ Args:
53
+ symbol_tree (SymbolTree): A handler of SymbolTree indicated position in which SymbolTree.
54
+ node (Node): A handler of Node indicated position is around which Node.
55
+ before_node (bool): A bool indicated position is before or after the 'node'.
56
+ """
57
+
58
+ def __init__(self, symbol_tree, node, before_node: bool):
59
+ self.symbol_tree = symbol_tree
60
+ self.node = node
61
+ self.before_node = before_node
62
+
63
+ @classmethod
64
+ def create(cls, symbol_tree, node, before_node):
65
+ """
66
+ Class method of Position. Return None when symbol_tree or node is None.
67
+
68
+ Args:
69
+ symbol_tree: A handler of SymbolTree indicated position in which SymbolTree.
70
+ node: A handler of Node indicated position is around which Node.
71
+ before_node (bool): A bool indicated position is before or after the 'node'.
72
+
73
+ Returns:
74
+ A Position.
75
+ """
76
+ if symbol_tree is None or node is None:
77
+ return None
78
+ return Position(symbol_tree, node, before_node)
79
+
80
+
81
+ class FieldFinder(AstFinder):
82
+ """
83
+ Check whether field exist in specific scope.
84
+
85
+ Args:
86
+ scope (ast.AST): An instance of ast node as search scope.
87
+ """
88
+
89
+ def __init__(self, scope: ast.AST):
90
+ super().__init__(scope)
91
+ self._result = False
92
+ self._field_name = ""
93
+
94
+ def visit_Attribute(self, node: ast.Attribute) -> Any:
95
+ """Visit a node of type ast.Attribute."""
96
+ value = node.value
97
+ if not isinstance(value, ast.Name):
98
+ return super(FieldFinder, self).generic_visit(node)
99
+ if value.id != "self":
100
+ return super(FieldFinder, self).generic_visit(node)
101
+ if node.attr == self._field_name:
102
+ self._result = True
103
+ return super(FieldFinder, self).generic_visit(node)
104
+
105
+ def check(self, field) -> bool:
106
+ """
107
+ Check whether `field` exist in scope.
108
+
109
+ Args:
110
+ field (str): A string indicates target field name.
111
+
112
+ Returns:
113
+ A bool indicate whether `field` exist in scope.
114
+ """
115
+ self._result = False
116
+ self._field_name = field
117
+ self.visit(self._scope)
118
+ return self._result
119
+
120
+
121
+ class SymbolTree(Observer, Observable, NodeManager):
122
+ """
123
+ A symbol-tree usually corresponding to forward method of a network.
124
+
125
+ Rewrite recommend using SymbolTreeBuilder to instantiate an instance of SymbolTree rather than invoking constructor
126
+ of SymbolTree directly.
127
+
128
+ Args:
129
+ origin_network (Cell): A handler to original network instance.
130
+ module_ast (ast.Module): An instance of ast.AST represents ast node of original network.
131
+ """
132
+ # whether parse CallFunction node inserted by user.
133
+ _unparse_inserted_function = True
134
+
135
+ def __init__(self, origin_network: Cell, module_ast: ast.Module):
136
+ Observer.__init__(self)
137
+ Observable.__init__(self)
138
+ self._node_namer = NodeNamer()
139
+ self._node_namer.add_name('obj')
140
+ NodeManager.__init__(self)
141
+ NodeManager.set_manager_node_namer(self, self._node_namer)
142
+ NodeManager.reg_observer(self, observer=self)
143
+ # init unique-namers
144
+ self._target_namer = TargetNamer()
145
+ # input arguments of function
146
+ self._ori_cls_name = type(origin_network).__name__
147
+ self._opt_cls_name = ClassNamer.instance().get_name(self._ori_cls_name)
148
+ NodeManager.set_manager_name(self, self._opt_cls_name)
149
+ self._origin_network = origin_network
150
+ self._module_ast: ast.Module = module_ast
151
+ self._import_asts: Optional[ast.Ast] = []
152
+ self._class_ast: Optional[ast.ClassDef] = None
153
+ self._root_ast: Optional[ast.FunctionDef] = None
154
+ self._init_func_ast: Optional[ast.FunctionDef] = None
155
+ self._deleted_field = {}
156
+ self._deleted_node = []
157
+ # {ast_function: [import_asts]}
158
+ self._external_ast: Dict[ast.FunctionDef, list] = OrderedDict()
159
+ # {ast_class: [import_asts]}
160
+ self._father_class_ast: Dict[ast.ClassDef, list] = OrderedDict()
161
+ self._modified = False
162
+ self._saved_file_name = "./network_define.py"
163
+ # used to insert "sys.path.append(xxx)"
164
+ self._net_file_paths = []
165
+ self._tmp_import_strs = []
166
+ self._tmp_unmodified_strees: {type, List[SymbolTree]} = {}
167
+ self._tmp_replacers = []
168
+ # user custom codes
169
+ self._custom_codes: List[ast.AST] = []
170
+ # local primitive instances initialized during forward method, e.g. abs_inst = P.Abs()
171
+ self._local_prim_inits: List[Node] = []
172
+
173
+ @staticmethod
174
+ def _remove_unused_import(module_ast):
175
+ """remove unused import in self._module_ast"""
176
+ import_nodes: List[Union[ast.Import, ast.ImportFrom]] = []
177
+
178
+ def is_divider(ast_node):
179
+ """judge if ast node is divider of new class or function by checking ast.Expr of '#'."""
180
+ return isinstance(ast_node, ast.Expr) and isinstance(ast_node.value, ast.Name) and ast_node.value.id == '#'
181
+
182
+ for ast_node in module_ast.body[:]:
183
+ if isinstance(ast_node, (ast.Import, ast.ImportFrom)):
184
+ import_nodes.append(ast_node)
185
+ if isinstance(ast_node, (ast.ClassDef, ast.FunctionDef)):
186
+ str_checker = StrChecker(ast_node)
187
+ for import_node in import_nodes:
188
+ for alias in import_node.names[:]:
189
+ name = alias.asname if alias.asname else alias.name
190
+ if name == '*':
191
+ continue
192
+ if not str_checker.check(name):
193
+ import_node.names.remove(alias)
194
+ if not import_node.names:
195
+ module_ast.body.remove(import_node)
196
+ if is_divider(ast_node):
197
+ import_nodes.clear()
198
+
199
+ @staticmethod
200
+ def _remove_duplicated_import(module_ast):
201
+ """Remove duplicated import of 'net'."""
202
+ imports = set()
203
+ futures = set()
204
+ names = set()
205
+
206
+ class TransImportNode(ast.NodeTransformer):
207
+ """Find all import nodes from input ast node."""
208
+
209
+ def visit_ClassDef(self, node: ast.ClassDef) -> Any:
210
+ if node.name not in names:
211
+ names.add(node.name)
212
+ return node
213
+ return None
214
+
215
+ def visit_FunctionDef(self, node: ast.FunctionDef) -> Any:
216
+ if node.name not in names:
217
+ names.add(node.name)
218
+ return node
219
+ return None
220
+
221
+ def visit_Try(self, node: ast.Try) -> Any:
222
+ if isinstance(node.body[0], (ast.Import, ast.ImportFrom)):
223
+ import_str = astunparse.unparse(node)
224
+ if import_str not in imports:
225
+ imports.add(import_str)
226
+ return node
227
+ return None
228
+
229
+ def visit_Import(self, node: ast.Import) -> Any:
230
+ import_str = astunparse.unparse(node)
231
+ if import_str not in imports:
232
+ imports.add(import_str)
233
+ return node
234
+ return None
235
+
236
+ def visit_ImportFrom(self, node: ast.ImportFrom) -> Any:
237
+ """
238
+ Once the father class 'A' is defined in the current module, all the next imported class 'A' should
239
+ be removed. e.g.
240
+ def class A():
241
+ ...
242
+ from xxx import A, B
243
+ =>
244
+ def class A():
245
+ ...
246
+ from xxx import B
247
+ """
248
+ import_str = astunparse.unparse(node)
249
+
250
+ if import_str not in imports:
251
+ imports.add(import_str)
252
+ # remove "__future__" module
253
+ if node.module == '__future__':
254
+ futures.add(node.module)
255
+ return None
256
+ # remove modules which have been defined in the code file
257
+ # it occurs when class A is a father class and other sub-classes import A
258
+ for alias in node.names[:]:
259
+ if alias.name in names:
260
+ node.names.remove(alias)
261
+ # if the alias(es) in node.names are all removed, this import statement should be removed
262
+ if not node.names:
263
+ return None
264
+ return node
265
+ return None
266
+
267
+ get_node_handler = TransImportNode()
268
+ get_node_handler.generic_visit(module_ast)
269
+
270
+ @staticmethod
271
+ def _remove_arg_annotations(module_ast):
272
+ """Remove annotations in ast.arg to avoid 'xxx is not defined'."""
273
+ ast_args: List[ast.arg] = AstFinder(module_ast).find_all(ast.arg)
274
+ for ast_arg in ast_args:
275
+ ast_arg.annotation = None
276
+
277
+ @staticmethod
278
+ def _check_import(import_path: str, import_module: str):
279
+ """
280
+ Check whether import operation is valid when importing module from specific path.
281
+ """
282
+ if import_path not in sys.path:
283
+ sys.path.append(import_path)
284
+ try:
285
+ importlib.import_module(name=import_module)
286
+ except (ValueError, ImportError) as e:
287
+ logger.info(f"Test import {import_module} from {import_path} failed: {e}.")
288
+ return False
289
+ except Exception as e: # pylint: disable=W0703
290
+ logger.info(f"Test import {import_module} from {import_path} failed: {e}.")
291
+ return False
292
+ return True
293
+
294
+ @staticmethod
295
+ def _process_relative_import(import_node: Union[ast.Import, ast.ImportFrom], file_path: str):
296
+ """Process relative imports"""
297
+ file_path = os.path.normcase(file_path)
298
+ file_path = os.path.normpath(file_path)
299
+ if isinstance(import_node, ast.ImportFrom):
300
+ # pad the ImportFrom with parent path
301
+ # e.g. from ..C import xxx -> from A.B.C import xxx
302
+ import_module = SymbolTree._get_valid_import_info(import_node, file_path)
303
+ if import_module:
304
+ import_node = ast.ImportFrom(module=import_module, names=import_node.names, level=0)
305
+ return import_node
306
+
307
+ @staticmethod
308
+ def _get_valid_import_info(import_node: ast.ImportFrom, file_path: str):
309
+ """Get valid import info while import_node.module is at form of relative path"""
310
+ file_path = os.path.dirname(os.path.realpath(file_path))
311
+ # get real path from import_node.level
312
+ # from .(A) import xxx: current path
313
+ # from ..(A) import xxx: last level path
314
+ level = import_node.level
315
+ # from A import xxx: it does not need to pad, directly return the module name
316
+ if level == 0:
317
+ return import_node.module
318
+ if level > 1:
319
+ for _ in range(level - 1):
320
+ file_path = os.path.dirname(file_path)
321
+ file_path_tmp = file_path[:]
322
+ max_level_count = file_path.count(os.path.sep) - 1
323
+ level_count = 0
324
+ # suffix is the module_name, e.g. 'A' in 'from ..(A) import xxx'
325
+ suffix = ''
326
+ if import_node.module:
327
+ suffix = '.' + import_node.module
328
+ while level_count < max_level_count:
329
+ file_path_tmp = os.path.dirname(file_path_tmp)
330
+ if file_path_tmp not in sys.path:
331
+ logger.debug(f"{file_path_tmp} not in sys.path, try upper level.")
332
+ level_count += 1
333
+ continue
334
+ import_module = file_path[len(file_path_tmp) + 1:].replace(os.path.sep, '.') + suffix
335
+ if SymbolTree._check_import(file_path_tmp, import_module):
336
+ # try test code success
337
+ return import_module
338
+ # test import ast failed, try upper level
339
+ level_count += 1
340
+ logger.info(f"Try upper level.")
341
+ # try codes with all level failed
342
+ logger.info(f"Test import code: {astunparse.unparse(import_node).strip()} failed, ignore this import code.")
343
+ return None
344
+
345
+ @staticmethod
346
+ def insert_to_ast_while_insert_input(new_node: Node, node_manager: NodeManager):
347
+ """update ast when inserting NodeType.Input node"""
348
+ if not isinstance(node_manager, (SymbolTree, CallFunction)):
349
+ raise ValueError(f"Only support insert Input node into a SymbolTree or a node with type of "
350
+ f"CallFunction, but get {type(node_manager)}")
351
+ # insert a new input
352
+ node_manager.get_input_nodes().append(new_node)
353
+ ast_function: ast.FunctionDef = node_manager.get_manager_ast()
354
+ arg: str = new_node.get_targets()[0].value
355
+ ast_arg = ast.arg(arg=arg, annotation=None, type_comment=None)
356
+ AstModifier.append_arg_to_function(ast_function, ast_arg)
357
+
358
+ @staticmethod
359
+ def insert_to_ast_while_insert_cell_primitive(new_node: Node, base_node: Node, before_node: bool,
360
+ node_manager: NodeManager, stree):
361
+ """update ast when inserting NodeType.CallCell or NodeType.CallPrimitive node"""
362
+ # create a new assign statement
363
+ ast_assign = new_node.get_ast()
364
+ if ast_assign is None:
365
+ func_name = stree.unique_func_name(new_node.get_name())
366
+ new_node.set_func_name(ScopedValue.create_naming_value(func_name, "self"))
367
+ ast_assign = new_node.update_ast_node()
368
+ if not isinstance(ast_assign, ast.Assign):
369
+ raise ValueError(f"Only support insert ast.Assign or Input now, but get {type(ast_assign)}")
370
+ # Save instance into _origin_network.
371
+ setattr(stree.get_origin_network(), new_node.get_name(), new_node.get_instance())
372
+ # Insert ast to __init__ function
373
+ if isinstance(new_node, TreeNode):
374
+ init_code = f"{new_node.get_func_name()} = " \
375
+ f"{new_node.symbol_tree.get_opt_cls_name()}(obj.{new_node.get_name()})"
376
+ else:
377
+ init_code = f"{new_node.get_func_name()} = obj.{new_node.get_name()}"
378
+ init_ast = ast.parse(init_code).body[0]
379
+ AstModifier.insert_ast_to_function(stree.get_init_func_ast(), init_ast)
380
+ # Insert ast to construct_function/class_internal_function
381
+ ast_base_node = base_node.get_ast() if base_node else None
382
+ ast_node_manager = node_manager.get_manager_ast()
383
+ if not ast_node_manager:
384
+ raise RuntimeError(f"ast_node_manager is None in node_manager {node_manager.get_manager_name()} "
385
+ "when inserting the ast.")
386
+ AstModifier.insert_ast_to_ast(ast_node_manager, ast_assign, ast_base_node, before_node)
387
+
388
+ @staticmethod
389
+ def insert_to_ast_while_insert_function(new_node: CallFunction, base_node: Node, before_node: bool,
390
+ node_manager: NodeManager, stree: 'SymbolTree'):
391
+ """update ast when inserting NodeType.CallFunction node"""
392
+ func_name = str(new_node.get_func_name())
393
+ # create a new assign statement
394
+ ast_assign = new_node.get_ast()
395
+ if ast_assign is None:
396
+ ast_assign = new_node.update_ast_node()
397
+ # Insert ast to node_manager
398
+ ast_base_node = base_node.get_ast() if base_node else None
399
+ ast_node_manager = node_manager.get_manager_ast()
400
+ if not ast_node_manager:
401
+ raise RuntimeError(f"ast_node_manager is None in node_manager {node_manager.get_manager_name()} "
402
+ "when inserting the ast.")
403
+ AstModifier.insert_ast_to_ast(ast_node_manager, ast_assign, ast_base_node, before_node)
404
+ # Ignore Python builtin functions
405
+ func_obj = new_node.get_instance()
406
+ if isinstance(func_obj, types.BuiltinFunctionType):
407
+ logger.warning(f"Ignore built in function: {func_name}")
408
+ return
409
+ # get ast.FunctionDef
410
+ source_code = inspect.getsource(func_obj)
411
+ ast_functiondef = ast.parse(dedent(source_code)).body[0]
412
+ if SymbolTree._unparse_inserted_function or not isinstance(ast_functiondef, ast.FunctionDef):
413
+ logger.debug(f"import '{func_name}' to access function object")
414
+ # add import to make sure that the function object can be accessed.
415
+ module = inspect.getmodule(func_obj)
416
+ top_node_manager = node_manager.get_top_manager()
417
+ belonging_ast = None if isinstance(top_node_manager, SymbolTree) else top_node_manager.get_manager_ast()
418
+ stree.add_import(module, func_name, belonging_ast)
419
+ return
420
+ # parse nodes in inserted function.
421
+ new_node.set_manager_ast(ast_functiondef)
422
+ new_node.set_manager_node_namer(stree.get_node_namer())
423
+ stree.get_external_ast()[ast_functiondef] = []
424
+ # import module which function defined in
425
+ func_file_path = inspect.getabsfile(func_obj)
426
+ stree.save_imports_from_file(func_file_path, ast_functiondef)
427
+ # expand ast codes in function
428
+ from ..ast_helpers import AstFlattener
429
+ ast_functiondef = AstFlattener().transform(ast_functiondef, [func_name], stree)
430
+ # parse ast codes into CallFunction Node
431
+ from ..parsers import ParserRegister
432
+ parser = ParserRegister.instance().get_parser(ast.FunctionDef)
433
+ parser.process(stree, ast_functiondef, node_manager=new_node)
434
+
435
+ @staticmethod
436
+ def insert_to_ast_while_insert_node(new_node: Node, base_node: Node, before_node: bool):
437
+ """ insert_to_ast_while_insert_node. """
438
+ stree = new_node.get_belong_symbol_tree()
439
+ if not stree:
440
+ raise ValueError(f"When inserting node to ast, the belonging symbol tree of new_node is None.")
441
+ node_manager = new_node.get_node_manager()
442
+ if not isinstance(node_manager, (SymbolTree, CallFunction, ControlFlow)):
443
+ raise ValueError(f"When inserting node to ast, the node_manager of new_node {new_node.get_name()} can "
444
+ f"only be one of [SymbolTree, CallFunction, ControlFlow], but get {type(node_manager)}")
445
+ if new_node.get_node_type() == NodeType.Input:
446
+ SymbolTree.insert_to_ast_while_insert_input(new_node, node_manager)
447
+ elif new_node.get_node_type() in (NodeType.CallCell, NodeType.CallPrimitive, NodeType.Tree):
448
+ SymbolTree.insert_to_ast_while_insert_cell_primitive(new_node, base_node, before_node, node_manager,
449
+ stree)
450
+ elif new_node.get_node_type() == NodeType.CallFunction:
451
+ SymbolTree.insert_to_ast_while_insert_function(new_node, base_node, before_node, node_manager, stree)
452
+ else:
453
+ raise ValueError(f"When insert node '{new_node.get_name()}' into ast, the type of node can only be "
454
+ f"one of [Input, CallCell, CallPrimitive, CallFunction, Tree], but got "
455
+ f"{new_node.get_node_type()}.")
456
+
457
+ @staticmethod
458
+ def get_node_full_name(node: Node) -> str:
459
+ """Get full name of node"""
460
+ name = node.get_manager_name() if isinstance(node, NodeManager) else node.get_name()
461
+ # traverse node_manager with type of Node
462
+ node_manager = node.get_node_manager()
463
+ while isinstance(node_manager, Node):
464
+ name = f"{node_manager.get_manager_name()}.{name}"
465
+ node_manager = node_manager.get_node_manager()
466
+ # type of node_manager is SymbolTree now
467
+ name = f"{node_manager.get_manager_name()}.{name}"
468
+ return name
469
+
470
+ def local_prim_inits(self) -> List[Node]:
471
+ """get local primitives constructed during forward method"""
472
+ return self._local_prim_inits
473
+
474
+ def finish_build(self):
475
+ """Add Event.TopologicalChangeEvent event when build is finished."""
476
+ self.add_event(Event.TopologicalChangeEvent)
477
+
478
+ def get_ori_cls_name(self) -> str:
479
+ """
480
+ Get class name of original network.
481
+
482
+ Returns:
483
+ A str represents class name of original network.
484
+ """
485
+ return self._ori_cls_name
486
+
487
+ def get_opt_cls_name(self) -> str:
488
+ """
489
+ Get class name of rewritten network.
490
+
491
+ Returns:
492
+ A str represents class name of rewritten network.
493
+ """
494
+ return self._opt_cls_name
495
+
496
+ def get_module_ast(self):
497
+ """
498
+ Getter of `_module_ast`.
499
+
500
+ Returns:
501
+ An instance of ast.AST represents ast node of corresponding module.
502
+ """
503
+ return self._module_ast
504
+
505
+ def set_module_ast(self, ast_node: ast.Module):
506
+ """
507
+ Setter of _module_ast.
508
+
509
+ Args:
510
+ ast_node (ast.Module): An instance of ast.Module represents ast node of module of corresponding network
511
+ class.
512
+ """
513
+ self._module_ast = ast_node
514
+
515
+ def get_ast_root(self):
516
+ """
517
+ Getter of `_root_ast`.
518
+
519
+ Returns:
520
+ An instance of ast.AST represents ast node of corresponding forward method.
521
+ """
522
+ return self._root_ast
523
+
524
+ def set_ast_root(self, ast_node: ast.FunctionDef):
525
+ """
526
+ Setter of _root_ast.
527
+
528
+ Args:
529
+ ast_node (ast.FunctionDef): An instance of ast.FunctionDef represents ast node of forward method of
530
+ corresponding network class.
531
+ """
532
+ self._root_ast = ast_node
533
+ NodeManager.set_manager_ast(self, ast_node)
534
+
535
+ def get_class_ast(self):
536
+ """
537
+ Getter of `_class_ast`.
538
+
539
+ Returns:
540
+ An instance of ast.ClassDef represents ast node of corresponding network class.
541
+ """
542
+ return self._class_ast
543
+
544
+ def set_class_ast(self, ast_node: ast.ClassDef):
545
+ """
546
+ Setter of `_class_ast`.
547
+
548
+ Args:
549
+ ast_node (ast.ClassDef): An instance of ast.ClassDef represents ast node of corresponding network class.
550
+ """
551
+ self._class_ast = ast_node
552
+
553
+ def get_init_func_ast(self):
554
+ """
555
+ Getter of _init_func_ast.
556
+
557
+ Returns:
558
+ An instance of ast.FunctionDef represents ast node of init method of corresponding network class.
559
+ """
560
+ return self._init_func_ast
561
+
562
+ def set_init_func_ast(self, ast_node: ast.FunctionDef):
563
+ """
564
+ Setter of _init_func_ast.
565
+
566
+ Args:
567
+ ast_node (ast.FunctionDef): An instance of ast.FunctionDef represents ast node of init method of
568
+ corresponding network class.
569
+ """
570
+ self._init_func_ast = ast_node
571
+
572
+ def get_origin_network(self):
573
+ """
574
+ Getter of `_origin_network`.
575
+
576
+ Returns:
577
+ An instance of Cell which represents original network.
578
+ """
579
+ return self._origin_network
580
+
581
+ def get_nodes_dict(self):
582
+ """Get dict of nodes"""
583
+ return self._nodes
584
+
585
+ def get_node_namer(self):
586
+ """Get _node_namer"""
587
+ return self._node_namer
588
+
589
+ def is_modified(self):
590
+ """
591
+ Check whether symbol tree is modified.
592
+
593
+ Symbol tree is considered as modified if operations like insert/replace/erase/set_arg is called after
594
+ the symbol tree is created.
595
+ """
596
+ return self._modified
597
+
598
+ def set_modified_true(self):
599
+ """
600
+ Set self._modified true.
601
+
602
+ Self._modified is set true when 'if' exists in the original network.
603
+ In this situation, different original network instance tends to be different.
604
+ Hence, the class name should be updated.
605
+ """
606
+ self._modified = True
607
+
608
+ def get_import_asts(self):
609
+ """Get _import_asts"""
610
+ return self._import_asts
611
+
612
+ def get_external_ast(self):
613
+ """Get _external_ast"""
614
+ return self._external_ast
615
+
616
+ def get_father_class_ast(self):
617
+ """Get _father_class_ast"""
618
+ return self._father_class_ast
619
+
620
+ def get_node_inputs(self, node_or_name: Union[Node, str]) -> [Node]:
621
+ """
622
+ Getter of inputs in topological relation of current 'node_or_name'.
623
+
624
+ Args:
625
+ node_or_name (Union[Node, str]): An instance of node or a str represents name of node.
626
+
627
+ Returns:
628
+ A list of instances of Node as input nodes if 'node_or_name' belong to current SymbolTree. An empty list if
629
+ 'node_or_name' not belong to current SymbolTree.
630
+ """
631
+
632
+ real_node: Optional[Node] = self._get_real_node(node_or_name)
633
+ if real_node is None:
634
+ logger.info("Node(%s) is not belong to current SymbolTree", node_or_name)
635
+ return []
636
+ return node_or_name.get_inputs()
637
+
638
+ def get_node_users(self, node_or_name: Union[Node, str]) -> [Tuple[Node, int]]:
639
+ """
640
+ Getter of outputs in topological relation of current 'node_or_name'.
641
+
642
+ Args:
643
+ node_or_name (Union[Node, str]): An instance of node or a str represents name of node.
644
+
645
+ Returns:
646
+ A list of instances of Node as output nodes if 'node_or_name' belong to current SymbolTree. An empty list if
647
+ 'node_or_name' not belong to current SymbolTree.
648
+ """
649
+
650
+ real_node: Optional[Node] = self._get_real_node(node_or_name)
651
+ if real_node is None:
652
+ logger.info("Node(%s) is not belong to current SymbolTree", node_or_name)
653
+ return []
654
+ if real_node.get_node_type() == NodeType.Output:
655
+ return []
656
+ node_users = []
657
+ for target_users in real_node.get_target_users().values():
658
+ if not target_users:
659
+ continue
660
+ if target_users not in node_users:
661
+ node_users.extend(target_users)
662
+ return node_users
663
+
664
+ def before(self, node_or_name: Union[Node, str]) -> Position:
665
+ """
666
+ Get insert position before 'node_or_name' in source code list.
667
+ Consider using symbol_tree, node and before/after as position for sub-tree feature.
668
+
669
+ Note:
670
+ Topological order is not determined here which is determined by arguments of node and updated by
671
+ TopologicalManager automatically.
672
+
673
+ Args:
674
+ node_or_name (Union[Node, str]): An instance of node or a str represents name of node.
675
+
676
+ Returns:
677
+ A Position represents an insert point.
678
+
679
+ Raises:
680
+ AssertError: If 'node_or_name' is not a Node or a str
681
+ RuntimeError: If 'node_or_name' is not belong to this SymbolTree or any sub-SymbolTree of current
682
+ SymbolTree.
683
+ """
684
+
685
+ node = self._get_real_node(node_or_name)
686
+ if node is None:
687
+ raise RuntimeError("Node is not belong to current SymbolTree: ", node_or_name)
688
+ return Position.create(node.get_belong_symbol_tree(), node, True)
689
+
690
+ def after(self, node_or_name: Union[Node, str]) -> Position:
691
+ """
692
+ Get insert position after 'node_or_name' in source code list.
693
+ Consider using symbol_tree, node and before/after as position for sub-tree feature.
694
+
695
+ Note:
696
+ Topological order is not determined here which is determined by arguments of node and updated by
697
+ TopologicalManager automatically.
698
+
699
+ Args:
700
+ node_or_name (Union[Node, str]): An instance of node or a str represents name of node.
701
+
702
+ Returns:
703
+ A Position represents an insert point.
704
+
705
+ Raises:
706
+ AssertError: If 'node_or_name' is not a Node or a str
707
+ RuntimeError: If 'node_or_name' is not belong to this SymbolTree or any sub-SymbolTree of current
708
+ SymbolTree.
709
+ """
710
+ node = self._get_real_node(node_or_name)
711
+ if node is None:
712
+ raise RuntimeError("Node is not belong to current SymbolTree: ", node_or_name)
713
+ return Position.create(node.get_belong_symbol_tree(), node, False)
714
+
715
+ def insert_node(self, new_node: Node, base_node: Node, before_node: bool, node_manager: NodeManager = None,
716
+ insert_to_ast: bool = True):
717
+ """
718
+ Insert a node before or after base_node.
719
+
720
+ Note:
721
+ Name of node will be unique while inserting node into SymbolTree.
722
+
723
+ ValueType.CustomObjValue type arguments will be converted to ValueType.NamingValue and custom object will
724
+ be saved in global_vars dict while inserting node into SymbolTree.
725
+
726
+ Targets of node will be unique while inserting node into SymbolTree.
727
+
728
+ A field instantiation statement will be added into "init" function of network class using node name as field
729
+ name when `insert_to_ast` is True while inserting node into SymbolTree.
730
+
731
+ An assign statement represents invoking to this node will be added into forward function of network class
732
+ corresponding to field-instantiation-statement when `insert_to_ast` is True while inserting node into
733
+ SymbolTree.
734
+
735
+ Topological relation is updated and inputs of corresponding node is updated.
736
+
737
+ Args:
738
+ new_node (Node): Node to be inserted.
739
+ base_node (Node): New node will be inserted before or after base_node.
740
+ before_node (bool): Indicate whether new node is inserted before base_node.
741
+ node_manager (NodeManager): NodeManager those asts belong to. Default: None, means those asts belong to
742
+ NodeManager of symboltree's construct function.
743
+ insert_to_ast (bool): Indicate whether ast nodes need to be updated.
744
+
745
+ Returns:
746
+ An instance of node which has been inserted into SymbolTree.
747
+
748
+ Raises:
749
+ ValueError: Node in the SymbolTree is inserted into SymbolTree again.
750
+ RuntimeError: If corresponding ast node is not an ast.Assign when 'insert_to_ast' is True.
751
+ """
752
+ if new_node.get_belong_symbol_tree():
753
+ raise ValueError(f"Node in the SymbolTree cannot be inserted into SymbolTree again: {new_node.get_name()}")
754
+
755
+ # Check if base_node in current SymbolTree
756
+ if base_node is not None:
757
+ stree = base_node.get_belong_symbol_tree()
758
+ if stree is not None and stree is not self:
759
+ raise ValueError(f"Position is not in current SymbolTree, node:{stree.get_ori_cls_name()}, "
760
+ f"current: {self.get_ori_cls_name()}.")
761
+
762
+ # Check if node is inserted between Input node
763
+ if base_node is not None and base_node.get_node_type() == NodeType.Input:
764
+ valid = True
765
+ if before_node:
766
+ valid = False
767
+ if base_node.get_next() is not None and base_node.get_next().get_node_type() == NodeType.Input:
768
+ valid = False
769
+ if not valid:
770
+ raise RuntimeError("Can not insert a node before or between parameters:", base_node.get_name())
771
+
772
+ # save target name, which is used to provide unique target
773
+ if new_node.get_targets():
774
+ for target in new_node.get_targets():
775
+ self._target_namer.add_name(str(target))
776
+
777
+ self._handle_custom_obj_in_normalized_args(new_node)
778
+
779
+ # Insert node into NodeManager
780
+ if node_manager is None:
781
+ if base_node is None:
782
+ raise RuntimeError("node_manager and base_node cannot both be None when inserting a node.")
783
+ node_manager = base_node.get_node_manager()
784
+
785
+ # set node's _belong_symbol_tree
786
+ new_node.set_belong_symbol_tree(self)
787
+
788
+ if node_manager is self:
789
+ NodeManager.insert_node(self, new_node, base_node, before_node)
790
+ if insert_to_ast:
791
+ # update init-function-ast and construct-function-ast
792
+ self.insert_to_ast_while_insert_node(new_node, base_node, before_node)
793
+ else:
794
+ node_manager.insert_node(new_node, base_node, before_node, insert_to_ast)
795
+
796
+ # register code changed event observer, which is used to update _modified flag.
797
+ if new_node.get_node_type() == NodeType.Tree:
798
+ new_node.symbol_tree.reg_observer(self)
799
+ elif isinstance(new_node, NodeManager):
800
+ new_node.reg_observer(self)
801
+
802
+ return new_node
803
+
804
+ def append_node(self, node: Node, node_manager: NodeManager = None, append_to_ast: bool = True) -> Node:
805
+ """
806
+ Append a node to SymbolTree.
807
+
808
+ Args:
809
+ node (Node): An instance of node to be appended.
810
+ append_to_ast (bool): A bool indicates whether to update corresponding ast node at same time, default is
811
+ True.
812
+ node_manager (NodeManager): NodeManager those asts belong to. Default: None, means those asts belong to
813
+ NodeManager of symboltree's construct function.
814
+
815
+ Returns:
816
+ An instance of node which has been appended to SymbolTree.
817
+ """
818
+ if node_manager is None:
819
+ node_manager = self
820
+ return self.insert_node(node, node_manager.get_tail(), False, node_manager, append_to_ast)
821
+
822
+ def append_origin_field(self, node: Node, node_manager: NodeManager = None) -> Node:
823
+ """
824
+ Append an original field node to SymbolTree. An original field node represents a node created from existing
825
+ statement in forward method, from existing ast node in ast of forward method, so ast node do not need to update
826
+ while these nodes appending to SymbolTree.
827
+ This method is called while building SymbolTree usually.
828
+
829
+ Args:
830
+ node (Node): An instance of node to be appended.
831
+ node_manager (NodeManager): NodeManager those asts belong to. Default: None, means those asts belong to
832
+ NodeManager of symboltree's construct function.
833
+
834
+ Returns:
835
+ An instance of node which has been appended to SymbolTree.
836
+ """
837
+ return self.append_node(node, node_manager, False)
838
+
839
+ def append_input_node(self, ast_node, param_name: str, default: Optional[ScopedValue] = None,
840
+ node_manager: NodeManager = None):
841
+ """
842
+ Append an input node to SymbolTree corresponding to parameter of forward method of network class.
843
+ This method is called while building SymbolTree usually.
844
+
845
+ Args:
846
+ ast_node (ast.AST): A ast Node corresponding to current parameter.
847
+ param_name (str): A str represents name of parameter of forward method of network class.
848
+ default (ScopedValue, optional): A ScopedValue represents default value of parameter. Default is None which
849
+ means parameter has no default value.
850
+ node_manager (NodeManager): NodeManager those asts belong to. Default: None, means those asts belong to
851
+ NodeManager of symboltree's construct function.
852
+
853
+ Returns:
854
+ An instance of input node which has been appended to SymbolTree.
855
+ """
856
+ if param_name == "self":
857
+ return
858
+ # check param_name duplicated
859
+ if node_manager is None:
860
+ node_manager = self
861
+ for input_node in node_manager.get_input_nodes():
862
+ targets = input_node.get_targets()
863
+ if len(targets) != 1:
864
+ raise RuntimeError("targets should have 1 elements")
865
+ target: ScopedValue = targets[0]
866
+ if target.type != ValueType.NamingValue:
867
+ raise RuntimeError("target.type should equal to ValueType.NamingValue")
868
+ if target.scope != "":
869
+ raise RuntimeError("target.scope should be empty")
870
+ exist_param = target.value
871
+ if exist_param == param_name:
872
+ raise RuntimeError("input duplicated:", param_name)
873
+ input_node = Node.create_input_node(ast_node, param_name, default, name=f"input_{param_name}")
874
+ self.append_origin_field(input_node, node_manager)
875
+
876
+ def try_append_python_node(self, ast_scope: ast.AST, ast_node: ast.AST,
877
+ node_manager: NodeManager = None) -> Optional[Node]:
878
+ """
879
+ Try appending a python node to SymbolTree if 'ast_node' is not None and 'ast_node' is not Empty if 'ast_node' is
880
+ a list or a dict.
881
+ This method is called while building SymbolTree usually.
882
+
883
+ Args:
884
+ ast_scope (ast.AST): A ast node represents ast node of scope of node.
885
+ ast_node (ast.AST): A ast node represents ast node.
886
+ node_manager (NodeManager): NodeManager those asts belong to. Default: None, means those asts belong to
887
+ NodeManager of symboltree's construct function.
888
+
889
+ Returns:
890
+ An instance of python node if a new node has been appended to SymbolTree else None.
891
+ """
892
+ if ast_node is None:
893
+ return None
894
+ if isinstance(ast_node, (list, dict)) and not ast_node:
895
+ return None
896
+ return self.append_python_node(ast_scope, ast_node, node_manager)
897
+
898
+ def append_python_node(self, ast_scope: ast.AST, ast_node: ast.AST, node_manager: NodeManager = None) -> Node:
899
+ """
900
+ Append a python node to SymbolTree.
901
+ This method is called while building SymbolTree usually.
902
+
903
+ Args:
904
+ ast_scope (ast.AST): A ast node represents ast node of scope of node.
905
+ ast_node (ast.AST): A ast node represents ast node.
906
+ node_manager (NodeManager): NodeManager those asts belong to. Default: None, means those asts belong to
907
+ NodeManager of symboltree's construct function.
908
+
909
+ Returns:
910
+ An instance of python node which has been appended to SymbolTree.
911
+ """
912
+ logger.info("Ignoring unsupported node (%s) (%s).", type(ast_node).__name__, type(ast_scope).__name__)
913
+ node_name = type(ast_node).__name__
914
+ node = Node.create_python_node(ast_node, node_name)
915
+ if node_manager is None or node_manager is self:
916
+ NodeManager.append_python_node(self, node)
917
+ else:
918
+ node_manager.append_python_node(node)
919
+ return node
920
+
921
+ def set_output(self, return_value: str, arg_index: int, return_idx: int = 0,
922
+ node_manager: NodeManager = None) -> Node:
923
+ """
924
+ Update return value of return of forward method of network class.
925
+
926
+ Args:
927
+ return_value (str): A str represents new return value.
928
+ arg_index (int): A int indicates which value in return to be updated.
929
+ return_idx (int): A int indicates which return node to be updated. Default: 0.
930
+ node_manager (NodeManager): NodeManager those asts belong to. Default: None, means
931
+ symboltree's construct function.
932
+
933
+ Returns:
934
+ An instance of node represents return node after updated.
935
+ """
936
+ node_returns = NodeManager.get_returns(self) if node_manager is None else node_manager.get_returns()
937
+ if not node_returns:
938
+ raise RuntimeError("Current node_manager has no output")
939
+ if return_idx >= len(node_returns):
940
+ raise RuntimeError(f"return_idx {return_idx} should be less than return num {len(node_returns)}.")
941
+ node_return = node_returns[return_idx]
942
+ self.set_node_arg(node_return, arg_index, return_value)
943
+ return node_return
944
+
945
+ def erase_node(self, node_or_name: Union[Node, str]) -> Node:
946
+ """
947
+ Erase a node from SymbolTree.
948
+
949
+ Topological relation will be updated.
950
+
951
+ Args:
952
+ node_or_name (Union[Node, str]): An instance of node or a str represents name of node.
953
+
954
+ Returns:
955
+ An instance of node which has been erased from SymbolTree.
956
+
957
+ Raises:
958
+ RuntimeError: If 'node_or_name' is not in current SymbolTree.
959
+ RuntimeError: If erase corresponding ast node failed.
960
+ """
961
+
962
+ node = self._get_real_node(node_or_name)
963
+ if node is None:
964
+ raise RuntimeError("Node is not belong to current SymbolTree: ", node_or_name)
965
+ # erase node in NodeManager
966
+ node_manager = node.get_node_manager()
967
+
968
+ logger.debug(f"[earse]stree: {self.get_opt_cls_name()}, "
969
+ f"node_manager: {node_manager.get_manager_name()}, "
970
+ f"code: {astunparse.unparse(node.get_ast()).strip()}, "
971
+ f"node_name:{node.get_name()}")
972
+
973
+ if node_manager is self:
974
+ NodeManager.erase_node(self, node)
975
+ if isinstance(node, ControlFlow):
976
+ ret = AstModifier.earse_ast_of_control_flow(self._root_ast.body, node.get_ast(), node.is_orelse)
977
+ else:
978
+ ret = AstModifier.erase_ast_from_function(self._root_ast, node.get_ast())
979
+ if not ret:
980
+ raise RuntimeError(f"erase node failed, node {node.get_name()} not in function ast tree.")
981
+ else:
982
+ node_manager.erase_node(node)
983
+ node.set_belong_symbol_tree(None)
984
+ self._deleted_node.append(node.get_name())
985
+ return node
986
+
987
+ def replace(self, old_node: Node, new_nodes: [Node]) -> Node:
988
+ """
989
+ Replace an old_node with a node list.
990
+
991
+ Args:
992
+ old_node (Node): Node to be replaced.
993
+ new_nodes (list[Node]): Node list to replace in.
994
+
995
+ Returns:
996
+ Last node in new_nodes list.
997
+
998
+ Raises:
999
+ RuntimeError: If 'old_node' is isolated.
1000
+ RuntimeError: If 'old_node' is not belong to current SymbolTree.
1001
+ """
1002
+ real_old_node = self._get_real_node(old_node)
1003
+ if real_old_node is None:
1004
+ raise RuntimeError("Old node is not belong to current SymbolTree:", old_node)
1005
+ # insert new_nodes into node_manager
1006
+ node_manager = real_old_node.get_node_manager()
1007
+ # insert new_nodes into NodeManager
1008
+ base_node = old_node
1009
+ for node in new_nodes:
1010
+ self.insert_node(node, base_node, False, node_manager, True)
1011
+ base_node = node
1012
+ self.erase_node(old_node)
1013
+ return new_nodes[-1]
1014
+
1015
+ def set_node_arg(self, node: Union[Node, str], index: int, arg: Union[ScopedValue, str]):
1016
+ """
1017
+ Set argument of 'node'.
1018
+
1019
+ Args:
1020
+ node (Union[Node, str]): Node to be modified. Can be a node or name of node.
1021
+ index (int): Indicate which input being modified.
1022
+ arg (Union[ScopedValue, str]): New argument to been set.
1023
+
1024
+ Raises:
1025
+ RuntimeError: If 'node' is not belong to current SymbolTree.
1026
+ """
1027
+
1028
+ real_node = self._get_real_node(node)
1029
+ if real_node is None:
1030
+ raise RuntimeError("Node is not belong to current SymbolTree: ", node)
1031
+
1032
+ new_arg, old_arg = node.set_arg(arg, index)
1033
+ node.get_node_manager().on_update_arg(node, index, old_arg, new_arg)
1034
+
1035
+ def set_node_arg_by_node(self, dst_node: Union[Node, str], arg_idx: int, src_node: Union[Node, str],
1036
+ out_idx: Optional[int] = None):
1037
+ """
1038
+ Set argument of 'dst_node' by another Node.
1039
+
1040
+ Args:
1041
+ dst_node (Node): Node to be modified. Can be a node or name of node.
1042
+ arg_idx (int): Indicate which input being modified.
1043
+ src_node (Node): Node as new input. Can be a node or name of node.
1044
+ out_idx ([int, optional]): Indicate which output of 'src_node' as new input of 'dst_node'. Default is None
1045
+ which means use first output of 'node_to_link' as new input.
1046
+
1047
+ Raises:
1048
+ RuntimeError: If 'dst_node' is not belong to current SymbolTree.
1049
+ RuntimeError: If 'src_node' is not belong to current SymbolTree.
1050
+ RuntimeError: If 'out_idx' is out of range.
1051
+ RuntimeError: If 'src_node' has multi-outputs while 'out_idx' is None or 'out_idx' is not offered.
1052
+ """
1053
+
1054
+ real_dst_node = self._get_real_node(dst_node)
1055
+ if real_dst_node is None:
1056
+ raise RuntimeError("dst_node is not belong to current SymbolTree: ", dst_node)
1057
+ real_src_node = self._get_real_node(src_node)
1058
+ if real_src_node is None:
1059
+ raise RuntimeError("src_node is not belong to current SymbolTree: ", src_node)
1060
+
1061
+ targets = real_src_node.get_targets()
1062
+ if out_idx is None:
1063
+ if len(targets) != 1:
1064
+ raise RuntimeError("node should has one output when out_idx is not provided")
1065
+ out_idx = 0
1066
+ if out_idx >= len(targets):
1067
+ raise RuntimeError("out_idx out of range: ", out_idx)
1068
+ new_arg = targets[out_idx]
1069
+ real_dst_node.set_arg(new_arg, arg_idx)
1070
+ real_dst_node.get_node_manager().on_update_arg_by_node(real_dst_node, arg_idx, real_src_node, out_idx)
1071
+
1072
+ def unique_name(self, name: str):
1073
+ """Get a unique name in the symboltree"""
1074
+ return self._target_namer.get_name(name)
1075
+
1076
+ def unique_func_name(self, name: str):
1077
+ """Get a unique function name in the symboltree"""
1078
+ if not hasattr(self._origin_network, name):
1079
+ return name
1080
+ suffix = 1
1081
+ while hasattr(self._origin_network, f"{name}_{suffix}"):
1082
+ suffix += 1
1083
+ return f"{name}_{suffix}"
1084
+
1085
+ def set_node_target(self, node: Union[Node, str], index: int, target: Union[ScopedValue, str]):
1086
+ """
1087
+ Set target of `node` .
1088
+
1089
+ Args:
1090
+ node (Union[Node, str]): Node to be modified. Can be a node or name of node.
1091
+ index (int): Indicate which target being modified.
1092
+ arg (Union[ScopedValue, str]): New target to been set.
1093
+
1094
+ Raises:
1095
+ ValueError: If `node` is not belong to current SymbolTree.
1096
+ ValueError: If index of `node` 's target is greater than number of targets.
1097
+ """
1098
+
1099
+ real_node = self._get_real_node(node)
1100
+ if real_node is None:
1101
+ raise ValueError("Node is not belong to current SymbolTree: ", node)
1102
+ if isinstance(target, str):
1103
+ target = ScopedValue.create_naming_value(target)
1104
+ targets = node.get_targets()
1105
+ if index >= len(targets):
1106
+ raise ValueError(f"Index of node's target should be less than {len(targets)}, but got {index}")
1107
+ old_target = targets[index]
1108
+ targets[index] = target
1109
+ node.set_targets(targets)
1110
+ self._topo_mgr.on_update_target(node, index, old_target, target)
1111
+
1112
+ def all_nodes(self, subtree_nodes: bool = True):
1113
+ """
1114
+ Get all nodes including nodes in CallFunction node, CellContainer node and sub symbol tree.
1115
+
1116
+ Args:
1117
+ subtree_nodes (bool): Whether include nodes in subtree. Default: True.
1118
+
1119
+ Returns:
1120
+ A list of nodes.
1121
+ """
1122
+ nodes = []
1123
+ node_managers = [self]
1124
+ while node_managers:
1125
+ node_manager = node_managers.pop()
1126
+ nodes.extend(node_manager.nodes())
1127
+ for node in node_manager.nodes():
1128
+ if isinstance(node, NodeManager):
1129
+ node_managers.append(node)
1130
+ if subtree_nodes:
1131
+ for tree_node in self.get_tree_nodes():
1132
+ stree = tree_node.symbol_tree
1133
+ nodes.extend(stree.all_nodes())
1134
+ return nodes
1135
+
1136
+ def get_node_from_name(self, node_name: str):
1137
+ """
1138
+ Get node from all NodeManagers in current symbol tree by `node_name`.
1139
+
1140
+ Args:
1141
+ node_name (str): A str represents name of node as key of query.
1142
+
1143
+ Returns:
1144
+ An instance of Node if found else None.
1145
+ """
1146
+ node_managers = [self]
1147
+ while node_managers:
1148
+ node_manager = node_managers.pop()
1149
+ node = node_manager.get_node(node_name)
1150
+ if node:
1151
+ return node
1152
+ for node in node_manager.nodes():
1153
+ if isinstance(node, NodeManager):
1154
+ node_managers.append(node)
1155
+ return None
1156
+
1157
+ def get_node_tabulate(self, all_nodes: bool = False) -> str:
1158
+ """
1159
+ Get nodes information and nodes' topological relations.
1160
+
1161
+ Args:
1162
+ all_nodes (bool): Print nodes out of construct functions, such as nodes in CallFunction
1163
+ nodes, CellContainer nodes and sub symbol trees.
1164
+
1165
+ Returns:
1166
+ String of nodes' information and topological relations.
1167
+ """
1168
+ try:
1169
+ from tabulate import tabulate # pylint: disable=unused-import,reportMissingModuleSource
1170
+ except ImportError:
1171
+ logger.warning("print_node_tabulate relies on the library `tabulate`, "
1172
+ "which could not be found on this machine. Run `pip "
1173
+ "install tabulate` to install the library.")
1174
+ return ""
1175
+ dump_str = NodeManager.dump(self, self.get_manager_name())
1176
+ if all_nodes:
1177
+ node_managers = [self]
1178
+ while node_managers:
1179
+ node_manager = node_managers.pop()
1180
+ for node in node_manager.nodes():
1181
+ if isinstance(node, NodeManager):
1182
+ dump_str += node.dump(SymbolTree.get_node_full_name(node))
1183
+ node_managers.append(node)
1184
+ for tree_node in self.get_tree_nodes():
1185
+ stree = tree_node.symbol_tree
1186
+ dump_str += stree.get_node_tabulate(all_nodes)
1187
+ return dump_str
1188
+
1189
+ def dump(self):
1190
+ """Dump graph."""
1191
+ dump_st = SymbolTreeDumper(self)
1192
+ dump_st.dump()
1193
+
1194
+ def check_body_exist(self, body, code_bodies):
1195
+ """Check whether body already exist in code_bodies"""
1196
+ # Check import ast node exist by saving import code string to self._tmp_import_strs
1197
+ if isinstance(body, (ast.Import, ast.ImportFrom, ast.Expr)):
1198
+ import_str = astunparse.unparse(body)
1199
+ if import_str in self._tmp_import_strs:
1200
+ return True
1201
+ self._tmp_import_strs.append(import_str)
1202
+ return False
1203
+
1204
+ # Check ClassDef ast node exist by using AstClassFinder
1205
+ if isinstance(body, ast.ClassDef):
1206
+ if sys.version_info >= (3, 9):
1207
+ class_finder = AstClassFinder(ast.Module(body=code_bodies, type_ignores=[]))
1208
+ else:
1209
+ class_finder = AstClassFinder(ast.Module(body=code_bodies))
1210
+ results = class_finder.find_all(body.name)
1211
+ return bool(results)
1212
+
1213
+ # Check FunctionDef ast node exist by using AstFunctionFinder
1214
+ if isinstance(body, ast.FunctionDef):
1215
+ if sys.version_info >= (3, 9):
1216
+ function_finder = AstFunctionFinder(ast.Module(body=code_bodies, type_ignores=[]))
1217
+ else:
1218
+ function_finder = AstFunctionFinder(ast.Module(body=code_bodies))
1219
+ results = function_finder.find_all(body.name)
1220
+ return bool(results)
1221
+
1222
+ return False
1223
+
1224
+ def deduplicate_unmodified_stree(self, code_bodies):
1225
+ """
1226
+ Init function may be different even if stree is not modified manually, when subnets in stree is
1227
+ initialized by different arguments.
1228
+ In this case, we need to wait for code_bodies being fully generated, so that the name of subnets
1229
+ will be updated, then we can deduplicate again according to ast of init function.
1230
+ """
1231
+ # prepare AstClassFinder and AstReplacer
1232
+ if sys.version_info >= (3, 9):
1233
+ class_finder = AstClassFinder(ast.Module(body=code_bodies, type_ignores=[]))
1234
+ name_replacer = AstReplacer(ast.Module(body=code_bodies, type_ignores=[]))
1235
+ else:
1236
+ class_finder = AstClassFinder(ast.Module(body=code_bodies))
1237
+ name_replacer = AstReplacer(ast.Module(body=code_bodies))
1238
+ # deduplicate all unmodified strees in self._tmp_unmodified_strees
1239
+ deduplicated = False
1240
+ for _, unmodified_strees in self._tmp_unmodified_strees.items():
1241
+ if len(unmodified_strees) <= 1:
1242
+ continue
1243
+ init_func_codes = [astunparse.unparse(stree.get_init_func_ast()) for stree in unmodified_strees]
1244
+ # If the index of an element is not its own, it means that it is a duplicate element
1245
+ to_be_erase = []
1246
+ for idx, code in enumerate(init_func_codes):
1247
+ first_idx = init_func_codes.index(code)
1248
+ if first_idx != idx:
1249
+ first_stree_cls_name = unmodified_strees[first_idx].get_opt_cls_name()
1250
+ duplicated_stree_cls_name = unmodified_strees[idx].get_opt_cls_name()
1251
+ logger.debug(f"replace stree:{duplicated_stree_cls_name} to {first_stree_cls_name}.")
1252
+ # delete duplicated class from code_bodies
1253
+ results = class_finder.find_all(duplicated_stree_cls_name)
1254
+ for ast_cls in results:
1255
+ code_bodies.remove(ast_cls)
1256
+ # replace name of duplicated class in code_bodies to first_stree_cls_name
1257
+ name_replacer.replace_all(duplicated_stree_cls_name, first_stree_cls_name)
1258
+ # record deduplicated stree
1259
+ to_be_erase.append(idx)
1260
+ deduplicated = True
1261
+ # remove class in self._tmp_unmodified_strees
1262
+ for idx in reversed(to_be_erase):
1263
+ unmodified_strees.pop(idx)
1264
+
1265
+ # the name of subnets is updated, so we need to deduplicate again.
1266
+ if deduplicated:
1267
+ self._tmp_replacers.append(name_replacer)
1268
+ self.deduplicate_unmodified_stree(code_bodies)
1269
+
1270
+ def update_unmodified_stree(self, stree, code_bodies) -> bool:
1271
+ """
1272
+ For the unmodified symbol tree, only one definition code remains in the generated code.
1273
+ Everywhere else calling this symbol tree will use the class in this definition code.
1274
+ """
1275
+ # all modified ast.ClassDef will be exported to code
1276
+ if stree.is_modified():
1277
+ logger.debug(f"stree:{stree.get_opt_cls_name()} is modified.")
1278
+ return False
1279
+ # all un-modified ast.ClassDef only keep one instance
1280
+ unmodified_strees = self._tmp_unmodified_strees.get(type(stree.get_origin_network()))
1281
+ if not unmodified_strees:
1282
+ self._tmp_unmodified_strees[type(stree.get_origin_network())] = [stree]
1283
+ logger.debug(f"stree:{stree.get_opt_cls_name()} is the first stree.")
1284
+ return False
1285
+ # Init function may be different even if stree is not modified, when subnets in stree is
1286
+ # initialized by different arguments.
1287
+ first_stree = unmodified_strees[0]
1288
+ first_stree_cls_name = first_stree.get_opt_cls_name()
1289
+ if astunparse.unparse(stree.get_init_func_ast()) != astunparse.unparse(first_stree.get_init_func_ast()):
1290
+ # init ast may be updated after inserting subtrees of stree, so we need to save unmodified strees
1291
+ # and deduplicate later
1292
+ self._tmp_unmodified_strees[type(stree.get_origin_network())].append(stree)
1293
+ logger.debug(f"init func different, stree:{stree.get_opt_cls_name()}, first_stree:{first_stree_cls_name}.")
1294
+ return False
1295
+ # Un-modified ast.ClassDef already exist in code_bodies,
1296
+ # replace class name to class name of first un-modified ast.ClassDef.
1297
+ if sys.version_info >= (3, 9):
1298
+ replacer = AstReplacer(ast.Module(body=code_bodies, type_ignores=[]))
1299
+ else:
1300
+ replacer = AstReplacer(ast.Module(body=code_bodies))
1301
+ logger.debug(f"replace stree:{stree.get_opt_cls_name()} to {first_stree_cls_name}.")
1302
+ replacer.replace_all(stree.get_class_ast().name, first_stree_cls_name)
1303
+ self._tmp_replacers.append(replacer)
1304
+ return True
1305
+
1306
+ def init_code_bodies(self, code_bodies: list) -> int:
1307
+ """Init code bodied"""
1308
+ # Add basic imports
1309
+ code_bodies.append(ast.Import([ast.alias(name='sys', asname=None)]))
1310
+ code_bodies.append(ast.Import([ast.alias(name='mindspore', asname=None)]))
1311
+ code_bodies.append(ast.ImportFrom(module='mindspore', names=[ast.alias(name='nn', asname=None)], level=0))
1312
+ code_bodies.append(ast.ImportFrom(module='mindspore.nn', names=[ast.alias(name='Cell', asname=None)], level=0))
1313
+ code_bodies.append(ast.ImportFrom(module='mindspore.ops',
1314
+ names=[ast.alias(name='functional', asname='F')], level=0))
1315
+ code_bodies.append(ast.Expr(ast.Name("#", ast.Load())))
1316
+ # Add user custom codes into code_bodies
1317
+ custom_codes = self.get_custom_codes()
1318
+ for code_ast in custom_codes:
1319
+ code_bodies.append(code_ast)
1320
+ code_bodies.append(ast.Expr(ast.Name("#", ast.Load())))
1321
+ return len(code_bodies)
1322
+
1323
+ def convert_stree_to_code_bodies(self, stree: 'SymbolTree', code_bodies: list, dividing_pos=0) -> int:
1324
+ """
1325
+ Convert nodes in stree to code_bodies
1326
+ - Add external function asts into code_bodies
1327
+ - Add father class asts into code_bodies
1328
+ - Add import asts of symbol tree into code_bodies
1329
+ - Add user custom codes into code_bodies
1330
+ - Add class asts of symbol tree into code_bodies
1331
+ - Add subtrees to code_bodies
1332
+ """
1333
+ insert_pos = dividing_pos
1334
+ # Add external asts into code_bodies
1335
+ for ast_func, import_asts in reversed(stree.get_external_ast().items()):
1336
+ if self.check_body_exist(ast_func, code_bodies):
1337
+ continue
1338
+ # add imports of external_ast
1339
+ self._tmp_import_strs.clear()
1340
+ for ast_import in import_asts:
1341
+ if not self.check_body_exist(ast_import, code_bodies):
1342
+ code_bodies.insert(insert_pos, ast_import)
1343
+ insert_pos += 1
1344
+ # add external_ast
1345
+ code_bodies.insert(insert_pos, ast_func)
1346
+ insert_pos += 1
1347
+ # add divide
1348
+ code_bodies.insert(insert_pos, ast.Expr(ast.Name("#", ast.Load())))
1349
+ insert_pos += 1
1350
+
1351
+ # Add father class asts into code_bodies
1352
+ for ast_class, import_asts in stree.get_father_class_ast().items():
1353
+ if self.check_body_exist(ast_class, code_bodies):
1354
+ continue
1355
+ # add imports of father class
1356
+ self._tmp_import_strs.clear()
1357
+ for ast_import in import_asts:
1358
+ if not self.check_body_exist(ast_import, code_bodies):
1359
+ code_bodies.insert(insert_pos, ast_import)
1360
+ insert_pos += 1
1361
+ # add ast of father class
1362
+ code_bodies.insert(insert_pos, ast_class)
1363
+ insert_pos += 1
1364
+ # add divide
1365
+ code_bodies.insert(insert_pos, ast.Expr(ast.Name("#", ast.Load())))
1366
+ insert_pos += 1
1367
+
1368
+ # external functions and father class are above the dividing_pos to support deduplication.
1369
+ dividing_pos = insert_pos
1370
+
1371
+ # Add import asts of symbol tree into code_bodies
1372
+ self._tmp_import_strs.clear()
1373
+ for body in stree.get_import_asts():
1374
+ if not self.check_body_exist(body, code_bodies):
1375
+ code_bodies.insert(insert_pos, body)
1376
+ insert_pos += 1
1377
+
1378
+ # Add class asts of symbol tree into code_bodies
1379
+ if stree.get_module_ast():
1380
+ for body in stree.get_module_ast().body:
1381
+ if self.check_body_exist(body, code_bodies):
1382
+ continue
1383
+ code_bodies.insert(insert_pos, body)
1384
+ insert_pos += 1
1385
+
1386
+ # add divide
1387
+ code_bodies.insert(insert_pos, ast.Expr(ast.Name("#", ast.Load())))
1388
+ insert_pos += 1
1389
+
1390
+ # Add subtrees to code_bodies
1391
+ for node in stree.get_tree_nodes():
1392
+ sub_stree = node.symbol_tree
1393
+ # For the unmodified class, update class name to name of first class
1394
+ if self.update_unmodified_stree(sub_stree, code_bodies):
1395
+ continue
1396
+ dividing_pos = self.convert_stree_to_code_bodies(node.symbol_tree, code_bodies, dividing_pos)
1397
+
1398
+ # return new dividing position
1399
+ return dividing_pos
1400
+
1401
+ def get_code(self) -> str:
1402
+ """
1403
+ Get source code of modified network.
1404
+
1405
+ Returns:
1406
+ A str represents source code of modified network.
1407
+ """
1408
+ self._tmp_import_strs.clear()
1409
+ self._tmp_unmodified_strees.clear()
1410
+ self._tmp_replacers.clear()
1411
+ code_bodies = []
1412
+ begin_pos = self.init_code_bodies(code_bodies)
1413
+ self.convert_stree_to_code_bodies(self, code_bodies, begin_pos)
1414
+ self.deduplicate_unmodified_stree(code_bodies)
1415
+ if sys.version_info >= (3, 9):
1416
+ gencode_module = ast.Module(body=code_bodies, type_ignores=[])
1417
+ else:
1418
+ gencode_module = ast.Module(body=code_bodies)
1419
+ SymbolTree._remove_unused_import(gencode_module)
1420
+ self._process_duplicate_name_modules(gencode_module)
1421
+ SymbolTree._remove_duplicated_import(gencode_module)
1422
+ SymbolTree._remove_arg_annotations(gencode_module)
1423
+ ast.fix_missing_locations(self._module_ast)
1424
+ code = astunparse.unparse(gencode_module)
1425
+ # Revert the class name to its original state
1426
+ for replacer in self._tmp_replacers:
1427
+ replacer.undo_all()
1428
+ return code
1429
+
1430
+ def get_network(self):
1431
+ """
1432
+ Get modified network.
1433
+
1434
+ Returns:
1435
+ A network object.
1436
+ """
1437
+ cls = self._get_cls_through_file()
1438
+ new_net = cls(self._origin_network)
1439
+ self._merge_origin_property(new_net)
1440
+ # update parameters' names to fix duplicated names bug
1441
+ # which occurs after inserting cell to celllist/sequentialcell
1442
+ new_net.update_parameters_name()
1443
+ return new_net
1444
+
1445
+ def set_saved_file_name(self, file_name: str):
1446
+ if file_name.endswith(".py"):
1447
+ self._saved_file_name = file_name
1448
+ else:
1449
+ self._saved_file_name = file_name + ".py"
1450
+
1451
+ def get_saved_file_name(self):
1452
+ return self._saved_file_name
1453
+
1454
+ def save_network_to_file(self):
1455
+ abs_path = os.path.realpath(self._saved_file_name)
1456
+ if os.path.isfile(abs_path):
1457
+ os.remove(abs_path)
1458
+ with os.fdopen(os.open(self._saved_file_name, os.O_WRONLY | os.O_CREAT, stat.S_IRWXU), 'wb') as f:
1459
+ source = self.get_code()
1460
+ f.write(source.encode('utf-8'))
1461
+ f.flush()
1462
+
1463
+
1464
+ def flatten_nodes(self, node, erase_another_branch: bool = False, erase_nodes_after_return: bool = False):
1465
+ """Flatten nodes in ControlFlow node."""
1466
+ if not isinstance(node, ControlFlow):
1467
+ raise ValueError(f"For flatten_nodes, the type of node can only be ControlFlow, but got {type(node)}.")
1468
+ upper_node_manager = node.get_node_manager()
1469
+ if isinstance(upper_node_manager, (SymbolTree, CallFunction)):
1470
+ ast_bodies = upper_node_manager.get_manager_ast().body
1471
+ elif isinstance(upper_node_manager, ControlFlow):
1472
+ ast_bodies = upper_node_manager.get_manager_ast()
1473
+ else:
1474
+ raise ValueError("For flatten_nodes, the node can only be contained in [SymbolTree, CallFunction, "
1475
+ f"ControlFlow], but the node is in {type(upper_node_manager)}.")
1476
+ base_node = node.orelse_node if node.orelse_node else node.body_node
1477
+ for n in node.nodes()[:]:
1478
+ self.erase_node(n)
1479
+ self.insert_node(n, base_node, False, upper_node_manager, False)
1480
+ AstModifier.insert_ast_to_bodies(ast_bodies, n.get_ast(), base_node.get_ast(), False)
1481
+ base_node = n
1482
+ self.erase_node(node)
1483
+ # remove another branch
1484
+ if erase_another_branch:
1485
+ if node.is_orelse:
1486
+ self.erase_node(node.body_node)
1487
+ elif node.orelse_node is not None:
1488
+ self.erase_node(node.orelse_node)
1489
+ # remove nodes after return node
1490
+ if erase_nodes_after_return:
1491
+ has_return = False
1492
+ for n in upper_node_manager.nodes():
1493
+ if has_return:
1494
+ logger.warning(f"Node {n.get_name()} which is behind the flatten return node is "
1495
+ f"automatically erased.")
1496
+ self.erase_node(n)
1497
+ elif n.get_node_type() == NodeType.Output:
1498
+ has_return = True
1499
+
1500
+ def eval_ast_result(self, ast_node: ast.AST) -> (bool, bool):
1501
+ """
1502
+ Eval ast_node and get result, only used in control flow node.
1503
+ """
1504
+ # ast.Constant can be check without eval
1505
+ if isinstance(ast_node, ast.Constant):
1506
+ return True, bool(ast.value)
1507
+ # Get the module where the code of ast_node is located
1508
+ file_path = inspect.getfile(type(self.get_origin_network()))
1509
+ module = None
1510
+ for m in list(sys.modules.values()):
1511
+ if hasattr(m, "__file__") and m.__file__ and os.path.normcase(m.__file__) == os.path.normcase(file_path):
1512
+ module = m
1513
+ break
1514
+ if not module:
1515
+ logger.warning("Failed to get module of ast_node.")
1516
+ return False, False
1517
+ # eval ast_node and get result
1518
+ logger.debug(f"Eval ast node: {astunparse.unparse(ast_node)}")
1519
+ ast_expr = ast.Expression(ast_node)
1520
+ ast_expr = ast.fix_missing_locations(ast_expr)
1521
+ try:
1522
+ # eval with ast make this operation free of instruction injection
1523
+ # pylint: disable=eval-used
1524
+ result = eval(compile(ast_expr, "eval_ast_result", "eval"), {**globals(), **module.__dict__}, locals())
1525
+ except Exception as e: # pylint: disable=broad-except
1526
+ logger.debug(f"Cannot get result of ast_node by eval, err:{e}")
1527
+ return False, False
1528
+ logger.debug(f"Eval ast result success, result: {result}")
1529
+ return True, bool(result)
1530
+
1531
+ def flatten_static_if_control_flow(self):
1532
+ """
1533
+ For static if control flow, flatten codes in branch which will be executed and erase another branch.
1534
+ """
1535
+ for node in self.all_nodes()[:]:
1536
+ if not node.get_belong_symbol_tree():
1537
+ # the node has been erased
1538
+ continue
1539
+ if isinstance(node, ControlFlow) and node.test_result is not None:
1540
+ stree = node.get_belong_symbol_tree()
1541
+ if node.test_result:
1542
+ stree.flatten_nodes(node.body_node, True, True)
1543
+ else:
1544
+ if node.orelse_node is not None:
1545
+ stree.flatten_nodes(node.orelse_node, True, True)
1546
+ else:
1547
+ stree.erase_node(node.body_node)
1548
+
1549
+ def add_custom_codes(self, code: str):
1550
+ """Add user custom codes"""
1551
+ code_ast = ast.parse(code)
1552
+ self._custom_codes.extend(code_ast.body)
1553
+
1554
+ def get_custom_codes(self) -> List[ast.AST]:
1555
+ """Add user custom codes"""
1556
+ return self._custom_codes
1557
+
1558
+ def save_file_path_to_sys(self, level_num, file_path, belonging_ast: ast.AST = None):
1559
+ """
1560
+ Save file path into stree._import_asts. `level_num` is used when level exist in ast.ImportFrom.
1561
+
1562
+ When level_num = 0(e.g. from xxx import yyy), current path will be saved.
1563
+ When level_num = 1(e.g. from .xxx import yyy), current path will be saved.
1564
+ When level_num = 2(e.g. from ..xxx import yyy), the path one level above the current path will be saved.
1565
+ """
1566
+ file_path = os.path.dirname(os.path.realpath(file_path))
1567
+ file_path = os.path.normcase(file_path)
1568
+ file_path = os.path.normpath(file_path)
1569
+ if level_num > 1:
1570
+ for _ in range(level_num - 1):
1571
+ file_path = os.path.dirname(file_path)
1572
+ sys_path_append_ast = ast.parse(f"sys.path.insert(0, r'{file_path}')").body[0]
1573
+ # add imports to import_asts of belonging_ast
1574
+ import_asts = self._get_imports_list_of_ast(belonging_ast)
1575
+ import_asts.append(ast.Import([ast.alias(name='sys', asname=None)]))
1576
+ import_asts.append(sys_path_append_ast)
1577
+
1578
+ def save_imports_from_file(self, file_path, belonging_ast: ast.AST = None):
1579
+ """Save imports from file"""
1580
+ self.save_file_path_to_sys(0, file_path, belonging_ast)
1581
+ if not os.path.exists(file_path):
1582
+ raise RuntimeError(f"For MindSpore Rewrite, in module parser, file {file_path} not exist.")
1583
+ with open(file_path, "r", encoding="utf-8") as f:
1584
+ source_code = f.read()
1585
+ import_nodes = AstImportFinder(ast.parse(dedent(source_code))).get_import_node()
1586
+ if not import_nodes:
1587
+ return
1588
+ # add imports to import_asts of belonging_ast
1589
+ import_asts = self._get_imports_list_of_ast(belonging_ast)
1590
+ for import_node in import_nodes:
1591
+ import_node = SymbolTree._process_relative_import(import_node, file_path)
1592
+ if import_node:
1593
+ import_asts.append(import_node)
1594
+
1595
+ def add_import(self, module: types.ModuleType, name: str, belonging_ast: None):
1596
+ """add codes: from `module` import `name`"""
1597
+ if not isinstance(module, types.ModuleType):
1598
+ raise TypeError(f"For add_import, module should be ModuleType, but got {type(module)}")
1599
+ if not hasattr(module, name):
1600
+ logger.info(f"module {module.__name__} doesn't have attr '{name}', it may be a local variable.")
1601
+ return
1602
+ # add imports to import_asts of belonging_ast
1603
+ import_asts = self._get_imports_list_of_ast(belonging_ast)
1604
+ if module.__name__ == "__main__":
1605
+ # get attr from module instead of import to avoid duplicate execution of __main__ module
1606
+ code = f"{name} = getattr(sys.modules['__main__'], '{name}')"
1607
+ code_ast = ast.parse(code).body[0]
1608
+ import_asts.append(code_ast)
1609
+ elif module.__name__ == "builtins":
1610
+ # built-in functions are not need to be imported
1611
+ pass
1612
+ else:
1613
+ # add import of obj to ast
1614
+ func_file_path = inspect.getabsfile(module)
1615
+ func_file_path = os.path.normcase(func_file_path)
1616
+ prefix_paths = []
1617
+ for path in sys.path:
1618
+ path = os.path.normcase(path)
1619
+ if func_file_path.startswith(path):
1620
+ prefix_paths.append(path)
1621
+ prefix_paths.sort(key=len, reverse=True)
1622
+ for path in prefix_paths:
1623
+ import_path = func_file_path[len(path):]
1624
+ import_str = import_path.replace(os.path.sep, '.')
1625
+ import_str = import_str[1:] # remove first '.'
1626
+ mod = import_str.rsplit('.', 1)[0]
1627
+ if SymbolTree._check_import(func_file_path[:len(path)], mod):
1628
+ import_node = ast.ImportFrom(module=mod, names=[ast.alias(name=name, asname=None)], level=0)
1629
+ import_asts.append(import_node)
1630
+ break
1631
+ else:
1632
+ self.save_file_path_to_sys(0, func_file_path, belonging_ast)
1633
+ mod = os.path.basename(func_file_path).rsplit('.')[0]
1634
+ import_node = ast.ImportFrom(module=mod, names=[ast.alias(name=name, asname=None)], level=0)
1635
+ import_asts.append(import_node)
1636
+
1637
+ def _get_imports_list_of_ast(self, belonging_ast: ast.AST):
1638
+ # get import_asts of belonging_ast
1639
+ import_asts = self._import_asts
1640
+ if belonging_ast is not None:
1641
+ if belonging_ast in self._father_class_ast:
1642
+ import_asts = self._father_class_ast.get(belonging_ast)
1643
+ elif belonging_ast in self._external_ast:
1644
+ import_asts = self._external_ast.get(belonging_ast)
1645
+ return import_asts
1646
+
1647
+ def _get_real_node(self, node_or_name: Union[Node, str]) -> Optional[Node]:
1648
+ if isinstance(node_or_name, str):
1649
+ return self.get_node(node_or_name)
1650
+ return node_or_name
1651
+
1652
+ def _handle_custom_obj_in_normalized_args(self, node: Node):
1653
+ """
1654
+ Convert CustomObjValue type argument to NamingValue type argument by storing custom object to obj.
1655
+
1656
+ Args:
1657
+ node (Node): A Node whose arguments and keyword arguments to be handled.
1658
+ """
1659
+ normalized_args: {str, ScopedValue} = {}
1660
+ for key, value in node.get_normalized_args().items():
1661
+ if not isinstance(value, ScopedValue):
1662
+ raise TypeError("value should be ScopedValue, got: ", type(value))
1663
+ if value.type == ValueType.CustomObjValue:
1664
+ # Save CustomObjValue into _origin_network(i.e. obj): obj.arg_name = CustomObjValue
1665
+ arg_name = self.unique_name(f"arg_{type(value.value).__name__}")
1666
+ setattr(self._origin_network, arg_name, value.value)
1667
+ # Add new code to __init__(): self.arg_name = obj.arg_name
1668
+ new_ast = ast.parse(f"self.{arg_name} = obj.{arg_name}").body[0]
1669
+ self._init_func_ast.body.append(new_ast)
1670
+ # Modify node's normalized_args: CustomObjValue -> self.arg_name
1671
+ normalized_args[key] = ScopedValue.create_naming_value(arg_name, "self")
1672
+ else:
1673
+ normalized_args[key] = value
1674
+ node.set_normalized_args(normalized_args)
1675
+
1676
+ def _get_cls_through_file(self):
1677
+ """
1678
+ Load rewritten network class of current SymbolTree.
1679
+ 1. Get source code of current SymbolTree.
1680
+ 2. Saving source code to a tempfile.
1681
+ 3. Import rewritten network class using "__import__" function.
1682
+
1683
+ Returns:
1684
+ A class handle.
1685
+ """
1686
+ file_path = os.getcwd()
1687
+ file_path = os.path.join(file_path, "rewritten_network")
1688
+ if not os.path.exists(file_path):
1689
+ try:
1690
+ os.mkdir(file_path, mode=0o700)
1691
+ except FileExistsError:
1692
+ pass
1693
+ file_name = f"{self._opt_cls_name}_{id(self)}.py"
1694
+ network_file = os.path.join(file_path, file_name)
1695
+ with os.fdopen(os.open(network_file, os.O_WRONLY | os.O_CREAT, stat.S_IRWXU), 'wb') as f:
1696
+ source = self.get_code()
1697
+ f.write(source.encode('utf-8'))
1698
+ f.flush()
1699
+ os.fsync(f)
1700
+ tmp_module_path, tmp_module_file = os.path.split(network_file)
1701
+ tmp_module_name = tmp_module_file[:-3]
1702
+ sys.path.append(tmp_module_path)
1703
+ tmp_module = None
1704
+
1705
+ i = 0
1706
+ while not tmp_module:
1707
+ spec = importlib.util.spec_from_file_location(tmp_module_name, network_file)
1708
+ if spec:
1709
+ tmp_module = importlib.util.module_from_spec(spec)
1710
+ spec.loader.exec_module(tmp_module)
1711
+ else:
1712
+ logger.warning(f"load module {tmp_module_name} failed, retrying.")
1713
+ if i > 10:
1714
+ break
1715
+ time.sleep(0.5)
1716
+ i += 1
1717
+ if not tmp_module:
1718
+ raise ImportError(f"load module {tmp_module_name} failed.")
1719
+ # Save new module to sys.modules to support inspect.getsource().
1720
+ sys.modules[tmp_module_name] = tmp_module
1721
+ network_cls = getattr(tmp_module, self._opt_cls_name)
1722
+ if network_cls is None:
1723
+ raise RuntimeError("Can not find network class:", self._opt_cls_name)
1724
+ return network_cls
1725
+
1726
+ def _on_change(self, event: Event):
1727
+ self._modified = True
1728
+ self.changed(event)
1729
+
1730
+ def _cal_difference_set(self, input, other):
1731
+ """Calculate different set of two sets."""
1732
+ set1 = set(input)
1733
+ set2 = set(other)
1734
+ return set1 - set2
1735
+
1736
+ def _merge_origin_property(self, new_net):
1737
+ """Merge property of two network."""
1738
+ tmp = self._cal_difference_set(dir(self._origin_network), dir(new_net))
1739
+ new_attr_names = self._cal_difference_set(tmp, self._deleted_field.keys())
1740
+ for name in new_attr_names:
1741
+ setattr(new_net, name, getattr(self._origin_network, name))
1742
+ # merger cells
1743
+ cells = self._cal_difference_set(self._origin_network.name_cells().keys(), new_net.name_cells().keys())
1744
+ cells = self._cal_difference_set(cells, self._deleted_node)
1745
+ for c in cells:
1746
+ new_net.insert_child_to_cell(c, self._origin_network.name_cells()[c])
1747
+ # merge primitives
1748
+ # pylint: disable=protected-access
1749
+ primitives = self._cal_difference_set(self._origin_network._primitives.keys(), new_net._primitives.keys())
1750
+ for p in primitives:
1751
+ new_net._primitives[p] = self._origin_network._primitives[p] # pylint: disable=protected-access
1752
+
1753
+ def _process_duplicate_name_modules(self, module_ast: ast.Module):
1754
+ """Adjust names of imported modules with same name and different import path."""
1755
+ # {name1: [path1, path2, ...], ...}
1756
+ name_path_dict: Dict[str, List[str]] = {}
1757
+ # names of modules need to be suffixed: {name1: suffixed_name1, ...}
1758
+ name_need_suffix: Dict[str, str] = {}
1759
+ # used to record replace actions in ast.ImportFrom
1760
+ import_replacer = AstReplacer(None)
1761
+ self._tmp_replacers.append(import_replacer)
1762
+
1763
+ def suffix_alias(alias: ast.alias, suffix: int):
1764
+ """suffix the name of alias in ast.ImportFrom"""
1765
+ new_name = f"{alias.asname}_{suffix}" if alias.asname else f"{alias.name}_{suffix}"
1766
+ import_replacer._trace.append((alias, 'asname', alias.asname, new_name)) # pylint: disable=protected-access
1767
+ alias.asname = new_name
1768
+ return new_name
1769
+
1770
+ def is_divider(ast_node):
1771
+ """judge if ast node is divider of new class or function by checking ast.Expr of '#'."""
1772
+ return isinstance(ast_node, ast.Expr) and isinstance(ast_node.value, ast.Name) and ast_node.value.id == '#'
1773
+
1774
+ def record_imports(ast_node: ast.ImportFrom):
1775
+ """record name and path of imported modules to find the duplicate name modules."""
1776
+ for alias in ast_node.names[:]:
1777
+ name = alias.asname if alias.asname else alias.name
1778
+ if name == '*':
1779
+ continue
1780
+ # current name is firstly imported, just record it
1781
+ if name not in name_path_dict:
1782
+ name_path_dict[name] = [ast_node.module]
1783
+ continue
1784
+ # current name is imported before, check whether it is a duplicated name
1785
+ for idx, path in enumerate(name_path_dict[name]):
1786
+ if path.startswith(ast_node.module):
1787
+ # e.g. origin code is 'from a.b.c import A' and new code is 'from a.b import A'
1788
+ # then we update name_path_dict[name][idx] from 'a.b.c' to 'a.b' and update name to A_{idx}
1789
+ name_path_dict[name][idx] = ast_node.module
1790
+ if idx > 0:
1791
+ name_need_suffix[name] = suffix_alias(alias, idx)
1792
+ break
1793
+ elif ast_node.module.startswith(path):
1794
+ # e.g. origin code is 'from a.b import A' and new code is 'from a.b.c import A'
1795
+ # then we just need to update name to A_{idx}
1796
+ if idx > 0:
1797
+ name_need_suffix[name] = suffix_alias(alias, idx)
1798
+ break
1799
+ else:
1800
+ # current name is imported from a new path, save the path and update the name
1801
+ name_path_dict[name].append(ast_node.module)
1802
+ name_need_suffix[name] = suffix_alias(alias, len(name_path_dict[name]) - 1)
1803
+
1804
+ def suffix_names_in_ast(ast_node: Union[ast.ClassDef, ast.FunctionDef]):
1805
+ """suffix names in ast.ClassDef or ast.FunctionDef"""
1806
+ if not name_need_suffix:
1807
+ return
1808
+ name_replacer = AstReplacer(ast_node)
1809
+ self._tmp_replacers.append(name_replacer)
1810
+ for name, new_name in name_need_suffix.items():
1811
+ name_replacer.replace_all(name, new_name)
1812
+
1813
+ for ast_node in module_ast.body:
1814
+ if isinstance(ast_node, ast.ImportFrom):
1815
+ record_imports(ast_node)
1816
+ if isinstance(ast_node, (ast.ClassDef, ast.FunctionDef)):
1817
+ suffix_names_in_ast(ast_node)
1818
+ if is_divider(ast_node):
1819
+ name_need_suffix.clear()