mindspore 2.4.0__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (1406) hide show
  1. mindspore/.commit_id +1 -0
  2. mindspore/ConcurrencyCheck.dll +0 -0
  3. mindspore/CppBuildInsights.dll +0 -0
  4. mindspore/CppCoreCheck.dll +0 -0
  5. mindspore/EnumIndex.dll +0 -0
  6. mindspore/EspXEngine.dll +0 -0
  7. mindspore/HResultCheck.dll +0 -0
  8. mindspore/KernelTraceControl.dll +0 -0
  9. mindspore/LocalESPC.dll +0 -0
  10. mindspore/Microsoft.Diagnostics.Tracing.EventSource.dll +0 -0
  11. mindspore/Microsoft.VisualStudio.RemoteControl.dll +0 -0
  12. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  13. mindspore/Microsoft.VisualStudio.Utilities.Internal.dll +0 -0
  14. mindspore/Newtonsoft.Json.dll +0 -0
  15. mindspore/System.Runtime.CompilerServices.Unsafe.dll +0 -0
  16. mindspore/VariantClear.dll +0 -0
  17. mindspore/__init__.py +53 -0
  18. mindspore/_c_dataengine.cp311-win_amd64.pyd +0 -0
  19. mindspore/_c_expression.cp311-win_amd64.pyd +0 -0
  20. mindspore/_c_mindrecord.cp311-win_amd64.pyd +0 -0
  21. mindspore/_check_jit_forbidden_api.py +106 -0
  22. mindspore/_checkparam.py +1419 -0
  23. mindspore/_extends/__init__.py +23 -0
  24. mindspore/_extends/builtin_operations.py +224 -0
  25. mindspore/_extends/graph_kernel/__init__.py +17 -0
  26. mindspore/_extends/graph_kernel/model/__init__.py +19 -0
  27. mindspore/_extends/graph_kernel/model/graph_parallel.py +311 -0
  28. mindspore/_extends/graph_kernel/model/graph_split.py +1348 -0
  29. mindspore/_extends/graph_kernel/model/model.py +553 -0
  30. mindspore/_extends/graph_kernel/model/model_builder.py +216 -0
  31. mindspore/_extends/graph_kernel/parallel_estimate.py +60 -0
  32. mindspore/_extends/graph_kernel/splitter.py +140 -0
  33. mindspore/_extends/graph_kernel/utils.py +28 -0
  34. mindspore/_extends/parallel_compile/__init__.py +19 -0
  35. mindspore/_extends/parallel_compile/akg_compiler/__init__.py +19 -0
  36. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +269 -0
  37. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +529 -0
  38. mindspore/_extends/parallel_compile/akg_compiler/compiler.py +56 -0
  39. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +96 -0
  40. mindspore/_extends/parallel_compile/akg_compiler/get_file_path.py +36 -0
  41. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +556 -0
  42. mindspore/_extends/parallel_compile/akg_compiler/util.py +159 -0
  43. mindspore/_extends/parse/__init__.py +49 -0
  44. mindspore/_extends/parse/compile_config.py +299 -0
  45. mindspore/_extends/parse/namespace.py +136 -0
  46. mindspore/_extends/parse/parser.py +1448 -0
  47. mindspore/_extends/parse/resources.py +213 -0
  48. mindspore/_extends/parse/standard_method.py +4475 -0
  49. mindspore/_extends/parse/trope.py +97 -0
  50. mindspore/_extends/pijit/__init__.py +23 -0
  51. mindspore/_extends/pijit/pijit_func_white_list.py +669 -0
  52. mindspore/_extends/remote/__init__.py +19 -0
  53. mindspore/_extends/remote/kernel_build_server.py +199 -0
  54. mindspore/_extends/remote/kernel_build_server_akg.py +55 -0
  55. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  56. mindspore/_extends/remote/kernel_build_server_ascend.py +75 -0
  57. mindspore/_extends/utils.py +68 -0
  58. mindspore/_install_custom.py +43 -0
  59. mindspore/_profiler.py +30 -0
  60. mindspore/amp.py +433 -0
  61. mindspore/atlprov.dll +0 -0
  62. mindspore/avcodec-59.dll +0 -0
  63. mindspore/avdevice-59.dll +0 -0
  64. mindspore/avfilter-8.dll +0 -0
  65. mindspore/avformat-59.dll +0 -0
  66. mindspore/avutil-57.dll +0 -0
  67. mindspore/boost/__init__.py +42 -0
  68. mindspore/boost/adasum.py +319 -0
  69. mindspore/boost/base.py +535 -0
  70. mindspore/boost/boost.py +400 -0
  71. mindspore/boost/boost_cell_wrapper.py +790 -0
  72. mindspore/boost/dim_reduce.py +323 -0
  73. mindspore/boost/grad_accumulation.py +79 -0
  74. mindspore/boost/grad_freeze.py +382 -0
  75. mindspore/boost/group_loss_scale_manager.py +166 -0
  76. mindspore/boost/less_batch_normalization.py +174 -0
  77. mindspore/c1.dll +0 -0
  78. mindspore/c1xx.dll +0 -0
  79. mindspore/c2.dll +0 -0
  80. mindspore/cfgpersist.dll +0 -0
  81. mindspore/clang_rt.asan_dbg_dynamic-x86_64.dll +0 -0
  82. mindspore/clang_rt.asan_dynamic-x86_64.dll +0 -0
  83. mindspore/common/__init__.py +86 -0
  84. mindspore/common/_auto_dynamic.py +68 -0
  85. mindspore/common/_decorator.py +50 -0
  86. mindspore/common/_jit_fallback_utils.py +110 -0
  87. mindspore/common/_monad.py +25 -0
  88. mindspore/common/_pijit_context.py +190 -0
  89. mindspore/common/_register_for_adapter.py +74 -0
  90. mindspore/common/_register_for_recompute.py +48 -0
  91. mindspore/common/_register_for_tensor.py +46 -0
  92. mindspore/common/_stub_tensor.py +210 -0
  93. mindspore/common/_tensor_overload.py +139 -0
  94. mindspore/common/_utils.py +122 -0
  95. mindspore/common/api.py +2064 -0
  96. mindspore/common/auto_dynamic_shape.py +507 -0
  97. mindspore/common/dtype.py +422 -0
  98. mindspore/common/dump.py +130 -0
  99. mindspore/common/file_system.py +48 -0
  100. mindspore/common/generator.py +254 -0
  101. mindspore/common/hook_handle.py +143 -0
  102. mindspore/common/initializer.py +880 -0
  103. mindspore/common/jit_config.py +98 -0
  104. mindspore/common/lazy_inline.py +240 -0
  105. mindspore/common/mindir_util.py +111 -0
  106. mindspore/common/mutable.py +234 -0
  107. mindspore/common/no_inline.py +54 -0
  108. mindspore/common/np_dtype.py +25 -0
  109. mindspore/common/parameter.py +1081 -0
  110. mindspore/common/recompute.py +292 -0
  111. mindspore/common/seed.py +260 -0
  112. mindspore/common/sparse_tensor.py +1175 -0
  113. mindspore/common/symbol.py +122 -0
  114. mindspore/common/tensor.py +5039 -0
  115. mindspore/communication/__init__.py +37 -0
  116. mindspore/communication/_comm_helper.py +501 -0
  117. mindspore/communication/_hccl_management.py +297 -0
  118. mindspore/communication/comm_func.py +1395 -0
  119. mindspore/communication/management.py +673 -0
  120. mindspore/config/op_info.config +533 -0
  121. mindspore/context.py +2077 -0
  122. mindspore/d3dcompiler_47.dll +0 -0
  123. mindspore/dataset/__init__.py +90 -0
  124. mindspore/dataset/audio/__init__.py +61 -0
  125. mindspore/dataset/audio/transforms.py +3690 -0
  126. mindspore/dataset/audio/utils.py +386 -0
  127. mindspore/dataset/audio/validators.py +1172 -0
  128. mindspore/dataset/callback/__init__.py +20 -0
  129. mindspore/dataset/callback/ds_callback.py +368 -0
  130. mindspore/dataset/callback/validators.py +32 -0
  131. mindspore/dataset/core/__init__.py +13 -0
  132. mindspore/dataset/core/config.py +1095 -0
  133. mindspore/dataset/core/datatypes.py +101 -0
  134. mindspore/dataset/core/py_util_helpers.py +65 -0
  135. mindspore/dataset/core/validator_helpers.py +781 -0
  136. mindspore/dataset/debug/__init__.py +21 -0
  137. mindspore/dataset/debug/debug_hook.py +97 -0
  138. mindspore/dataset/debug/pre_defined_hook.py +67 -0
  139. mindspore/dataset/engine/__init__.py +124 -0
  140. mindspore/dataset/engine/cache_admin.py +47 -0
  141. mindspore/dataset/engine/cache_client.py +129 -0
  142. mindspore/dataset/engine/datasets.py +4582 -0
  143. mindspore/dataset/engine/datasets_audio.py +911 -0
  144. mindspore/dataset/engine/datasets_standard_format.py +543 -0
  145. mindspore/dataset/engine/datasets_text.py +2161 -0
  146. mindspore/dataset/engine/datasets_user_defined.py +1184 -0
  147. mindspore/dataset/engine/datasets_vision.py +4816 -0
  148. mindspore/dataset/engine/iterators.py +371 -0
  149. mindspore/dataset/engine/obs/__init__.py +23 -0
  150. mindspore/dataset/engine/obs/config_loader.py +68 -0
  151. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +508 -0
  152. mindspore/dataset/engine/obs/util.py +482 -0
  153. mindspore/dataset/engine/offload.py +596 -0
  154. mindspore/dataset/engine/queue.py +304 -0
  155. mindspore/dataset/engine/samplers.py +895 -0
  156. mindspore/dataset/engine/serializer_deserializer.py +159 -0
  157. mindspore/dataset/engine/validators.py +2895 -0
  158. mindspore/dataset/text/__init__.py +51 -0
  159. mindspore/dataset/text/transforms.py +1703 -0
  160. mindspore/dataset/text/utils.py +715 -0
  161. mindspore/dataset/text/validators.py +642 -0
  162. mindspore/dataset/transforms/__init__.py +45 -0
  163. mindspore/dataset/transforms/c_transforms.py +638 -0
  164. mindspore/dataset/transforms/py_transforms.py +393 -0
  165. mindspore/dataset/transforms/py_transforms_util.py +255 -0
  166. mindspore/dataset/transforms/transforms.py +1260 -0
  167. mindspore/dataset/transforms/validators.py +410 -0
  168. mindspore/dataset/utils/__init__.py +19 -0
  169. mindspore/dataset/utils/browse_dataset.py +190 -0
  170. mindspore/dataset/utils/line_reader.py +126 -0
  171. mindspore/dataset/vision/__init__.py +65 -0
  172. mindspore/dataset/vision/c_transforms.py +2641 -0
  173. mindspore/dataset/vision/py_transforms.py +2120 -0
  174. mindspore/dataset/vision/py_transforms_util.py +1660 -0
  175. mindspore/dataset/vision/transforms.py +7295 -0
  176. mindspore/dataset/vision/utils.py +863 -0
  177. mindspore/dataset/vision/validators.py +1483 -0
  178. mindspore/default_config.py +2 -0
  179. mindspore/dnnl.dll +0 -0
  180. mindspore/dpcmi.dll +0 -0
  181. mindspore/experimental/__init__.py +20 -0
  182. mindspore/experimental/es/__init__.py +22 -0
  183. mindspore/experimental/es/embedding_service.py +883 -0
  184. mindspore/experimental/es/embedding_service_layer.py +581 -0
  185. mindspore/experimental/llm_boost/__init__.py +21 -0
  186. mindspore/experimental/llm_boost/atb/__init__.py +23 -0
  187. mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
  188. mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
  189. mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
  190. mindspore/experimental/llm_boost/register.py +129 -0
  191. mindspore/experimental/llm_boost/utils.py +31 -0
  192. mindspore/experimental/map_parameter.py +309 -0
  193. mindspore/experimental/optim/__init__.py +40 -0
  194. mindspore/experimental/optim/adadelta.py +161 -0
  195. mindspore/experimental/optim/adagrad.py +168 -0
  196. mindspore/experimental/optim/adam.py +193 -0
  197. mindspore/experimental/optim/adamax.py +170 -0
  198. mindspore/experimental/optim/adamw.py +290 -0
  199. mindspore/experimental/optim/asgd.py +153 -0
  200. mindspore/experimental/optim/lr_scheduler.py +1371 -0
  201. mindspore/experimental/optim/nadam.py +157 -0
  202. mindspore/experimental/optim/optimizer.py +262 -0
  203. mindspore/experimental/optim/radam.py +194 -0
  204. mindspore/experimental/optim/rmsprop.py +154 -0
  205. mindspore/experimental/optim/rprop.py +164 -0
  206. mindspore/experimental/optim/sgd.py +156 -0
  207. mindspore/hal/__init__.py +40 -0
  208. mindspore/hal/_ascend.py +57 -0
  209. mindspore/hal/_base.py +57 -0
  210. mindspore/hal/_cpu.py +56 -0
  211. mindspore/hal/_gpu.py +57 -0
  212. mindspore/hal/contiguous_tensors_handle.py +175 -0
  213. mindspore/hal/device.py +356 -0
  214. mindspore/hal/event.py +179 -0
  215. mindspore/hal/memory.py +326 -0
  216. mindspore/hal/stream.py +357 -0
  217. mindspore/include/OWNERS +7 -0
  218. mindspore/include/api/allocator.h +97 -0
  219. mindspore/include/api/callback/callback.h +93 -0
  220. mindspore/include/api/callback/ckpt_saver.h +41 -0
  221. mindspore/include/api/callback/loss_monitor.h +33 -0
  222. mindspore/include/api/callback/lr_scheduler.h +51 -0
  223. mindspore/include/api/callback/time_monitor.h +34 -0
  224. mindspore/include/api/callback/train_accuracy.h +37 -0
  225. mindspore/include/api/cell.h +90 -0
  226. mindspore/include/api/cfg.h +82 -0
  227. mindspore/include/api/context.h +602 -0
  228. mindspore/include/api/data_type.h +47 -0
  229. mindspore/include/api/delegate.h +178 -0
  230. mindspore/include/api/delegate_api.h +75 -0
  231. mindspore/include/api/dual_abi_helper.h +208 -0
  232. mindspore/include/api/format.h +28 -0
  233. mindspore/include/api/graph.h +46 -0
  234. mindspore/include/api/kernel.h +58 -0
  235. mindspore/include/api/kernel_api.h +168 -0
  236. mindspore/include/api/metrics/accuracy.h +36 -0
  237. mindspore/include/api/metrics/metrics.h +41 -0
  238. mindspore/include/api/model.h +438 -0
  239. mindspore/include/api/model_group.h +91 -0
  240. mindspore/include/api/model_parallel_runner.h +168 -0
  241. mindspore/include/api/serialization.h +185 -0
  242. mindspore/include/api/status.h +192 -0
  243. mindspore/include/api/types.h +431 -0
  244. mindspore/include/api/visible.h +41 -0
  245. mindspore/include/c_api/context_c.h +179 -0
  246. mindspore/include/c_api/data_type_c.h +52 -0
  247. mindspore/include/c_api/format_c.h +46 -0
  248. mindspore/include/c_api/model_c.h +347 -0
  249. mindspore/include/c_api/status_c.h +79 -0
  250. mindspore/include/c_api/tensor_c.h +146 -0
  251. mindspore/include/c_api/types_c.h +67 -0
  252. mindspore/include/dataset/config.h +163 -0
  253. mindspore/include/dataset/constants.h +363 -0
  254. mindspore/include/dataset/execute.h +196 -0
  255. mindspore/include/dataset/text.h +1092 -0
  256. mindspore/include/dataset/transforms.h +638 -0
  257. mindspore/include/dataset/vision.h +2129 -0
  258. mindspore/include/dataset/vision_ascend.h +206 -0
  259. mindspore/include/dataset/vision_lite.h +625 -0
  260. mindspore/jpeg62.dll +0 -0
  261. mindspore/log.py +633 -0
  262. mindspore/mindrecord/__init__.py +43 -0
  263. mindspore/mindrecord/common/__init__.py +17 -0
  264. mindspore/mindrecord/common/constant.py +20 -0
  265. mindspore/mindrecord/common/enums.py +44 -0
  266. mindspore/mindrecord/common/exceptions.py +311 -0
  267. mindspore/mindrecord/config.py +809 -0
  268. mindspore/mindrecord/filereader.py +174 -0
  269. mindspore/mindrecord/filewriter.py +722 -0
  270. mindspore/mindrecord/mindpage.py +210 -0
  271. mindspore/mindrecord/shardheader.py +141 -0
  272. mindspore/mindrecord/shardindexgenerator.py +74 -0
  273. mindspore/mindrecord/shardreader.py +117 -0
  274. mindspore/mindrecord/shardsegment.py +128 -0
  275. mindspore/mindrecord/shardutils.py +185 -0
  276. mindspore/mindrecord/shardwriter.py +237 -0
  277. mindspore/mindrecord/tools/__init__.py +17 -0
  278. mindspore/mindrecord/tools/cifar10.py +140 -0
  279. mindspore/mindrecord/tools/cifar100.py +153 -0
  280. mindspore/mindrecord/tools/cifar100_to_mr.py +185 -0
  281. mindspore/mindrecord/tools/cifar10_to_mr.py +177 -0
  282. mindspore/mindrecord/tools/csv_to_mr.py +200 -0
  283. mindspore/mindrecord/tools/imagenet_to_mr.py +206 -0
  284. mindspore/mindrecord/tools/mnist_to_mr.py +259 -0
  285. mindspore/mindrecord/tools/tfrecord_to_mr.py +360 -0
  286. mindspore/mindspore_backend.dll +0 -0
  287. mindspore/mindspore_common.dll +0 -0
  288. mindspore/mindspore_core.dll +0 -0
  289. mindspore/mindspore_glog.dll +0 -0
  290. mindspore/mindspore_np_dtype.dll +0 -0
  291. mindspore/mindspore_ops.dll +0 -0
  292. mindspore/mint/__init__.py +1586 -0
  293. mindspore/mint/distributed/__init__.py +31 -0
  294. mindspore/mint/distributed/distributed.py +254 -0
  295. mindspore/mint/linalg/__init__.py +22 -0
  296. mindspore/mint/nn/__init__.py +757 -0
  297. mindspore/mint/nn/functional.py +679 -0
  298. mindspore/mint/nn/layer/__init__.py +39 -0
  299. mindspore/mint/nn/layer/activation.py +133 -0
  300. mindspore/mint/nn/layer/normalization.py +477 -0
  301. mindspore/mint/nn/layer/pooling.py +110 -0
  302. mindspore/mint/optim/__init__.py +24 -0
  303. mindspore/mint/optim/adamw.py +206 -0
  304. mindspore/mint/special/__init__.py +63 -0
  305. mindspore/msobj140.dll +0 -0
  306. mindspore/mspdb140.dll +0 -0
  307. mindspore/mspdbcore.dll +0 -0
  308. mindspore/mspdbst.dll +0 -0
  309. mindspore/mspft140.dll +0 -0
  310. mindspore/msvcdis140.dll +0 -0
  311. mindspore/msvcp140.dll +0 -0
  312. mindspore/msvcp140_1.dll +0 -0
  313. mindspore/msvcp140_2.dll +0 -0
  314. mindspore/msvcp140_atomic_wait.dll +0 -0
  315. mindspore/msvcp140_codecvt_ids.dll +0 -0
  316. mindspore/multiprocessing/__init__.py +73 -0
  317. mindspore/nn/__init__.py +47 -0
  318. mindspore/nn/cell.py +2787 -0
  319. mindspore/nn/dynamic_lr.py +482 -0
  320. mindspore/nn/grad/__init__.py +21 -0
  321. mindspore/nn/grad/cell_grad.py +196 -0
  322. mindspore/nn/layer/__init__.py +63 -0
  323. mindspore/nn/layer/activation.py +1822 -0
  324. mindspore/nn/layer/basic.py +1629 -0
  325. mindspore/nn/layer/channel_shuffle.py +90 -0
  326. mindspore/nn/layer/combined.py +248 -0
  327. mindspore/nn/layer/container.py +734 -0
  328. mindspore/nn/layer/conv.py +1505 -0
  329. mindspore/nn/layer/dense.py +204 -0
  330. mindspore/nn/layer/embedding.py +869 -0
  331. mindspore/nn/layer/image.py +661 -0
  332. mindspore/nn/layer/math.py +1069 -0
  333. mindspore/nn/layer/normalization.py +1273 -0
  334. mindspore/nn/layer/padding.py +880 -0
  335. mindspore/nn/layer/pooling.py +2302 -0
  336. mindspore/nn/layer/rnn_cells.py +388 -0
  337. mindspore/nn/layer/rnns.py +849 -0
  338. mindspore/nn/layer/thor_layer.py +963 -0
  339. mindspore/nn/layer/timedistributed.py +155 -0
  340. mindspore/nn/layer/transformer.py +823 -0
  341. mindspore/nn/learning_rate_schedule.py +512 -0
  342. mindspore/nn/loss/__init__.py +36 -0
  343. mindspore/nn/loss/loss.py +2924 -0
  344. mindspore/nn/metrics.py +53 -0
  345. mindspore/nn/optim/__init__.py +45 -0
  346. mindspore/nn/optim/_dist_optimizer_registry.py +111 -0
  347. mindspore/nn/optim/ada_grad.py +217 -0
  348. mindspore/nn/optim/adadelta.py +206 -0
  349. mindspore/nn/optim/adafactor.py +448 -0
  350. mindspore/nn/optim/adam.py +1297 -0
  351. mindspore/nn/optim/adamax.py +220 -0
  352. mindspore/nn/optim/adasum.py +548 -0
  353. mindspore/nn/optim/asgd.py +216 -0
  354. mindspore/nn/optim/ftrl.py +401 -0
  355. mindspore/nn/optim/lamb.py +296 -0
  356. mindspore/nn/optim/lars.py +202 -0
  357. mindspore/nn/optim/lazyadam.py +533 -0
  358. mindspore/nn/optim/momentum.py +239 -0
  359. mindspore/nn/optim/optimizer.py +1034 -0
  360. mindspore/nn/optim/proximal_ada_grad.py +242 -0
  361. mindspore/nn/optim/rmsprop.py +264 -0
  362. mindspore/nn/optim/rprop.py +251 -0
  363. mindspore/nn/optim/sgd.py +237 -0
  364. mindspore/nn/optim/tft_wrapper.py +127 -0
  365. mindspore/nn/optim/thor.py +1310 -0
  366. mindspore/nn/probability/__init__.py +22 -0
  367. mindspore/nn/probability/bijector/__init__.py +35 -0
  368. mindspore/nn/probability/bijector/bijector.py +337 -0
  369. mindspore/nn/probability/bijector/exp.py +65 -0
  370. mindspore/nn/probability/bijector/gumbel_cdf.py +144 -0
  371. mindspore/nn/probability/bijector/invert.py +126 -0
  372. mindspore/nn/probability/bijector/power_transform.py +196 -0
  373. mindspore/nn/probability/bijector/scalar_affine.py +167 -0
  374. mindspore/nn/probability/bijector/softplus.py +189 -0
  375. mindspore/nn/probability/bnn_layers/__init__.py +29 -0
  376. mindspore/nn/probability/bnn_layers/_util.py +46 -0
  377. mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +112 -0
  378. mindspore/nn/probability/bnn_layers/conv_variational.py +267 -0
  379. mindspore/nn/probability/bnn_layers/dense_variational.py +302 -0
  380. mindspore/nn/probability/bnn_layers/layer_distribution.py +123 -0
  381. mindspore/nn/probability/distribution/__init__.py +56 -0
  382. mindspore/nn/probability/distribution/_utils/__init__.py +34 -0
  383. mindspore/nn/probability/distribution/_utils/custom_ops.py +96 -0
  384. mindspore/nn/probability/distribution/_utils/utils.py +362 -0
  385. mindspore/nn/probability/distribution/bernoulli.py +334 -0
  386. mindspore/nn/probability/distribution/beta.py +391 -0
  387. mindspore/nn/probability/distribution/categorical.py +435 -0
  388. mindspore/nn/probability/distribution/cauchy.py +383 -0
  389. mindspore/nn/probability/distribution/distribution.py +827 -0
  390. mindspore/nn/probability/distribution/exponential.py +350 -0
  391. mindspore/nn/probability/distribution/gamma.py +391 -0
  392. mindspore/nn/probability/distribution/geometric.py +335 -0
  393. mindspore/nn/probability/distribution/gumbel.py +257 -0
  394. mindspore/nn/probability/distribution/half_normal.py +133 -0
  395. mindspore/nn/probability/distribution/laplace.py +128 -0
  396. mindspore/nn/probability/distribution/log_normal.py +272 -0
  397. mindspore/nn/probability/distribution/logistic.py +379 -0
  398. mindspore/nn/probability/distribution/normal.py +336 -0
  399. mindspore/nn/probability/distribution/poisson.py +288 -0
  400. mindspore/nn/probability/distribution/student_t.py +149 -0
  401. mindspore/nn/probability/distribution/transformed_distribution.py +235 -0
  402. mindspore/nn/probability/distribution/uniform.py +375 -0
  403. mindspore/nn/reinforcement/__init__.py +24 -0
  404. mindspore/nn/reinforcement/_batch_read_write.py +142 -0
  405. mindspore/nn/reinforcement/_tensors_queue.py +152 -0
  406. mindspore/nn/reinforcement/tensor_array.py +145 -0
  407. mindspore/nn/sparse/__init__.py +23 -0
  408. mindspore/nn/sparse/sparse.py +147 -0
  409. mindspore/nn/wrap/__init__.py +49 -0
  410. mindspore/nn/wrap/cell_wrapper.py +968 -0
  411. mindspore/nn/wrap/grad_reducer.py +608 -0
  412. mindspore/nn/wrap/loss_scale.py +694 -0
  413. mindspore/numpy/__init__.py +121 -0
  414. mindspore/numpy/array_creations.py +2731 -0
  415. mindspore/numpy/array_ops.py +2629 -0
  416. mindspore/numpy/dtypes.py +185 -0
  417. mindspore/numpy/fft.py +966 -0
  418. mindspore/numpy/logic_ops.py +936 -0
  419. mindspore/numpy/math_ops.py +5911 -0
  420. mindspore/numpy/utils.py +214 -0
  421. mindspore/numpy/utils_const.py +565 -0
  422. mindspore/opencv_core452.dll +0 -0
  423. mindspore/opencv_imgcodecs452.dll +0 -0
  424. mindspore/opencv_imgproc452.dll +0 -0
  425. mindspore/ops/__init__.py +56 -0
  426. mindspore/ops/_constants.py +30 -0
  427. mindspore/ops/_grad_experimental/__init__.py +31 -0
  428. mindspore/ops/_grad_experimental/grad_array_ops.py +830 -0
  429. mindspore/ops/_grad_experimental/grad_base.py +143 -0
  430. mindspore/ops/_grad_experimental/grad_comm_ops.py +714 -0
  431. mindspore/ops/_grad_experimental/grad_debug_ops.py +31 -0
  432. mindspore/ops/_grad_experimental/grad_implementations.py +203 -0
  433. mindspore/ops/_grad_experimental/grad_inner_ops.py +79 -0
  434. mindspore/ops/_grad_experimental/grad_math_ops.py +802 -0
  435. mindspore/ops/_grad_experimental/grad_nn_ops.py +231 -0
  436. mindspore/ops/_grad_experimental/grad_quant_ops.py +238 -0
  437. mindspore/ops/_grad_experimental/grad_sparse.py +342 -0
  438. mindspore/ops/_grad_experimental/grad_sparse_ops.py +399 -0
  439. mindspore/ops/_grad_experimental/taylor_rule.py +220 -0
  440. mindspore/ops/_op_impl/__init__.py +23 -0
  441. mindspore/ops/_op_impl/_custom_op/__init__.py +39 -0
  442. mindspore/ops/_op_impl/_custom_op/_basic.py +158 -0
  443. mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +279 -0
  444. mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +156 -0
  445. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +109 -0
  446. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +125 -0
  447. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +105 -0
  448. mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +124 -0
  449. mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +116 -0
  450. mindspore/ops/_op_impl/_custom_op/correction_mul.py +89 -0
  451. mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +196 -0
  452. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +366 -0
  453. mindspore/ops/_op_impl/_custom_op/dsd_impl.py +162 -0
  454. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +136 -0
  455. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +206 -0
  456. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +88 -0
  457. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +128 -0
  458. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +199 -0
  459. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +88 -0
  460. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +156 -0
  461. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +184 -0
  462. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +143 -0
  463. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +169 -0
  464. mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +548 -0
  465. mindspore/ops/_op_impl/_custom_op/img2col_impl.py +881 -0
  466. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +278 -0
  467. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +200 -0
  468. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +334 -0
  469. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +255 -0
  470. mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +222 -0
  471. mindspore/ops/_op_impl/_custom_op/matmul_dds_grad_impl.py +644 -0
  472. mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +488 -0
  473. mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +87 -0
  474. mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +129 -0
  475. mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +121 -0
  476. mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +352 -0
  477. mindspore/ops/_op_impl/aicpu/__init__.py +441 -0
  478. mindspore/ops/_op_impl/aicpu/abs.py +36 -0
  479. mindspore/ops/_op_impl/aicpu/acos.py +32 -0
  480. mindspore/ops/_op_impl/aicpu/acos_grad.py +33 -0
  481. mindspore/ops/_op_impl/aicpu/acosh.py +34 -0
  482. mindspore/ops/_op_impl/aicpu/acosh_grad.py +35 -0
  483. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d.py +34 -0
  484. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
  485. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d.py +39 -0
  486. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d_grad.py +39 -0
  487. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
  488. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d_grad.py +37 -0
  489. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d.py +42 -0
  490. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d_grad.py +152 -0
  491. mindspore/ops/_op_impl/aicpu/add.py +43 -0
  492. mindspore/ops/_op_impl/aicpu/add_n.py +41 -0
  493. mindspore/ops/_op_impl/aicpu/add_v2.py +40 -0
  494. mindspore/ops/_op_impl/aicpu/addcdiv.py +41 -0
  495. mindspore/ops/_op_impl/aicpu/addcmul.py +47 -0
  496. mindspore/ops/_op_impl/aicpu/adjust_contrastv2.py +32 -0
  497. mindspore/ops/_op_impl/aicpu/adjust_hue.py +31 -0
  498. mindspore/ops/_op_impl/aicpu/adjust_saturation.py +32 -0
  499. mindspore/ops/_op_impl/aicpu/affine_grid.py +33 -0
  500. mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
  501. mindspore/ops/_op_impl/aicpu/angle.py +31 -0
  502. mindspore/ops/_op_impl/aicpu/arg_max.py +75 -0
  503. mindspore/ops/_op_impl/aicpu/arg_min.py +75 -0
  504. mindspore/ops/_op_impl/aicpu/argmax_with_value.py +43 -0
  505. mindspore/ops/_op_impl/aicpu/argmin_with_value.py +43 -0
  506. mindspore/ops/_op_impl/aicpu/asin.py +32 -0
  507. mindspore/ops/_op_impl/aicpu/asin_grad.py +33 -0
  508. mindspore/ops/_op_impl/aicpu/asinh.py +34 -0
  509. mindspore/ops/_op_impl/aicpu/asinh_grad.py +35 -0
  510. mindspore/ops/_op_impl/aicpu/atanh.py +34 -0
  511. mindspore/ops/_op_impl/aicpu/avgpool_grad_v1.py +37 -0
  512. mindspore/ops/_op_impl/aicpu/avgpool_v1.py +36 -0
  513. mindspore/ops/_op_impl/aicpu/bartlett_window.py +36 -0
  514. mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -0
  515. mindspore/ops/_op_impl/aicpu/batch_norm_grad_grad.py +49 -0
  516. mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
  517. mindspore/ops/_op_impl/aicpu/bessel_i0.py +31 -0
  518. mindspore/ops/_op_impl/aicpu/betainc.py +31 -0
  519. mindspore/ops/_op_impl/aicpu/bias_add.py +44 -0
  520. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +42 -0
  521. mindspore/ops/_op_impl/aicpu/bincount.py +33 -0
  522. mindspore/ops/_op_impl/aicpu/blackman_window.py +36 -0
  523. mindspore/ops/_op_impl/aicpu/broadcast_to.py +58 -0
  524. mindspore/ops/_op_impl/aicpu/bucketize.py +34 -0
  525. mindspore/ops/_op_impl/aicpu/cache_swap_table.py +102 -0
  526. mindspore/ops/_op_impl/aicpu/cast.py +225 -0
  527. mindspore/ops/_op_impl/aicpu/cauchy.py +33 -0
  528. mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
  529. mindspore/ops/_op_impl/aicpu/check_numerics.py +33 -0
  530. mindspore/ops/_op_impl/aicpu/cholesky.py +32 -0
  531. mindspore/ops/_op_impl/aicpu/cholesky_inverse.py +31 -0
  532. mindspore/ops/_op_impl/aicpu/cholesky_solve.py +33 -0
  533. mindspore/ops/_op_impl/aicpu/choleskygrad.py +32 -0
  534. mindspore/ops/_op_impl/aicpu/coalesce.py +37 -0
  535. mindspore/ops/_op_impl/aicpu/col2im.py +38 -0
  536. mindspore/ops/_op_impl/aicpu/combined_non_max_suppression.py +42 -0
  537. mindspore/ops/_op_impl/aicpu/compare_and_bitpack.py +37 -0
  538. mindspore/ops/_op_impl/aicpu/complex.py +32 -0
  539. mindspore/ops/_op_impl/aicpu/complex_abs.py +31 -0
  540. mindspore/ops/_op_impl/aicpu/compute_accidental_hits.py +44 -0
  541. mindspore/ops/_op_impl/aicpu/concat.py +57 -0
  542. mindspore/ops/_op_impl/aicpu/concat_offset.py +42 -0
  543. mindspore/ops/_op_impl/aicpu/concat_offset_v1.py +31 -0
  544. mindspore/ops/_op_impl/aicpu/conj.py +42 -0
  545. mindspore/ops/_op_impl/aicpu/conjugate_transpose.py +58 -0
  546. mindspore/ops/_op_impl/aicpu/cos.py +34 -0
  547. mindspore/ops/_op_impl/aicpu/cosh.py +34 -0
  548. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  549. mindspore/ops/_op_impl/aicpu/crop_and_resize.py +69 -0
  550. mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_boxes.py +68 -0
  551. mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_image.py +38 -0
  552. mindspore/ops/_op_impl/aicpu/cross.py +42 -0
  553. mindspore/ops/_op_impl/aicpu/csr_sparse_matrix_to_dense.py +48 -0
  554. mindspore/ops/_op_impl/aicpu/csr_sparse_matrix_to_sparse_tensor.py +51 -0
  555. mindspore/ops/_op_impl/aicpu/ctc_greedy_decoder.py +35 -0
  556. mindspore/ops/_op_impl/aicpu/ctc_loss_v2.py +43 -0
  557. mindspore/ops/_op_impl/aicpu/ctc_loss_v2_grad.py +45 -0
  558. mindspore/ops/_op_impl/aicpu/ctcloss.py +38 -0
  559. mindspore/ops/_op_impl/aicpu/cummax.py +41 -0
  560. mindspore/ops/_op_impl/aicpu/cumprod.py +58 -0
  561. mindspore/ops/_op_impl/aicpu/cumsum.py +58 -0
  562. mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +36 -0
  563. mindspore/ops/_op_impl/aicpu/data_format_vec_permute.py +32 -0
  564. mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
  565. mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +43 -0
  566. mindspore/ops/_op_impl/aicpu/dense_to_csr_sparse_matrix.py +49 -0
  567. mindspore/ops/_op_impl/aicpu/dense_to_dense_set_operation.py +45 -0
  568. mindspore/ops/_op_impl/aicpu/dense_to_sparse_set_operation.py +48 -0
  569. mindspore/ops/_op_impl/aicpu/depth_to_space.py +44 -0
  570. mindspore/ops/_op_impl/aicpu/diag.py +36 -0
  571. mindspore/ops/_op_impl/aicpu/diag_part.py +36 -0
  572. mindspore/ops/_op_impl/aicpu/diagonal.py +35 -0
  573. mindspore/ops/_op_impl/aicpu/digamma.py +31 -0
  574. mindspore/ops/_op_impl/aicpu/div.py +41 -0
  575. mindspore/ops/_op_impl/aicpu/div_no_nan.py +35 -0
  576. mindspore/ops/_op_impl/aicpu/dropout2d.py +42 -0
  577. mindspore/ops/_op_impl/aicpu/dropout3d.py +42 -0
  578. mindspore/ops/_op_impl/aicpu/dropout_genmask.py +41 -0
  579. mindspore/ops/_op_impl/aicpu/dropout_genmask_v3.py +32 -0
  580. mindspore/ops/_op_impl/aicpu/dynamic_stitch.py +42 -0
  581. mindspore/ops/_op_impl/aicpu/edit_distance.py +56 -0
  582. mindspore/ops/_op_impl/aicpu/eig.py +35 -0
  583. mindspore/ops/_op_impl/aicpu/embedding_lookup.py +102 -0
  584. mindspore/ops/_op_impl/aicpu/end_of_sequence.py +30 -0
  585. mindspore/ops/_op_impl/aicpu/environ_create.py +28 -0
  586. mindspore/ops/_op_impl/aicpu/environ_destroy_all.py +28 -0
  587. mindspore/ops/_op_impl/aicpu/environ_get.py +41 -0
  588. mindspore/ops/_op_impl/aicpu/environ_set.py +40 -0
  589. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  590. mindspore/ops/_op_impl/aicpu/equal.py +41 -0
  591. mindspore/ops/_op_impl/aicpu/exp.py +37 -0
  592. mindspore/ops/_op_impl/aicpu/expand.py +45 -0
  593. mindspore/ops/_op_impl/aicpu/expand_dims.py +42 -0
  594. mindspore/ops/_op_impl/aicpu/expm1.py +34 -0
  595. mindspore/ops/_op_impl/aicpu/extract_glimpse.py +35 -0
  596. mindspore/ops/_op_impl/aicpu/eye.py +44 -0
  597. mindspore/ops/_op_impl/aicpu/fft_with_size.py +47 -0
  598. mindspore/ops/_op_impl/aicpu/fill_diagonal.py +39 -0
  599. mindspore/ops/_op_impl/aicpu/fill_v2.py +58 -0
  600. mindspore/ops/_op_impl/aicpu/flatten.py +43 -0
  601. mindspore/ops/_op_impl/aicpu/floor_div.py +38 -0
  602. mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
  603. mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
  604. mindspore/ops/_op_impl/aicpu/fractional_avg_pool.py +41 -0
  605. mindspore/ops/_op_impl/aicpu/fractional_avg_pool_grad.py +41 -0
  606. mindspore/ops/_op_impl/aicpu/fractional_max_pool.py +41 -0
  607. mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_grad_with_fixed_ksize.py +43 -0
  608. mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +65 -0
  609. mindspore/ops/_op_impl/aicpu/fractional_max_pool_grad.py +42 -0
  610. mindspore/ops/_op_impl/aicpu/fractional_max_pool_grad_with_fixed_ksize.py +42 -0
  611. mindspore/ops/_op_impl/aicpu/fractional_max_pool_with_fixed_ksize.py +49 -0
  612. mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
  613. mindspore/ops/_op_impl/aicpu/fused_sparse_adam.py +46 -0
  614. mindspore/ops/_op_impl/aicpu/fused_sparse_ftrl.py +41 -0
  615. mindspore/ops/_op_impl/aicpu/fused_sparse_lazy_adam.py +46 -0
  616. mindspore/ops/_op_impl/aicpu/fused_sparse_proximal_adagrad.py +39 -0
  617. mindspore/ops/_op_impl/aicpu/gamma.py +38 -0
  618. mindspore/ops/_op_impl/aicpu/gather.py +46 -0
  619. mindspore/ops/_op_impl/aicpu/gather_d.py +79 -0
  620. mindspore/ops/_op_impl/aicpu/gather_d_grad_v2.py +79 -0
  621. mindspore/ops/_op_impl/aicpu/gather_grad.py +54 -0
  622. mindspore/ops/_op_impl/aicpu/gather_nd.py +56 -0
  623. mindspore/ops/_op_impl/aicpu/gcd.py +32 -0
  624. mindspore/ops/_op_impl/aicpu/generate_eod_mask.py +38 -0
  625. mindspore/ops/_op_impl/aicpu/geqrf.py +32 -0
  626. mindspore/ops/_op_impl/aicpu/get_next.py +39 -0
  627. mindspore/ops/_op_impl/aicpu/glu.py +33 -0
  628. mindspore/ops/_op_impl/aicpu/glu_grad.py +34 -0
  629. mindspore/ops/_op_impl/aicpu/greater.py +41 -0
  630. mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
  631. mindspore/ops/_op_impl/aicpu/grid_sampler_2d.py +35 -0
  632. mindspore/ops/_op_impl/aicpu/grid_sampler_2d_grad.py +38 -0
  633. mindspore/ops/_op_impl/aicpu/grid_sampler_3d.py +34 -0
  634. mindspore/ops/_op_impl/aicpu/grid_sampler_3d_grad.py +38 -0
  635. mindspore/ops/_op_impl/aicpu/hamming_window.py +57 -0
  636. mindspore/ops/_op_impl/aicpu/hard_sigmoid.py +32 -0
  637. mindspore/ops/_op_impl/aicpu/hard_sigmoid_grad.py +33 -0
  638. mindspore/ops/_op_impl/aicpu/heaviside.py +40 -0
  639. mindspore/ops/_op_impl/aicpu/histogram.py +35 -0
  640. mindspore/ops/_op_impl/aicpu/hsv_to_rgb.py +32 -0
  641. mindspore/ops/_op_impl/aicpu/hypot.py +32 -0
  642. mindspore/ops/_op_impl/aicpu/identity.py +42 -0
  643. mindspore/ops/_op_impl/aicpu/identity_n.py +41 -0
  644. mindspore/ops/_op_impl/aicpu/igamma.py +30 -0
  645. mindspore/ops/_op_impl/aicpu/igammac.py +30 -0
  646. mindspore/ops/_op_impl/aicpu/igammagrada.py +30 -0
  647. mindspore/ops/_op_impl/aicpu/im2col.py +43 -0
  648. mindspore/ops/_op_impl/aicpu/imag.py +31 -0
  649. mindspore/ops/_op_impl/aicpu/index_fill.py +54 -0
  650. mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
  651. mindspore/ops/_op_impl/aicpu/init_data_set_queue.py +27 -0
  652. mindspore/ops/_op_impl/aicpu/inplace_index_add.py +39 -0
  653. mindspore/ops/_op_impl/aicpu/instance_norm_v2.py +41 -0
  654. mindspore/ops/_op_impl/aicpu/instance_norm_v2_grad.py +44 -0
  655. mindspore/ops/_op_impl/aicpu/is_finite.py +40 -0
  656. mindspore/ops/_op_impl/aicpu/is_inf.py +31 -0
  657. mindspore/ops/_op_impl/aicpu/is_nan.py +31 -0
  658. mindspore/ops/_op_impl/aicpu/kldivloss.py +34 -0
  659. mindspore/ops/_op_impl/aicpu/kldivlossgrad.py +35 -0
  660. mindspore/ops/_op_impl/aicpu/layer_norm_grad_grad.py +47 -0
  661. mindspore/ops/_op_impl/aicpu/lcm.py +32 -0
  662. mindspore/ops/_op_impl/aicpu/left_shift.py +38 -0
  663. mindspore/ops/_op_impl/aicpu/less.py +41 -0
  664. mindspore/ops/_op_impl/aicpu/less_equal.py +41 -0
  665. mindspore/ops/_op_impl/aicpu/lgamma.py +33 -0
  666. mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +57 -0
  667. mindspore/ops/_op_impl/aicpu/linspace.py +33 -0
  668. mindspore/ops/_op_impl/aicpu/list_diff.py +50 -0
  669. mindspore/ops/_op_impl/aicpu/log.py +37 -0
  670. mindspore/ops/_op_impl/aicpu/log1p.py +34 -0
  671. mindspore/ops/_op_impl/aicpu/log_matrix_determinant.py +31 -0
  672. mindspore/ops/_op_impl/aicpu/log_normal_reverse.py +33 -0
  673. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +37 -0
  674. mindspore/ops/_op_impl/aicpu/logical_xor.py +30 -0
  675. mindspore/ops/_op_impl/aicpu/logit.py +33 -0
  676. mindspore/ops/_op_impl/aicpu/logit_grad.py +34 -0
  677. mindspore/ops/_op_impl/aicpu/logspace.py +36 -0
  678. mindspore/ops/_op_impl/aicpu/lower_bound.py +47 -0
  679. mindspore/ops/_op_impl/aicpu/lstsq.py +34 -0
  680. mindspore/ops/_op_impl/aicpu/lu.py +39 -0
  681. mindspore/ops/_op_impl/aicpu/lu_solve.py +32 -0
  682. mindspore/ops/_op_impl/aicpu/lu_unpack.py +114 -0
  683. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +49 -0
  684. mindspore/ops/_op_impl/aicpu/masked_fill.py +42 -0
  685. mindspore/ops/_op_impl/aicpu/masked_scatter.py +40 -0
  686. mindspore/ops/_op_impl/aicpu/masked_select.py +31 -0
  687. mindspore/ops/_op_impl/aicpu/masked_select_grad.py +35 -0
  688. mindspore/ops/_op_impl/aicpu/matmul.py +39 -0
  689. mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
  690. mindspore/ops/_op_impl/aicpu/matrix_determinant.py +30 -0
  691. mindspore/ops/_op_impl/aicpu/matrix_diag_part_v3.py +54 -0
  692. mindspore/ops/_op_impl/aicpu/matrix_diag_v3.py +56 -0
  693. mindspore/ops/_op_impl/aicpu/matrix_exp.py +34 -0
  694. mindspore/ops/_op_impl/aicpu/matrix_inverse.py +31 -0
  695. mindspore/ops/_op_impl/aicpu/matrix_logarithm.py +31 -0
  696. mindspore/ops/_op_impl/aicpu/matrix_power.py +37 -0
  697. mindspore/ops/_op_impl/aicpu/matrix_set_diag_v3.py +54 -0
  698. mindspore/ops/_op_impl/aicpu/matrix_solve.py +35 -0
  699. mindspore/ops/_op_impl/aicpu/matrix_solve_ls.py +36 -0
  700. mindspore/ops/_op_impl/aicpu/matrix_triangular_solve.py +36 -0
  701. mindspore/ops/_op_impl/aicpu/max_pool3d_grad_with_argmax.py +60 -0
  702. mindspore/ops/_op_impl/aicpu/max_pool3d_with_argmax.py +59 -0
  703. mindspore/ops/_op_impl/aicpu/max_unpool2d.py +57 -0
  704. mindspore/ops/_op_impl/aicpu/max_unpool2d_grad.py +58 -0
  705. mindspore/ops/_op_impl/aicpu/max_unpool3d.py +57 -0
  706. mindspore/ops/_op_impl/aicpu/max_unpool3d_grad.py +58 -0
  707. mindspore/ops/_op_impl/aicpu/maximum_grad_grad.py +40 -0
  708. mindspore/ops/_op_impl/aicpu/maxpool_grad_v1.py +46 -0
  709. mindspore/ops/_op_impl/aicpu/maxpool_v1.py +42 -0
  710. mindspore/ops/_op_impl/aicpu/median.py +39 -0
  711. mindspore/ops/_op_impl/aicpu/median_grad.py +45 -0
  712. mindspore/ops/_op_impl/aicpu/meshgrid.py +41 -0
  713. mindspore/ops/_op_impl/aicpu/minimum_grad_grad.py +40 -0
  714. mindspore/ops/_op_impl/aicpu/mirror_pad.py +50 -0
  715. mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +48 -0
  716. mindspore/ops/_op_impl/aicpu/mul.py +43 -0
  717. mindspore/ops/_op_impl/aicpu/mul_no_nan.py +42 -0
  718. mindspore/ops/_op_impl/aicpu/multi_margin_loss.py +37 -0
  719. mindspore/ops/_op_impl/aicpu/multi_margin_loss_grad.py +41 -0
  720. mindspore/ops/_op_impl/aicpu/multilabel_margin_loss_grad.py +37 -0
  721. mindspore/ops/_op_impl/aicpu/multinomial.py +47 -0
  722. mindspore/ops/_op_impl/aicpu/multinomial_with_replacement.py +35 -0
  723. mindspore/ops/_op_impl/aicpu/mvlgamma.py +32 -0
  724. mindspore/ops/_op_impl/aicpu/mvlgamma_grad.py +33 -0
  725. mindspore/ops/_op_impl/aicpu/nan_to_num.py +34 -0
  726. mindspore/ops/_op_impl/aicpu/neg.py +36 -0
  727. mindspore/ops/_op_impl/aicpu/nextafter.py +32 -0
  728. mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
  729. mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
  730. mindspore/ops/_op_impl/aicpu/no_repeat_ngram.py +34 -0
  731. mindspore/ops/_op_impl/aicpu/non_deterministic_ints.py +33 -0
  732. mindspore/ops/_op_impl/aicpu/non_max_suppression.py +36 -0
  733. mindspore/ops/_op_impl/aicpu/non_max_suppression_with_overlaps.py +35 -0
  734. mindspore/ops/_op_impl/aicpu/non_zero.py +43 -0
  735. mindspore/ops/_op_impl/aicpu/not_equal.py +39 -0
  736. mindspore/ops/_op_impl/aicpu/nth_element.py +39 -0
  737. mindspore/ops/_op_impl/aicpu/nuclear_norm.py +33 -0
  738. mindspore/ops/_op_impl/aicpu/one_hot.py +116 -0
  739. mindspore/ops/_op_impl/aicpu/ones_like.py +39 -0
  740. mindspore/ops/_op_impl/aicpu/orgqr.py +34 -0
  741. mindspore/ops/_op_impl/aicpu/pad_and_shift.py +33 -0
  742. mindspore/ops/_op_impl/aicpu/pad_v3.py +61 -0
  743. mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +59 -0
  744. mindspore/ops/_op_impl/aicpu/padding.py +41 -0
  745. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +54 -0
  746. mindspore/ops/_op_impl/aicpu/pdist_grad.py +33 -0
  747. mindspore/ops/_op_impl/aicpu/poisson.py +37 -0
  748. mindspore/ops/_op_impl/aicpu/polar.py +32 -0
  749. mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
  750. mindspore/ops/_op_impl/aicpu/pow.py +39 -0
  751. mindspore/ops/_op_impl/aicpu/print_tensor.py +39 -0
  752. mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +113 -0
  753. mindspore/ops/_op_impl/aicpu/qr.py +36 -0
  754. mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
  755. mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
  756. mindspore/ops/_op_impl/aicpu/ragged_range.py +49 -0
  757. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
  758. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_tensor.py +74 -0
  759. mindspore/ops/_op_impl/aicpu/random_categorical.py +68 -0
  760. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +36 -0
  761. mindspore/ops/_op_impl/aicpu/random_gamma.py +38 -0
  762. mindspore/ops/_op_impl/aicpu/random_poisson.py +134 -0
  763. mindspore/ops/_op_impl/aicpu/random_shuffle.py +47 -0
  764. mindspore/ops/_op_impl/aicpu/randperm.py +38 -0
  765. mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
  766. mindspore/ops/_op_impl/aicpu/range.py +36 -0
  767. mindspore/ops/_op_impl/aicpu/range_v2.py +35 -0
  768. mindspore/ops/_op_impl/aicpu/real.py +31 -0
  769. mindspore/ops/_op_impl/aicpu/real_div.py +40 -0
  770. mindspore/ops/_op_impl/aicpu/reciprocal.py +34 -0
  771. mindspore/ops/_op_impl/aicpu/reciprocal_grad.py +35 -0
  772. mindspore/ops/_op_impl/aicpu/reduce_mean.py +57 -0
  773. mindspore/ops/_op_impl/aicpu/reduce_prod.py +57 -0
  774. mindspore/ops/_op_impl/aicpu/reduce_sum.py +57 -0
  775. mindspore/ops/_op_impl/aicpu/relu_grad_v3.py +41 -0
  776. mindspore/ops/_op_impl/aicpu/relu_v3.py +38 -0
  777. mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +96 -0
  778. mindspore/ops/_op_impl/aicpu/reshape.py +42 -0
  779. mindspore/ops/_op_impl/aicpu/resize_area.py +40 -0
  780. mindspore/ops/_op_impl/aicpu/resize_bicubic.py +20 -0
  781. mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +19 -0
  782. mindspore/ops/_op_impl/aicpu/resize_bilinear.py +32 -0
  783. mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +32 -0
  784. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +36 -0
  785. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +35 -0
  786. mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
  787. mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
  788. mindspore/ops/_op_impl/aicpu/reverse_sequence.py +55 -0
  789. mindspore/ops/_op_impl/aicpu/reversev2.py +54 -0
  790. mindspore/ops/_op_impl/aicpu/rgb_to_hsv.py +32 -0
  791. mindspore/ops/_op_impl/aicpu/right_shift.py +38 -0
  792. mindspore/ops/_op_impl/aicpu/rnnt_loss.py +35 -0
  793. mindspore/ops/_op_impl/aicpu/round.py +34 -0
  794. mindspore/ops/_op_impl/aicpu/rsqrt.py +33 -0
  795. mindspore/ops/_op_impl/aicpu/rsqrt_grad.py +36 -0
  796. mindspore/ops/_op_impl/aicpu/sample_distorted_bounding_box_v2.py +49 -0
  797. mindspore/ops/_op_impl/aicpu/scale_and_translate.py +52 -0
  798. mindspore/ops/_op_impl/aicpu/scale_and_translate_grad.py +36 -0
  799. mindspore/ops/_op_impl/aicpu/scatter.py +79 -0
  800. mindspore/ops/_op_impl/aicpu/scatter_add_with_axis.py +53 -0
  801. mindspore/ops/_op_impl/aicpu/scatter_elements.py +39 -0
  802. mindspore/ops/_op_impl/aicpu/scatter_nd.py +59 -0
  803. mindspore/ops/_op_impl/aicpu/scatter_nd_max.py +54 -0
  804. mindspore/ops/_op_impl/aicpu/scatter_nd_min.py +54 -0
  805. mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +59 -0
  806. mindspore/ops/_op_impl/aicpu/search_sorted.py +44 -0
  807. mindspore/ops/_op_impl/aicpu/segment_max.py +52 -0
  808. mindspore/ops/_op_impl/aicpu/segment_mean.py +56 -0
  809. mindspore/ops/_op_impl/aicpu/segment_min.py +52 -0
  810. mindspore/ops/_op_impl/aicpu/segment_prod.py +56 -0
  811. mindspore/ops/_op_impl/aicpu/segment_sum.py +56 -0
  812. mindspore/ops/_op_impl/aicpu/select.py +45 -0
  813. mindspore/ops/_op_impl/aicpu/self_adjoint_eig.py +34 -0
  814. mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
  815. mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
  816. mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
  817. mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
  818. mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
  819. mindspore/ops/_op_impl/aicpu/set_size.py +38 -0
  820. mindspore/ops/_op_impl/aicpu/sign.py +36 -0
  821. mindspore/ops/_op_impl/aicpu/sin.py +34 -0
  822. mindspore/ops/_op_impl/aicpu/sinc.py +43 -0
  823. mindspore/ops/_op_impl/aicpu/sinh.py +34 -0
  824. mindspore/ops/_op_impl/aicpu/slice.py +59 -0
  825. mindspore/ops/_op_impl/aicpu/slice_grad.py +76 -0
  826. mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
  827. mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
  828. mindspore/ops/_op_impl/aicpu/sort.py +39 -0
  829. mindspore/ops/_op_impl/aicpu/space_to_depth.py +44 -0
  830. mindspore/ops/_op_impl/aicpu/sparse_addmm.py +87 -0
  831. mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +80 -0
  832. mindspore/ops/_op_impl/aicpu/sparse_apply_centered_rms_prop.py +105 -0
  833. mindspore/ops/_op_impl/aicpu/sparse_apply_momentum.py +80 -0
  834. mindspore/ops/_op_impl/aicpu/sparse_apply_proximal_gradient_descent.py +79 -0
  835. mindspore/ops/_op_impl/aicpu/sparse_concat.py +59 -0
  836. mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
  837. mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_add.py +58 -0
  838. mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_div.py +58 -0
  839. mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_mul.py +58 -0
  840. mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows.py +63 -0
  841. mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows_grad.py +45 -0
  842. mindspore/ops/_op_impl/aicpu/sparse_matrix_mat_mul.py +56 -0
  843. mindspore/ops/_op_impl/aicpu/sparse_matrix_nnz.py +81 -0
  844. mindspore/ops/_op_impl/aicpu/sparse_matrix_transpose.py +116 -0
  845. mindspore/ops/_op_impl/aicpu/sparse_reorder.py +56 -0
  846. mindspore/ops/_op_impl/aicpu/sparse_reshape.py +34 -0
  847. mindspore/ops/_op_impl/aicpu/sparse_segment_mean_grad.py +36 -0
  848. mindspore/ops/_op_impl/aicpu/sparse_segment_mean_with_num_segments.py +44 -0
  849. mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n.py +43 -0
  850. mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n_grad.py +38 -0
  851. mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n_with_num_segments.py +44 -0
  852. mindspore/ops/_op_impl/aicpu/sparse_segment_sum.py +49 -0
  853. mindspore/ops/_op_impl/aicpu/sparse_segment_sum_with_num_segments.py +68 -0
  854. mindspore/ops/_op_impl/aicpu/sparse_slice.py +63 -0
  855. mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +61 -0
  856. mindspore/ops/_op_impl/aicpu/sparse_softmax.py +33 -0
  857. mindspore/ops/_op_impl/aicpu/sparse_softmax_cross_entropy_with_logits_v2.py +35 -0
  858. mindspore/ops/_op_impl/aicpu/sparse_sparse_maximum.py +53 -0
  859. mindspore/ops/_op_impl/aicpu/sparse_sparse_minimum.py +53 -0
  860. mindspore/ops/_op_impl/aicpu/sparse_tensor_dense_add.py +84 -0
  861. mindspore/ops/_op_impl/aicpu/sparse_tensor_dense_mat_mul.py +190 -0
  862. mindspore/ops/_op_impl/aicpu/sparse_tensor_to_csr_sparse_matrix.py +51 -0
  863. mindspore/ops/_op_impl/aicpu/sparse_to_dense_v2.py +73 -0
  864. mindspore/ops/_op_impl/aicpu/split.py +45 -0
  865. mindspore/ops/_op_impl/aicpu/sqrt.py +34 -0
  866. mindspore/ops/_op_impl/aicpu/sqrt_grad.py +35 -0
  867. mindspore/ops/_op_impl/aicpu/square.py +35 -0
  868. mindspore/ops/_op_impl/aicpu/squared_difference.py +37 -0
  869. mindspore/ops/_op_impl/aicpu/squeeze.py +42 -0
  870. mindspore/ops/_op_impl/aicpu/sspaddmm.py +97 -0
  871. mindspore/ops/_op_impl/aicpu/stack.py +45 -0
  872. mindspore/ops/_op_impl/aicpu/stack_push_pop.py +87 -0
  873. mindspore/ops/_op_impl/aicpu/standard_laplace.py +34 -0
  874. mindspore/ops/_op_impl/aicpu/standard_normal.py +34 -0
  875. mindspore/ops/_op_impl/aicpu/stateless_dropout_genmask.py +37 -0
  876. mindspore/ops/_op_impl/aicpu/stft.py +70 -0
  877. mindspore/ops/_op_impl/aicpu/strided_slice.py +43 -0
  878. mindspore/ops/_op_impl/aicpu/strided_slice_grad.py +50 -0
  879. mindspore/ops/_op_impl/aicpu/sub.py +41 -0
  880. mindspore/ops/_op_impl/aicpu/sub_and_filter.py +36 -0
  881. mindspore/ops/_op_impl/aicpu/tan.py +34 -0
  882. mindspore/ops/_op_impl/aicpu/tanh.py +34 -0
  883. mindspore/ops/_op_impl/aicpu/tanh_grad.py +35 -0
  884. mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
  885. mindspore/ops/_op_impl/aicpu/tile.py +56 -0
  886. mindspore/ops/_op_impl/aicpu/topk.py +34 -0
  887. mindspore/ops/_op_impl/aicpu/trace.py +40 -0
  888. mindspore/ops/_op_impl/aicpu/tracegrad.py +41 -0
  889. mindspore/ops/_op_impl/aicpu/trans_data.py +35 -0
  890. mindspore/ops/_op_impl/aicpu/transpose.py +58 -0
  891. mindspore/ops/_op_impl/aicpu/tridiagonal_matmul.py +42 -0
  892. mindspore/ops/_op_impl/aicpu/tridiagonal_solve.py +35 -0
  893. mindspore/ops/_op_impl/aicpu/tril.py +42 -0
  894. mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
  895. mindspore/ops/_op_impl/aicpu/triplet_margin_loss.py +62 -0
  896. mindspore/ops/_op_impl/aicpu/triu.py +43 -0
  897. mindspore/ops/_op_impl/aicpu/triu_indices.py +34 -0
  898. mindspore/ops/_op_impl/aicpu/truncated_normal.py +39 -0
  899. mindspore/ops/_op_impl/aicpu/uniform.py +36 -0
  900. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +41 -0
  901. mindspore/ops/_op_impl/aicpu/uniform_int.py +36 -0
  902. mindspore/ops/_op_impl/aicpu/uniform_real.py +33 -0
  903. mindspore/ops/_op_impl/aicpu/unique.py +31 -0
  904. mindspore/ops/_op_impl/aicpu/unique_consecutive.py +47 -0
  905. mindspore/ops/_op_impl/aicpu/unique_with_pad.py +32 -0
  906. mindspore/ops/_op_impl/aicpu/unravel_index.py +32 -0
  907. mindspore/ops/_op_impl/aicpu/unsorted_segment_prod.py +53 -0
  908. mindspore/ops/_op_impl/aicpu/unsorted_segment_sum.py +57 -0
  909. mindspore/ops/_op_impl/aicpu/unstack.py +45 -0
  910. mindspore/ops/_op_impl/aicpu/update_cache.py +44 -0
  911. mindspore/ops/_op_impl/aicpu/upper_bound.py +47 -0
  912. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +42 -0
  913. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +49 -0
  914. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +40 -0
  915. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +50 -0
  916. mindspore/ops/_op_impl/aicpu/xdivy.py +35 -0
  917. mindspore/ops/_op_impl/aicpu/xlogy.py +33 -0
  918. mindspore/ops/_op_impl/aicpu/zeros_like.py +42 -0
  919. mindspore/ops/_op_impl/aicpu/zeta.py +31 -0
  920. mindspore/ops/_op_impl/akg/__init__.py +19 -0
  921. mindspore/ops/_op_impl/akg/ascend/__init__.py +48 -0
  922. mindspore/ops/_op_impl/akg/ascend/abs.py +35 -0
  923. mindspore/ops/_op_impl/akg/ascend/add.py +42 -0
  924. mindspore/ops/_op_impl/akg/ascend/add_n.py +37 -0
  925. mindspore/ops/_op_impl/akg/ascend/batchmatmul.py +33 -0
  926. mindspore/ops/_op_impl/akg/ascend/cast.py +46 -0
  927. mindspore/ops/_op_impl/akg/ascend/equal.py +35 -0
  928. mindspore/ops/_op_impl/akg/ascend/exp.py +35 -0
  929. mindspore/ops/_op_impl/akg/ascend/expand_dims.py +33 -0
  930. mindspore/ops/_op_impl/akg/ascend/greater.py +34 -0
  931. mindspore/ops/_op_impl/akg/ascend/greater_equal.py +35 -0
  932. mindspore/ops/_op_impl/akg/ascend/less.py +31 -0
  933. mindspore/ops/_op_impl/akg/ascend/less_equal.py +35 -0
  934. mindspore/ops/_op_impl/akg/ascend/load_im2col.py +33 -0
  935. mindspore/ops/_op_impl/akg/ascend/log.py +34 -0
  936. mindspore/ops/_op_impl/akg/ascend/maximum.py +36 -0
  937. mindspore/ops/_op_impl/akg/ascend/minimum.py +39 -0
  938. mindspore/ops/_op_impl/akg/ascend/mul.py +41 -0
  939. mindspore/ops/_op_impl/akg/ascend/neg.py +37 -0
  940. mindspore/ops/_op_impl/akg/ascend/pow.py +35 -0
  941. mindspore/ops/_op_impl/akg/ascend/prod_force_se_a.py +33 -0
  942. mindspore/ops/_op_impl/akg/ascend/real_div.py +36 -0
  943. mindspore/ops/_op_impl/akg/ascend/reciprocal.py +32 -0
  944. mindspore/ops/_op_impl/akg/ascend/reduce_max.py +32 -0
  945. mindspore/ops/_op_impl/akg/ascend/reduce_min.py +32 -0
  946. mindspore/ops/_op_impl/akg/ascend/reduce_sum.py +37 -0
  947. mindspore/ops/_op_impl/akg/ascend/rsqrt.py +35 -0
  948. mindspore/ops/_op_impl/akg/ascend/select.py +37 -0
  949. mindspore/ops/_op_impl/akg/ascend/sqrt.py +35 -0
  950. mindspore/ops/_op_impl/akg/ascend/square.py +35 -0
  951. mindspore/ops/_op_impl/akg/ascend/sub.py +42 -0
  952. mindspore/ops/_op_impl/akg/cpu/__init__.py +23 -0
  953. mindspore/ops/_op_impl/akg/cpu/coo2csr.py +29 -0
  954. mindspore/ops/_op_impl/akg/cpu/csr2coo.py +29 -0
  955. mindspore/ops/_op_impl/akg/cpu/csr_gather.py +33 -0
  956. mindspore/ops/_op_impl/akg/cpu/csr_mm.py +34 -0
  957. mindspore/ops/_op_impl/akg/cpu/csr_mul.py +33 -0
  958. mindspore/ops/_op_impl/akg/cpu/csr_mv.py +33 -0
  959. mindspore/ops/_op_impl/akg/cpu/csr_reduce_sum.py +31 -0
  960. mindspore/ops/_op_impl/akg/gpu/__init__.py +24 -0
  961. mindspore/ops/_op_impl/akg/gpu/coo2csr.py +29 -0
  962. mindspore/ops/_op_impl/akg/gpu/csr2coo.py +29 -0
  963. mindspore/ops/_op_impl/akg/gpu/csr_div.py +36 -0
  964. mindspore/ops/_op_impl/akg/gpu/csr_gather.py +33 -0
  965. mindspore/ops/_op_impl/akg/gpu/csr_mm.py +37 -0
  966. mindspore/ops/_op_impl/akg/gpu/csr_mul.py +36 -0
  967. mindspore/ops/_op_impl/akg/gpu/csr_mv.py +36 -0
  968. mindspore/ops/_op_impl/akg/gpu/csr_reduce_sum.py +33 -0
  969. mindspore/ops/_op_impl/cpu/__init__.py +78 -0
  970. mindspore/ops/_op_impl/cpu/adam.py +49 -0
  971. mindspore/ops/_op_impl/cpu/adam_weight_decay.py +47 -0
  972. mindspore/ops/_op_impl/cpu/arg_max.py +30 -0
  973. mindspore/ops/_op_impl/cpu/arg_max_with_value.py +31 -0
  974. mindspore/ops/_op_impl/cpu/arg_min_with_value.py +31 -0
  975. mindspore/ops/_op_impl/cpu/buffer_append.py +28 -0
  976. mindspore/ops/_op_impl/cpu/buffer_get.py +28 -0
  977. mindspore/ops/_op_impl/cpu/buffer_sample.py +28 -0
  978. mindspore/ops/_op_impl/cpu/cast.py +171 -0
  979. mindspore/ops/_op_impl/cpu/concat_offset.py +38 -0
  980. mindspore/ops/_op_impl/cpu/conv2d.py +30 -0
  981. mindspore/ops/_op_impl/cpu/conv3d.py +30 -0
  982. mindspore/ops/_op_impl/cpu/div.py +32 -0
  983. mindspore/ops/_op_impl/cpu/dropout.py +31 -0
  984. mindspore/ops/_op_impl/cpu/dropout_grad.py +30 -0
  985. mindspore/ops/_op_impl/cpu/dynamic_shape.py +42 -0
  986. mindspore/ops/_op_impl/cpu/dynamic_stitch.py +41 -0
  987. mindspore/ops/_op_impl/cpu/equal_count.py +30 -0
  988. mindspore/ops/_op_impl/cpu/gather_d.py +49 -0
  989. mindspore/ops/_op_impl/cpu/gather_d_grad.py +38 -0
  990. mindspore/ops/_op_impl/cpu/gather_d_grad_v2.py +40 -0
  991. mindspore/ops/_op_impl/cpu/gather_v2.py +40 -0
  992. mindspore/ops/_op_impl/cpu/hsigmoid.py +33 -0
  993. mindspore/ops/_op_impl/cpu/hsigmoid_grad.py +34 -0
  994. mindspore/ops/_op_impl/cpu/hswish.py +32 -0
  995. mindspore/ops/_op_impl/cpu/hswish_grad.py +33 -0
  996. mindspore/ops/_op_impl/cpu/identity_n.py +40 -0
  997. mindspore/ops/_op_impl/cpu/is_finite.py +39 -0
  998. mindspore/ops/_op_impl/cpu/l2loss.py +30 -0
  999. mindspore/ops/_op_impl/cpu/layer_norm.py +36 -0
  1000. mindspore/ops/_op_impl/cpu/layer_norm_grad.py +38 -0
  1001. mindspore/ops/_op_impl/cpu/maximum.py +35 -0
  1002. mindspore/ops/_op_impl/cpu/maximum_grad.py +47 -0
  1003. mindspore/ops/_op_impl/cpu/minimum.py +40 -0
  1004. mindspore/ops/_op_impl/cpu/minimum_grad.py +51 -0
  1005. mindspore/ops/_op_impl/cpu/mirror_pad.py +36 -0
  1006. mindspore/ops/_op_impl/cpu/mirror_pad_grad.py +36 -0
  1007. mindspore/ops/_op_impl/cpu/mul.py +32 -0
  1008. mindspore/ops/_op_impl/cpu/one_hot.py +31 -0
  1009. mindspore/ops/_op_impl/cpu/pad.py +32 -0
  1010. mindspore/ops/_op_impl/cpu/pow.py +32 -0
  1011. mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +42 -0
  1012. mindspore/ops/_op_impl/cpu/pyexecute.py +29 -0
  1013. mindspore/ops/_op_impl/cpu/pyfunc.py +29 -0
  1014. mindspore/ops/_op_impl/cpu/range.py +34 -0
  1015. mindspore/ops/_op_impl/cpu/real_div.py +33 -0
  1016. mindspore/ops/_op_impl/cpu/reduce_all.py +29 -0
  1017. mindspore/ops/_op_impl/cpu/reduce_any.py +29 -0
  1018. mindspore/ops/_op_impl/cpu/reduce_max.py +32 -0
  1019. mindspore/ops/_op_impl/cpu/reduce_mean.py +40 -0
  1020. mindspore/ops/_op_impl/cpu/reduce_min.py +32 -0
  1021. mindspore/ops/_op_impl/cpu/reduce_prod.py +40 -0
  1022. mindspore/ops/_op_impl/cpu/reduce_std.py +31 -0
  1023. mindspore/ops/_op_impl/cpu/reduce_sum.py +41 -0
  1024. mindspore/ops/_op_impl/cpu/space_to_batch_nd.py +38 -0
  1025. mindspore/ops/_op_impl/cpu/sparse_slice.py +62 -0
  1026. mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +60 -0
  1027. mindspore/ops/_op_impl/cpu/split.py +34 -0
  1028. mindspore/ops/_op_impl/cpu/sspaddmm.py +95 -0
  1029. mindspore/ops/_op_impl/cpu/stack.py +38 -0
  1030. mindspore/ops/_op_impl/cpu/sub.py +32 -0
  1031. mindspore/ops/_op_impl/cpu/tensor_copy_slices.py +41 -0
  1032. mindspore/ops/_op_impl/cpu/tile.py +37 -0
  1033. mindspore/ops/_op_impl/cpu/top_k.py +31 -0
  1034. mindspore/ops/_op_impl/cpu/transpose.py +39 -0
  1035. mindspore/ops/_primitive_cache.py +90 -0
  1036. mindspore/ops/_register_for_op.py +73 -0
  1037. mindspore/ops/_utils/__init__.py +20 -0
  1038. mindspore/ops/_utils/utils.py +147 -0
  1039. mindspore/ops/_vmap/__init__.py +25 -0
  1040. mindspore/ops/_vmap/vmap_array_ops.py +2149 -0
  1041. mindspore/ops/_vmap/vmap_base.py +533 -0
  1042. mindspore/ops/_vmap/vmap_convolution_ops.py +441 -0
  1043. mindspore/ops/_vmap/vmap_debug_ops.py +50 -0
  1044. mindspore/ops/_vmap/vmap_grad_math_ops.py +274 -0
  1045. mindspore/ops/_vmap/vmap_grad_nn_ops.py +806 -0
  1046. mindspore/ops/_vmap/vmap_image_ops.py +194 -0
  1047. mindspore/ops/_vmap/vmap_math_ops.py +993 -0
  1048. mindspore/ops/_vmap/vmap_nn_ops.py +2250 -0
  1049. mindspore/ops/_vmap/vmap_other_ops.py +105 -0
  1050. mindspore/ops/_vmap/vmap_random_ops.py +122 -0
  1051. mindspore/ops/_vmap/vmap_sparse_ops.py +89 -0
  1052. mindspore/ops/auto_generate/__init__.py +31 -0
  1053. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +309 -0
  1054. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +252 -0
  1055. mindspore/ops/auto_generate/gen_arg_handler.py +197 -0
  1056. mindspore/ops/auto_generate/gen_extend_func.py +1701 -0
  1057. mindspore/ops/auto_generate/gen_ops_def.py +8482 -0
  1058. mindspore/ops/auto_generate/gen_ops_prim.py +16704 -0
  1059. mindspore/ops/auto_generate/pyboost_inner_prim.py +549 -0
  1060. mindspore/ops/composite/__init__.py +71 -0
  1061. mindspore/ops/composite/base.py +1318 -0
  1062. mindspore/ops/composite/env_ops.py +41 -0
  1063. mindspore/ops/composite/math_ops.py +125 -0
  1064. mindspore/ops/composite/multitype_ops/__init__.py +77 -0
  1065. mindspore/ops/composite/multitype_ops/_compile_utils.py +1459 -0
  1066. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +897 -0
  1067. mindspore/ops/composite/multitype_ops/add_impl.py +606 -0
  1068. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +56 -0
  1069. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +56 -0
  1070. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +56 -0
  1071. mindspore/ops/composite/multitype_ops/div_impl.py +189 -0
  1072. mindspore/ops/composite/multitype_ops/equal_impl.py +335 -0
  1073. mindspore/ops/composite/multitype_ops/floordiv_impl.py +88 -0
  1074. mindspore/ops/composite/multitype_ops/getitem_impl.py +400 -0
  1075. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +109 -0
  1076. mindspore/ops/composite/multitype_ops/greater_impl.py +110 -0
  1077. mindspore/ops/composite/multitype_ops/in_impl.py +196 -0
  1078. mindspore/ops/composite/multitype_ops/left_shift_impl.py +37 -0
  1079. mindspore/ops/composite/multitype_ops/less_equal_impl.py +111 -0
  1080. mindspore/ops/composite/multitype_ops/less_impl.py +112 -0
  1081. mindspore/ops/composite/multitype_ops/logic_not_impl.py +113 -0
  1082. mindspore/ops/composite/multitype_ops/logical_and_impl.py +60 -0
  1083. mindspore/ops/composite/multitype_ops/logical_or_impl.py +61 -0
  1084. mindspore/ops/composite/multitype_ops/mod_impl.py +86 -0
  1085. mindspore/ops/composite/multitype_ops/mul_impl.py +294 -0
  1086. mindspore/ops/composite/multitype_ops/negative_impl.py +79 -0
  1087. mindspore/ops/composite/multitype_ops/not_equal_impl.py +290 -0
  1088. mindspore/ops/composite/multitype_ops/not_in_impl.py +196 -0
  1089. mindspore/ops/composite/multitype_ops/ones_like_impl.py +96 -0
  1090. mindspore/ops/composite/multitype_ops/pow_impl.py +87 -0
  1091. mindspore/ops/composite/multitype_ops/right_shift_impl.py +37 -0
  1092. mindspore/ops/composite/multitype_ops/setitem_impl.py +884 -0
  1093. mindspore/ops/composite/multitype_ops/sub_impl.py +116 -0
  1094. mindspore/ops/composite/multitype_ops/uadd_impl.py +29 -0
  1095. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +228 -0
  1096. mindspore/ops/deprecated.py +315 -0
  1097. mindspore/ops/function/__init__.py +782 -0
  1098. mindspore/ops/function/array_func.py +7226 -0
  1099. mindspore/ops/function/clip_func.py +384 -0
  1100. mindspore/ops/function/debug_func.py +181 -0
  1101. mindspore/ops/function/fft_func.py +44 -0
  1102. mindspore/ops/function/grad/__init__.py +34 -0
  1103. mindspore/ops/function/grad/grad_func.py +1425 -0
  1104. mindspore/ops/function/image_func.py +292 -0
  1105. mindspore/ops/function/linalg_func.py +416 -0
  1106. mindspore/ops/function/math_func.py +12228 -0
  1107. mindspore/ops/function/nn_func.py +8609 -0
  1108. mindspore/ops/function/other_func.py +115 -0
  1109. mindspore/ops/function/parameter_func.py +134 -0
  1110. mindspore/ops/function/random_func.py +1715 -0
  1111. mindspore/ops/function/reshard_func.py +104 -0
  1112. mindspore/ops/function/sparse_func.py +884 -0
  1113. mindspore/ops/function/sparse_unary_func.py +2422 -0
  1114. mindspore/ops/function/spectral_func.py +150 -0
  1115. mindspore/ops/function/vmap_func.py +117 -0
  1116. mindspore/ops/functional.py +464 -0
  1117. mindspore/ops/op_info_register.py +1572 -0
  1118. mindspore/ops/operations/__init__.py +722 -0
  1119. mindspore/ops/operations/_csr_ops.py +403 -0
  1120. mindspore/ops/operations/_custom_grad.py +181 -0
  1121. mindspore/ops/operations/_embedding_cache_ops.py +307 -0
  1122. mindspore/ops/operations/_grad_ops.py +2978 -0
  1123. mindspore/ops/operations/_infer_ops.py +19 -0
  1124. mindspore/ops/operations/_inner_ops.py +2544 -0
  1125. mindspore/ops/operations/_map_tensor_ops.py +112 -0
  1126. mindspore/ops/operations/_ms_kernel.py +601 -0
  1127. mindspore/ops/operations/_ocr_ops.py +379 -0
  1128. mindspore/ops/operations/_opaque_predicate_registry.py +41 -0
  1129. mindspore/ops/operations/_pyfunc_registry.py +58 -0
  1130. mindspore/ops/operations/_quant_ops.py +1844 -0
  1131. mindspore/ops/operations/_rl_inner_ops.py +1231 -0
  1132. mindspore/ops/operations/_scalar_ops.py +106 -0
  1133. mindspore/ops/operations/_sequence_ops.py +1155 -0
  1134. mindspore/ops/operations/_sparse_grad_ops.py +56 -0
  1135. mindspore/ops/operations/_tensor_array.py +359 -0
  1136. mindspore/ops/operations/_thor_ops.py +807 -0
  1137. mindspore/ops/operations/array_ops.py +6124 -0
  1138. mindspore/ops/operations/comm_ops.py +1985 -0
  1139. mindspore/ops/operations/control_ops.py +127 -0
  1140. mindspore/ops/operations/custom_ops.py +1129 -0
  1141. mindspore/ops/operations/debug_ops.py +678 -0
  1142. mindspore/ops/operations/image_ops.py +1041 -0
  1143. mindspore/ops/operations/inner_ops.py +697 -0
  1144. mindspore/ops/operations/linalg_ops.py +95 -0
  1145. mindspore/ops/operations/manually_defined/__init__.py +24 -0
  1146. mindspore/ops/operations/manually_defined/_inner.py +73 -0
  1147. mindspore/ops/operations/manually_defined/ops_def.py +2271 -0
  1148. mindspore/ops/operations/math_ops.py +5095 -0
  1149. mindspore/ops/operations/nn_ops.py +9575 -0
  1150. mindspore/ops/operations/other_ops.py +874 -0
  1151. mindspore/ops/operations/random_ops.py +1288 -0
  1152. mindspore/ops/operations/reshard_ops.py +53 -0
  1153. mindspore/ops/operations/rl_ops.py +288 -0
  1154. mindspore/ops/operations/sparse_ops.py +2753 -0
  1155. mindspore/ops/operations/spectral_ops.py +111 -0
  1156. mindspore/ops/primitive.py +1046 -0
  1157. mindspore/ops/signature.py +54 -0
  1158. mindspore/ops/vm_impl_registry.py +91 -0
  1159. mindspore/ops_generate/__init__.py +27 -0
  1160. mindspore/ops_generate/arg_dtype_cast.py +252 -0
  1161. mindspore/ops_generate/arg_handler.py +197 -0
  1162. mindspore/ops_generate/gen_aclnn_implement.py +263 -0
  1163. mindspore/ops_generate/gen_constants.py +36 -0
  1164. mindspore/ops_generate/gen_ops.py +1099 -0
  1165. mindspore/ops_generate/gen_ops_inner_prim.py +131 -0
  1166. mindspore/ops_generate/gen_pyboost_func.py +1052 -0
  1167. mindspore/ops_generate/gen_utils.py +209 -0
  1168. mindspore/ops_generate/op_proto.py +145 -0
  1169. mindspore/ops_generate/pyboost_utils.py +367 -0
  1170. mindspore/ops_generate/template.py +261 -0
  1171. mindspore/parallel/__init__.py +30 -0
  1172. mindspore/parallel/_auto_parallel_context.py +1486 -0
  1173. mindspore/parallel/_cell_wrapper.py +174 -0
  1174. mindspore/parallel/_cost_model_context.py +700 -0
  1175. mindspore/parallel/_dp_allreduce_fusion.py +159 -0
  1176. mindspore/parallel/_offload_context.py +275 -0
  1177. mindspore/parallel/_parallel_serialization.py +561 -0
  1178. mindspore/parallel/_ps_context.py +242 -0
  1179. mindspore/parallel/_recovery_context.py +110 -0
  1180. mindspore/parallel/_tensor.py +730 -0
  1181. mindspore/parallel/_transformer/__init__.py +35 -0
  1182. mindspore/parallel/_transformer/layers.py +765 -0
  1183. mindspore/parallel/_transformer/loss.py +251 -0
  1184. mindspore/parallel/_transformer/moe.py +693 -0
  1185. mindspore/parallel/_transformer/op_parallel_config.py +222 -0
  1186. mindspore/parallel/_transformer/transformer.py +3119 -0
  1187. mindspore/parallel/_utils.py +612 -0
  1188. mindspore/parallel/algo_parameter_config.py +400 -0
  1189. mindspore/parallel/checkpoint_transform.py +650 -0
  1190. mindspore/parallel/cluster/__init__.py +15 -0
  1191. mindspore/parallel/cluster/process_entity/__init__.py +18 -0
  1192. mindspore/parallel/cluster/process_entity/_api.py +352 -0
  1193. mindspore/parallel/cluster/process_entity/_utils.py +101 -0
  1194. mindspore/parallel/cluster/run.py +136 -0
  1195. mindspore/parallel/mpi/__init__.py +14 -0
  1196. mindspore/parallel/mpi/_mpi_config.py +116 -0
  1197. mindspore/parallel/parameter_broadcast.py +151 -0
  1198. mindspore/parallel/shard.py +481 -0
  1199. mindspore/parallel/transform_safetensors.py +993 -0
  1200. mindspore/perf_msvcbuildinsights.dll +0 -0
  1201. mindspore/pgodb140.dll +0 -0
  1202. mindspore/pgort140.dll +0 -0
  1203. mindspore/profiler/__init__.py +28 -0
  1204. mindspore/profiler/common/__init__.py +14 -0
  1205. mindspore/profiler/common/constant.py +29 -0
  1206. mindspore/profiler/common/exceptions/__init__.py +14 -0
  1207. mindspore/profiler/common/exceptions/error_code.py +83 -0
  1208. mindspore/profiler/common/exceptions/exceptions.py +286 -0
  1209. mindspore/profiler/common/process_pool.py +41 -0
  1210. mindspore/profiler/common/registry.py +47 -0
  1211. mindspore/profiler/common/singleton.py +28 -0
  1212. mindspore/profiler/common/struct_type.py +118 -0
  1213. mindspore/profiler/common/util.py +472 -0
  1214. mindspore/profiler/common/validator/__init__.py +14 -0
  1215. mindspore/profiler/common/validator/validate_path.py +84 -0
  1216. mindspore/profiler/dynamic_profiler.py +694 -0
  1217. mindspore/profiler/envprofiling.py +254 -0
  1218. mindspore/profiler/parser/__init__.py +14 -0
  1219. mindspore/profiler/parser/aicpu_data_parser.py +272 -0
  1220. mindspore/profiler/parser/ascend_analysis/__init__.py +14 -0
  1221. mindspore/profiler/parser/ascend_analysis/constant.py +71 -0
  1222. mindspore/profiler/parser/ascend_analysis/file_manager.py +180 -0
  1223. mindspore/profiler/parser/ascend_analysis/function_event.py +185 -0
  1224. mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +136 -0
  1225. mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +131 -0
  1226. mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +104 -0
  1227. mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
  1228. mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +123 -0
  1229. mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +86 -0
  1230. mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +75 -0
  1231. mindspore/profiler/parser/ascend_cluster_generator.py +116 -0
  1232. mindspore/profiler/parser/ascend_communicate_generator.py +314 -0
  1233. mindspore/profiler/parser/ascend_flops_generator.py +116 -0
  1234. mindspore/profiler/parser/ascend_fpbp_generator.py +82 -0
  1235. mindspore/profiler/parser/ascend_hccl_generator.py +271 -0
  1236. mindspore/profiler/parser/ascend_integrate_generator.py +42 -0
  1237. mindspore/profiler/parser/ascend_memory_generator.py +185 -0
  1238. mindspore/profiler/parser/ascend_msprof_exporter.py +282 -0
  1239. mindspore/profiler/parser/ascend_msprof_generator.py +187 -0
  1240. mindspore/profiler/parser/ascend_op_generator.py +334 -0
  1241. mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
  1242. mindspore/profiler/parser/ascend_timeline_generator.py +545 -0
  1243. mindspore/profiler/parser/base_timeline_generator.py +483 -0
  1244. mindspore/profiler/parser/container.py +229 -0
  1245. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +697 -0
  1246. mindspore/profiler/parser/flops_parser.py +531 -0
  1247. mindspore/profiler/parser/framework_enum.py +111 -0
  1248. mindspore/profiler/parser/framework_parser.py +464 -0
  1249. mindspore/profiler/parser/framework_struct.py +61 -0
  1250. mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
  1251. mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
  1252. mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
  1253. mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
  1254. mindspore/profiler/parser/hccl_parser.py +573 -0
  1255. mindspore/profiler/parser/hwts_log_parser.py +122 -0
  1256. mindspore/profiler/parser/integrator.py +526 -0
  1257. mindspore/profiler/parser/memory_usage_parser.py +277 -0
  1258. mindspore/profiler/parser/minddata_analyzer.py +800 -0
  1259. mindspore/profiler/parser/minddata_parser.py +186 -0
  1260. mindspore/profiler/parser/minddata_pipeline_parser.py +299 -0
  1261. mindspore/profiler/parser/op_intermediate_parser.py +149 -0
  1262. mindspore/profiler/parser/optime_parser.py +250 -0
  1263. mindspore/profiler/parser/profiler_info.py +213 -0
  1264. mindspore/profiler/parser/step_trace_parser.py +666 -0
  1265. mindspore/profiler/profiler.py +153 -0
  1266. mindspore/profiler/profiling.py +1922 -0
  1267. mindspore/rewrite/__init__.py +28 -0
  1268. mindspore/rewrite/api/__init__.py +17 -0
  1269. mindspore/rewrite/api/node.py +519 -0
  1270. mindspore/rewrite/api/node_type.py +53 -0
  1271. mindspore/rewrite/api/pattern_engine.py +490 -0
  1272. mindspore/rewrite/api/scoped_value.py +181 -0
  1273. mindspore/rewrite/api/symbol_tree.py +497 -0
  1274. mindspore/rewrite/ast_helpers/__init__.py +25 -0
  1275. mindspore/rewrite/ast_helpers/ast_converter.py +143 -0
  1276. mindspore/rewrite/ast_helpers/ast_finder.py +404 -0
  1277. mindspore/rewrite/ast_helpers/ast_flattener.py +268 -0
  1278. mindspore/rewrite/ast_helpers/ast_modifier.py +605 -0
  1279. mindspore/rewrite/ast_helpers/ast_replacer.py +79 -0
  1280. mindspore/rewrite/common/__init__.py +19 -0
  1281. mindspore/rewrite/common/config.py +24 -0
  1282. mindspore/rewrite/common/error_log.py +39 -0
  1283. mindspore/rewrite/common/event.py +28 -0
  1284. mindspore/rewrite/common/namer.py +271 -0
  1285. mindspore/rewrite/common/namespace.py +118 -0
  1286. mindspore/rewrite/common/observable.py +44 -0
  1287. mindspore/rewrite/common/observer.py +54 -0
  1288. mindspore/rewrite/node/__init__.py +22 -0
  1289. mindspore/rewrite/node/call_function.py +95 -0
  1290. mindspore/rewrite/node/cell_container.py +139 -0
  1291. mindspore/rewrite/node/control_flow.py +113 -0
  1292. mindspore/rewrite/node/node.py +1428 -0
  1293. mindspore/rewrite/node/node_manager.py +283 -0
  1294. mindspore/rewrite/node/node_topological_manager.py +223 -0
  1295. mindspore/rewrite/parsers/__init__.py +29 -0
  1296. mindspore/rewrite/parsers/arguments_parser.py +63 -0
  1297. mindspore/rewrite/parsers/assign_parser.py +852 -0
  1298. mindspore/rewrite/parsers/attribute_parser.py +57 -0
  1299. mindspore/rewrite/parsers/class_def_parser.py +289 -0
  1300. mindspore/rewrite/parsers/constant_parser.py +104 -0
  1301. mindspore/rewrite/parsers/container_parser.py +88 -0
  1302. mindspore/rewrite/parsers/expr_parser.py +55 -0
  1303. mindspore/rewrite/parsers/for_parser.py +61 -0
  1304. mindspore/rewrite/parsers/function_def_parser.py +84 -0
  1305. mindspore/rewrite/parsers/if_parser.py +85 -0
  1306. mindspore/rewrite/parsers/module_parser.py +117 -0
  1307. mindspore/rewrite/parsers/parser.py +43 -0
  1308. mindspore/rewrite/parsers/parser_register.py +86 -0
  1309. mindspore/rewrite/parsers/return_parser.py +37 -0
  1310. mindspore/rewrite/parsers/while_parser.py +59 -0
  1311. mindspore/rewrite/sparsify/__init__.py +0 -0
  1312. mindspore/rewrite/sparsify/sparse_transformer.py +457 -0
  1313. mindspore/rewrite/sparsify/sparsify.py +112 -0
  1314. mindspore/rewrite/sparsify/utils.py +179 -0
  1315. mindspore/rewrite/symbol_tree/__init__.py +20 -0
  1316. mindspore/rewrite/symbol_tree/symbol_tree.py +1819 -0
  1317. mindspore/rewrite/symbol_tree/symbol_tree_builder.py +76 -0
  1318. mindspore/rewrite/symbol_tree/symbol_tree_dumper.py +142 -0
  1319. mindspore/run_check/__init__.py +20 -0
  1320. mindspore/run_check/_check_version.py +507 -0
  1321. mindspore/run_check/run_check.py +66 -0
  1322. mindspore/safeguard/__init__.py +18 -0
  1323. mindspore/safeguard/rewrite_obfuscation.py +875 -0
  1324. mindspore/swresample-4.dll +0 -0
  1325. mindspore/swscale-6.dll +0 -0
  1326. mindspore/tbbmalloc.dll +0 -0
  1327. mindspore/tinyxml2.dll +0 -0
  1328. mindspore/train/__init__.py +48 -0
  1329. mindspore/train/_utils.py +465 -0
  1330. mindspore/train/amp.py +935 -0
  1331. mindspore/train/anf_ir_pb2.py +1517 -0
  1332. mindspore/train/callback/__init__.py +44 -0
  1333. mindspore/train/callback/_backup_and_restore.py +117 -0
  1334. mindspore/train/callback/_callback.py +613 -0
  1335. mindspore/train/callback/_checkpoint.py +814 -0
  1336. mindspore/train/callback/_cluster_monitor.py +201 -0
  1337. mindspore/train/callback/_dataset_graph.py +150 -0
  1338. mindspore/train/callback/_early_stop.py +239 -0
  1339. mindspore/train/callback/_flops_collector.py +239 -0
  1340. mindspore/train/callback/_history.py +92 -0
  1341. mindspore/train/callback/_lambda_callback.py +80 -0
  1342. mindspore/train/callback/_landscape.py +1049 -0
  1343. mindspore/train/callback/_loss_monitor.py +107 -0
  1344. mindspore/train/callback/_lr_scheduler_callback.py +76 -0
  1345. mindspore/train/callback/_on_request_exit.py +298 -0
  1346. mindspore/train/callback/_reduce_lr_on_plateau.py +226 -0
  1347. mindspore/train/callback/_summary_collector.py +1184 -0
  1348. mindspore/train/callback/_tft_register.py +352 -0
  1349. mindspore/train/callback/_time_monitor.py +141 -0
  1350. mindspore/train/checkpoint_pb2.py +233 -0
  1351. mindspore/train/data_sink.py +219 -0
  1352. mindspore/train/dataset_helper.py +692 -0
  1353. mindspore/train/lineage_pb2.py +1260 -0
  1354. mindspore/train/loss_scale_manager.py +213 -0
  1355. mindspore/train/memory_profiling_pb2.py +298 -0
  1356. mindspore/train/metrics/__init__.py +175 -0
  1357. mindspore/train/metrics/accuracy.py +133 -0
  1358. mindspore/train/metrics/auc.py +129 -0
  1359. mindspore/train/metrics/bleu_score.py +170 -0
  1360. mindspore/train/metrics/confusion_matrix.py +700 -0
  1361. mindspore/train/metrics/cosine_similarity.py +109 -0
  1362. mindspore/train/metrics/dice.py +116 -0
  1363. mindspore/train/metrics/error.py +175 -0
  1364. mindspore/train/metrics/fbeta.py +167 -0
  1365. mindspore/train/metrics/hausdorff_distance.py +333 -0
  1366. mindspore/train/metrics/loss.py +97 -0
  1367. mindspore/train/metrics/mean_surface_distance.py +189 -0
  1368. mindspore/train/metrics/metric.py +373 -0
  1369. mindspore/train/metrics/occlusion_sensitivity.py +225 -0
  1370. mindspore/train/metrics/perplexity.py +133 -0
  1371. mindspore/train/metrics/precision.py +160 -0
  1372. mindspore/train/metrics/recall.py +159 -0
  1373. mindspore/train/metrics/roc.py +223 -0
  1374. mindspore/train/metrics/root_mean_square_surface_distance.py +191 -0
  1375. mindspore/train/metrics/topk.py +167 -0
  1376. mindspore/train/mind_ir_pb2.py +1908 -0
  1377. mindspore/train/model.py +2252 -0
  1378. mindspore/train/node_strategy_pb2.py +653 -0
  1379. mindspore/train/print_pb2.py +184 -0
  1380. mindspore/train/profiling_parallel_pb2.py +151 -0
  1381. mindspore/train/serialization.py +3325 -0
  1382. mindspore/train/summary/__init__.py +23 -0
  1383. mindspore/train/summary/_lineage_adapter.py +41 -0
  1384. mindspore/train/summary/_summary_adapter.py +496 -0
  1385. mindspore/train/summary/_writer_pool.py +207 -0
  1386. mindspore/train/summary/enums.py +56 -0
  1387. mindspore/train/summary/summary_record.py +581 -0
  1388. mindspore/train/summary/writer.py +167 -0
  1389. mindspore/train/summary_pb2.py +1165 -0
  1390. mindspore/train/train_thor/__init__.py +20 -0
  1391. mindspore/train/train_thor/convert_utils.py +268 -0
  1392. mindspore/train/train_thor/dataset_helper.py +192 -0
  1393. mindspore/train/train_thor/model_thor.py +257 -0
  1394. mindspore/turbojpeg.dll +0 -0
  1395. mindspore/utils/__init__.py +21 -0
  1396. mindspore/utils/utils.py +60 -0
  1397. mindspore/vcmeta.dll +0 -0
  1398. mindspore/vcomp140.dll +0 -0
  1399. mindspore/vcruntime140.dll +0 -0
  1400. mindspore/vcruntime140_1.dll +0 -0
  1401. mindspore/version.py +1 -0
  1402. mindspore-2.4.0.dist-info/METADATA +352 -0
  1403. mindspore-2.4.0.dist-info/RECORD +1406 -0
  1404. mindspore-2.4.0.dist-info/WHEEL +5 -0
  1405. mindspore-2.4.0.dist-info/entry_points.txt +3 -0
  1406. mindspore-2.4.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1428 @@
1
+ # Copyright 2022 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+ """Node class define of Rewrite. See detail in Node class docstring."""
16
+ from typing import Optional, Union, List, Dict
17
+ import ast
18
+ import inspect
19
+ from types import FunctionType
20
+ import sys
21
+
22
+ from mindspore.nn import Cell
23
+ from mindspore.ops import Primitive
24
+ from mindspore import log as logger
25
+ from ..api.scoped_value import ScopedValue, ValueType
26
+ from ..api.node_type import NodeType
27
+ from ..common.namespace import is_subtree
28
+ from ..common.error_log import error_str
29
+ from ..ast_helpers import AstModifier, AstReplacer, AstConverter
30
+ from ... import _checkparam as Validator
31
+
32
+
33
+ if sys.version_info >= (3, 9):
34
+ import ast as astunparse # pylint: disable=reimported, ungrouped-imports
35
+ else:
36
+ import astunparse
37
+
38
+
39
+ class LocalPrim(Primitive):
40
+ """This class is used to indicate a local primitive instance"""
41
+ def __init__(self, prim_obj: type):
42
+ super().__init__("rewrite_local_prim")
43
+ self.prim_obj = prim_obj
44
+
45
+
46
+ class Node:
47
+ """
48
+ Node is a data structure represents a source code line in network. For the most part, Node represents an operator
49
+ invoking in forward which could be an instance of Cell, an instance of Primitive or a callable method. Fields of
50
+ Node has different meaning in different type of node:
51
+
52
+ - CallCell: a call-cell node represents an assign statement whose value is a calling to cell in mindspore.
53
+ `targets` is corresponding to targets of ast.Assign which means return values of this cell-op. `args` and
54
+ `kwargs` are corresponding to args and keywords of ast.Call which mean arguments to invoke cell-op's forward
55
+ method. `func` is corresponding to func of call expression which means symbol of the cell-op.
56
+ - CallPrimitive: a call-primitive node represents an ast.Assign whose value is a calling to operator in mindspore.
57
+ `targets`, `args`, `kwargs` and `func_name` are as previous.
58
+ - CallMethod: a call-method node represents an ast.Assign whose value is a calling to python-method such as `len`.
59
+ `targets` is corresponding to targets of ast.Assign which means return values of this method. `func_name`
60
+ represents the string name of method. `args` and `kwargs` are corresponding to args and keywords to invoke the
61
+ method. When value of ast.Assign is an ast.Name or ast.Attribute, it means a simplest assign which would also be
62
+ mapped to CallMethod node whose `func_name` is "PassThrough".
63
+ - Python: a python node holds an ast-node which is not parsed. a python node means some python statement is not
64
+ supported by Rewrite or ignored by Rewrite. `targets`, `args`, `kwargs` and `func_name` are don't-care.
65
+ - Input: an input node represents an input of current network which also a parameter of forward method of Cell.
66
+ `targets` is corresponding to arg-name of parameter of forward function. `args` means default-value of parameter
67
+ of forward function. `kwargs` and `func_name` are don't-care.
68
+ - Output: an output node represents the output of current network which is corresponding to return statement of
69
+ forward method of Cell. `args` represents return values. `func_name` are always be "return". `targets` and
70
+ `kwargs` are don't-care.
71
+ - Tree: a tree node represents a sub-network call in current network. A sub-network is also a Cell in mindspore, so
72
+ `targets`, `args`, `kwargs` and `func_name` are same as a call-cell node. `symbol_tree` is a handler of a
73
+ SymbolTree instance.
74
+ """
75
+
76
+ def __init__(self, node_type: NodeType, ast_node: Optional[ast.AST], targets: [ScopedValue],
77
+ func_name: Optional[ScopedValue], args: List[ScopedValue], kwargs: Dict[str, ScopedValue], name: str,
78
+ instance):
79
+ """
80
+ Constructor of Node. Rewrite recommend invoking class method of Node to instantiate an instance of Node such
81
+ as `create_call_op`, `create_call_method`, `create_python_node`, `create_input_node` and
82
+ `create_output_node`, etc. rather than invoking constructor of Node directly.
83
+
84
+ Args:
85
+ node_type (NodeType): A NodeType as type of Node.
86
+ ast_node (ast.AST, optional): An instance of ast.AST represents corresponding node in ast. `ast_node` should
87
+ not be None except when node type is Unknown.
88
+ targets (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
89
+ func_name (ScopedValue, optional): An instance of ScopedValue. See detail in docstring of Node class.
90
+ args (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
91
+ kwargs (Dict[str, ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
92
+ name (str): A string represents name of node. Name of node will be unique when inserted into SymbolTree.
93
+ Name of node also used as field name in network class.
94
+ instance: Object in network corresponding to this node.
95
+ """
96
+ self._node_type: NodeType = node_type
97
+ self._ast_node: Optional[ast.AST] = ast_node
98
+ self._attribute: {str, object} = {}
99
+ if node_type in (NodeType.CallModule, NodeType.CallCell, NodeType.CallPrimitive):
100
+ self._attribute = Node._get_cell_or_prim_op_attribute(instance)
101
+ self._instance = instance
102
+ self._name = name
103
+ self._func_name: Optional[ScopedValue] = func_name
104
+ self._targets: [ScopedValue] = targets if targets is not None else []
105
+ self._args_num = len(args) if args is not None else 0
106
+ self._kwargs_num = len(kwargs) if kwargs is not None else 0
107
+ self._normalized_args_keys = [] # for saving args' order
108
+ self._normalized_args = self._get_normalized_args(args, kwargs)
109
+ # position in graph nodes list
110
+ # it will affect code-order of python code
111
+ self._prev: Optional[Node] = None
112
+ self._next: Optional[Node] = None
113
+ # A handler of SymbolTree current node belonging to
114
+ self._belong_tree = None
115
+ # A handler of NodeManager current node belonging to
116
+ self._node_manager = None
117
+ # A dict that records which target of which Node current Node's argument come from
118
+ self._arg_providers: {int: (Node, int)} = {}
119
+ # A dict that records which argument of which Node uses current Node's target
120
+ self._target_users: {int: [(Node, int)]} = {}
121
+ # Indicate this node represent a class type object, e.g. abs_ops = _get_cache_prim(P.Abs)
122
+ self._type_cls = None
123
+ # Indicate this node represent the initialize of a class type, e.g. abs_inst = P.Abs()
124
+ self._init_cls = None
125
+
126
+ @classmethod
127
+ def create_call_method(cls, ast_node: Optional[ast.AST], targets: [Union[ScopedValue, str]],
128
+ func_name: Union[ScopedValue, str], args: [ScopedValue] = None,
129
+ kwargs: {str: ScopedValue}=None, name: str = ""):
130
+ """
131
+ Class method of Node. Instantiate an instance of node whose type is CallCell. A CallCell node represents an
132
+ invoking to cell-op.
133
+
134
+ Args:
135
+ ast_node ([ast.AST, optional]): An instance of ast.AST represents corresponding node in ast. `ast_node`
136
+ should not be None currently.
137
+ targets (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
138
+ func_name ([ScopedValue, optional]): An instance of ScopedValue. See detail in docstring of Node class.
139
+ args (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
140
+ kwargs (dict{str: ScopedValue}): A list of instance of ScopedValue. See detail in docstring of Node class.
141
+ name (str): A string represents name of node. Name of node will be unique when inserted into SymbolTree.
142
+ Name of node also used as field name in network class.
143
+ """
144
+ if args is None:
145
+ args = []
146
+ if kwargs is None:
147
+ kwargs = {}
148
+ if isinstance(func_name, str):
149
+ func_name = ScopedValue.create_naming_value(func_name)
150
+ new_targets = Node._handle_targets(targets)
151
+ if ast_node is None:
152
+ raise RuntimeError("Input ast_node is None")
153
+ return cls(NodeType.CallMethod, ast_node, new_targets, func_name, args, kwargs, name, None)
154
+
155
+ @classmethod
156
+ def create_python_node(cls, ast_node: ast.AST, name: str = "", instance=None):
157
+ """
158
+ Class method of Node. Instantiate an instance of node whose type is Python. A Python node represents some python
159
+ statement is not supported by Rewrite or ignored by Rewrite.
160
+
161
+ Args:
162
+ ast_node (ast.AST): An instance of ast.AST represents corresponding node in ast.
163
+ name (str): A string represents name of node. Name of node will be unique when inserted into SymbolTree.
164
+ Name of node also used as field name in network class.
165
+ instance: An object corresponding to this node in network.
166
+ """
167
+ return cls(NodeType.Python, ast_node, None, None, [], {}, name, instance)
168
+
169
+ @classmethod
170
+ def create_input_node(cls, ast_node: Optional[ast.AST], arg_name: str, default: Optional[ScopedValue] = None,
171
+ name: str = ""):
172
+ """
173
+ Class method of Node. Instantiate an instance of node whose type is Input. An Input node represents input of
174
+ SymbolTree which is corresponding to parameters of forward function.
175
+
176
+ Args:
177
+ ast_node (ast.AST): An instance of ast.AST represents corresponding node in ast.
178
+ arg_name (str): A string represents name of parameter.
179
+ default ([ScopedValue, optional]): An instance of ScopedValue represents default value of parameter.
180
+ name (str): A string represents name of node. Name of node will be unique when inserted into SymbolTree.
181
+ Name of node also used as field name in network class.
182
+ """
183
+ target = ScopedValue.create_naming_value(arg_name)
184
+ if default is None:
185
+ args = []
186
+ else:
187
+ args = [default]
188
+ if ast_node is None:
189
+ ast_node = ast.arg(arg_name, annotation="")
190
+ return cls(NodeType.Input, ast_node, [target], None, args, {}, name, None)
191
+
192
+ @classmethod
193
+ def create_output_node(cls, ast_node: ast.AST, return_value: [ScopedValue], name: str = "return"):
194
+ """
195
+ Class method of Node. Instantiate an instance of node whose type is Output. An Output node represents output of
196
+ SymbolTree which is corresponding to return statement of forward function.
197
+
198
+ Args:
199
+ ast_node (ast.AST): An instance of ast.AST represents corresponding node in ast.
200
+ return_values (list[str]): A list of string represents name of return values.
201
+ name (ScopedValue): An instance of ScopedValue represents name of node.
202
+ """
203
+ return cls(NodeType.Output, ast_node, None, ScopedValue.create_naming_value("return"), return_value, {},
204
+ name, None)
205
+
206
+ @classmethod
207
+ def create_mathops_node(cls, ast_node: ast.AST, targets: [ScopedValue],
208
+ op_type: ScopedValue, args: [ScopedValue], name: str = ""):
209
+ """
210
+ Class method of Node. Instantiate an instance of node whose type is `MathOps` .
211
+ A mathops node is used to represent a node with mathematical operations, such as
212
+ `y = a + b` , `y = not a` , `y = 0 < a < 1`, `y = a or b` , etc.
213
+
214
+ Args:
215
+ ast_node ([ast.AST, optional]): An instance of ast.AST represents corresponding node in ast. The type of
216
+ node is ast.Assign, and the type of ast_node.value is one of ast.BinOp, ast.UnaryOp, ast.BoolOp and
217
+ ast.Compare.
218
+ targets (list[ScopedValue]): Targets of mathematical operations. A list of instance of `ScopedValue`.
219
+ See detail in docstring of Node class.
220
+ op_type (ScopedValue): The type of ast_node.value saved by string. A ScopedValue with NamingValue type.
221
+ args (list[ScopedValue]): Values participating in the mathematical operations. All values are saved
222
+ sequentially in the list.
223
+ name (str): A string represents name of node. Name of node will be unique when inserted into `SymbolTree`.
224
+ Name of node also used as field name in network class. The format of mathops node name
225
+ is 'AstNodeName_AstOpName_n'.
226
+ """
227
+ return cls(NodeType.MathOps, ast_node, targets, op_type, args, None, name, None)
228
+
229
+ @staticmethod
230
+ def _create_call_function(function: FunctionType, targets: [Union[ScopedValue, str]], args: [ScopedValue] = None,
231
+ kwargs: {str: ScopedValue}=None):
232
+ """
233
+ Create a node that corresponds to a function call.
234
+
235
+ Args:
236
+ function (FunctionType): The function to be called.
237
+ targets (list[str]): indicates output names. Used as targets of an assign statement in source code.
238
+ args (list[ScopedValue]): Indicate input names. Used as args of a call expression of an assign statement in
239
+ source code. Default: ``None`` , which indicates the `function` has no args inputs.
240
+ kwargs (dict): Type of key must be `str` and type of value must be `ScopedValue`.
241
+ Indicate keyword input names. Used as kwargs of a call expression of an assign statement in source
242
+ code. Default: ``None`` , which indicates the `function` has no kwargs inputs.
243
+
244
+ Returns:
245
+ An instance of `Node`.
246
+ """
247
+ if args is None:
248
+ args = []
249
+ if kwargs is None:
250
+ kwargs = {}
251
+ targets = Node._handle_targets(targets)
252
+ func_name = function.__name__
253
+ func_scope_name = ScopedValue.create_naming_value(func_name)
254
+ node = Node.inner_create_call_function(func_name, None, func_scope_name, function, targets, args, kwargs)
255
+ return node
256
+
257
+ @classmethod
258
+ def inner_create_call_function(cls, node_name: str, ast_node: ast.Assign, func_name: ScopedValue, func_obj: object,
259
+ targets: List[ScopedValue], args: List[ScopedValue], kwargs: Dict[str, ScopedValue]):
260
+ '''
261
+ Instantiate an instance of node whose type is `CallFunction`.
262
+
263
+ Args:
264
+ node_name (str): Name of node.
265
+ func_name (ScopedValue): Name of function.
266
+ ast_node ([ast.AST, optional]): An instance of ast.AST represents corresponding node in ast.
267
+ func_obj (Object): An object of function. See detail in docstring of Node class.
268
+ targets (List[ScopedValue]): A list of instance of `ScopedValue`. See detail in docstring of Node class.
269
+ args (List[ScopedValue]): A list of instance of `ScopedValue`. See detail in docstring of Node class.
270
+ kwargs (Dict[str, ScopedValue]): A list of instance of `ScopedValue`. See detail in docstring of `Node`
271
+ class.
272
+ '''
273
+ from . import CallFunction
274
+ # create CallFunction node
275
+ return CallFunction(targets, func_name, args, kwargs, node_name, ast_node, None, None, func_obj, False)
276
+
277
+ @staticmethod
278
+ def create_call_op(op: Union[Cell, Primitive], ast_node: Optional[ast.AST], targets: [Union[ScopedValue, str]],
279
+ args: [ScopedValue] = None, kwargs: {str: ScopedValue}=None, node_name: str = "",
280
+ is_sub_net: bool = False):
281
+ """
282
+ Static method of Node. Instantiate an instance of node whose type is `CallCell` or `CallPrimitive`.
283
+ If op is custom defined, it is treated by TreeNode.
284
+ A `CallCell` node represents an invoking to cell-op.
285
+ A `CallPrimitive` node represents an invoking to primitive-op.
286
+
287
+ Args:
288
+ op (Union[Cell, Primitive]): An instance of `Cell` or `Primitive` corresponding to this node.
289
+ ast_node ([ast.AST, optional]): An instance of ast.AST represents corresponding node in ast.
290
+ targets (list[ScopedValue]): A list of instance of `ScopedValue`. See detail in docstring of Node class.
291
+ args (list[ScopedValue]): A list of instance of `ScopedValue`. See detail in docstring of Node class.
292
+ kwargs (dict{str: ScopedValue}): A list of instance of `ScopedValue`. See detail in docstring of `Node`
293
+ class.
294
+ node_name (str): A string represents name of node. Name of node will be unique when inserted into
295
+ `SymbolTree`. Name of node also used as field name in network class.
296
+ is_sub_net (bool): Indicate that is `cell` a network. If `is_sub_net` is true, Rewrite will try to parse the
297
+ `cell` to a TreeNode, else a CallCell Node. Default is a False.
298
+ """
299
+ Validator.check_value_type("op", op, [Cell, Primitive], "Node")
300
+ if ast_node is not None:
301
+ Validator.check_value_type("ast_node", ast_node, [ast.AST], "Node")
302
+ Validator.check_element_type_of_iterable("targets", targets, [ScopedValue, str], "Node")
303
+ if args is not None:
304
+ Validator.check_element_type_of_iterable("args", args, [ScopedValue], "Node")
305
+ if kwargs is not None:
306
+ Validator.check_element_type_of_dict("kwargs", kwargs, [str], [ScopedValue], "Node")
307
+ if args is None:
308
+ args = []
309
+ if kwargs is None:
310
+ kwargs = {}
311
+ Validator.check_value_type("node_name", node_name, [str], "Node")
312
+ new_targets = Node._handle_targets(targets)
313
+ if isinstance(node_name, str):
314
+ func_name = ScopedValue.create_naming_value(node_name)
315
+ else:
316
+ func_name = node_name
317
+ if is_sub_net and is_subtree(op):
318
+ from ..symbol_tree import SymbolTreeBuilder
319
+ stb = SymbolTreeBuilder(op)
320
+ stree = stb.build()
321
+ replacer = AstReplacer(stree.get_class_ast())
322
+ replacer.replace_all(stree.get_ori_cls_name(), stree.get_opt_cls_name())
323
+ return TreeNode.create_tree_node(stree, ast_node, new_targets, func_name, args, kwargs, node_name, op)
324
+
325
+ return Node.create_call_buildin_op(op, ast_node, new_targets, func_name, args, kwargs, node_name)
326
+
327
+ @classmethod
328
+ def create_call_buildin_op(cls, op: Union[Cell, Primitive], ast_node: Optional[ast.AST], targets: [ScopedValue],
329
+ func_name: ScopedValue, args: [ScopedValue] = None, kwargs: {str: ScopedValue}=None,
330
+ node_name: str = ""):
331
+ """
332
+ Class method of Node. Instantiate an instance of node whose type is `CallCell` or `CallPrimitive`.
333
+ A `CallCell` node represents an invoking to cell-op.
334
+ A `CallPrimitive` node represents an invoking to primitive-op.
335
+
336
+ Args:
337
+ op (Union[Cell, Primitive]): An instance of `Cell` or `Primitive` corresponding to this node.
338
+ ast_node ([ast.AST, optional]): An instance of ast.AST represents corresponding node in ast.
339
+ targets (list[ScopedValue]): A list of instance of `ScopedValue`. See detail in docstring of Node class.
340
+ func_name ([ScopedValue, optional]): An instance of `ScopedValue`. See detail in docstring of Node class.
341
+ args (list[ScopedValue]): A list of instance of `ScopedValue`. See detail in docstring of Node class.
342
+ kwargs (dict{str: ScopedValue}): A list of instance of `ScopedValue`. See detail in docstring of `Node`
343
+ class.
344
+ node_name (str): A string represents name of node. Name of node will be unique when inserted into
345
+ `SymbolTree`. Name of node also used as field name in network class.
346
+ """
347
+
348
+ if not isinstance(op, (Cell, Primitive)):
349
+ raise ValueError("Input op is not a buildin op(Cell or Primitive): ", type(op))
350
+ if isinstance(op, Cell):
351
+ node_type = NodeType.CallCell
352
+ else:
353
+ node_type = NodeType.CallPrimitive
354
+ return cls(node_type, ast_node, targets, func_name, args, kwargs, node_name, op)
355
+
356
+ @staticmethod
357
+ def _get_construct_arg_names(parameters):
358
+ """
359
+ Static method of `Node`. Get parameters' names of the construct function.
360
+
361
+ Args:
362
+ parameters (MappingProxyType): An ordered mapping of parameters' names to the corresponding Parameter
363
+ objects.
364
+
365
+ Raises:
366
+ RuntimeError: Invalid parameter kind.
367
+
368
+ Returns:
369
+ - arg_names, Parameters' names, contain parameters of types in [POSITIONAL_ONLY, POSITIONAL_OR_KEYWORD].
370
+ - var_positional_name, Name of VAR_POSITIONAL parameters.
371
+ - var_keyword_name, Name of VAR_KEYWORD parameters.
372
+ """
373
+ position_only_names: [str] = []
374
+ positional_or_keyword_names: [str] = []
375
+ var_positional_name = None
376
+ keyword_only_names: [str] = []
377
+ var_keyword_name = None
378
+ for name, para in parameters.items():
379
+ if para.kind == inspect.Parameter.POSITIONAL_ONLY: # parameters which appear before a '/'
380
+ position_only_names.append(name)
381
+ elif para.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD: # parameters which appear before '*' or '*args'
382
+ positional_or_keyword_names.append(name)
383
+ elif para.kind == inspect.Parameter.VAR_POSITIONAL: # corresponds to a '*args'
384
+ var_positional_name = name
385
+ elif para.kind == inspect.Parameter.KEYWORD_ONLY: # parameters which appear after '*' and before '**'
386
+ keyword_only_names.append(name)
387
+ elif para.kind == inspect.Parameter.VAR_KEYWORD: # corresponds to a '**kwargs'
388
+ var_keyword_name = name
389
+ else:
390
+ raise RuntimeError("invalid parameter kind:", para.kind)
391
+ if "self" in position_only_names:
392
+ position_only_names.remove("self")
393
+ if "self" in positional_or_keyword_names:
394
+ positional_or_keyword_names.remove("self")
395
+ names = (position_only_names, positional_or_keyword_names, var_positional_name, keyword_only_names,
396
+ var_keyword_name)
397
+ return names
398
+
399
+ @staticmethod
400
+ def _map_args_names(names: tuple, args: [ScopedValue], kwargs: {str: ScopedValue},
401
+ normalized_args_keys: [str], normalized_args: {str: ScopedValue}):
402
+ """
403
+ Fill in normalized_args according to the order of parameters of construct func.
404
+
405
+ Args:
406
+ names (tuple): Parameters' name got from construct func.
407
+ args (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
408
+ kwargs (dict{str: ScopedValue}): A list of instance of ScopedValue. See detail in docstring of Node class.
409
+ normalized_args (dict{str: ScopedValue}): The normalized args to be filled.
410
+
411
+ Raises:
412
+ RuntimeError: Input args are invalid.
413
+ RuntimeError: Arg name already exist in kwargs.
414
+ RuntimeError: Input kwargs invalid.
415
+ """
416
+ position_only_names, positional_or_keyword_names, var_positional_name, keyword_only_names, var_keyword_name = \
417
+ names
418
+ for arg_index, arg in enumerate(args):
419
+ if arg_index < len(position_only_names):
420
+ arg_key = position_only_names[arg_index]
421
+ elif arg_index < len(position_only_names) + len(positional_or_keyword_names):
422
+ arg_key = positional_or_keyword_names[arg_index - len(position_only_names)]
423
+ elif var_positional_name:
424
+ arg_key = "{}_{}".format(var_positional_name, arg_index)
425
+ else:
426
+ raise RuntimeError("Input args are invalid.")
427
+
428
+ if arg_key in kwargs.keys():
429
+ raise RuntimeError("Arg name already exist in kwargs.")
430
+ normalized_args[arg_key] = arg
431
+ normalized_args_keys.append(arg_key)
432
+
433
+ # add kwargs according to parameters' order
434
+ parameters_order: [str] = []
435
+ parameters_order.extend(position_only_names)
436
+ parameters_order.extend(positional_or_keyword_names)
437
+ parameters_order.append(var_keyword_name)
438
+ parameters_order.extend(keyword_only_names)
439
+ parameters_order.append(var_keyword_name)
440
+
441
+ sorted_kwargs = []
442
+ var_keyword_count = len(parameters_order)
443
+ for arg_key, value in kwargs.items():
444
+ if arg_key not in parameters_order and not var_keyword_name:
445
+ raise RuntimeError("Input kwargs invalid.")
446
+ if arg_key in parameters_order:
447
+ sorted_kwargs.append([arg_key, value, parameters_order.index(arg_key)])
448
+ else:
449
+ sorted_kwargs.append([arg_key, value, var_keyword_count])
450
+ var_keyword_count += 1
451
+
452
+ sorted_kwargs.sort(key=lambda x: x[2])
453
+ for sorted_kwarg in sorted_kwargs:
454
+ normalized_args[sorted_kwarg[0]] = sorted_kwarg[1]
455
+ normalized_args_keys.append(sorted_kwarg[0])
456
+
457
+ @staticmethod
458
+ def _handle_custom_obj_in_args(args: [ScopedValue]) -> [ScopedValue]:
459
+ """
460
+ Convert CustomObjValue type argument to NamingValue type argument.
461
+
462
+ Args:
463
+ args (list[ScopedValue]): A list of instance of ScopedValue to be converted.
464
+
465
+ Returns:
466
+ A list of instance of ScopedValue which have been converted.
467
+ """
468
+ result = []
469
+ for arg in args:
470
+ if not isinstance(arg, ScopedValue):
471
+ raise TypeError("arg should be ScopedValue, got: ", type(arg))
472
+ if arg.type == ValueType.CustomObjValue:
473
+ logger.info("custom-object exist in args, should be replace before compile")
474
+ result.append(ScopedValue.create_naming_value("custom-object", "self"))
475
+ else:
476
+ result.append(arg)
477
+ return result
478
+
479
+ @staticmethod
480
+ def _handle_custom_obj_in_kwargs(kwargs: {str: ScopedValue}) -> {str: ScopedValue}:
481
+ """
482
+ Convert CustomObjValue type argument to NamingValue type argument.
483
+
484
+ Args:
485
+ kwargs (dict{str: ScopedValue}): A str to instance of ScopedValue dict whose value to be converted.
486
+
487
+ Returns:
488
+ A str to instance of ScopedValue dict whose value has be converted.
489
+ """
490
+ result: {str, ScopedValue} = {}
491
+ for arg, value in kwargs.items():
492
+ if not isinstance(value, ScopedValue):
493
+ raise TypeError("value should be ScopedValue, got: ", type(value))
494
+ if value.type == ValueType.CustomObjValue:
495
+ result[arg] = ScopedValue.create_naming_value("custom-object", "self")
496
+ else:
497
+ result[arg] = value
498
+ return result
499
+
500
+ @staticmethod
501
+ def _handle_targets(targets: [Union[ScopedValue, str]]) -> [ScopedValue]:
502
+ """
503
+ Normalize targets to be a list of ScopedValue. If target is a str, it will be converted to NamingValue type
504
+ ScopedValue.
505
+
506
+ Args:
507
+ targets (Union[ScopedValue, str]]): A list whose element could be a ScopedValue or a str to be normalized.
508
+
509
+ Returns:
510
+ A list of instance of ScopedValue which have been converted.
511
+ """
512
+ if not isinstance(targets, list):
513
+ raise TypeError("targets should be list, got: ", type(targets))
514
+ results = []
515
+ for target in targets:
516
+ if isinstance(target, str):
517
+ scope = ""
518
+ name = target
519
+ if target.count('.') > 0:
520
+ scope, name = target.rsplit('.', 1)
521
+ results.append(ScopedValue.create_naming_value(name, scope))
522
+ elif isinstance(target, ScopedValue):
523
+ results.append(target)
524
+ else:
525
+ raise RuntimeError("Invalid symbol type: ", target)
526
+ return results
527
+
528
+ @staticmethod
529
+ def _get_cell_or_prim_op_attribute(obj) -> dict:
530
+ """
531
+ Find attributes of cell-op or primitive-op.
532
+
533
+ Args:
534
+ obj: A cell-op or a primitive-op.
535
+
536
+ Returns:
537
+ A dict represents attributes of input 'obj'.
538
+ """
539
+ attributes = {}
540
+ if obj is None:
541
+ return attributes
542
+ for k, v in obj.__dict__.items():
543
+ if k.startswith("_"):
544
+ continue
545
+ attributes[k] = v
546
+ attributes["cls"] = obj.__class__
547
+ return attributes
548
+
549
+ def get_type_cls(self) -> object:
550
+ """Get the class type object this node represented, e.g. abs_ops = _get_cache_prim(P.Abs)"""
551
+ return self._type_cls
552
+
553
+ def set_type_cls(self, x):
554
+ """Set the class type object this node represented, e.g. abs_ops = _get_cache_prim(P.Abs)"""
555
+ self._type_cls = x
556
+
557
+ def get_init_cls(self) -> object:
558
+ """Get the class type object initialized by this node, e.g. abs_inst = P.Abs()"""
559
+ return self._init_cls
560
+
561
+ def set_init_cls(self, x):
562
+ """Set the class type object initialized by this node"""
563
+ self._init_cls = x
564
+
565
+ def get_prev(self) -> 'Node':
566
+ """
567
+ Get previous node of current node in source code order.
568
+
569
+ Returns:
570
+ An instance of Node as previous node.
571
+ """
572
+ return self._prev
573
+
574
+ def get_next(self) -> 'Node':
575
+ """
576
+ Get next node of current node in source code order.
577
+
578
+ Returns:
579
+ An instance of Node as next node.
580
+ """
581
+ return self._next
582
+
583
+ def set_prev(self, node: 'Node'):
584
+ """
585
+ Set previous node of current node.
586
+
587
+ Args:
588
+ node (Node): Node to be set as previous node of current node.
589
+ """
590
+ self._prev = node
591
+
592
+ def set_next(self, node: 'Node'):
593
+ """
594
+ Set next node of current node.
595
+
596
+ Args:
597
+ node (Node): Node to be set as next node of current node.
598
+ """
599
+ self._next = node
600
+
601
+ def get_ast(self) -> Optional[ast.AST]:
602
+ """
603
+ Getter of _ast_node.
604
+
605
+ Returns:
606
+ An instance of ast.AST if self._ast_node if not None else None.
607
+ """
608
+ return self._ast_node
609
+
610
+ def set_ast(self, ast_node: ast.AST):
611
+ """
612
+ Setter of _ast_node.
613
+
614
+ Args:
615
+ ast_node (ast.AST): An instance of ast.AST as new value for _ast_node.
616
+ """
617
+ if not isinstance(ast_node, ast.AST):
618
+ raise TypeError("ast_node should be ast.AST, got: ", type(ast_node))
619
+ self._ast_node = ast_node
620
+
621
+ def get_belong_symbol_tree(self):
622
+ """Get the symbol tree to which node belongs."""
623
+ return self._belong_tree
624
+
625
+ def set_belong_symbol_tree(self, symbol_tree):
626
+ """Set the symbol tree to which node belongs."""
627
+ self._belong_tree = symbol_tree
628
+
629
+ def get_node_manager(self):
630
+ """Get the NodeManager current node belongs to."""
631
+ return self._node_manager
632
+
633
+ def set_node_manager(self, node_manager):
634
+ """Set NodeManager current node belongs."""
635
+ self._node_manager = node_manager
636
+
637
+ def isolate(self):
638
+ """Link prev node to next node and isolate node from source code order list."""
639
+ origin_prev: Optional[Node] = self.get_prev()
640
+ origin_next: Optional[Node] = self.get_next()
641
+ if origin_prev is not None:
642
+ origin_prev.set_next(origin_next)
643
+ if origin_next is not None:
644
+ origin_next.set_prev(origin_prev)
645
+ self.set_prev(None)
646
+ self.set_next(None)
647
+
648
+ def insert_before(self, node: 'Node'):
649
+ """
650
+ Insert a node before current node in source code list. Note that topological order is not determined here.
651
+
652
+ Args:
653
+ node (Node): An instance of node to be inserted in.
654
+ """
655
+ node.isolate()
656
+ origin_prev: Optional[Node] = self.get_prev()
657
+ if origin_prev is not None:
658
+ origin_prev.set_next(node)
659
+ node.set_prev(origin_prev)
660
+ node.set_next(self)
661
+ self.set_prev(node)
662
+
663
+ def insert_after(self, node: 'Node'):
664
+ """
665
+ Insert a node after current node in source code list. Note that topological order is not determined here.
666
+
667
+ Args:
668
+ node (Node): An instance of node to be inserted in.
669
+ """
670
+ node.isolate()
671
+ origin_next: Optional[Node] = self.get_next()
672
+ self.set_next(node)
673
+ node.set_prev(self)
674
+ node.set_next(origin_next)
675
+ if origin_next is not None:
676
+ origin_next.set_prev(node)
677
+
678
+ def get_inputs(self) -> ['Node']:
679
+ """
680
+ Get input nodes of current node in topological order.
681
+
682
+ Returns:
683
+ A list of instances of Node as input nodes.
684
+ """
685
+ inputs = []
686
+ for arg_provider in self.get_arg_providers().values():
687
+ if not arg_provider:
688
+ continue
689
+ inputs.append(arg_provider[0])
690
+ return inputs
691
+
692
+ def get_users(self) -> ['Node']:
693
+ """
694
+ Get user nodes of current node in topological order.
695
+
696
+ Returns:
697
+ A list of instances of Node as user nodes.
698
+ """
699
+ users = []
700
+ for target_users in self.get_target_users().values():
701
+ if not target_users:
702
+ continue
703
+ for (user, _) in target_users:
704
+ if user not in users:
705
+ users.append(user)
706
+ return users
707
+
708
+ def get_targets(self) -> [ScopedValue]:
709
+ """
710
+ Getter of _targets.
711
+
712
+ - When node_type of current node is CallCell or CallPrimitive or CallMethod or Tree, `targets` are strings
713
+ represents invoke result of the cell-op or primitive-op or function-call which are corresponding to targets of
714
+ ast.Assign.
715
+ - When node_type of current node is Input, `targets` should have only one element which is a string represents
716
+ name of parameter of function.
717
+ - When node_type of current node is Python or Output, `targets` are don't-care.
718
+
719
+ Returns:
720
+ A list of instances of ScopedValue as targets of node.
721
+ """
722
+ return self._targets
723
+
724
+ def set_targets(self, targets: [ScopedValue]):
725
+ """
726
+ Setter of _targets.
727
+
728
+ Note:
729
+ This interface can only be called before node been inserted into symbol-tree because target will be unique
730
+ while insert into symbol-tree, in other word, set_targets is not a user-interface.
731
+
732
+ When `_targets` is updated, corresponding ast node would be updated also.
733
+
734
+ When node_type of current node is CallCell or CallPrimitive or CallMethod or Tree, `targets` are strings
735
+ represents invoke result of the cell-op or primitive-op or function-call which are corresponding to targets
736
+ of ast.Assign.
737
+
738
+ When node_type of current node is Input, `targets` should have only one element which is a string represents
739
+ name of parameter of function.
740
+
741
+ When node_type of current node is Python or Output, `targets` are don't-care.
742
+
743
+ Args:
744
+ targets ([ScopedValue]): A list of instances of ScopedValue as new targets.
745
+ """
746
+ self._targets = targets
747
+ if self._node_type in (NodeType.CallCell, NodeType.CallMethod, NodeType.CallPrimitive,
748
+ NodeType.Tree, NodeType.CallFunction, NodeType.CellContainer,
749
+ NodeType.MathOps):
750
+ self._sync_assign_targets_to_ast()
751
+
752
+ def get_func_name(self) -> ScopedValue:
753
+ """
754
+ Getter of `_func_name`. See detail in docstring of Node class for meaning of func.
755
+
756
+ Returns:
757
+ An instance of ScopedValue.
758
+ """
759
+ return self._func_name
760
+
761
+ def set_func_name(self, func_name: ScopedValue):
762
+ """
763
+ Setter of `_func_name`. See detail in docstring of Node class for meaning of func.
764
+
765
+ Note:
766
+ When `_func_name` is updated, corresponding ast node would be updated also.
767
+
768
+ Args:
769
+ func (ScopedValue): An instance of ScopedValue as new func.
770
+ """
771
+ self._func_name = func_name
772
+ if self._node_type in (NodeType.CallCell, NodeType.CallPrimitive):
773
+ self._sync_assign_func_name_to_ast()
774
+
775
+ def get_name(self) -> str:
776
+ """
777
+ Getter of `_name`.
778
+
779
+ Returns:
780
+ A str represents name of node.
781
+ """
782
+ return self._name
783
+
784
+ def set_name(self, name: str):
785
+ """
786
+ Setter of `_name`.
787
+
788
+ Args:
789
+ name (str): A str as new name of node.
790
+ """
791
+ self._name = name
792
+
793
+ def get_node_type(self) -> NodeType:
794
+ """
795
+ Get the node_type of current node.
796
+
797
+ Returns:
798
+ A NodeType as node_type of node.
799
+ """
800
+ return self._node_type
801
+
802
+ def get_instance_type(self) -> type:
803
+ """
804
+ Get the instance_type of current node.
805
+
806
+ - When node_type of current node is CallCell, instance_type is type of cell-op.
807
+ - When node_type of current node is CallPrimitive, instance_type is type of primitive-op.
808
+ - When node_type of current node is Tree, instance_type is type of network-cell.
809
+ - When node_type of current node is Python, Input, Output or CallMethod, instance_type should be NoneType
810
+
811
+ Returns:
812
+ A type.
813
+ """
814
+ if isinstance(self._instance, LocalPrim):
815
+ return self._instance.prim_obj
816
+ if inspect.isfunction(self._instance):
817
+ return self._instance
818
+ return type(self._instance)
819
+
820
+ def get_instance(self):
821
+ """
822
+ Get the instance of current node.
823
+
824
+ - When node_type of current node is CallCell, instance is an instance of Cell.
825
+ - When node_type of current node is CallPrimitive, instance is an instance of primitive.
826
+ - When node_type of current node is Tree, instance is an instance of network-cell.
827
+ - When node_type of current node is Python, Input, Output or CallMethod, instance should be None
828
+
829
+ Returns:
830
+ A object.
831
+ """
832
+ return self._instance
833
+
834
+ def set_arg_by_node(self, arg_idx: int, node: 'Node', out_idx: Optional[int] = None):
835
+ """
836
+ Set argument by another Node.
837
+ Note that when _normalized_args is updated, corresponding ast node would be updated also.
838
+
839
+ Args:
840
+ arg_idx (int): Indicate which input being modified.
841
+ node (Node): Node as new input. Can be a node or name of node.
842
+ out_idx ([int, optional]): Indicate which output of `node` as new argument. Default is None which means use
843
+ first output of `node_to_link` as new input.
844
+
845
+ Raises:
846
+ ValueError: If `arg_idx` is out of range.
847
+ ValueError: If `node` has multi-outputs while `out_idx` is None or `out_idx` is not offered.
848
+ """
849
+ Validator.check_value_type("node", node, [Node], "Node")
850
+ Validator.check_int_range(arg_idx, 0, self._args_num, Validator.INC_LEFT, "arg_idx")
851
+ if out_idx is None:
852
+ if len(node.get_targets()) != 1:
853
+ raise ValueError("node should has one output when out_idx is not provided")
854
+ out_idx = 0
855
+ Validator.check_int_range(out_idx, 0, len(node.get_targets()), Validator.INC_LEFT, "arg_idx")
856
+ new_arg = node.get_targets()[out_idx]
857
+ self._normalized_args[self._normalized_args_keys[arg_idx]] = new_arg
858
+ self._sync_arg()
859
+
860
+ def set_arg(self, arg: Union[ScopedValue, str], index: int) -> (ScopedValue, ScopedValue):
861
+ """
862
+ Set argument of `node`.
863
+ Note that when _normalized_args is updated, corresponding ast node would be updated also.
864
+
865
+ Args:
866
+ index (int): Indicate which input being modified.
867
+ arg (Union[ScopedValue, str]): New argument to been set.
868
+
869
+ Raises:
870
+ ValueError: If `index` is out of range.
871
+ """
872
+ Validator.check_int_range(index, 0, self._args_num, Validator.INC_LEFT, "index")
873
+ Validator.check_value_type("arg", arg, [ScopedValue, str], "Node")
874
+ if isinstance(arg, str):
875
+ arg = ScopedValue.create_naming_value(arg)
876
+ old_arg = self._normalized_args.get(self._normalized_args_keys[index])
877
+ self._normalized_args[self._normalized_args_keys[index]] = arg
878
+ self._sync_arg()
879
+ return arg, old_arg
880
+
881
+ def set_args(self, args: [ScopedValue]):
882
+ """
883
+ Set arguments of `node`.
884
+ Note that when _normalized_args is updated, corresponding ast node would be updated also.
885
+
886
+ Args:
887
+ args (list[ScopedValue]): New arguments to been set.
888
+
889
+ Raises:
890
+ TypeError: Element of new argument is not an instance of ScopedValue.
891
+ """
892
+ Validator.check_int_range(len(args), 0, self._args_num, Validator.INC_LEFT, "Length of args")
893
+ Validator.check_element_type_of_iterable("args", args, [ScopedValue], "Node")
894
+ for arg_index, arg in enumerate(args):
895
+ if not isinstance(arg, ScopedValue):
896
+ raise TypeError("arg should be ScopedValue, got: ", type(arg))
897
+ self._normalized_args[self._normalized_args_keys[arg_index]] = arg
898
+ self._sync_arg()
899
+
900
+ def set_kwargs(self, kwargs: {str: ScopedValue}):
901
+ """
902
+ Set keywords arguments of 'node'.
903
+ Note that when _normalized_args is updated, corresponding ast node would be updated also.
904
+
905
+ Args:
906
+ kwargs (dict{str: ScopedValue}): New arguments to been set.
907
+
908
+ Raises:
909
+ TypeError: Value of new argument is not an instance of ScopedValue.
910
+ RuntimeError: Length of new arguments is not equal to length of old arguments.
911
+ """
912
+ Validator.check_int_range(len(kwargs), 0, self._kwargs_num, Validator.INC_LEFT, "Length of kwargs")
913
+ Validator.check_element_type_of_dict("kwargs", kwargs, [str], [ScopedValue], "Node")
914
+ for key, arg in kwargs.items():
915
+ if key not in self._normalized_args.keys() or key not in self._normalized_args_keys:
916
+ raise RuntimeError("Input key is not exist, ", key)
917
+ if not isinstance(arg, ScopedValue):
918
+ raise TypeError("arg should be ScopedValue, got: ", type(arg))
919
+ self._normalized_args[key] = arg
920
+ self._sync_arg()
921
+
922
+ def set_kwarg(self, key: str, arg: ScopedValue):
923
+ """
924
+ Set keyword argument of 'node'.
925
+ Note that when _normalized_args is updated, corresponding ast node would be updated also.
926
+
927
+ Args:
928
+ key (str): A str represents key of new argument.
929
+ arg (ScopedValue): An instance of ScopedValue represents argument.
930
+
931
+ Raises:
932
+ RuntimeError: If 'key' is not in original kwargs' keys.
933
+ """
934
+ if key not in self._normalized_args_keys[self._args_num:] or key not in self._normalized_args.keys():
935
+ raise RuntimeError("Input key is not exist, ", key)
936
+ self._normalized_args[key] = arg
937
+ self._sync_arg()
938
+
939
+ def get_args(self):
940
+ """
941
+ Get the arguments of current node.
942
+
943
+ - When node_type of current node is CallCell, CallPrimitive or Tree, arguments are corresponding to args of
944
+ ast.Call which represents arguments to invoke cell-op's forward method or primitive-op's `call()` method.
945
+ - When node_type of current node is Input, arguments represents default-value of argument of function.
946
+ - When node_type of current node is Output, arguments represents return values.
947
+ - When node_type of current node is Python, arguments are don't-care.
948
+
949
+ Returns:
950
+ A list of instances of ScopedValue.
951
+ """
952
+ args = []
953
+ for arg_index in range(self._args_num):
954
+ args.append(self._normalized_args.get(self._normalized_args_keys[arg_index]))
955
+ return args
956
+
957
+ def get_kwargs(self):
958
+ """
959
+ Get the keyword arguments of current node.
960
+
961
+ - When node_type of current node is CallCell, CallPrimitive or Tree, keyword arguments are corresponding to
962
+ kwargs of ast.Call which represents arguments to invoke cell-op's forward method or primitive-op's `call()`
963
+ method.
964
+ - When node_type of current node is Python, Input or Output, keyword arguments are don't-care.
965
+
966
+ Returns:
967
+ A dict of str to instance of ScopedValue.
968
+ """
969
+ kwargs: {str, ScopedValue} = {}
970
+ for arg_index in range(self._args_num, self._args_num + self._kwargs_num):
971
+ key = self._normalized_args_keys[arg_index]
972
+ kwargs[key] = self._normalized_args.get(key)
973
+ return kwargs
974
+
975
+ def get_normalized_args(self) -> {str: ScopedValue}:
976
+ """
977
+ Get the normalized keyword arguments of current node.
978
+ Normalized arguments combine arguments and keyword arguments into keyword arguments by using parameter name as
979
+ key of arguments.
980
+
981
+ Returns:
982
+ A dict of str to instance of ScopedValue.
983
+ """
984
+ output = {}
985
+ for key in self._normalized_args_keys:
986
+ output[key] = self._normalized_args.get(key)
987
+ return output
988
+
989
+ def set_normalized_args(self, args: {str, ScopedValue}):
990
+ """
991
+ Set the normalized keyword arguments of current node.
992
+ Normalized arguments combine arguments and keyword arguments into keyword arguments by using parameter name as
993
+ key of arguments.
994
+
995
+ Args:
996
+ args ({str, ScopedValue}): A dict of str to instance of ScopedValue represents new normalized_args.
997
+ """
998
+ if len(args.values()) != len(self._normalized_args_keys):
999
+ raise RuntimeError("Length of args.values() should be equal to length of _normalized_args_keys, ",
1000
+ len(args.values()), " vs ", len(self._normalized_args_keys))
1001
+ for key, arg in args.items():
1002
+ self._normalized_args[key] = arg
1003
+ self._sync_arg()
1004
+
1005
+ def set_attribute(self, key: str, value):
1006
+ """
1007
+ Set attribute of current node.
1008
+
1009
+ Args:
1010
+ key (str): Key of new attribute.
1011
+ value (object): Value of new attribute.
1012
+ """
1013
+ self._attribute[key] = value
1014
+
1015
+ def set_attributes(self, attributes):
1016
+ """
1017
+ Set attributes of current node.
1018
+
1019
+ Args:
1020
+ attributes (dict): A dict represents new attributes.
1021
+ """
1022
+ self._attribute = attributes
1023
+
1024
+ def get_attributes(self):
1025
+ """
1026
+ Get all attributes of current node.
1027
+
1028
+ Returns:
1029
+ A dict of str to instance of object as attributes.
1030
+ """
1031
+ return self._attribute
1032
+
1033
+ def get_attribute(self, key: str):
1034
+ """
1035
+ Get attribute of current node by key.
1036
+
1037
+ Args:
1038
+ key (str): A str represents key of attribute you want to get.
1039
+
1040
+ Returns:
1041
+ A object as attribute.
1042
+ """
1043
+ return self._attribute.get(key)
1044
+
1045
+ def get_arg_providers(self) -> dict:
1046
+ """
1047
+ Getter of _arg_providers.
1048
+
1049
+ Return:
1050
+ dict, key is type of int indicating the index of args, and value is type of tuple, which includes
1051
+ the node and the index of node's targets who provides the argument.
1052
+ """
1053
+ return self._arg_providers
1054
+
1055
+ def set_arg_providers(self, index: int, provider: tuple):
1056
+ """
1057
+ Setter of _arg_providers.
1058
+
1059
+ Args:
1060
+ index (int): Indicating provider of which argument need to be set.
1061
+ provider (tuple): A tuple includes the node and the index of node's targets who provides the argument.
1062
+ """
1063
+ self._arg_providers[index] = provider
1064
+
1065
+ def get_target_users(self, index=-1) -> Union[dict, list]:
1066
+ """
1067
+ Getter of _target_users.
1068
+
1069
+ Args:
1070
+ index (int): Indicating users of which target need to be got. Default: -1, means all targets's users will
1071
+ be returned.
1072
+
1073
+ Return:
1074
+ Union[dict, list]. When index is not -1, a list of users of specified target will be returned.
1075
+ The type of elements in list is tuple, which includes the user node and the index of node's arguments
1076
+ who uses the target. When index is -1, a dict will be returned. The key is index of targets, and the
1077
+ value is list of users of corresponding target.
1078
+ """
1079
+ if index == -1:
1080
+ return self._target_users
1081
+ if index not in self._target_users.keys():
1082
+ self._target_users[index] = []
1083
+ return self._target_users.get(index, None)
1084
+
1085
+ def append_target_users(self, index: int, provider: tuple):
1086
+ """
1087
+ Setter of _target_users.
1088
+
1089
+ Args:
1090
+ index (int): Indicating users of which target need to be append.
1091
+ provider (tuple): A tuple includes the node and the index of node's argument who uses the target.
1092
+
1093
+ """
1094
+ if index not in self._target_users.keys():
1095
+ self._target_users[index] = []
1096
+ self._target_users.get(index).append(provider)
1097
+
1098
+ def update_ast_node(self) -> ast.AST:
1099
+ """Update node's ast_node by current targets, func_name, args and kwargs."""
1100
+ ast_assign = AstModifier.create_call_assign(self.get_targets(), self.get_func_name(),
1101
+ self.get_args(), self.get_kwargs())
1102
+ self.set_ast(ast_assign)
1103
+ return ast_assign
1104
+
1105
+ def get_source_code(self) -> str:
1106
+ """Get source code of node from ast of node."""
1107
+ return astunparse.unparse(self._ast_node).strip()
1108
+
1109
+ def append_kwarg(self, kwarg: Dict[str, ScopedValue]):
1110
+ """
1111
+ Append a new keyword arg to node.
1112
+
1113
+ Args:
1114
+ kwarg (Dict[str, ScopedValue]): The new keyword arg.
1115
+
1116
+ """
1117
+ if self.get_node_type() not in [NodeType.Tree, NodeType.CallFunction]:
1118
+ raise TypeError(f"For append_new_kwarg, the type of node can only be one of [Tree, CallFunction], "
1119
+ f"but got {self.get_node_type()}")
1120
+ Validator.check_element_type_of_dict("kwarg", kwarg, [str], [ScopedValue], "append_new_kwarg")
1121
+ for arg_key, value in kwarg.items():
1122
+ # add keyword into _normalized_args
1123
+ self._normalized_args[arg_key] = value
1124
+ self._normalized_args_keys.append(arg_key)
1125
+ self._kwargs_num += 1
1126
+ # add keyword ast into ast.Call
1127
+ ast_assign: ast.Assign = self._ast_node
1128
+ ast_call: ast.Call = ast_assign.value
1129
+ new_keyword = ast.keyword(arg=arg_key, value=AstModifier.get_ast_by_value(value, None))
1130
+ ast_call.keywords.append(new_keyword)
1131
+
1132
+ def _get_normalized_args(self, args: [ScopedValue], kwargs: {str: ScopedValue}) -> dict:
1133
+ """
1134
+ Merge args and kwargs to normalized args.
1135
+ The keys of args are obtained from the construct function of type(self._instance).
1136
+
1137
+ Args:
1138
+ args (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
1139
+ kwargs (dict{str: ScopedValue}): A list of instance of ScopedValue. See detail in docstring of Node class.
1140
+
1141
+ Raises:
1142
+ RuntimeError: Input args are invalid.
1143
+ RuntimeError: Arg name already exist in kwargs.
1144
+
1145
+ Returns:
1146
+ The normalized args.
1147
+ """
1148
+ if not args:
1149
+ args = []
1150
+ if not kwargs:
1151
+ kwargs = {}
1152
+ normalized_args: dict = dict()
1153
+ if (args or kwargs) and self._instance and hasattr(type(self._instance), "construct"):
1154
+ parameters = inspect.signature(type(self._instance).construct).parameters
1155
+ names = Node._get_construct_arg_names(parameters)
1156
+ Node._map_args_names(names, args, kwargs, self._normalized_args_keys, normalized_args)
1157
+ else:
1158
+ logger.debug("fail to get arg name from op, using arg_xx for args' name")
1159
+ arg_temp_name, suffix = "arg", 0
1160
+ for arg in args:
1161
+ arg_key = "{}_{}".format(arg_temp_name, suffix)
1162
+ while arg_key in kwargs.keys() or arg_key in normalized_args.keys():
1163
+ suffix += 1
1164
+ arg_key = "{}_{}".format(arg_temp_name, suffix)
1165
+ normalized_args[arg_key] = arg
1166
+ self._normalized_args_keys.append(arg_key)
1167
+ for arg_key, value in kwargs.items():
1168
+ normalized_args[arg_key] = value
1169
+ self._normalized_args_keys.append(arg_key)
1170
+ return normalized_args
1171
+
1172
+ # Synchronize rewrite node args to ast node
1173
+ def _sync_assign_func_name_to_ast(self):
1174
+ """Sync func_name of ast.Call of ast.Assign from self._name when NodeType is CallCell or CallPrimitive."""
1175
+ if self._ast_node is None:
1176
+ return
1177
+ assign_ast = self._ast_node
1178
+ if not isinstance(assign_ast, ast.Assign):
1179
+ raise TypeError("assign_ast should be ast.Assign, got: ", type(assign_ast))
1180
+ call_ast = assign_ast.value
1181
+ if not isinstance(call_ast, ast.Call):
1182
+ raise TypeError("call_ast should be ast.Call, got: ", type(call_ast))
1183
+ if self._func_name.type == ValueType.UnsupportedValue:
1184
+ return
1185
+ func_ast = call_ast.func
1186
+ if not self._func_name.scope:
1187
+ if isinstance(func_ast, ast.Name):
1188
+ func_ast.id = self._func_name.value
1189
+ else:
1190
+ call_ast.func = ast.Name(self._func_name.value, ast.Store())
1191
+ else:
1192
+ if isinstance(func_ast, ast.Attribute):
1193
+ if not isinstance(func_ast.value, ast.Name):
1194
+ func_ast.value = ast.Name(self._func_name.scope, ast.Load())
1195
+ else:
1196
+ func_ast.value.id = self._func_name.scope
1197
+ func_ast.attr = self._func_name.value
1198
+ else:
1199
+ call_ast.func = ast.Attribute(ast.Name(self._func_name.scope, ast.Load()),
1200
+ self._func_name.value, ast.Store())
1201
+ ast.fix_missing_locations(assign_ast)
1202
+
1203
+ def _sync_assign_targets_to_ast(self):
1204
+ """Sync targets of ast.Assign from self._targets when NodeType is CallCell, CallPrimitive or CallMethod."""
1205
+ if self._ast_node is None:
1206
+ return
1207
+ assign_ast = self._ast_node
1208
+ if not isinstance(assign_ast, ast.Assign):
1209
+ raise TypeError(error_str(f"assign_ast should be ast.Assign, but got: {type(assign_ast)}",
1210
+ father_node=assign_ast))
1211
+ # update targets
1212
+ target_ast_elems = AstConverter.get_ast_target_elems(assign_ast.targets[0])
1213
+ if len(self._targets) != len(target_ast_elems):
1214
+ raise ValueError(error_str(f"The number of targets should be {len(target_ast_elems)}, "
1215
+ f"but got {len(self._targets)}", father_node=assign_ast))
1216
+ for i, target_ast in enumerate(target_ast_elems):
1217
+ target_ast_elems[i] = AstModifier.get_ast_by_value(self._targets[i], target_ast)
1218
+
1219
+ def _sync_call_args_to_ast(self):
1220
+ """Sync args of ast.Call from self._normalized_args."""
1221
+ if self._ast_node is None:
1222
+ return
1223
+ assign_ast = self._ast_node
1224
+ if not isinstance(assign_ast, ast.Assign):
1225
+ raise TypeError(f"When synchronizing args for '{self._name}'({self._node_type}), _ast_node should be "
1226
+ f"ast.Assign, but got: {type(assign_ast)}")
1227
+ assign_value = assign_ast.value
1228
+ if not isinstance(assign_value, ast.Call):
1229
+ if isinstance(assign_value, ast.Attribute) and self._node_type in (NodeType.CellContainer,
1230
+ NodeType.CallCell):
1231
+ # CellContainers in control flow may be flatten to ast.Attribute: blocks_var = self.blocks
1232
+ # In this case, no args exist in node, so we don't need to sync.
1233
+ # CellContainers may be type of CallCell when share one implementation
1234
+ return
1235
+ raise TypeError(f"When synchronizing args for '{self._name}'({self._node_type}), _ast_node.value should "
1236
+ f"be ast.Call, but got: {type(assign_value)}")
1237
+ keywords_ast = assign_value.keywords
1238
+ args_ast = assign_value.args
1239
+ if len(self._normalized_args_keys) != (len(keywords_ast) + len(args_ast)):
1240
+ raise ValueError("ast keywords plus args len is not equal to self._normalized_args value")
1241
+ for arg_index in range(self._args_num):
1242
+ arg_ast = args_ast[arg_index]
1243
+ args_ast[arg_index] = \
1244
+ AstModifier.get_ast_by_value(self._normalized_args.get(self._normalized_args_keys[arg_index]), arg_ast)
1245
+
1246
+ # the order of kwargs may not the same as that in keywords_ast
1247
+ keyword_map_index = {}
1248
+ for index, keyword_ast in enumerate(keywords_ast):
1249
+ keyword_map_index[keyword_ast.arg] = index
1250
+ for keyword_index in range(self._kwargs_num):
1251
+ key = self._normalized_args_keys[keyword_index + self._args_num]
1252
+ keywords_ast[keyword_map_index.get(key)].value = \
1253
+ AstModifier.get_ast_by_value(self._normalized_args.get(key),
1254
+ keywords_ast[keyword_map_index.get(key)].value)
1255
+
1256
+ def _sync_call_method_args_to_ast(self):
1257
+ """
1258
+ Sync args to value of ast.Assign from self._normalized_args when NodeType is CallMethod.
1259
+ For node with type of CallMethod, the value of ast.Assign is one of:
1260
+ | func_name | data_type | value of ast.Assign |
1261
+ |:---------------|:------------|:------------------------|
1262
+ | 'pass_through' | constants | ast.Constant, ast.NameConstant, ast.Num, ast.Bytes, ast.Str |
1263
+ | 'pass_through' | variables | ast.Name, ast.Attribute |
1264
+ | 'tuple' | tuple | ast.Tuple |
1265
+ | 'list' | list | ast.List |
1266
+ | 'dict' | dict | ast.Dict |
1267
+ """
1268
+ if self._ast_node is None:
1269
+ return
1270
+ assign_ast = self._ast_node
1271
+ if not isinstance(assign_ast, ast.Assign):
1272
+ raise TypeError(f"For node '{self.get_name()}', assign_ast should be ast.Assign, got: ", type(assign_ast))
1273
+ assign_value = assign_ast.value
1274
+ if self._func_name.value == "pass_through":
1275
+ # update constants/variables
1276
+ assign_ast.value = \
1277
+ AstModifier.get_ast_by_value(self._normalized_args.get(self._normalized_args_keys[0]), assign_value)
1278
+ elif self._func_name.value in ("tuple", "list", "dict"):
1279
+ # update tuple/list/dict
1280
+ ast_elts = assign_value.values if isinstance(assign_value, ast.Dict) else assign_value.elts
1281
+ if len(self._normalized_args_keys) != len(ast_elts):
1282
+ raise ValueError(f"For node '{self.get_name()}', size of self._normalized_args_keys"
1283
+ f"({len(self._normalized_args_keys)}) should be equal to size of elements of "
1284
+ f"ast_elts({len(ast_elts)})")
1285
+ for index, elt in enumerate(ast_elts):
1286
+ scoped_value: ScopedValue = self._normalized_args.get(self._normalized_args_keys[index])
1287
+ ast_elts[index] = AstModifier.get_ast_by_value(scoped_value, elt)
1288
+ else:
1289
+ raise TypeError(f"For node '{self.get_name()}', only support (pass_through, tuple or dict method) as "
1290
+ f"call_method, but got {self._func_name.value}")
1291
+
1292
+ def _sync_return_node_to_ast(self):
1293
+ """
1294
+ Sync args to value of ast.Return from self._normalized_args when NodeType is Output.
1295
+
1296
+ For node with type of CallMethod, the value of ast.Assign is one of:
1297
+ (ast.Name, ast.Attribute)
1298
+ """
1299
+ if self._ast_node is None:
1300
+ return
1301
+ return_ast = self._ast_node
1302
+ if not isinstance(return_ast, ast.Return):
1303
+ raise TypeError(f"For node '{self.get_name()}', return_ast should be ast.Return, got: {type(return_ast)}")
1304
+ return_value_ast = return_ast.value
1305
+ return_ast.value = AstModifier.get_ast_by_value(self._normalized_args.get(self._normalized_args_keys[0]),
1306
+ return_value_ast)
1307
+
1308
+ def _sync_mathops_node_args_to_ast(self):
1309
+ """
1310
+ Sync values from self._normalized_args to the ast node for mathematical operations.
1311
+ """
1312
+ if self._ast_node is None:
1313
+ return
1314
+ if not isinstance(self._ast_node, ast.Assign):
1315
+ raise TypeError(f"type of node should be ast.Assign, but got {type(self._ast_node)}")
1316
+ mathops_node = self._ast_node.value
1317
+ if isinstance(mathops_node, ast.BinOp):
1318
+ left = mathops_node.left
1319
+ right = mathops_node.right
1320
+ mathops_node.left = AstModifier.get_ast_by_value(self._normalized_args.get(self._normalized_args_keys[0]),
1321
+ left)
1322
+ mathops_node.right = AstModifier.get_ast_by_value(self._normalized_args.get(self._normalized_args_keys[1]),
1323
+ right)
1324
+ elif isinstance(mathops_node, ast.UnaryOp):
1325
+ operand = mathops_node.operand
1326
+ mathops_node.operand = \
1327
+ AstModifier.get_ast_by_value(self._normalized_args.get(self._normalized_args_keys[0]), operand)
1328
+ elif isinstance(mathops_node, ast.BoolOp):
1329
+ values = mathops_node.values
1330
+ for arg_index in range(self._args_num):
1331
+ arg_value = self._normalized_args.get(self._normalized_args_keys[arg_index])
1332
+ values[arg_index] = AstModifier.get_ast_by_value(arg_value, values[arg_index])
1333
+ elif isinstance(mathops_node, ast.Compare):
1334
+ left = mathops_node.left
1335
+ mathops_node.left = AstModifier.get_ast_by_value(self._normalized_args.get(self._normalized_args_keys[0]),
1336
+ left)
1337
+ comparators = mathops_node.comparators
1338
+ for arg_index in range(1, self._args_num):
1339
+ arg_value = self._normalized_args.get(self._normalized_args_keys[arg_index])
1340
+ comparators[arg_index - 1] = AstModifier.get_ast_by_value(arg_value, comparators[arg_index - 1])
1341
+ else:
1342
+ raise TypeError("The type of 'mathops_node' must be one of (ast.BinOp, ast.UnaryOp, "
1343
+ "ast.BoolOp, ast.Compare), but got ", type(mathops_node))
1344
+
1345
+ def _sync_control_flow_args_to_ast(self):
1346
+ """
1347
+ Sync values from self._normalized_args to the ast node of control flow.
1348
+ """
1349
+ if self._ast_node is None:
1350
+ return
1351
+ normalized_args_num = len(self._normalized_args_keys)
1352
+ if normalized_args_num == 0:
1353
+ return
1354
+ if normalized_args_num > 1:
1355
+ raise ValueError("self._normalized_args_keys should have less than 1 elements")
1356
+ arg_value = self._normalized_args.get(self._normalized_args_keys[0])
1357
+ if isinstance(self._ast_node, (ast.If, ast.IfExp, ast.While)):
1358
+ self._ast_node.test = AstModifier.get_ast_by_value(arg_value, self._ast_node.test)
1359
+ elif isinstance(self._ast_node, ast.For):
1360
+ self._ast_node.iter = AstModifier.get_ast_by_value(arg_value, self._ast_node.iter)
1361
+ else:
1362
+ raise ValueError(f"For Control Flow, ast_node should be one of [ast.If, ast.IfExp, "
1363
+ f"ast.While, ast.For], but got {type(self._ast_node)}")
1364
+
1365
+ def _sync_arg(self):
1366
+ """Sync _normalized_args to corresponding ast node when updated."""
1367
+ if self._node_type in (NodeType.CallCell, NodeType.CallPrimitive, NodeType.Tree, \
1368
+ NodeType.CellContainer, NodeType.CallFunction):
1369
+ self._sync_call_args_to_ast()
1370
+ elif self._node_type == NodeType.Output:
1371
+ self._sync_return_node_to_ast()
1372
+ elif self._node_type == NodeType.CallMethod:
1373
+ self._sync_call_method_args_to_ast()
1374
+ elif self._node_type == NodeType.MathOps:
1375
+ self._sync_mathops_node_args_to_ast()
1376
+ elif self._node_type == NodeType.ControlFlow:
1377
+ self._sync_control_flow_args_to_ast()
1378
+
1379
+
1380
+ # Child classes
1381
+ class TreeNode(Node):
1382
+ """Tree type Node who holds a handler of SymbolTree."""
1383
+
1384
+ def __init__(self, tree, ast_node: ast.AST, targets: [ScopedValue], func: ScopedValue,
1385
+ args: [ScopedValue], kwargs: {str: ScopedValue}, name: str, instance):
1386
+ """
1387
+ Constructor of TreeNode. Rewrite recommend to invoking class method of Node to instantiate an instance of
1388
+ TreeNode such as `create_tree_node` rather than invoking constructor of Node directly.
1389
+
1390
+ Args:
1391
+ tree: An instance of SymbolTree represents a handler of sub-symbol-tree.
1392
+ ast_node (ast.AST): An instance of ast.AST represents corresponding node in ast.
1393
+ targets (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
1394
+ func ([ScopedValue, optional]): An instance of ScopedValue. See detail in docstring of Node class.
1395
+ args (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
1396
+ kwargs (dict{str: ScopedValue}): A list of instance of ScopedValue. See detail in docstring of Node class.
1397
+ name (str): A string represents name of node. Name of node will be unique when inserted into SymbolTree.
1398
+ Name of node also used as field name in network class.
1399
+ instance: Object in network corresponding to this node.
1400
+ """
1401
+ if isinstance(func, str):
1402
+ func = ScopedValue.create_naming_value(func)
1403
+ super().__init__(NodeType.Tree, ast_node, targets, func, args, kwargs, name, instance)
1404
+ self.symbol_tree = tree
1405
+
1406
+ @classmethod
1407
+ def create_tree_node(cls, tree, ast_node: ast.AST, targets: Union[ScopedValue, str],
1408
+ func_name: Union[ScopedValue, str], args: [ScopedValue], kwargs: {str: ScopedValue},
1409
+ name: str = "", instance=None):
1410
+ """
1411
+ Class method of TreeNode. Instantiate an instance of node whose type is Tree. A Tree node represents an invoking
1412
+ to sub-network.
1413
+
1414
+ Args:
1415
+ tree: An instance of SymbolTree represents a handler of sub-symbol-tree.
1416
+ ast_node (ast.AST): An instance of ast.AST represents corresponding node in ast.
1417
+ targets (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
1418
+ func_name ([ScopedValue, optional]): An instance of ScopedValue. See detail in docstring of Node class.
1419
+ args (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
1420
+ kwargs (dict{str: ScopedValue}): A list of instance of ScopedValue. See detail in docstring of Node class.
1421
+ name (str): A string represents name of node. Name of node will be unique when inserted into SymbolTree.
1422
+ Name of node also used as field name in network class.
1423
+ instance: Object in network corresponding to this node.
1424
+ """
1425
+ new_targets = Node._handle_targets(targets)
1426
+ if isinstance(func_name, str):
1427
+ func_name = ScopedValue.create_naming_value(func_name)
1428
+ return cls(tree, ast_node, new_targets, func_name, args, kwargs, name, instance)