mindspore 2.3.0__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (1400) hide show
  1. mindspore/.commit_id +1 -0
  2. mindspore/ConcurrencyCheck.dll +0 -0
  3. mindspore/CppBuildInsights.dll +0 -0
  4. mindspore/CppCoreCheck.dll +0 -0
  5. mindspore/EnumIndex.dll +0 -0
  6. mindspore/EspXEngine.dll +0 -0
  7. mindspore/HResultCheck.dll +0 -0
  8. mindspore/KernelTraceControl.dll +0 -0
  9. mindspore/LocalESPC.dll +0 -0
  10. mindspore/Microsoft.Diagnostics.Tracing.EventSource.dll +0 -0
  11. mindspore/Microsoft.VisualStudio.RemoteControl.dll +0 -0
  12. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  13. mindspore/Microsoft.VisualStudio.Utilities.Internal.dll +0 -0
  14. mindspore/Newtonsoft.Json.dll +0 -0
  15. mindspore/System.Runtime.CompilerServices.Unsafe.dll +0 -0
  16. mindspore/VariantClear.dll +0 -0
  17. mindspore/__init__.py +51 -0
  18. mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
  19. mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
  20. mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
  21. mindspore/_check_jit_forbidden_api.py +106 -0
  22. mindspore/_checkparam.py +1378 -0
  23. mindspore/_extends/__init__.py +23 -0
  24. mindspore/_extends/builtin_operations.py +224 -0
  25. mindspore/_extends/graph_kernel/__init__.py +17 -0
  26. mindspore/_extends/graph_kernel/model/__init__.py +19 -0
  27. mindspore/_extends/graph_kernel/model/graph_parallel.py +311 -0
  28. mindspore/_extends/graph_kernel/model/graph_split.py +1348 -0
  29. mindspore/_extends/graph_kernel/model/model.py +553 -0
  30. mindspore/_extends/graph_kernel/model/model_builder.py +216 -0
  31. mindspore/_extends/graph_kernel/parallel_estimate.py +60 -0
  32. mindspore/_extends/graph_kernel/splitter.py +140 -0
  33. mindspore/_extends/graph_kernel/utils.py +28 -0
  34. mindspore/_extends/parallel_compile/__init__.py +19 -0
  35. mindspore/_extends/parallel_compile/akg_compiler/__init__.py +19 -0
  36. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +269 -0
  37. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +529 -0
  38. mindspore/_extends/parallel_compile/akg_compiler/compiler.py +56 -0
  39. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +96 -0
  40. mindspore/_extends/parallel_compile/akg_compiler/get_file_path.py +36 -0
  41. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +556 -0
  42. mindspore/_extends/parallel_compile/akg_compiler/util.py +159 -0
  43. mindspore/_extends/parse/__init__.py +49 -0
  44. mindspore/_extends/parse/compile_config.py +258 -0
  45. mindspore/_extends/parse/namespace.py +136 -0
  46. mindspore/_extends/parse/parser.py +1446 -0
  47. mindspore/_extends/parse/resources.py +213 -0
  48. mindspore/_extends/parse/standard_method.py +4437 -0
  49. mindspore/_extends/parse/trope.py +97 -0
  50. mindspore/_extends/pijit/__init__.py +23 -0
  51. mindspore/_extends/pijit/pijit_func_white_list.py +343 -0
  52. mindspore/_extends/remote/__init__.py +19 -0
  53. mindspore/_extends/remote/kernel_build_server.py +199 -0
  54. mindspore/_extends/remote/kernel_build_server_akg.py +55 -0
  55. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  56. mindspore/_extends/remote/kernel_build_server_ascend.py +75 -0
  57. mindspore/_extends/utils.py +68 -0
  58. mindspore/_install_custom.py +43 -0
  59. mindspore/_profiler.py +30 -0
  60. mindspore/amp.py +419 -0
  61. mindspore/atlprov.dll +0 -0
  62. mindspore/avcodec-59.dll +0 -0
  63. mindspore/avdevice-59.dll +0 -0
  64. mindspore/avfilter-8.dll +0 -0
  65. mindspore/avformat-59.dll +0 -0
  66. mindspore/avutil-57.dll +0 -0
  67. mindspore/boost/__init__.py +42 -0
  68. mindspore/boost/adasum.py +319 -0
  69. mindspore/boost/base.py +535 -0
  70. mindspore/boost/boost.py +400 -0
  71. mindspore/boost/boost_cell_wrapper.py +790 -0
  72. mindspore/boost/dim_reduce.py +323 -0
  73. mindspore/boost/grad_accumulation.py +79 -0
  74. mindspore/boost/grad_freeze.py +382 -0
  75. mindspore/boost/group_loss_scale_manager.py +166 -0
  76. mindspore/boost/less_batch_normalization.py +174 -0
  77. mindspore/c1.dll +0 -0
  78. mindspore/c1xx.dll +0 -0
  79. mindspore/c2.dll +0 -0
  80. mindspore/cfgpersist.dll +0 -0
  81. mindspore/clang_rt.asan_dbg_dynamic-x86_64.dll +0 -0
  82. mindspore/clang_rt.asan_dynamic-x86_64.dll +0 -0
  83. mindspore/common/__init__.py +84 -0
  84. mindspore/common/_auto_dynamic.py +68 -0
  85. mindspore/common/_decorator.py +50 -0
  86. mindspore/common/_jit_fallback_utils.py +110 -0
  87. mindspore/common/_monad.py +25 -0
  88. mindspore/common/_register_for_adapter.py +74 -0
  89. mindspore/common/_register_for_recompute.py +48 -0
  90. mindspore/common/_register_for_tensor.py +45 -0
  91. mindspore/common/_stub_tensor.py +210 -0
  92. mindspore/common/_utils.py +122 -0
  93. mindspore/common/api.py +2049 -0
  94. mindspore/common/auto_dynamic_shape.py +507 -0
  95. mindspore/common/dtype.py +422 -0
  96. mindspore/common/dump.py +131 -0
  97. mindspore/common/file_system.py +48 -0
  98. mindspore/common/generator.py +260 -0
  99. mindspore/common/hook_handle.py +155 -0
  100. mindspore/common/initializer.py +880 -0
  101. mindspore/common/jit_config.py +98 -0
  102. mindspore/common/lazy_inline.py +240 -0
  103. mindspore/common/mindir_util.py +111 -0
  104. mindspore/common/mutable.py +234 -0
  105. mindspore/common/no_inline.py +54 -0
  106. mindspore/common/np_dtype.py +25 -0
  107. mindspore/common/parameter.py +1048 -0
  108. mindspore/common/recompute.py +262 -0
  109. mindspore/common/seed.py +260 -0
  110. mindspore/common/sparse_tensor.py +1171 -0
  111. mindspore/common/symbol.py +122 -0
  112. mindspore/common/tensor.py +4859 -0
  113. mindspore/communication/__init__.py +37 -0
  114. mindspore/communication/_comm_helper.py +466 -0
  115. mindspore/communication/_hccl_management.py +297 -0
  116. mindspore/communication/comm_func.py +1140 -0
  117. mindspore/communication/management.py +673 -0
  118. mindspore/config/op_info.config +533 -0
  119. mindspore/context.py +1976 -0
  120. mindspore/d3dcompiler_47.dll +0 -0
  121. mindspore/dataset/__init__.py +90 -0
  122. mindspore/dataset/audio/__init__.py +61 -0
  123. mindspore/dataset/audio/transforms.py +3690 -0
  124. mindspore/dataset/audio/utils.py +386 -0
  125. mindspore/dataset/audio/validators.py +1172 -0
  126. mindspore/dataset/callback/__init__.py +20 -0
  127. mindspore/dataset/callback/ds_callback.py +368 -0
  128. mindspore/dataset/callback/validators.py +32 -0
  129. mindspore/dataset/core/__init__.py +13 -0
  130. mindspore/dataset/core/config.py +1088 -0
  131. mindspore/dataset/core/datatypes.py +101 -0
  132. mindspore/dataset/core/py_util_helpers.py +65 -0
  133. mindspore/dataset/core/validator_helpers.py +774 -0
  134. mindspore/dataset/debug/__init__.py +21 -0
  135. mindspore/dataset/debug/debug_hook.py +97 -0
  136. mindspore/dataset/debug/pre_defined_hook.py +67 -0
  137. mindspore/dataset/engine/__init__.py +124 -0
  138. mindspore/dataset/engine/cache_admin.py +47 -0
  139. mindspore/dataset/engine/cache_client.py +129 -0
  140. mindspore/dataset/engine/datasets.py +4554 -0
  141. mindspore/dataset/engine/datasets_audio.py +911 -0
  142. mindspore/dataset/engine/datasets_standard_format.py +493 -0
  143. mindspore/dataset/engine/datasets_text.py +2161 -0
  144. mindspore/dataset/engine/datasets_user_defined.py +1114 -0
  145. mindspore/dataset/engine/datasets_vision.py +4816 -0
  146. mindspore/dataset/engine/iterators.py +342 -0
  147. mindspore/dataset/engine/obs/__init__.py +23 -0
  148. mindspore/dataset/engine/obs/config_loader.py +68 -0
  149. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +508 -0
  150. mindspore/dataset/engine/obs/util.py +475 -0
  151. mindspore/dataset/engine/offload.py +596 -0
  152. mindspore/dataset/engine/queue.py +250 -0
  153. mindspore/dataset/engine/samplers.py +895 -0
  154. mindspore/dataset/engine/serializer_deserializer.py +159 -0
  155. mindspore/dataset/engine/validators.py +2875 -0
  156. mindspore/dataset/text/__init__.py +54 -0
  157. mindspore/dataset/text/transforms.py +1703 -0
  158. mindspore/dataset/text/utils.py +715 -0
  159. mindspore/dataset/text/validators.py +642 -0
  160. mindspore/dataset/transforms/__init__.py +48 -0
  161. mindspore/dataset/transforms/c_transforms.py +638 -0
  162. mindspore/dataset/transforms/py_transforms.py +393 -0
  163. mindspore/dataset/transforms/py_transforms_util.py +255 -0
  164. mindspore/dataset/transforms/transforms.py +1260 -0
  165. mindspore/dataset/transforms/validators.py +410 -0
  166. mindspore/dataset/utils/__init__.py +19 -0
  167. mindspore/dataset/utils/browse_dataset.py +190 -0
  168. mindspore/dataset/utils/line_reader.py +124 -0
  169. mindspore/dataset/vision/__init__.py +68 -0
  170. mindspore/dataset/vision/c_transforms.py +2641 -0
  171. mindspore/dataset/vision/py_transforms.py +2120 -0
  172. mindspore/dataset/vision/py_transforms_util.py +1660 -0
  173. mindspore/dataset/vision/transforms.py +7295 -0
  174. mindspore/dataset/vision/utils.py +863 -0
  175. mindspore/dataset/vision/validators.py +1482 -0
  176. mindspore/default_config.py +2 -0
  177. mindspore/dnnl.dll +0 -0
  178. mindspore/dpcmi.dll +0 -0
  179. mindspore/experimental/__init__.py +20 -0
  180. mindspore/experimental/map_parameter.py +309 -0
  181. mindspore/experimental/optim/__init__.py +40 -0
  182. mindspore/experimental/optim/adadelta.py +161 -0
  183. mindspore/experimental/optim/adagrad.py +168 -0
  184. mindspore/experimental/optim/adam.py +193 -0
  185. mindspore/experimental/optim/adamax.py +170 -0
  186. mindspore/experimental/optim/adamw.py +205 -0
  187. mindspore/experimental/optim/asgd.py +153 -0
  188. mindspore/experimental/optim/lr_scheduler.py +1371 -0
  189. mindspore/experimental/optim/nadam.py +157 -0
  190. mindspore/experimental/optim/optimizer.py +259 -0
  191. mindspore/experimental/optim/radam.py +194 -0
  192. mindspore/experimental/optim/rmsprop.py +154 -0
  193. mindspore/experimental/optim/rprop.py +164 -0
  194. mindspore/experimental/optim/sgd.py +156 -0
  195. mindspore/hal/__init__.py +40 -0
  196. mindspore/hal/_ascend.py +57 -0
  197. mindspore/hal/_base.py +57 -0
  198. mindspore/hal/_cpu.py +56 -0
  199. mindspore/hal/_gpu.py +57 -0
  200. mindspore/hal/device.py +356 -0
  201. mindspore/hal/event.py +179 -0
  202. mindspore/hal/memory.py +326 -0
  203. mindspore/hal/stream.py +339 -0
  204. mindspore/include/OWNERS +7 -0
  205. mindspore/include/api/allocator.h +97 -0
  206. mindspore/include/api/callback/callback.h +93 -0
  207. mindspore/include/api/callback/ckpt_saver.h +41 -0
  208. mindspore/include/api/callback/loss_monitor.h +33 -0
  209. mindspore/include/api/callback/lr_scheduler.h +51 -0
  210. mindspore/include/api/callback/time_monitor.h +34 -0
  211. mindspore/include/api/callback/train_accuracy.h +37 -0
  212. mindspore/include/api/cell.h +90 -0
  213. mindspore/include/api/cfg.h +82 -0
  214. mindspore/include/api/context.h +602 -0
  215. mindspore/include/api/data_type.h +47 -0
  216. mindspore/include/api/delegate.h +178 -0
  217. mindspore/include/api/delegate_api.h +75 -0
  218. mindspore/include/api/dual_abi_helper.h +208 -0
  219. mindspore/include/api/format.h +28 -0
  220. mindspore/include/api/graph.h +46 -0
  221. mindspore/include/api/kernel.h +58 -0
  222. mindspore/include/api/kernel_api.h +168 -0
  223. mindspore/include/api/metrics/accuracy.h +36 -0
  224. mindspore/include/api/metrics/metrics.h +41 -0
  225. mindspore/include/api/model.h +438 -0
  226. mindspore/include/api/model_group.h +79 -0
  227. mindspore/include/api/model_parallel_runner.h +168 -0
  228. mindspore/include/api/serialization.h +185 -0
  229. mindspore/include/api/status.h +192 -0
  230. mindspore/include/api/types.h +431 -0
  231. mindspore/include/api/visible.h +41 -0
  232. mindspore/include/c_api/context_c.h +179 -0
  233. mindspore/include/c_api/data_type_c.h +52 -0
  234. mindspore/include/c_api/format_c.h +46 -0
  235. mindspore/include/c_api/model_c.h +347 -0
  236. mindspore/include/c_api/ms/abstract.h +67 -0
  237. mindspore/include/c_api/ms/attribute.h +197 -0
  238. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  239. mindspore/include/c_api/ms/base/macros.h +32 -0
  240. mindspore/include/c_api/ms/base/status.h +33 -0
  241. mindspore/include/c_api/ms/base/types.h +283 -0
  242. mindspore/include/c_api/ms/context.h +102 -0
  243. mindspore/include/c_api/ms/graph.h +160 -0
  244. mindspore/include/c_api/ms/node.h +606 -0
  245. mindspore/include/c_api/ms/tensor.h +161 -0
  246. mindspore/include/c_api/ms/value.h +84 -0
  247. mindspore/include/c_api/status_c.h +79 -0
  248. mindspore/include/c_api/tensor_c.h +146 -0
  249. mindspore/include/c_api/types_c.h +67 -0
  250. mindspore/include/dataset/config.h +163 -0
  251. mindspore/include/dataset/constants.h +363 -0
  252. mindspore/include/dataset/execute.h +196 -0
  253. mindspore/include/dataset/text.h +1092 -0
  254. mindspore/include/dataset/transforms.h +638 -0
  255. mindspore/include/dataset/vision.h +2125 -0
  256. mindspore/include/dataset/vision_ascend.h +206 -0
  257. mindspore/include/dataset/vision_lite.h +625 -0
  258. mindspore/jpeg62.dll +0 -0
  259. mindspore/log.py +633 -0
  260. mindspore/mindrecord/__init__.py +43 -0
  261. mindspore/mindrecord/common/__init__.py +17 -0
  262. mindspore/mindrecord/common/constant.py +20 -0
  263. mindspore/mindrecord/common/enums.py +44 -0
  264. mindspore/mindrecord/common/exceptions.py +311 -0
  265. mindspore/mindrecord/config.py +809 -0
  266. mindspore/mindrecord/filereader.py +174 -0
  267. mindspore/mindrecord/filewriter.py +705 -0
  268. mindspore/mindrecord/mindpage.py +210 -0
  269. mindspore/mindrecord/shardheader.py +141 -0
  270. mindspore/mindrecord/shardindexgenerator.py +74 -0
  271. mindspore/mindrecord/shardreader.py +117 -0
  272. mindspore/mindrecord/shardsegment.py +128 -0
  273. mindspore/mindrecord/shardutils.py +185 -0
  274. mindspore/mindrecord/shardwriter.py +237 -0
  275. mindspore/mindrecord/tools/__init__.py +17 -0
  276. mindspore/mindrecord/tools/cifar10.py +140 -0
  277. mindspore/mindrecord/tools/cifar100.py +153 -0
  278. mindspore/mindrecord/tools/cifar100_to_mr.py +185 -0
  279. mindspore/mindrecord/tools/cifar10_to_mr.py +177 -0
  280. mindspore/mindrecord/tools/csv_to_mr.py +200 -0
  281. mindspore/mindrecord/tools/imagenet_to_mr.py +206 -0
  282. mindspore/mindrecord/tools/mnist_to_mr.py +259 -0
  283. mindspore/mindrecord/tools/tfrecord_to_mr.py +360 -0
  284. mindspore/mindspore_backend.dll +0 -0
  285. mindspore/mindspore_common.dll +0 -0
  286. mindspore/mindspore_core.dll +0 -0
  287. mindspore/mindspore_glog.dll +0 -0
  288. mindspore/mindspore_np_dtype.dll +0 -0
  289. mindspore/mindspore_shared_lib.dll +0 -0
  290. mindspore/mint/__init__.py +1137 -0
  291. mindspore/mint/linalg/__init__.py +22 -0
  292. mindspore/mint/nn/__init__.py +512 -0
  293. mindspore/mint/nn/functional.py +573 -0
  294. mindspore/mint/optim/__init__.py +24 -0
  295. mindspore/mint/optim/adamw.py +185 -0
  296. mindspore/msobj140.dll +0 -0
  297. mindspore/mspdb140.dll +0 -0
  298. mindspore/mspdbcore.dll +0 -0
  299. mindspore/mspdbst.dll +0 -0
  300. mindspore/mspft140.dll +0 -0
  301. mindspore/msvcdis140.dll +0 -0
  302. mindspore/msvcp140.dll +0 -0
  303. mindspore/msvcp140_1.dll +0 -0
  304. mindspore/msvcp140_2.dll +0 -0
  305. mindspore/msvcp140_atomic_wait.dll +0 -0
  306. mindspore/msvcp140_codecvt_ids.dll +0 -0
  307. mindspore/multiprocessing/__init__.py +72 -0
  308. mindspore/nn/__init__.py +48 -0
  309. mindspore/nn/cell.py +2605 -0
  310. mindspore/nn/dynamic_lr.py +482 -0
  311. mindspore/nn/extend/__init__.py +29 -0
  312. mindspore/nn/extend/basic.py +140 -0
  313. mindspore/nn/extend/embedding.py +143 -0
  314. mindspore/nn/extend/layer/__init__.py +27 -0
  315. mindspore/nn/extend/layer/normalization.py +109 -0
  316. mindspore/nn/extend/pooling.py +117 -0
  317. mindspore/nn/grad/__init__.py +21 -0
  318. mindspore/nn/grad/cell_grad.py +196 -0
  319. mindspore/nn/layer/__init__.py +63 -0
  320. mindspore/nn/layer/activation.py +1655 -0
  321. mindspore/nn/layer/basic.py +1519 -0
  322. mindspore/nn/layer/channel_shuffle.py +90 -0
  323. mindspore/nn/layer/combined.py +248 -0
  324. mindspore/nn/layer/container.py +734 -0
  325. mindspore/nn/layer/conv.py +1505 -0
  326. mindspore/nn/layer/dense.py +204 -0
  327. mindspore/nn/layer/embedding.py +751 -0
  328. mindspore/nn/layer/embedding_service.py +531 -0
  329. mindspore/nn/layer/embedding_service_layer.py +393 -0
  330. mindspore/nn/layer/image.py +661 -0
  331. mindspore/nn/layer/math.py +1069 -0
  332. mindspore/nn/layer/normalization.py +1177 -0
  333. mindspore/nn/layer/padding.py +894 -0
  334. mindspore/nn/layer/pooling.py +2148 -0
  335. mindspore/nn/layer/rnn_cells.py +388 -0
  336. mindspore/nn/layer/rnns.py +849 -0
  337. mindspore/nn/layer/thor_layer.py +963 -0
  338. mindspore/nn/layer/timedistributed.py +155 -0
  339. mindspore/nn/layer/transformer.py +823 -0
  340. mindspore/nn/learning_rate_schedule.py +512 -0
  341. mindspore/nn/loss/__init__.py +36 -0
  342. mindspore/nn/loss/loss.py +2846 -0
  343. mindspore/nn/metrics.py +53 -0
  344. mindspore/nn/optim/__init__.py +44 -0
  345. mindspore/nn/optim/_dist_optimizer_registry.py +111 -0
  346. mindspore/nn/optim/ada_grad.py +217 -0
  347. mindspore/nn/optim/adadelta.py +206 -0
  348. mindspore/nn/optim/adafactor.py +448 -0
  349. mindspore/nn/optim/adam.py +1297 -0
  350. mindspore/nn/optim/adamax.py +220 -0
  351. mindspore/nn/optim/adasum.py +548 -0
  352. mindspore/nn/optim/asgd.py +216 -0
  353. mindspore/nn/optim/ftrl.py +401 -0
  354. mindspore/nn/optim/lamb.py +296 -0
  355. mindspore/nn/optim/lars.py +202 -0
  356. mindspore/nn/optim/lazyadam.py +533 -0
  357. mindspore/nn/optim/momentum.py +239 -0
  358. mindspore/nn/optim/optimizer.py +1034 -0
  359. mindspore/nn/optim/proximal_ada_grad.py +242 -0
  360. mindspore/nn/optim/rmsprop.py +264 -0
  361. mindspore/nn/optim/rprop.py +251 -0
  362. mindspore/nn/optim/sgd.py +237 -0
  363. mindspore/nn/optim/thor.py +1310 -0
  364. mindspore/nn/probability/__init__.py +22 -0
  365. mindspore/nn/probability/bijector/__init__.py +35 -0
  366. mindspore/nn/probability/bijector/bijector.py +337 -0
  367. mindspore/nn/probability/bijector/exp.py +65 -0
  368. mindspore/nn/probability/bijector/gumbel_cdf.py +144 -0
  369. mindspore/nn/probability/bijector/invert.py +126 -0
  370. mindspore/nn/probability/bijector/power_transform.py +196 -0
  371. mindspore/nn/probability/bijector/scalar_affine.py +167 -0
  372. mindspore/nn/probability/bijector/softplus.py +189 -0
  373. mindspore/nn/probability/bnn_layers/__init__.py +29 -0
  374. mindspore/nn/probability/bnn_layers/_util.py +46 -0
  375. mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +112 -0
  376. mindspore/nn/probability/bnn_layers/conv_variational.py +267 -0
  377. mindspore/nn/probability/bnn_layers/dense_variational.py +302 -0
  378. mindspore/nn/probability/bnn_layers/layer_distribution.py +123 -0
  379. mindspore/nn/probability/distribution/__init__.py +56 -0
  380. mindspore/nn/probability/distribution/_utils/__init__.py +34 -0
  381. mindspore/nn/probability/distribution/_utils/custom_ops.py +96 -0
  382. mindspore/nn/probability/distribution/_utils/utils.py +362 -0
  383. mindspore/nn/probability/distribution/bernoulli.py +334 -0
  384. mindspore/nn/probability/distribution/beta.py +391 -0
  385. mindspore/nn/probability/distribution/categorical.py +435 -0
  386. mindspore/nn/probability/distribution/cauchy.py +383 -0
  387. mindspore/nn/probability/distribution/distribution.py +827 -0
  388. mindspore/nn/probability/distribution/exponential.py +350 -0
  389. mindspore/nn/probability/distribution/gamma.py +391 -0
  390. mindspore/nn/probability/distribution/geometric.py +335 -0
  391. mindspore/nn/probability/distribution/gumbel.py +257 -0
  392. mindspore/nn/probability/distribution/half_normal.py +133 -0
  393. mindspore/nn/probability/distribution/laplace.py +128 -0
  394. mindspore/nn/probability/distribution/log_normal.py +272 -0
  395. mindspore/nn/probability/distribution/logistic.py +379 -0
  396. mindspore/nn/probability/distribution/normal.py +336 -0
  397. mindspore/nn/probability/distribution/poisson.py +288 -0
  398. mindspore/nn/probability/distribution/student_t.py +149 -0
  399. mindspore/nn/probability/distribution/transformed_distribution.py +235 -0
  400. mindspore/nn/probability/distribution/uniform.py +375 -0
  401. mindspore/nn/reinforcement/__init__.py +24 -0
  402. mindspore/nn/reinforcement/_batch_read_write.py +142 -0
  403. mindspore/nn/reinforcement/_tensors_queue.py +152 -0
  404. mindspore/nn/reinforcement/tensor_array.py +145 -0
  405. mindspore/nn/sparse/__init__.py +23 -0
  406. mindspore/nn/sparse/sparse.py +147 -0
  407. mindspore/nn/wrap/__init__.py +49 -0
  408. mindspore/nn/wrap/cell_wrapper.py +979 -0
  409. mindspore/nn/wrap/grad_reducer.py +608 -0
  410. mindspore/nn/wrap/loss_scale.py +680 -0
  411. mindspore/numpy/__init__.py +121 -0
  412. mindspore/numpy/array_creations.py +2734 -0
  413. mindspore/numpy/array_ops.py +2625 -0
  414. mindspore/numpy/dtypes.py +185 -0
  415. mindspore/numpy/fft.py +431 -0
  416. mindspore/numpy/logic_ops.py +935 -0
  417. mindspore/numpy/math_ops.py +5910 -0
  418. mindspore/numpy/utils.py +214 -0
  419. mindspore/numpy/utils_const.py +565 -0
  420. mindspore/opencv_core452.dll +0 -0
  421. mindspore/opencv_imgcodecs452.dll +0 -0
  422. mindspore/opencv_imgproc452.dll +0 -0
  423. mindspore/ops/__init__.py +54 -0
  424. mindspore/ops/_constants.py +30 -0
  425. mindspore/ops/_grad_experimental/__init__.py +31 -0
  426. mindspore/ops/_grad_experimental/grad_array_ops.py +830 -0
  427. mindspore/ops/_grad_experimental/grad_base.py +143 -0
  428. mindspore/ops/_grad_experimental/grad_comm_ops.py +670 -0
  429. mindspore/ops/_grad_experimental/grad_debug_ops.py +31 -0
  430. mindspore/ops/_grad_experimental/grad_implementations.py +203 -0
  431. mindspore/ops/_grad_experimental/grad_inner_ops.py +79 -0
  432. mindspore/ops/_grad_experimental/grad_math_ops.py +824 -0
  433. mindspore/ops/_grad_experimental/grad_nn_ops.py +231 -0
  434. mindspore/ops/_grad_experimental/grad_quant_ops.py +238 -0
  435. mindspore/ops/_grad_experimental/grad_sparse.py +342 -0
  436. mindspore/ops/_grad_experimental/grad_sparse_ops.py +399 -0
  437. mindspore/ops/_grad_experimental/taylor_rule.py +220 -0
  438. mindspore/ops/_op_impl/__init__.py +23 -0
  439. mindspore/ops/_op_impl/_custom_op/__init__.py +39 -0
  440. mindspore/ops/_op_impl/_custom_op/_basic.py +158 -0
  441. mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +279 -0
  442. mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +156 -0
  443. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +109 -0
  444. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +125 -0
  445. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +105 -0
  446. mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +124 -0
  447. mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +116 -0
  448. mindspore/ops/_op_impl/_custom_op/correction_mul.py +89 -0
  449. mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +196 -0
  450. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +366 -0
  451. mindspore/ops/_op_impl/_custom_op/dsd_impl.py +162 -0
  452. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +136 -0
  453. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +206 -0
  454. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +88 -0
  455. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +128 -0
  456. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +199 -0
  457. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +88 -0
  458. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +156 -0
  459. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +184 -0
  460. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +143 -0
  461. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +169 -0
  462. mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +548 -0
  463. mindspore/ops/_op_impl/_custom_op/img2col_impl.py +881 -0
  464. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +278 -0
  465. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +200 -0
  466. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +334 -0
  467. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +255 -0
  468. mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +222 -0
  469. mindspore/ops/_op_impl/_custom_op/matmul_dds_grad_impl.py +644 -0
  470. mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +488 -0
  471. mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +87 -0
  472. mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +129 -0
  473. mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +121 -0
  474. mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +352 -0
  475. mindspore/ops/_op_impl/aicpu/__init__.py +441 -0
  476. mindspore/ops/_op_impl/aicpu/abs.py +36 -0
  477. mindspore/ops/_op_impl/aicpu/acos.py +32 -0
  478. mindspore/ops/_op_impl/aicpu/acos_grad.py +33 -0
  479. mindspore/ops/_op_impl/aicpu/acosh.py +34 -0
  480. mindspore/ops/_op_impl/aicpu/acosh_grad.py +35 -0
  481. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d.py +34 -0
  482. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
  483. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d.py +39 -0
  484. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d_grad.py +39 -0
  485. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
  486. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d_grad.py +37 -0
  487. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d.py +42 -0
  488. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d_grad.py +152 -0
  489. mindspore/ops/_op_impl/aicpu/add.py +43 -0
  490. mindspore/ops/_op_impl/aicpu/add_n.py +41 -0
  491. mindspore/ops/_op_impl/aicpu/add_v2.py +40 -0
  492. mindspore/ops/_op_impl/aicpu/addcdiv.py +41 -0
  493. mindspore/ops/_op_impl/aicpu/addcmul.py +47 -0
  494. mindspore/ops/_op_impl/aicpu/adjust_contrastv2.py +32 -0
  495. mindspore/ops/_op_impl/aicpu/adjust_hue.py +31 -0
  496. mindspore/ops/_op_impl/aicpu/adjust_saturation.py +32 -0
  497. mindspore/ops/_op_impl/aicpu/affine_grid.py +33 -0
  498. mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
  499. mindspore/ops/_op_impl/aicpu/angle.py +31 -0
  500. mindspore/ops/_op_impl/aicpu/arg_max.py +75 -0
  501. mindspore/ops/_op_impl/aicpu/arg_min.py +75 -0
  502. mindspore/ops/_op_impl/aicpu/argmax_with_value.py +43 -0
  503. mindspore/ops/_op_impl/aicpu/argmin_with_value.py +43 -0
  504. mindspore/ops/_op_impl/aicpu/asin.py +32 -0
  505. mindspore/ops/_op_impl/aicpu/asin_grad.py +33 -0
  506. mindspore/ops/_op_impl/aicpu/asinh.py +34 -0
  507. mindspore/ops/_op_impl/aicpu/asinh_grad.py +35 -0
  508. mindspore/ops/_op_impl/aicpu/atanh.py +34 -0
  509. mindspore/ops/_op_impl/aicpu/avgpool_grad_v1.py +37 -0
  510. mindspore/ops/_op_impl/aicpu/avgpool_v1.py +36 -0
  511. mindspore/ops/_op_impl/aicpu/bartlett_window.py +36 -0
  512. mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -0
  513. mindspore/ops/_op_impl/aicpu/batch_norm_grad_grad.py +49 -0
  514. mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
  515. mindspore/ops/_op_impl/aicpu/bessel_i0.py +31 -0
  516. mindspore/ops/_op_impl/aicpu/betainc.py +31 -0
  517. mindspore/ops/_op_impl/aicpu/bias_add.py +44 -0
  518. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +42 -0
  519. mindspore/ops/_op_impl/aicpu/bincount.py +33 -0
  520. mindspore/ops/_op_impl/aicpu/blackman_window.py +36 -0
  521. mindspore/ops/_op_impl/aicpu/broadcast_to.py +58 -0
  522. mindspore/ops/_op_impl/aicpu/bucketize.py +34 -0
  523. mindspore/ops/_op_impl/aicpu/cache_swap_table.py +102 -0
  524. mindspore/ops/_op_impl/aicpu/cast.py +225 -0
  525. mindspore/ops/_op_impl/aicpu/cauchy.py +33 -0
  526. mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
  527. mindspore/ops/_op_impl/aicpu/check_numerics.py +33 -0
  528. mindspore/ops/_op_impl/aicpu/cholesky.py +32 -0
  529. mindspore/ops/_op_impl/aicpu/cholesky_inverse.py +31 -0
  530. mindspore/ops/_op_impl/aicpu/cholesky_solve.py +33 -0
  531. mindspore/ops/_op_impl/aicpu/choleskygrad.py +32 -0
  532. mindspore/ops/_op_impl/aicpu/coalesce.py +37 -0
  533. mindspore/ops/_op_impl/aicpu/col2im.py +38 -0
  534. mindspore/ops/_op_impl/aicpu/combined_non_max_suppression.py +42 -0
  535. mindspore/ops/_op_impl/aicpu/compare_and_bitpack.py +37 -0
  536. mindspore/ops/_op_impl/aicpu/complex.py +32 -0
  537. mindspore/ops/_op_impl/aicpu/complex_abs.py +31 -0
  538. mindspore/ops/_op_impl/aicpu/compute_accidental_hits.py +44 -0
  539. mindspore/ops/_op_impl/aicpu/concat.py +57 -0
  540. mindspore/ops/_op_impl/aicpu/concat_offset.py +42 -0
  541. mindspore/ops/_op_impl/aicpu/concat_offset_v1.py +31 -0
  542. mindspore/ops/_op_impl/aicpu/conj.py +42 -0
  543. mindspore/ops/_op_impl/aicpu/conjugate_transpose.py +58 -0
  544. mindspore/ops/_op_impl/aicpu/cos.py +34 -0
  545. mindspore/ops/_op_impl/aicpu/cosh.py +34 -0
  546. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  547. mindspore/ops/_op_impl/aicpu/crop_and_resize.py +69 -0
  548. mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_boxes.py +68 -0
  549. mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_image.py +38 -0
  550. mindspore/ops/_op_impl/aicpu/cross.py +42 -0
  551. mindspore/ops/_op_impl/aicpu/csr_sparse_matrix_to_dense.py +48 -0
  552. mindspore/ops/_op_impl/aicpu/csr_sparse_matrix_to_sparse_tensor.py +51 -0
  553. mindspore/ops/_op_impl/aicpu/ctc_greedy_decoder.py +35 -0
  554. mindspore/ops/_op_impl/aicpu/ctc_loss_v2.py +43 -0
  555. mindspore/ops/_op_impl/aicpu/ctc_loss_v2_grad.py +45 -0
  556. mindspore/ops/_op_impl/aicpu/ctcloss.py +38 -0
  557. mindspore/ops/_op_impl/aicpu/cummax.py +41 -0
  558. mindspore/ops/_op_impl/aicpu/cumprod.py +58 -0
  559. mindspore/ops/_op_impl/aicpu/cumsum.py +58 -0
  560. mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +36 -0
  561. mindspore/ops/_op_impl/aicpu/data_format_vec_permute.py +32 -0
  562. mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
  563. mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +43 -0
  564. mindspore/ops/_op_impl/aicpu/dense_to_csr_sparse_matrix.py +49 -0
  565. mindspore/ops/_op_impl/aicpu/dense_to_dense_set_operation.py +45 -0
  566. mindspore/ops/_op_impl/aicpu/dense_to_sparse_set_operation.py +48 -0
  567. mindspore/ops/_op_impl/aicpu/depth_to_space.py +44 -0
  568. mindspore/ops/_op_impl/aicpu/diag.py +36 -0
  569. mindspore/ops/_op_impl/aicpu/diag_part.py +36 -0
  570. mindspore/ops/_op_impl/aicpu/diagonal.py +35 -0
  571. mindspore/ops/_op_impl/aicpu/digamma.py +31 -0
  572. mindspore/ops/_op_impl/aicpu/div.py +41 -0
  573. mindspore/ops/_op_impl/aicpu/div_no_nan.py +35 -0
  574. mindspore/ops/_op_impl/aicpu/dropout2d.py +42 -0
  575. mindspore/ops/_op_impl/aicpu/dropout3d.py +42 -0
  576. mindspore/ops/_op_impl/aicpu/dropout_genmask.py +41 -0
  577. mindspore/ops/_op_impl/aicpu/dropout_genmask_v3.py +32 -0
  578. mindspore/ops/_op_impl/aicpu/dynamic_stitch.py +42 -0
  579. mindspore/ops/_op_impl/aicpu/edit_distance.py +56 -0
  580. mindspore/ops/_op_impl/aicpu/eig.py +35 -0
  581. mindspore/ops/_op_impl/aicpu/embedding_lookup.py +102 -0
  582. mindspore/ops/_op_impl/aicpu/end_of_sequence.py +30 -0
  583. mindspore/ops/_op_impl/aicpu/environ_create.py +28 -0
  584. mindspore/ops/_op_impl/aicpu/environ_destroy_all.py +28 -0
  585. mindspore/ops/_op_impl/aicpu/environ_get.py +41 -0
  586. mindspore/ops/_op_impl/aicpu/environ_set.py +40 -0
  587. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  588. mindspore/ops/_op_impl/aicpu/equal.py +41 -0
  589. mindspore/ops/_op_impl/aicpu/exp.py +37 -0
  590. mindspore/ops/_op_impl/aicpu/expand.py +45 -0
  591. mindspore/ops/_op_impl/aicpu/expand_dims.py +42 -0
  592. mindspore/ops/_op_impl/aicpu/expm1.py +34 -0
  593. mindspore/ops/_op_impl/aicpu/extract_glimpse.py +35 -0
  594. mindspore/ops/_op_impl/aicpu/eye.py +44 -0
  595. mindspore/ops/_op_impl/aicpu/fft_with_size.py +47 -0
  596. mindspore/ops/_op_impl/aicpu/fill_diagonal.py +39 -0
  597. mindspore/ops/_op_impl/aicpu/fill_v2.py +58 -0
  598. mindspore/ops/_op_impl/aicpu/flatten.py +43 -0
  599. mindspore/ops/_op_impl/aicpu/floor_div.py +38 -0
  600. mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
  601. mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
  602. mindspore/ops/_op_impl/aicpu/fractional_avg_pool.py +41 -0
  603. mindspore/ops/_op_impl/aicpu/fractional_avg_pool_grad.py +41 -0
  604. mindspore/ops/_op_impl/aicpu/fractional_max_pool.py +41 -0
  605. mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_grad_with_fixed_ksize.py +43 -0
  606. mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +65 -0
  607. mindspore/ops/_op_impl/aicpu/fractional_max_pool_grad.py +42 -0
  608. mindspore/ops/_op_impl/aicpu/fractional_max_pool_grad_with_fixed_ksize.py +42 -0
  609. mindspore/ops/_op_impl/aicpu/fractional_max_pool_with_fixed_ksize.py +49 -0
  610. mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
  611. mindspore/ops/_op_impl/aicpu/fused_sparse_adam.py +46 -0
  612. mindspore/ops/_op_impl/aicpu/fused_sparse_ftrl.py +41 -0
  613. mindspore/ops/_op_impl/aicpu/fused_sparse_lazy_adam.py +46 -0
  614. mindspore/ops/_op_impl/aicpu/fused_sparse_proximal_adagrad.py +39 -0
  615. mindspore/ops/_op_impl/aicpu/gamma.py +38 -0
  616. mindspore/ops/_op_impl/aicpu/gather.py +46 -0
  617. mindspore/ops/_op_impl/aicpu/gather_d.py +79 -0
  618. mindspore/ops/_op_impl/aicpu/gather_d_grad_v2.py +79 -0
  619. mindspore/ops/_op_impl/aicpu/gather_grad.py +54 -0
  620. mindspore/ops/_op_impl/aicpu/gather_nd.py +56 -0
  621. mindspore/ops/_op_impl/aicpu/gcd.py +32 -0
  622. mindspore/ops/_op_impl/aicpu/generate_eod_mask.py +38 -0
  623. mindspore/ops/_op_impl/aicpu/geqrf.py +32 -0
  624. mindspore/ops/_op_impl/aicpu/get_next.py +39 -0
  625. mindspore/ops/_op_impl/aicpu/glu.py +33 -0
  626. mindspore/ops/_op_impl/aicpu/glu_grad.py +34 -0
  627. mindspore/ops/_op_impl/aicpu/greater.py +41 -0
  628. mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
  629. mindspore/ops/_op_impl/aicpu/grid_sampler_2d.py +35 -0
  630. mindspore/ops/_op_impl/aicpu/grid_sampler_2d_grad.py +38 -0
  631. mindspore/ops/_op_impl/aicpu/grid_sampler_3d.py +34 -0
  632. mindspore/ops/_op_impl/aicpu/grid_sampler_3d_grad.py +38 -0
  633. mindspore/ops/_op_impl/aicpu/hamming_window.py +57 -0
  634. mindspore/ops/_op_impl/aicpu/hard_sigmoid.py +32 -0
  635. mindspore/ops/_op_impl/aicpu/hard_sigmoid_grad.py +33 -0
  636. mindspore/ops/_op_impl/aicpu/heaviside.py +40 -0
  637. mindspore/ops/_op_impl/aicpu/histogram.py +35 -0
  638. mindspore/ops/_op_impl/aicpu/hsv_to_rgb.py +32 -0
  639. mindspore/ops/_op_impl/aicpu/hypot.py +32 -0
  640. mindspore/ops/_op_impl/aicpu/identity.py +42 -0
  641. mindspore/ops/_op_impl/aicpu/identity_n.py +41 -0
  642. mindspore/ops/_op_impl/aicpu/igamma.py +30 -0
  643. mindspore/ops/_op_impl/aicpu/igammac.py +30 -0
  644. mindspore/ops/_op_impl/aicpu/igammagrada.py +30 -0
  645. mindspore/ops/_op_impl/aicpu/im2col.py +43 -0
  646. mindspore/ops/_op_impl/aicpu/imag.py +31 -0
  647. mindspore/ops/_op_impl/aicpu/index_fill.py +54 -0
  648. mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
  649. mindspore/ops/_op_impl/aicpu/init_data_set_queue.py +27 -0
  650. mindspore/ops/_op_impl/aicpu/inplace_index_add.py +39 -0
  651. mindspore/ops/_op_impl/aicpu/instance_norm_v2.py +41 -0
  652. mindspore/ops/_op_impl/aicpu/instance_norm_v2_grad.py +44 -0
  653. mindspore/ops/_op_impl/aicpu/is_finite.py +40 -0
  654. mindspore/ops/_op_impl/aicpu/is_inf.py +31 -0
  655. mindspore/ops/_op_impl/aicpu/is_nan.py +31 -0
  656. mindspore/ops/_op_impl/aicpu/kldivloss.py +34 -0
  657. mindspore/ops/_op_impl/aicpu/kldivlossgrad.py +35 -0
  658. mindspore/ops/_op_impl/aicpu/layer_norm_grad_grad.py +47 -0
  659. mindspore/ops/_op_impl/aicpu/lcm.py +32 -0
  660. mindspore/ops/_op_impl/aicpu/left_shift.py +38 -0
  661. mindspore/ops/_op_impl/aicpu/less.py +41 -0
  662. mindspore/ops/_op_impl/aicpu/less_equal.py +41 -0
  663. mindspore/ops/_op_impl/aicpu/lgamma.py +33 -0
  664. mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +57 -0
  665. mindspore/ops/_op_impl/aicpu/linspace.py +33 -0
  666. mindspore/ops/_op_impl/aicpu/list_diff.py +50 -0
  667. mindspore/ops/_op_impl/aicpu/log.py +37 -0
  668. mindspore/ops/_op_impl/aicpu/log1p.py +34 -0
  669. mindspore/ops/_op_impl/aicpu/log_matrix_determinant.py +31 -0
  670. mindspore/ops/_op_impl/aicpu/log_normal_reverse.py +33 -0
  671. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +37 -0
  672. mindspore/ops/_op_impl/aicpu/logical_xor.py +30 -0
  673. mindspore/ops/_op_impl/aicpu/logit.py +33 -0
  674. mindspore/ops/_op_impl/aicpu/logit_grad.py +34 -0
  675. mindspore/ops/_op_impl/aicpu/logspace.py +36 -0
  676. mindspore/ops/_op_impl/aicpu/lower_bound.py +47 -0
  677. mindspore/ops/_op_impl/aicpu/lstsq.py +34 -0
  678. mindspore/ops/_op_impl/aicpu/lu.py +39 -0
  679. mindspore/ops/_op_impl/aicpu/lu_solve.py +32 -0
  680. mindspore/ops/_op_impl/aicpu/lu_unpack.py +114 -0
  681. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +49 -0
  682. mindspore/ops/_op_impl/aicpu/masked_fill.py +42 -0
  683. mindspore/ops/_op_impl/aicpu/masked_scatter.py +40 -0
  684. mindspore/ops/_op_impl/aicpu/masked_select.py +31 -0
  685. mindspore/ops/_op_impl/aicpu/masked_select_grad.py +35 -0
  686. mindspore/ops/_op_impl/aicpu/matmul.py +39 -0
  687. mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
  688. mindspore/ops/_op_impl/aicpu/matrix_determinant.py +30 -0
  689. mindspore/ops/_op_impl/aicpu/matrix_diag_part_v3.py +54 -0
  690. mindspore/ops/_op_impl/aicpu/matrix_diag_v3.py +56 -0
  691. mindspore/ops/_op_impl/aicpu/matrix_exp.py +34 -0
  692. mindspore/ops/_op_impl/aicpu/matrix_inverse.py +31 -0
  693. mindspore/ops/_op_impl/aicpu/matrix_logarithm.py +31 -0
  694. mindspore/ops/_op_impl/aicpu/matrix_power.py +37 -0
  695. mindspore/ops/_op_impl/aicpu/matrix_set_diag_v3.py +54 -0
  696. mindspore/ops/_op_impl/aicpu/matrix_solve.py +35 -0
  697. mindspore/ops/_op_impl/aicpu/matrix_solve_ls.py +36 -0
  698. mindspore/ops/_op_impl/aicpu/matrix_triangular_solve.py +36 -0
  699. mindspore/ops/_op_impl/aicpu/max_pool3d_grad_with_argmax.py +60 -0
  700. mindspore/ops/_op_impl/aicpu/max_pool3d_with_argmax.py +59 -0
  701. mindspore/ops/_op_impl/aicpu/max_unpool2d.py +57 -0
  702. mindspore/ops/_op_impl/aicpu/max_unpool2d_grad.py +58 -0
  703. mindspore/ops/_op_impl/aicpu/max_unpool3d.py +57 -0
  704. mindspore/ops/_op_impl/aicpu/max_unpool3d_grad.py +58 -0
  705. mindspore/ops/_op_impl/aicpu/maximum_grad_grad.py +40 -0
  706. mindspore/ops/_op_impl/aicpu/maxpool_grad_v1.py +46 -0
  707. mindspore/ops/_op_impl/aicpu/maxpool_v1.py +42 -0
  708. mindspore/ops/_op_impl/aicpu/median.py +39 -0
  709. mindspore/ops/_op_impl/aicpu/median_grad.py +45 -0
  710. mindspore/ops/_op_impl/aicpu/meshgrid.py +41 -0
  711. mindspore/ops/_op_impl/aicpu/minimum_grad_grad.py +40 -0
  712. mindspore/ops/_op_impl/aicpu/mirror_pad.py +50 -0
  713. mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +48 -0
  714. mindspore/ops/_op_impl/aicpu/mul.py +43 -0
  715. mindspore/ops/_op_impl/aicpu/mul_no_nan.py +42 -0
  716. mindspore/ops/_op_impl/aicpu/multi_margin_loss.py +37 -0
  717. mindspore/ops/_op_impl/aicpu/multi_margin_loss_grad.py +41 -0
  718. mindspore/ops/_op_impl/aicpu/multilabel_margin_loss_grad.py +37 -0
  719. mindspore/ops/_op_impl/aicpu/multinomial.py +47 -0
  720. mindspore/ops/_op_impl/aicpu/multinomial_with_replacement.py +35 -0
  721. mindspore/ops/_op_impl/aicpu/mvlgamma.py +32 -0
  722. mindspore/ops/_op_impl/aicpu/mvlgamma_grad.py +33 -0
  723. mindspore/ops/_op_impl/aicpu/nan_to_num.py +34 -0
  724. mindspore/ops/_op_impl/aicpu/neg.py +36 -0
  725. mindspore/ops/_op_impl/aicpu/nextafter.py +32 -0
  726. mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
  727. mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
  728. mindspore/ops/_op_impl/aicpu/no_repeat_ngram.py +34 -0
  729. mindspore/ops/_op_impl/aicpu/non_deterministic_ints.py +33 -0
  730. mindspore/ops/_op_impl/aicpu/non_max_suppression.py +36 -0
  731. mindspore/ops/_op_impl/aicpu/non_max_suppression_with_overlaps.py +35 -0
  732. mindspore/ops/_op_impl/aicpu/non_zero.py +43 -0
  733. mindspore/ops/_op_impl/aicpu/not_equal.py +39 -0
  734. mindspore/ops/_op_impl/aicpu/nth_element.py +39 -0
  735. mindspore/ops/_op_impl/aicpu/nuclear_norm.py +33 -0
  736. mindspore/ops/_op_impl/aicpu/one_hot.py +116 -0
  737. mindspore/ops/_op_impl/aicpu/ones_like.py +39 -0
  738. mindspore/ops/_op_impl/aicpu/orgqr.py +34 -0
  739. mindspore/ops/_op_impl/aicpu/pad_and_shift.py +33 -0
  740. mindspore/ops/_op_impl/aicpu/pad_v3.py +61 -0
  741. mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +59 -0
  742. mindspore/ops/_op_impl/aicpu/padding.py +41 -0
  743. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +54 -0
  744. mindspore/ops/_op_impl/aicpu/pdist_grad.py +33 -0
  745. mindspore/ops/_op_impl/aicpu/poisson.py +37 -0
  746. mindspore/ops/_op_impl/aicpu/polar.py +32 -0
  747. mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
  748. mindspore/ops/_op_impl/aicpu/pow.py +39 -0
  749. mindspore/ops/_op_impl/aicpu/print_tensor.py +39 -0
  750. mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +113 -0
  751. mindspore/ops/_op_impl/aicpu/qr.py +36 -0
  752. mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
  753. mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
  754. mindspore/ops/_op_impl/aicpu/ragged_range.py +49 -0
  755. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
  756. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_tensor.py +74 -0
  757. mindspore/ops/_op_impl/aicpu/random_categorical.py +68 -0
  758. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +36 -0
  759. mindspore/ops/_op_impl/aicpu/random_gamma.py +38 -0
  760. mindspore/ops/_op_impl/aicpu/random_poisson.py +134 -0
  761. mindspore/ops/_op_impl/aicpu/random_shuffle.py +47 -0
  762. mindspore/ops/_op_impl/aicpu/randperm.py +38 -0
  763. mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
  764. mindspore/ops/_op_impl/aicpu/range.py +36 -0
  765. mindspore/ops/_op_impl/aicpu/range_v2.py +35 -0
  766. mindspore/ops/_op_impl/aicpu/real.py +31 -0
  767. mindspore/ops/_op_impl/aicpu/real_div.py +40 -0
  768. mindspore/ops/_op_impl/aicpu/reciprocal.py +34 -0
  769. mindspore/ops/_op_impl/aicpu/reciprocal_grad.py +35 -0
  770. mindspore/ops/_op_impl/aicpu/reduce_mean.py +57 -0
  771. mindspore/ops/_op_impl/aicpu/reduce_prod.py +57 -0
  772. mindspore/ops/_op_impl/aicpu/reduce_sum.py +57 -0
  773. mindspore/ops/_op_impl/aicpu/relu_grad_v3.py +41 -0
  774. mindspore/ops/_op_impl/aicpu/relu_v3.py +38 -0
  775. mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +96 -0
  776. mindspore/ops/_op_impl/aicpu/reshape.py +42 -0
  777. mindspore/ops/_op_impl/aicpu/resize_area.py +40 -0
  778. mindspore/ops/_op_impl/aicpu/resize_bicubic.py +20 -0
  779. mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +19 -0
  780. mindspore/ops/_op_impl/aicpu/resize_bilinear.py +32 -0
  781. mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +32 -0
  782. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +36 -0
  783. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +35 -0
  784. mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
  785. mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
  786. mindspore/ops/_op_impl/aicpu/reverse_sequence.py +55 -0
  787. mindspore/ops/_op_impl/aicpu/reversev2.py +54 -0
  788. mindspore/ops/_op_impl/aicpu/rgb_to_hsv.py +32 -0
  789. mindspore/ops/_op_impl/aicpu/right_shift.py +38 -0
  790. mindspore/ops/_op_impl/aicpu/rnnt_loss.py +35 -0
  791. mindspore/ops/_op_impl/aicpu/round.py +34 -0
  792. mindspore/ops/_op_impl/aicpu/rsqrt.py +33 -0
  793. mindspore/ops/_op_impl/aicpu/rsqrt_grad.py +36 -0
  794. mindspore/ops/_op_impl/aicpu/sample_distorted_bounding_box_v2.py +49 -0
  795. mindspore/ops/_op_impl/aicpu/scale_and_translate.py +52 -0
  796. mindspore/ops/_op_impl/aicpu/scale_and_translate_grad.py +36 -0
  797. mindspore/ops/_op_impl/aicpu/scatter.py +79 -0
  798. mindspore/ops/_op_impl/aicpu/scatter_add_with_axis.py +53 -0
  799. mindspore/ops/_op_impl/aicpu/scatter_elements.py +39 -0
  800. mindspore/ops/_op_impl/aicpu/scatter_nd.py +59 -0
  801. mindspore/ops/_op_impl/aicpu/scatter_nd_max.py +54 -0
  802. mindspore/ops/_op_impl/aicpu/scatter_nd_min.py +54 -0
  803. mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +59 -0
  804. mindspore/ops/_op_impl/aicpu/search_sorted.py +44 -0
  805. mindspore/ops/_op_impl/aicpu/segment_max.py +52 -0
  806. mindspore/ops/_op_impl/aicpu/segment_mean.py +56 -0
  807. mindspore/ops/_op_impl/aicpu/segment_min.py +52 -0
  808. mindspore/ops/_op_impl/aicpu/segment_prod.py +56 -0
  809. mindspore/ops/_op_impl/aicpu/segment_sum.py +56 -0
  810. mindspore/ops/_op_impl/aicpu/select.py +45 -0
  811. mindspore/ops/_op_impl/aicpu/self_adjoint_eig.py +34 -0
  812. mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
  813. mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
  814. mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
  815. mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
  816. mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
  817. mindspore/ops/_op_impl/aicpu/set_size.py +38 -0
  818. mindspore/ops/_op_impl/aicpu/sign.py +36 -0
  819. mindspore/ops/_op_impl/aicpu/sin.py +34 -0
  820. mindspore/ops/_op_impl/aicpu/sinc.py +43 -0
  821. mindspore/ops/_op_impl/aicpu/sinh.py +34 -0
  822. mindspore/ops/_op_impl/aicpu/slice.py +59 -0
  823. mindspore/ops/_op_impl/aicpu/slice_grad.py +76 -0
  824. mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
  825. mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
  826. mindspore/ops/_op_impl/aicpu/sort.py +39 -0
  827. mindspore/ops/_op_impl/aicpu/space_to_depth.py +44 -0
  828. mindspore/ops/_op_impl/aicpu/sparse_addmm.py +87 -0
  829. mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +80 -0
  830. mindspore/ops/_op_impl/aicpu/sparse_apply_centered_rms_prop.py +105 -0
  831. mindspore/ops/_op_impl/aicpu/sparse_apply_momentum.py +80 -0
  832. mindspore/ops/_op_impl/aicpu/sparse_apply_proximal_gradient_descent.py +79 -0
  833. mindspore/ops/_op_impl/aicpu/sparse_concat.py +59 -0
  834. mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
  835. mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_add.py +58 -0
  836. mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_div.py +58 -0
  837. mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_mul.py +58 -0
  838. mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows.py +63 -0
  839. mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows_grad.py +45 -0
  840. mindspore/ops/_op_impl/aicpu/sparse_matrix_mat_mul.py +56 -0
  841. mindspore/ops/_op_impl/aicpu/sparse_matrix_nnz.py +81 -0
  842. mindspore/ops/_op_impl/aicpu/sparse_matrix_transpose.py +116 -0
  843. mindspore/ops/_op_impl/aicpu/sparse_reorder.py +56 -0
  844. mindspore/ops/_op_impl/aicpu/sparse_reshape.py +34 -0
  845. mindspore/ops/_op_impl/aicpu/sparse_segment_mean_grad.py +36 -0
  846. mindspore/ops/_op_impl/aicpu/sparse_segment_mean_with_num_segments.py +44 -0
  847. mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n.py +43 -0
  848. mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n_grad.py +38 -0
  849. mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n_with_num_segments.py +44 -0
  850. mindspore/ops/_op_impl/aicpu/sparse_segment_sum.py +49 -0
  851. mindspore/ops/_op_impl/aicpu/sparse_segment_sum_with_num_segments.py +68 -0
  852. mindspore/ops/_op_impl/aicpu/sparse_slice.py +63 -0
  853. mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +61 -0
  854. mindspore/ops/_op_impl/aicpu/sparse_softmax.py +33 -0
  855. mindspore/ops/_op_impl/aicpu/sparse_softmax_cross_entropy_with_logits_v2.py +35 -0
  856. mindspore/ops/_op_impl/aicpu/sparse_sparse_maximum.py +53 -0
  857. mindspore/ops/_op_impl/aicpu/sparse_sparse_minimum.py +53 -0
  858. mindspore/ops/_op_impl/aicpu/sparse_tensor_dense_add.py +84 -0
  859. mindspore/ops/_op_impl/aicpu/sparse_tensor_dense_mat_mul.py +190 -0
  860. mindspore/ops/_op_impl/aicpu/sparse_tensor_to_csr_sparse_matrix.py +51 -0
  861. mindspore/ops/_op_impl/aicpu/sparse_to_dense_v2.py +73 -0
  862. mindspore/ops/_op_impl/aicpu/split.py +45 -0
  863. mindspore/ops/_op_impl/aicpu/sqrt.py +34 -0
  864. mindspore/ops/_op_impl/aicpu/sqrt_grad.py +35 -0
  865. mindspore/ops/_op_impl/aicpu/square.py +35 -0
  866. mindspore/ops/_op_impl/aicpu/squared_difference.py +37 -0
  867. mindspore/ops/_op_impl/aicpu/squeeze.py +42 -0
  868. mindspore/ops/_op_impl/aicpu/sspaddmm.py +97 -0
  869. mindspore/ops/_op_impl/aicpu/stack.py +45 -0
  870. mindspore/ops/_op_impl/aicpu/stack_push_pop.py +87 -0
  871. mindspore/ops/_op_impl/aicpu/standard_laplace.py +34 -0
  872. mindspore/ops/_op_impl/aicpu/standard_normal.py +34 -0
  873. mindspore/ops/_op_impl/aicpu/stateless_dropout_genmask.py +37 -0
  874. mindspore/ops/_op_impl/aicpu/stft.py +70 -0
  875. mindspore/ops/_op_impl/aicpu/strided_slice.py +43 -0
  876. mindspore/ops/_op_impl/aicpu/strided_slice_grad.py +50 -0
  877. mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +93 -0
  878. mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +66 -0
  879. mindspore/ops/_op_impl/aicpu/sub.py +41 -0
  880. mindspore/ops/_op_impl/aicpu/sub_and_filter.py +36 -0
  881. mindspore/ops/_op_impl/aicpu/tan.py +34 -0
  882. mindspore/ops/_op_impl/aicpu/tanh.py +34 -0
  883. mindspore/ops/_op_impl/aicpu/tanh_grad.py +35 -0
  884. mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
  885. mindspore/ops/_op_impl/aicpu/tile.py +56 -0
  886. mindspore/ops/_op_impl/aicpu/topk.py +34 -0
  887. mindspore/ops/_op_impl/aicpu/trace.py +40 -0
  888. mindspore/ops/_op_impl/aicpu/tracegrad.py +41 -0
  889. mindspore/ops/_op_impl/aicpu/trans_data.py +35 -0
  890. mindspore/ops/_op_impl/aicpu/transpose.py +58 -0
  891. mindspore/ops/_op_impl/aicpu/tridiagonal_matmul.py +42 -0
  892. mindspore/ops/_op_impl/aicpu/tridiagonal_solve.py +35 -0
  893. mindspore/ops/_op_impl/aicpu/tril.py +42 -0
  894. mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
  895. mindspore/ops/_op_impl/aicpu/triplet_margin_loss.py +62 -0
  896. mindspore/ops/_op_impl/aicpu/triu.py +43 -0
  897. mindspore/ops/_op_impl/aicpu/triu_indices.py +34 -0
  898. mindspore/ops/_op_impl/aicpu/truncated_normal.py +39 -0
  899. mindspore/ops/_op_impl/aicpu/uniform.py +36 -0
  900. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +41 -0
  901. mindspore/ops/_op_impl/aicpu/uniform_int.py +36 -0
  902. mindspore/ops/_op_impl/aicpu/uniform_real.py +33 -0
  903. mindspore/ops/_op_impl/aicpu/unique.py +31 -0
  904. mindspore/ops/_op_impl/aicpu/unique_consecutive.py +47 -0
  905. mindspore/ops/_op_impl/aicpu/unique_with_pad.py +32 -0
  906. mindspore/ops/_op_impl/aicpu/unravel_index.py +32 -0
  907. mindspore/ops/_op_impl/aicpu/unsorted_segment_prod.py +53 -0
  908. mindspore/ops/_op_impl/aicpu/unsorted_segment_sum.py +57 -0
  909. mindspore/ops/_op_impl/aicpu/unstack.py +45 -0
  910. mindspore/ops/_op_impl/aicpu/update_cache.py +44 -0
  911. mindspore/ops/_op_impl/aicpu/upper_bound.py +47 -0
  912. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +42 -0
  913. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +49 -0
  914. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +40 -0
  915. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +50 -0
  916. mindspore/ops/_op_impl/aicpu/xdivy.py +35 -0
  917. mindspore/ops/_op_impl/aicpu/xlogy.py +33 -0
  918. mindspore/ops/_op_impl/aicpu/zeros_like.py +42 -0
  919. mindspore/ops/_op_impl/aicpu/zeta.py +31 -0
  920. mindspore/ops/_op_impl/akg/__init__.py +19 -0
  921. mindspore/ops/_op_impl/akg/ascend/__init__.py +48 -0
  922. mindspore/ops/_op_impl/akg/ascend/abs.py +35 -0
  923. mindspore/ops/_op_impl/akg/ascend/add.py +42 -0
  924. mindspore/ops/_op_impl/akg/ascend/add_n.py +37 -0
  925. mindspore/ops/_op_impl/akg/ascend/batchmatmul.py +33 -0
  926. mindspore/ops/_op_impl/akg/ascend/cast.py +46 -0
  927. mindspore/ops/_op_impl/akg/ascend/equal.py +35 -0
  928. mindspore/ops/_op_impl/akg/ascend/exp.py +35 -0
  929. mindspore/ops/_op_impl/akg/ascend/expand_dims.py +33 -0
  930. mindspore/ops/_op_impl/akg/ascend/greater.py +34 -0
  931. mindspore/ops/_op_impl/akg/ascend/greater_equal.py +35 -0
  932. mindspore/ops/_op_impl/akg/ascend/less.py +31 -0
  933. mindspore/ops/_op_impl/akg/ascend/less_equal.py +35 -0
  934. mindspore/ops/_op_impl/akg/ascend/load_im2col.py +33 -0
  935. mindspore/ops/_op_impl/akg/ascend/log.py +34 -0
  936. mindspore/ops/_op_impl/akg/ascend/maximum.py +36 -0
  937. mindspore/ops/_op_impl/akg/ascend/minimum.py +39 -0
  938. mindspore/ops/_op_impl/akg/ascend/mul.py +41 -0
  939. mindspore/ops/_op_impl/akg/ascend/neg.py +37 -0
  940. mindspore/ops/_op_impl/akg/ascend/pow.py +35 -0
  941. mindspore/ops/_op_impl/akg/ascend/prod_force_se_a.py +33 -0
  942. mindspore/ops/_op_impl/akg/ascend/real_div.py +36 -0
  943. mindspore/ops/_op_impl/akg/ascend/reciprocal.py +32 -0
  944. mindspore/ops/_op_impl/akg/ascend/reduce_max.py +32 -0
  945. mindspore/ops/_op_impl/akg/ascend/reduce_min.py +32 -0
  946. mindspore/ops/_op_impl/akg/ascend/reduce_sum.py +37 -0
  947. mindspore/ops/_op_impl/akg/ascend/rsqrt.py +35 -0
  948. mindspore/ops/_op_impl/akg/ascend/select.py +37 -0
  949. mindspore/ops/_op_impl/akg/ascend/sqrt.py +35 -0
  950. mindspore/ops/_op_impl/akg/ascend/square.py +35 -0
  951. mindspore/ops/_op_impl/akg/ascend/sub.py +42 -0
  952. mindspore/ops/_op_impl/akg/cpu/__init__.py +23 -0
  953. mindspore/ops/_op_impl/akg/cpu/coo2csr.py +29 -0
  954. mindspore/ops/_op_impl/akg/cpu/csr2coo.py +29 -0
  955. mindspore/ops/_op_impl/akg/cpu/csr_gather.py +33 -0
  956. mindspore/ops/_op_impl/akg/cpu/csr_mm.py +34 -0
  957. mindspore/ops/_op_impl/akg/cpu/csr_mul.py +33 -0
  958. mindspore/ops/_op_impl/akg/cpu/csr_mv.py +33 -0
  959. mindspore/ops/_op_impl/akg/cpu/csr_reduce_sum.py +31 -0
  960. mindspore/ops/_op_impl/akg/gpu/__init__.py +24 -0
  961. mindspore/ops/_op_impl/akg/gpu/coo2csr.py +29 -0
  962. mindspore/ops/_op_impl/akg/gpu/csr2coo.py +29 -0
  963. mindspore/ops/_op_impl/akg/gpu/csr_div.py +36 -0
  964. mindspore/ops/_op_impl/akg/gpu/csr_gather.py +33 -0
  965. mindspore/ops/_op_impl/akg/gpu/csr_mm.py +37 -0
  966. mindspore/ops/_op_impl/akg/gpu/csr_mul.py +36 -0
  967. mindspore/ops/_op_impl/akg/gpu/csr_mv.py +36 -0
  968. mindspore/ops/_op_impl/akg/gpu/csr_reduce_sum.py +33 -0
  969. mindspore/ops/_op_impl/cpu/__init__.py +78 -0
  970. mindspore/ops/_op_impl/cpu/adam.py +49 -0
  971. mindspore/ops/_op_impl/cpu/adam_weight_decay.py +47 -0
  972. mindspore/ops/_op_impl/cpu/arg_max.py +30 -0
  973. mindspore/ops/_op_impl/cpu/arg_max_with_value.py +31 -0
  974. mindspore/ops/_op_impl/cpu/arg_min_with_value.py +31 -0
  975. mindspore/ops/_op_impl/cpu/buffer_append.py +28 -0
  976. mindspore/ops/_op_impl/cpu/buffer_get.py +28 -0
  977. mindspore/ops/_op_impl/cpu/buffer_sample.py +28 -0
  978. mindspore/ops/_op_impl/cpu/cast.py +171 -0
  979. mindspore/ops/_op_impl/cpu/concat_offset.py +38 -0
  980. mindspore/ops/_op_impl/cpu/conv2d.py +30 -0
  981. mindspore/ops/_op_impl/cpu/conv3d.py +30 -0
  982. mindspore/ops/_op_impl/cpu/div.py +32 -0
  983. mindspore/ops/_op_impl/cpu/dropout.py +31 -0
  984. mindspore/ops/_op_impl/cpu/dropout_grad.py +30 -0
  985. mindspore/ops/_op_impl/cpu/dynamic_shape.py +42 -0
  986. mindspore/ops/_op_impl/cpu/dynamic_stitch.py +41 -0
  987. mindspore/ops/_op_impl/cpu/equal_count.py +30 -0
  988. mindspore/ops/_op_impl/cpu/gather_d.py +49 -0
  989. mindspore/ops/_op_impl/cpu/gather_d_grad.py +38 -0
  990. mindspore/ops/_op_impl/cpu/gather_d_grad_v2.py +40 -0
  991. mindspore/ops/_op_impl/cpu/gather_v2.py +40 -0
  992. mindspore/ops/_op_impl/cpu/hsigmoid.py +33 -0
  993. mindspore/ops/_op_impl/cpu/hsigmoid_grad.py +34 -0
  994. mindspore/ops/_op_impl/cpu/hswish.py +32 -0
  995. mindspore/ops/_op_impl/cpu/hswish_grad.py +33 -0
  996. mindspore/ops/_op_impl/cpu/identity_n.py +40 -0
  997. mindspore/ops/_op_impl/cpu/is_finite.py +39 -0
  998. mindspore/ops/_op_impl/cpu/l2loss.py +30 -0
  999. mindspore/ops/_op_impl/cpu/layer_norm.py +36 -0
  1000. mindspore/ops/_op_impl/cpu/layer_norm_grad.py +38 -0
  1001. mindspore/ops/_op_impl/cpu/maximum.py +35 -0
  1002. mindspore/ops/_op_impl/cpu/maximum_grad.py +47 -0
  1003. mindspore/ops/_op_impl/cpu/minimum.py +40 -0
  1004. mindspore/ops/_op_impl/cpu/minimum_grad.py +51 -0
  1005. mindspore/ops/_op_impl/cpu/mirror_pad.py +36 -0
  1006. mindspore/ops/_op_impl/cpu/mirror_pad_grad.py +36 -0
  1007. mindspore/ops/_op_impl/cpu/mul.py +32 -0
  1008. mindspore/ops/_op_impl/cpu/one_hot.py +31 -0
  1009. mindspore/ops/_op_impl/cpu/pad.py +32 -0
  1010. mindspore/ops/_op_impl/cpu/pow.py +32 -0
  1011. mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +42 -0
  1012. mindspore/ops/_op_impl/cpu/pyexecute.py +29 -0
  1013. mindspore/ops/_op_impl/cpu/pyfunc.py +29 -0
  1014. mindspore/ops/_op_impl/cpu/range.py +34 -0
  1015. mindspore/ops/_op_impl/cpu/real_div.py +33 -0
  1016. mindspore/ops/_op_impl/cpu/reduce_all.py +29 -0
  1017. mindspore/ops/_op_impl/cpu/reduce_any.py +29 -0
  1018. mindspore/ops/_op_impl/cpu/reduce_max.py +32 -0
  1019. mindspore/ops/_op_impl/cpu/reduce_mean.py +40 -0
  1020. mindspore/ops/_op_impl/cpu/reduce_min.py +32 -0
  1021. mindspore/ops/_op_impl/cpu/reduce_prod.py +40 -0
  1022. mindspore/ops/_op_impl/cpu/reduce_std.py +31 -0
  1023. mindspore/ops/_op_impl/cpu/reduce_sum.py +41 -0
  1024. mindspore/ops/_op_impl/cpu/space_to_batch_nd.py +38 -0
  1025. mindspore/ops/_op_impl/cpu/sparse_slice.py +62 -0
  1026. mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +60 -0
  1027. mindspore/ops/_op_impl/cpu/split.py +34 -0
  1028. mindspore/ops/_op_impl/cpu/sspaddmm.py +95 -0
  1029. mindspore/ops/_op_impl/cpu/stack.py +38 -0
  1030. mindspore/ops/_op_impl/cpu/sub.py +32 -0
  1031. mindspore/ops/_op_impl/cpu/tensor_copy_slices.py +41 -0
  1032. mindspore/ops/_op_impl/cpu/tile.py +37 -0
  1033. mindspore/ops/_op_impl/cpu/top_k.py +31 -0
  1034. mindspore/ops/_op_impl/cpu/transpose.py +39 -0
  1035. mindspore/ops/_primitive_cache.py +90 -0
  1036. mindspore/ops/_register_for_op.py +73 -0
  1037. mindspore/ops/_utils/__init__.py +20 -0
  1038. mindspore/ops/_utils/utils.py +147 -0
  1039. mindspore/ops/_vmap/__init__.py +25 -0
  1040. mindspore/ops/_vmap/vmap_array_ops.py +2151 -0
  1041. mindspore/ops/_vmap/vmap_base.py +533 -0
  1042. mindspore/ops/_vmap/vmap_convolution_ops.py +441 -0
  1043. mindspore/ops/_vmap/vmap_debug_ops.py +50 -0
  1044. mindspore/ops/_vmap/vmap_grad_math_ops.py +274 -0
  1045. mindspore/ops/_vmap/vmap_grad_nn_ops.py +806 -0
  1046. mindspore/ops/_vmap/vmap_image_ops.py +194 -0
  1047. mindspore/ops/_vmap/vmap_math_ops.py +977 -0
  1048. mindspore/ops/_vmap/vmap_nn_ops.py +2209 -0
  1049. mindspore/ops/_vmap/vmap_other_ops.py +105 -0
  1050. mindspore/ops/_vmap/vmap_random_ops.py +122 -0
  1051. mindspore/ops/_vmap/vmap_sparse_ops.py +89 -0
  1052. mindspore/ops/auto_generate/__init__.py +31 -0
  1053. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +231 -0
  1054. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +250 -0
  1055. mindspore/ops/auto_generate/gen_arg_handler.py +197 -0
  1056. mindspore/ops/auto_generate/gen_extend_func.py +980 -0
  1057. mindspore/ops/auto_generate/gen_ops_def.py +6443 -0
  1058. mindspore/ops/auto_generate/gen_ops_prim.py +13167 -0
  1059. mindspore/ops/auto_generate/pyboost_inner_prim.py +429 -0
  1060. mindspore/ops/composite/__init__.py +71 -0
  1061. mindspore/ops/composite/base.py +1281 -0
  1062. mindspore/ops/composite/env_ops.py +41 -0
  1063. mindspore/ops/composite/math_ops.py +125 -0
  1064. mindspore/ops/composite/multitype_ops/__init__.py +77 -0
  1065. mindspore/ops/composite/multitype_ops/_compile_utils.py +1458 -0
  1066. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +897 -0
  1067. mindspore/ops/composite/multitype_ops/add_impl.py +606 -0
  1068. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +56 -0
  1069. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +56 -0
  1070. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +56 -0
  1071. mindspore/ops/composite/multitype_ops/div_impl.py +189 -0
  1072. mindspore/ops/composite/multitype_ops/equal_impl.py +335 -0
  1073. mindspore/ops/composite/multitype_ops/floordiv_impl.py +88 -0
  1074. mindspore/ops/composite/multitype_ops/getitem_impl.py +400 -0
  1075. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +109 -0
  1076. mindspore/ops/composite/multitype_ops/greater_impl.py +110 -0
  1077. mindspore/ops/composite/multitype_ops/in_impl.py +196 -0
  1078. mindspore/ops/composite/multitype_ops/left_shift_impl.py +37 -0
  1079. mindspore/ops/composite/multitype_ops/less_equal_impl.py +111 -0
  1080. mindspore/ops/composite/multitype_ops/less_impl.py +112 -0
  1081. mindspore/ops/composite/multitype_ops/logic_not_impl.py +113 -0
  1082. mindspore/ops/composite/multitype_ops/logical_and_impl.py +60 -0
  1083. mindspore/ops/composite/multitype_ops/logical_or_impl.py +61 -0
  1084. mindspore/ops/composite/multitype_ops/mod_impl.py +86 -0
  1085. mindspore/ops/composite/multitype_ops/mul_impl.py +294 -0
  1086. mindspore/ops/composite/multitype_ops/negative_impl.py +79 -0
  1087. mindspore/ops/composite/multitype_ops/not_equal_impl.py +290 -0
  1088. mindspore/ops/composite/multitype_ops/not_in_impl.py +196 -0
  1089. mindspore/ops/composite/multitype_ops/ones_like_impl.py +96 -0
  1090. mindspore/ops/composite/multitype_ops/pow_impl.py +87 -0
  1091. mindspore/ops/composite/multitype_ops/right_shift_impl.py +37 -0
  1092. mindspore/ops/composite/multitype_ops/setitem_impl.py +884 -0
  1093. mindspore/ops/composite/multitype_ops/sub_impl.py +116 -0
  1094. mindspore/ops/composite/multitype_ops/uadd_impl.py +29 -0
  1095. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +228 -0
  1096. mindspore/ops/deprecated.py +315 -0
  1097. mindspore/ops/extend/__init__.py +53 -0
  1098. mindspore/ops/extend/array_func.py +218 -0
  1099. mindspore/ops/extend/math_func.py +76 -0
  1100. mindspore/ops/extend/nn_func.py +308 -0
  1101. mindspore/ops/function/__init__.py +760 -0
  1102. mindspore/ops/function/array_func.py +6889 -0
  1103. mindspore/ops/function/clip_func.py +384 -0
  1104. mindspore/ops/function/debug_func.py +69 -0
  1105. mindspore/ops/function/fft_func.py +31 -0
  1106. mindspore/ops/function/grad/__init__.py +34 -0
  1107. mindspore/ops/function/grad/grad_func.py +1424 -0
  1108. mindspore/ops/function/image_func.py +292 -0
  1109. mindspore/ops/function/linalg_func.py +416 -0
  1110. mindspore/ops/function/math_func.py +11877 -0
  1111. mindspore/ops/function/nn_func.py +8175 -0
  1112. mindspore/ops/function/other_func.py +114 -0
  1113. mindspore/ops/function/parameter_func.py +134 -0
  1114. mindspore/ops/function/random_func.py +1539 -0
  1115. mindspore/ops/function/reshard_func.py +102 -0
  1116. mindspore/ops/function/sparse_func.py +884 -0
  1117. mindspore/ops/function/sparse_unary_func.py +2422 -0
  1118. mindspore/ops/function/spectral_func.py +150 -0
  1119. mindspore/ops/function/vmap_func.py +116 -0
  1120. mindspore/ops/functional.py +454 -0
  1121. mindspore/ops/op_info_register.py +1572 -0
  1122. mindspore/ops/operations/__init__.py +717 -0
  1123. mindspore/ops/operations/_csr_ops.py +403 -0
  1124. mindspore/ops/operations/_custom_grad.py +181 -0
  1125. mindspore/ops/operations/_embedding_cache_ops.py +307 -0
  1126. mindspore/ops/operations/_grad_ops.py +3052 -0
  1127. mindspore/ops/operations/_infer_ops.py +19 -0
  1128. mindspore/ops/operations/_inner_ops.py +2567 -0
  1129. mindspore/ops/operations/_map_tensor_ops.py +112 -0
  1130. mindspore/ops/operations/_ms_kernel.py +601 -0
  1131. mindspore/ops/operations/_ocr_ops.py +379 -0
  1132. mindspore/ops/operations/_opaque_predicate_registry.py +41 -0
  1133. mindspore/ops/operations/_pyfunc_registry.py +58 -0
  1134. mindspore/ops/operations/_quant_ops.py +1844 -0
  1135. mindspore/ops/operations/_rl_inner_ops.py +1231 -0
  1136. mindspore/ops/operations/_scalar_ops.py +106 -0
  1137. mindspore/ops/operations/_sequence_ops.py +1155 -0
  1138. mindspore/ops/operations/_sparse_grad_ops.py +56 -0
  1139. mindspore/ops/operations/_tensor_array.py +359 -0
  1140. mindspore/ops/operations/_thor_ops.py +807 -0
  1141. mindspore/ops/operations/array_ops.py +6258 -0
  1142. mindspore/ops/operations/comm_ops.py +1996 -0
  1143. mindspore/ops/operations/control_ops.py +127 -0
  1144. mindspore/ops/operations/custom_ops.py +1065 -0
  1145. mindspore/ops/operations/debug_ops.py +646 -0
  1146. mindspore/ops/operations/image_ops.py +1041 -0
  1147. mindspore/ops/operations/inner_ops.py +697 -0
  1148. mindspore/ops/operations/linalg_ops.py +95 -0
  1149. mindspore/ops/operations/manually_defined/__init__.py +24 -0
  1150. mindspore/ops/operations/manually_defined/_inner.py +61 -0
  1151. mindspore/ops/operations/manually_defined/ops_def.py +2016 -0
  1152. mindspore/ops/operations/math_ops.py +5306 -0
  1153. mindspore/ops/operations/nn_ops.py +9669 -0
  1154. mindspore/ops/operations/other_ops.py +871 -0
  1155. mindspore/ops/operations/random_ops.py +1243 -0
  1156. mindspore/ops/operations/reshard_ops.py +53 -0
  1157. mindspore/ops/operations/rl_ops.py +288 -0
  1158. mindspore/ops/operations/sparse_ops.py +2753 -0
  1159. mindspore/ops/operations/spectral_ops.py +111 -0
  1160. mindspore/ops/primitive.py +1034 -0
  1161. mindspore/ops/signature.py +54 -0
  1162. mindspore/ops/silent_check.py +162 -0
  1163. mindspore/ops/vm_impl_registry.py +91 -0
  1164. mindspore/ops_generate/__init__.py +27 -0
  1165. mindspore/ops_generate/arg_dtype_cast.py +250 -0
  1166. mindspore/ops_generate/arg_handler.py +197 -0
  1167. mindspore/ops_generate/gen_aclnn_implement.py +263 -0
  1168. mindspore/ops_generate/gen_ops.py +1084 -0
  1169. mindspore/ops_generate/gen_ops_inner_prim.py +131 -0
  1170. mindspore/ops_generate/gen_pyboost_func.py +968 -0
  1171. mindspore/ops_generate/gen_utils.py +209 -0
  1172. mindspore/ops_generate/op_proto.py +138 -0
  1173. mindspore/ops_generate/pyboost_utils.py +354 -0
  1174. mindspore/ops_generate/template.py +239 -0
  1175. mindspore/parallel/__init__.py +28 -0
  1176. mindspore/parallel/_auto_parallel_context.py +1466 -0
  1177. mindspore/parallel/_cell_wrapper.py +91 -0
  1178. mindspore/parallel/_cost_model_context.py +700 -0
  1179. mindspore/parallel/_dp_allreduce_fusion.py +159 -0
  1180. mindspore/parallel/_offload_context.py +275 -0
  1181. mindspore/parallel/_parallel_serialization.py +533 -0
  1182. mindspore/parallel/_ps_context.py +242 -0
  1183. mindspore/parallel/_recovery_context.py +110 -0
  1184. mindspore/parallel/_tensor.py +660 -0
  1185. mindspore/parallel/_transformer/__init__.py +35 -0
  1186. mindspore/parallel/_transformer/layers.py +765 -0
  1187. mindspore/parallel/_transformer/loss.py +251 -0
  1188. mindspore/parallel/_transformer/moe.py +693 -0
  1189. mindspore/parallel/_transformer/op_parallel_config.py +222 -0
  1190. mindspore/parallel/_transformer/transformer.py +3119 -0
  1191. mindspore/parallel/_utils.py +600 -0
  1192. mindspore/parallel/algo_parameter_config.py +400 -0
  1193. mindspore/parallel/checkpoint_transform.py +643 -0
  1194. mindspore/parallel/cluster/__init__.py +15 -0
  1195. mindspore/parallel/cluster/process_entity/__init__.py +18 -0
  1196. mindspore/parallel/cluster/process_entity/_api.py +344 -0
  1197. mindspore/parallel/cluster/process_entity/_utils.py +126 -0
  1198. mindspore/parallel/cluster/run.py +136 -0
  1199. mindspore/parallel/mpi/__init__.py +14 -0
  1200. mindspore/parallel/mpi/_mpi_config.py +116 -0
  1201. mindspore/parallel/parameter_broadcast.py +152 -0
  1202. mindspore/parallel/shard.py +350 -0
  1203. mindspore/perf_msvcbuildinsights.dll +0 -0
  1204. mindspore/pgodb140.dll +0 -0
  1205. mindspore/pgort140.dll +0 -0
  1206. mindspore/profiler/__init__.py +27 -0
  1207. mindspore/profiler/common/__init__.py +14 -0
  1208. mindspore/profiler/common/exceptions/__init__.py +14 -0
  1209. mindspore/profiler/common/exceptions/error_code.py +83 -0
  1210. mindspore/profiler/common/exceptions/exceptions.py +286 -0
  1211. mindspore/profiler/common/process_pool.py +41 -0
  1212. mindspore/profiler/common/singleton.py +28 -0
  1213. mindspore/profiler/common/struct_type.py +118 -0
  1214. mindspore/profiler/common/util.py +444 -0
  1215. mindspore/profiler/common/validator/__init__.py +14 -0
  1216. mindspore/profiler/common/validator/validate_path.py +84 -0
  1217. mindspore/profiler/envprofiling.py +256 -0
  1218. mindspore/profiler/parser/__init__.py +14 -0
  1219. mindspore/profiler/parser/aicpu_data_parser.py +272 -0
  1220. mindspore/profiler/parser/ascend_analysis/__init__.py +14 -0
  1221. mindspore/profiler/parser/ascend_analysis/constant.py +53 -0
  1222. mindspore/profiler/parser/ascend_analysis/file_manager.py +159 -0
  1223. mindspore/profiler/parser/ascend_analysis/function_event.py +161 -0
  1224. mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +131 -0
  1225. mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +85 -0
  1226. mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +57 -0
  1227. mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +116 -0
  1228. mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +86 -0
  1229. mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +68 -0
  1230. mindspore/profiler/parser/ascend_cluster_generator.py +116 -0
  1231. mindspore/profiler/parser/ascend_communicate_generator.py +314 -0
  1232. mindspore/profiler/parser/ascend_flops_generator.py +116 -0
  1233. mindspore/profiler/parser/ascend_fpbp_generator.py +82 -0
  1234. mindspore/profiler/parser/ascend_hccl_generator.py +271 -0
  1235. mindspore/profiler/parser/ascend_integrate_generator.py +42 -0
  1236. mindspore/profiler/parser/ascend_memory_generator.py +185 -0
  1237. mindspore/profiler/parser/ascend_msprof_exporter.py +281 -0
  1238. mindspore/profiler/parser/ascend_msprof_generator.py +187 -0
  1239. mindspore/profiler/parser/ascend_op_generator.py +334 -0
  1240. mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
  1241. mindspore/profiler/parser/ascend_timeline_generator.py +543 -0
  1242. mindspore/profiler/parser/base_timeline_generator.py +489 -0
  1243. mindspore/profiler/parser/container.py +229 -0
  1244. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +684 -0
  1245. mindspore/profiler/parser/flops_parser.py +531 -0
  1246. mindspore/profiler/parser/framework_enum.py +111 -0
  1247. mindspore/profiler/parser/framework_parser.py +854 -0
  1248. mindspore/profiler/parser/framework_struct.py +61 -0
  1249. mindspore/profiler/parser/hccl_parser.py +573 -0
  1250. mindspore/profiler/parser/hwts_log_parser.py +122 -0
  1251. mindspore/profiler/parser/integrator.py +526 -0
  1252. mindspore/profiler/parser/memory_usage_parser.py +431 -0
  1253. mindspore/profiler/parser/minddata_analyzer.py +800 -0
  1254. mindspore/profiler/parser/minddata_parser.py +186 -0
  1255. mindspore/profiler/parser/minddata_pipeline_parser.py +299 -0
  1256. mindspore/profiler/parser/msadvisor_analyzer.py +82 -0
  1257. mindspore/profiler/parser/msadvisor_parser.py +240 -0
  1258. mindspore/profiler/parser/op_intermediate_parser.py +149 -0
  1259. mindspore/profiler/parser/optime_parser.py +250 -0
  1260. mindspore/profiler/parser/profiler_info.py +141 -0
  1261. mindspore/profiler/parser/step_trace_parser.py +666 -0
  1262. mindspore/profiler/profiling.py +2054 -0
  1263. mindspore/rewrite/__init__.py +29 -0
  1264. mindspore/rewrite/api/__init__.py +17 -0
  1265. mindspore/rewrite/api/node.py +519 -0
  1266. mindspore/rewrite/api/node_type.py +53 -0
  1267. mindspore/rewrite/api/pattern_engine.py +490 -0
  1268. mindspore/rewrite/api/scoped_value.py +181 -0
  1269. mindspore/rewrite/api/symbol_tree.py +497 -0
  1270. mindspore/rewrite/ast_helpers/__init__.py +25 -0
  1271. mindspore/rewrite/ast_helpers/ast_converter.py +143 -0
  1272. mindspore/rewrite/ast_helpers/ast_finder.py +404 -0
  1273. mindspore/rewrite/ast_helpers/ast_flattener.py +268 -0
  1274. mindspore/rewrite/ast_helpers/ast_modifier.py +605 -0
  1275. mindspore/rewrite/ast_helpers/ast_replacer.py +79 -0
  1276. mindspore/rewrite/common/__init__.py +19 -0
  1277. mindspore/rewrite/common/config.py +24 -0
  1278. mindspore/rewrite/common/error_log.py +39 -0
  1279. mindspore/rewrite/common/event.py +28 -0
  1280. mindspore/rewrite/common/namer.py +271 -0
  1281. mindspore/rewrite/common/namespace.py +118 -0
  1282. mindspore/rewrite/common/observable.py +44 -0
  1283. mindspore/rewrite/common/observer.py +54 -0
  1284. mindspore/rewrite/node/__init__.py +22 -0
  1285. mindspore/rewrite/node/call_function.py +95 -0
  1286. mindspore/rewrite/node/cell_container.py +139 -0
  1287. mindspore/rewrite/node/control_flow.py +113 -0
  1288. mindspore/rewrite/node/node.py +1428 -0
  1289. mindspore/rewrite/node/node_manager.py +283 -0
  1290. mindspore/rewrite/node/node_topological_manager.py +223 -0
  1291. mindspore/rewrite/parsers/__init__.py +29 -0
  1292. mindspore/rewrite/parsers/arguments_parser.py +63 -0
  1293. mindspore/rewrite/parsers/assign_parser.py +852 -0
  1294. mindspore/rewrite/parsers/attribute_parser.py +57 -0
  1295. mindspore/rewrite/parsers/class_def_parser.py +289 -0
  1296. mindspore/rewrite/parsers/constant_parser.py +104 -0
  1297. mindspore/rewrite/parsers/container_parser.py +88 -0
  1298. mindspore/rewrite/parsers/expr_parser.py +55 -0
  1299. mindspore/rewrite/parsers/for_parser.py +61 -0
  1300. mindspore/rewrite/parsers/function_def_parser.py +84 -0
  1301. mindspore/rewrite/parsers/if_parser.py +85 -0
  1302. mindspore/rewrite/parsers/module_parser.py +117 -0
  1303. mindspore/rewrite/parsers/parser.py +43 -0
  1304. mindspore/rewrite/parsers/parser_register.py +86 -0
  1305. mindspore/rewrite/parsers/return_parser.py +37 -0
  1306. mindspore/rewrite/parsers/while_parser.py +59 -0
  1307. mindspore/rewrite/sparsify/__init__.py +0 -0
  1308. mindspore/rewrite/sparsify/sparse_transformer.py +457 -0
  1309. mindspore/rewrite/sparsify/sparsify.py +112 -0
  1310. mindspore/rewrite/sparsify/utils.py +179 -0
  1311. mindspore/rewrite/symbol_tree/__init__.py +20 -0
  1312. mindspore/rewrite/symbol_tree/symbol_tree.py +1819 -0
  1313. mindspore/rewrite/symbol_tree/symbol_tree_builder.py +76 -0
  1314. mindspore/rewrite/symbol_tree/symbol_tree_dumper.py +142 -0
  1315. mindspore/run_check/__init__.py +20 -0
  1316. mindspore/run_check/_check_version.py +574 -0
  1317. mindspore/run_check/run_check.py +66 -0
  1318. mindspore/safeguard/__init__.py +18 -0
  1319. mindspore/safeguard/rewrite_obfuscation.py +531 -0
  1320. mindspore/swresample-4.dll +0 -0
  1321. mindspore/swscale-6.dll +0 -0
  1322. mindspore/tbbmalloc.dll +0 -0
  1323. mindspore/tinyxml2.dll +0 -0
  1324. mindspore/train/__init__.py +47 -0
  1325. mindspore/train/_utils.py +439 -0
  1326. mindspore/train/amp.py +817 -0
  1327. mindspore/train/anf_ir_pb2.py +1517 -0
  1328. mindspore/train/callback/__init__.py +44 -0
  1329. mindspore/train/callback/_backup_and_restore.py +117 -0
  1330. mindspore/train/callback/_callback.py +613 -0
  1331. mindspore/train/callback/_checkpoint.py +751 -0
  1332. mindspore/train/callback/_cluster_monitor.py +201 -0
  1333. mindspore/train/callback/_dataset_graph.py +150 -0
  1334. mindspore/train/callback/_early_stop.py +239 -0
  1335. mindspore/train/callback/_flops_collector.py +238 -0
  1336. mindspore/train/callback/_history.py +92 -0
  1337. mindspore/train/callback/_lambda_callback.py +80 -0
  1338. mindspore/train/callback/_landscape.py +1049 -0
  1339. mindspore/train/callback/_loss_monitor.py +107 -0
  1340. mindspore/train/callback/_lr_scheduler_callback.py +76 -0
  1341. mindspore/train/callback/_mindio_ttp.py +443 -0
  1342. mindspore/train/callback/_on_request_exit.py +195 -0
  1343. mindspore/train/callback/_reduce_lr_on_plateau.py +226 -0
  1344. mindspore/train/callback/_summary_collector.py +1184 -0
  1345. mindspore/train/callback/_time_monitor.py +141 -0
  1346. mindspore/train/checkpoint_pb2.py +233 -0
  1347. mindspore/train/data_sink.py +219 -0
  1348. mindspore/train/dataset_helper.py +688 -0
  1349. mindspore/train/lineage_pb2.py +1260 -0
  1350. mindspore/train/loss_scale_manager.py +213 -0
  1351. mindspore/train/memory_profiling_pb2.py +298 -0
  1352. mindspore/train/metrics/__init__.py +175 -0
  1353. mindspore/train/metrics/accuracy.py +133 -0
  1354. mindspore/train/metrics/auc.py +129 -0
  1355. mindspore/train/metrics/bleu_score.py +170 -0
  1356. mindspore/train/metrics/confusion_matrix.py +700 -0
  1357. mindspore/train/metrics/cosine_similarity.py +109 -0
  1358. mindspore/train/metrics/dice.py +116 -0
  1359. mindspore/train/metrics/error.py +175 -0
  1360. mindspore/train/metrics/fbeta.py +167 -0
  1361. mindspore/train/metrics/hausdorff_distance.py +333 -0
  1362. mindspore/train/metrics/loss.py +97 -0
  1363. mindspore/train/metrics/mean_surface_distance.py +189 -0
  1364. mindspore/train/metrics/metric.py +373 -0
  1365. mindspore/train/metrics/occlusion_sensitivity.py +225 -0
  1366. mindspore/train/metrics/perplexity.py +133 -0
  1367. mindspore/train/metrics/precision.py +160 -0
  1368. mindspore/train/metrics/recall.py +159 -0
  1369. mindspore/train/metrics/roc.py +223 -0
  1370. mindspore/train/metrics/root_mean_square_surface_distance.py +191 -0
  1371. mindspore/train/metrics/topk.py +167 -0
  1372. mindspore/train/mind_ir_pb2.py +1903 -0
  1373. mindspore/train/model.py +2176 -0
  1374. mindspore/train/node_strategy_pb2.py +653 -0
  1375. mindspore/train/print_pb2.py +184 -0
  1376. mindspore/train/profiling_parallel_pb2.py +151 -0
  1377. mindspore/train/serialization.py +3101 -0
  1378. mindspore/train/summary/__init__.py +23 -0
  1379. mindspore/train/summary/_lineage_adapter.py +41 -0
  1380. mindspore/train/summary/_summary_adapter.py +496 -0
  1381. mindspore/train/summary/_writer_pool.py +207 -0
  1382. mindspore/train/summary/enums.py +56 -0
  1383. mindspore/train/summary/summary_record.py +581 -0
  1384. mindspore/train/summary/writer.py +167 -0
  1385. mindspore/train/summary_pb2.py +1165 -0
  1386. mindspore/train/train_thor/__init__.py +20 -0
  1387. mindspore/train/train_thor/convert_utils.py +268 -0
  1388. mindspore/train/train_thor/dataset_helper.py +192 -0
  1389. mindspore/train/train_thor/model_thor.py +257 -0
  1390. mindspore/turbojpeg.dll +0 -0
  1391. mindspore/vcmeta.dll +0 -0
  1392. mindspore/vcomp140.dll +0 -0
  1393. mindspore/vcruntime140.dll +0 -0
  1394. mindspore/vcruntime140_1.dll +0 -0
  1395. mindspore/version.py +1 -0
  1396. mindspore-2.3.0.dist-info/METADATA +351 -0
  1397. mindspore-2.3.0.dist-info/RECORD +1400 -0
  1398. mindspore-2.3.0.dist-info/WHEEL +5 -0
  1399. mindspore-2.3.0.dist-info/entry_points.txt +4 -0
  1400. mindspore-2.3.0.dist-info/top_level.txt +1 -0
mindspore/nn/cell.py ADDED
@@ -0,0 +1,2605 @@
1
+ # Copyright 2020-2024 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+ """cell"""
16
+ from __future__ import absolute_import
17
+
18
+ import gc
19
+ import inspect
20
+ import os
21
+ import time
22
+ from collections import OrderedDict
23
+ import numpy
24
+
25
+ from mindspore._checkparam import args_type_check, check_hook_fn
26
+ from mindspore.common._auto_dynamic import is_auto_dynamic, convert_inputs_to_dynamic
27
+ from mindspore import log as logger
28
+ from mindspore.common.parameter import PARAMETER_NAME_DEFAULT
29
+ from mindspore.common.hook_handle import HookHandle
30
+ from mindspore.context import ParallelMode
31
+ from mindspore import context
32
+ from mindspore._c_expression import init_pipeline, update_func_graph_hyper_params, Cell_, FuncGraph, MixedPrecisionType
33
+ from mindspore import _checkparam as Validator
34
+ from mindspore.common import dtype as mstype
35
+ from mindspore.common.api import _cell_graph_executor, _pynative_executor, _get_args_for_run, cells_compile_cache
36
+ from mindspore.common.api import _generate_branch_control_input, _convert_python_data, _get_args_for_run_predict
37
+ from mindspore.common.api import _process_dyn_args, _generate_dyn_compile_args
38
+ from mindspore.common.parameter import Parameter, ParameterTuple
39
+ from mindspore.common.tensor import Tensor
40
+ from mindspore.ops.operations import Cast
41
+ from mindspore.ops.primitive import Primitive
42
+ from mindspore.ops.operations import _inner_ops as inner
43
+ from mindspore.parallel.shard import Shard
44
+ from mindspore._check_jit_forbidden_api import jit_forbidden_register
45
+ from mindspore.common._decorator import deprecated
46
+ from mindspore.common._register_for_recompute import recompute_registry
47
+
48
+
49
+ class Cell(Cell_):
50
+ """
51
+ The basic building block of neural networks in MindSpore. The model or neural network layer should inherit this
52
+ base class.
53
+
54
+ Layers in `mindspore.nn` are also the subclass of Cell, such as :class:`mindspore.nn.Conv2d`,
55
+ and :class:`mindspore.nn.ReLU`, etc. Cell will be compiled into a calculation
56
+ graph in GRAPH_MODE (static graph mode) and used as the basic module of neural networks in
57
+ PYNATIVE_MODE (dynamic graph mode).
58
+
59
+ .. note::
60
+ Cell is the inference mode by default. For a class that inherits a Cell,
61
+ if the training and inference have different structures, the subclass performs the inference branch by default.
62
+ To set the training mode, refer to `mindspore.nn.Cell.set_train` .
63
+
64
+ .. warning::
65
+ In the subclass of Cell, it's not allowed to define a method named 'cast' and not allowed to define an attribute
66
+ named 'phase' or 'cells', otherwise, an error will be raised.
67
+
68
+ Args:
69
+ auto_prefix (bool, optional): Whether to automatically generate NameSpace for Cell and its child cells. It also
70
+ affects the names of parameters in the `Cell`. If set to ``True`` , the parameter name will be
71
+ automatically prefixed, otherwise not. In general, the backbone network should be set to
72
+ ``True`` , otherwise the duplicate name problem will appear. The cell to train the backbone
73
+ network, such as optimizer and :class:`mindspore.nn.TrainOneStepCell`, should be set to
74
+ ``False`` , otherwise the parameter name in backbone will be changed by mistake.
75
+ Default: ``True`` .
76
+ flags (dict, optional): Network configuration information, currently it is used for the binding of network
77
+ and dataset. Users can also customize network attributes by this parameter. Default: ``None`` .
78
+
79
+ Supported Platforms:
80
+ ``Ascend`` ``GPU`` ``CPU``
81
+
82
+ Examples:
83
+ >>> import mindspore.nn as nn
84
+ >>> from mindspore import ops
85
+ >>> class MyCell(nn.Cell):
86
+ ... def __init__(self, forward_net):
87
+ ... super(MyCell, self).__init__(auto_prefix=False)
88
+ ... self.net = forward_net
89
+ ... self.relu = ops.ReLU()
90
+ ...
91
+ ... def construct(self, x):
92
+ ... y = self.net(x)
93
+ ... return self.relu(y)
94
+ >>>
95
+ >>> inner_net = nn.Conv2d(120, 240, 4, has_bias=False, weight_init='normal')
96
+ >>> my_net = MyCell(inner_net)
97
+ >>> print(my_net.trainable_params())
98
+ ... # If the 'auto_prefix' set to True or not set when call the '__init__' method of the parent class,
99
+ ... # the parameter's name will be 'net.weight'.
100
+ [Parameter (name=weight, shape=(240, 120, 4, 4), dtype=Float32, requires_grad=True)]
101
+ """
102
+
103
+ IGNORE_LIST = ['_scope', '_cell_init_args', '_auto_prefix', '_cells', '_params', '_create_time',
104
+ '_func_graph_flags', '_parameter_layout_dict', '_params_list', '_phase',
105
+ '_forward_pre_hook', '_forward_hook', '_enable_forward_pre_hook', '_enable_forward_hook',
106
+ '_bprop_debug', '_enable_backward_hook', '_cell_backward_hook', '_is_run', '_param_prefix',
107
+ '_attr_synced', 'pynative', 'requires_grad', 'cell_type']
108
+ total_instance_count = 0
109
+
110
+ def __init__(self, auto_prefix=True, flags=None):
111
+ Cell_.__init__(self, self._cell_tag)
112
+ Cell.total_instance_count += 1
113
+ self.instance_count = Cell.total_instance_count
114
+ self._params = OrderedDict()
115
+ self._cells = OrderedDict()
116
+ self._params_list = OrderedDict()
117
+ self._primitives = OrderedDict()
118
+ self.training = False
119
+ self.requires_grad = False
120
+ self.pynative = False
121
+ self._attr_synced = False
122
+ self._param_prefix = ''
123
+ self._auto_prefix = auto_prefix
124
+ self._scope = None
125
+ self._phase = 'train'
126
+ self._parameter_layout_dict = {}
127
+ self._parallel_parameter_name_list = ()
128
+ self._parallel_parameter_merge_net_dict = {}
129
+ self._create_time = int(time.time() * 1e9)
130
+ self.arguments_key = ""
131
+ self.compile_cache = set()
132
+ self.phase_cache = dict()
133
+ cells_compile_cache[id(self)] = self.compile_cache
134
+ self.parameter_broadcast_done = False
135
+ self._id = 1
136
+ self.exist_names = set("")
137
+ self.exist_objs = set()
138
+ self.recompute_cell = None
139
+ self.sig = inspect.signature(self.construct)
140
+ init_pipeline()
141
+
142
+ # call gc to release GE session resources used by non-used cell objects
143
+ if os.getenv('GC_COLLECT_IN_CELL') == '1':
144
+ gc.collect()
145
+
146
+ if flags:
147
+ self.add_flags(**flags)
148
+ self._bprop_debug = False
149
+ self._forward_pre_hook = OrderedDict()
150
+ self._forward_hook = OrderedDict()
151
+ self._enable_forward_pre_hook = False
152
+ self._enable_forward_hook = False
153
+ self._enable_backward_hook = False
154
+ self._cell_backward_hook = None
155
+ self._is_recursion_hook = False
156
+ self.cell_type = None
157
+ self.cast = Cast()
158
+ self._has_config_recompute = False
159
+ self._user_parameters = []
160
+ self._dynamic_shape_inputs = None
161
+ self._compile_args = None
162
+ self.saved_dynamic_shape = None
163
+ self._jit_config_dict = dict()
164
+ self.grad_ops_label = False
165
+ self.ge_sync_data = False
166
+ self._is_check_and_refresh = False
167
+ self._amp_level = ""
168
+ self._init_flag = False
169
+
170
+ def __getstate__(self):
171
+ base = Cell_.__getstate__(self)
172
+ return base, self.__dict__
173
+
174
+ def __setstate__(self, state):
175
+ base, dict_ = state
176
+ Cell_.__setstate__(self, base)
177
+ self.__dict__ = dict_
178
+ self._attr_synced = False
179
+
180
+ def __bool__(self):
181
+ return True
182
+
183
+ @property
184
+ def _cell_tag(self):
185
+ # `<class 'xxxxxxx'>` to `xxxxxxx`
186
+ return str(self.__class__)[8:-2]
187
+
188
+ @property
189
+ def create_time(self):
190
+ return self._create_time
191
+
192
+ @property
193
+ def cell_init_args(self):
194
+ return self._cell_init_args
195
+
196
+ @property
197
+ def param_prefix(self):
198
+ """
199
+ Param prefix is the prefix of current cell's direct child parameter.
200
+
201
+ Examples:
202
+ >>> import mindspore as ms
203
+ >>> from mindspore import Tensor, nn
204
+ ...
205
+ >>> class Net(nn.Cell):
206
+ ... def __init__(self):
207
+ ... super(Net, self).__init__()
208
+ ... self.dense = nn.Dense(2, 2)
209
+ ...
210
+ ... def construct(self, x):
211
+ ... x = self.dense(x)
212
+ ... return x
213
+ >>> net = Net()
214
+ >>> net.update_cell_prefix()
215
+ >>> print(net.dense.param_prefix)
216
+ dense
217
+ """
218
+ return self._param_prefix
219
+
220
+ @property
221
+ def bprop_debug(self):
222
+ """
223
+ Get whether cell custom bprop debug is enabled.
224
+
225
+ Tutorial Examples:
226
+ - `Cell and Parameter - Custom Cell Reverse
227
+ <https://mindspore.cn/tutorials/en/master/advanced/modules/layer.html#custom-cell-reverse>`_
228
+ """
229
+ return self._bprop_debug
230
+
231
+ @bprop_debug.setter
232
+ def bprop_debug(self, value):
233
+ """
234
+ Set whether to enable cell custom bprop debug.
235
+
236
+ Note:
237
+ When bprop is defined in cell, the bprop function will be executed
238
+ in python interpreter when bprop debug is true, and will be parsed
239
+ and add to graph when bprop debug is false.
240
+
241
+ Args:
242
+ value (bool): Specifies whether to enable bprop debug. Default: ``False``.
243
+ """
244
+ if not isinstance(value, bool):
245
+ raise TypeError(f"For 'Cell', the property 'bprop_debug' must be bool type, but got type {type(value)}.")
246
+ self._bprop_debug = value
247
+
248
+ def update_cell_prefix(self):
249
+ """
250
+ Update the `param_prefix` of all child cells.
251
+
252
+ After being invoked, it can get all the cell's children's name prefix by '_param_prefix'.
253
+ """
254
+ cells_name = self.cells_and_names()
255
+
256
+ for cell_name, cell in cells_name:
257
+ cell._param_prefix = cell_name
258
+
259
+ def update_cell_type(self, cell_type):
260
+ """
261
+ The current cell type is updated when a quantization aware training network is encountered.
262
+
263
+ After being invoked, it can set the cell type to 'cell_type'.
264
+
265
+ Args:
266
+ cell_type(str): The type of cell to be updated, cell_type can be "quant" or "second-order".
267
+ """
268
+ self.cell_type = cell_type
269
+
270
+ @cell_init_args.setter
271
+ def cell_init_args(self, value):
272
+ if not isinstance(value, str):
273
+ raise TypeError(f"For 'Cell', the property 'cell_init_args' must be string type, "
274
+ f"but got type {type(value)}.")
275
+ self._cell_init_args = value
276
+
277
+ @property
278
+ def phase(self):
279
+ return self._phase
280
+
281
+ @phase.setter
282
+ def phase(self, value):
283
+ if not isinstance(value, str):
284
+ raise TypeError(f"For 'Cell', the property 'phase' must be string type, but got type {type(value)}.")
285
+ self._phase = value
286
+
287
+ @property
288
+ def parameter_layout_dict(self):
289
+ """
290
+ `parameter_layout_dict` represents the tensor layout of a parameter, which is inferred by shard strategy and
291
+ distributed operator information.
292
+ """
293
+ return self._parameter_layout_dict
294
+
295
+ @property
296
+ def cls_name(self):
297
+ return self.__class__.__name__
298
+
299
+ @parameter_layout_dict.setter
300
+ def parameter_layout_dict(self, value):
301
+ if not isinstance(value, dict):
302
+ raise TypeError(f"For 'Cell', the property 'parameter_layout_dict' must be dict type, "
303
+ f"but got type {type(value)}.")
304
+ self._parameter_layout_dict = value
305
+
306
+ @property
307
+ def parallel_parameter_name_list(self):
308
+ return self._parallel_parameter_name_list
309
+
310
+ @parallel_parameter_name_list.setter
311
+ def parallel_parameter_name_list(self, value):
312
+ if not isinstance(value, list):
313
+ raise TypeError(f"For 'Cell', the property 'parallel_parameter_name_list' must be list type, "
314
+ f"but got type {type(value)}.")
315
+ self._parallel_parameter_name_list = value
316
+
317
+ @property
318
+ def pipeline_stage(self):
319
+ """
320
+ `pipeline_stage` represents the pipeline stage of current Cell.
321
+ """
322
+ return self._pipeline_stage
323
+
324
+ @pipeline_stage.setter
325
+ def pipeline_stage(self, value):
326
+ """
327
+ Set the `pipeline_stage` of a Cell.
328
+
329
+ Args:
330
+ value (int): The pipeline stage of a parameter.
331
+
332
+ Raises:
333
+ TypeError: If `value` is not int type or is a bool type.
334
+ ValueError: If `value` is not a positive integer.
335
+ """
336
+ if not isinstance(value, int) or isinstance(value, bool):
337
+ raise TypeError("For 'Cell', the property 'pipeline_stage' "
338
+ "must be int type, but got type : {}".format(type(value)))
339
+
340
+ if value < 0:
341
+ raise ValueError("For 'Cell', the property 'pipeline_stage' "
342
+ "can not be less than 0, but got {}".format(value))
343
+ self._pipeline_stage = value
344
+ for item in self.trainable_params():
345
+ item.add_pipeline_stage(value)
346
+
347
+ @property
348
+ def pipeline_segment(self):
349
+ return self._pipeline_segment
350
+
351
+ @pipeline_segment.setter
352
+ def pipeline_segment(self, value):
353
+ if not isinstance(value, int) or isinstance(value, bool):
354
+ raise TypeError("For 'context.set_auto_parallel_context', the argument 'pipeline_stages' "
355
+ "must be int type, but got type : {}".format(type(value)))
356
+
357
+ if value < 0:
358
+ raise ValueError("For 'context.set_auto_parallel_context', the argument 'pipeline_stages' "
359
+ "can not be less than 0, but got {}".format(value))
360
+ self._pipeline_segment = value
361
+
362
+ @property
363
+ def parallel_parameter_merge_net_dict(self):
364
+ return self._parallel_parameter_merge_net_dict
365
+
366
+ @parallel_parameter_merge_net_dict.setter
367
+ def parallel_parameter_merge_net_dict(self, value):
368
+ if not isinstance(value, dict):
369
+ raise TypeError(f"For 'Cell', the property 'parallel_parameter_merge_net_dict' must be dict type, "
370
+ f"but got type {type(value)}.")
371
+ self._parallel_parameter_merge_net_dict = value
372
+
373
+ @property
374
+ def jit_config_dict(self):
375
+ return self._jit_config_dict
376
+
377
+ def get_func_graph_proto(self):
378
+ """Return graph binary proto."""
379
+ exec_id = ".".join([self.phase, str(self.create_time), str(id(self))])
380
+ return _cell_graph_executor._get_func_graph_proto(self, exec_id, "anf_ir", True)
381
+
382
+ def __getattr__(self, name):
383
+ if '_params' in self.__dict__:
384
+ params = self.__dict__['_params']
385
+ if name in params:
386
+ return params[name]
387
+ if '_cells' in self.__dict__:
388
+ cells = self.__dict__['_cells']
389
+ if name in cells:
390
+ return cells[name]
391
+ if '_params_list' in self.__dict__:
392
+ params_list = self.__dict__['_params_list']
393
+ if name in params_list:
394
+ return params_list[name]
395
+ raise AttributeError("The '{}' object has no attribute '{}'.".format(type(self).__name__, name))
396
+
397
+ def __del__(self):
398
+ if isinstance(cells_compile_cache, dict):
399
+ # while deepcopy a cell instance, the copied cell instance can't be added to cells_compile_cache
400
+ # here using pop(id(self), None) to avoid KeyError exception
401
+ cells_compile_cache.pop(id(self), None)
402
+ if hasattr(self, "compile_cache") and self.compile_cache:
403
+ _cell_graph_executor.del_net_res(self, self.compile_cache)
404
+ if isinstance(self, GraphCell):
405
+ _cell_graph_executor.dec_graph_cell_count()
406
+ Cell.total_instance_count -= 1
407
+
408
+ def __delattr__(self, name):
409
+ if name in self._params:
410
+ del self._params[name]
411
+ elif name in self._cells:
412
+ del self._cells[name]
413
+ elif '_params_list' in self.__dict__ and name in self._params_list:
414
+ del self._params_list[name]
415
+ else:
416
+ object.__delattr__(self, name)
417
+ self._attr_synced = False
418
+
419
+ def _cast_mixed_precision_inputs(self, inputs, dst_type):
420
+ """Cast input for mixed precision"""
421
+ res = list()
422
+ for item in inputs:
423
+ if isinstance(item, tuple):
424
+ res.append(self._cast_mixed_precision_inputs(item, dst_type))
425
+ elif isinstance(item, float):
426
+ res.append(self.cast(item, dst_type))
427
+ elif hasattr(item, "dtype") and item.dtype in \
428
+ {mstype.float16, mstype.float32, mstype.float64, mstype.bfloat16} and item.dtype != dst_type:
429
+ res.append(self.cast(item, dst_type))
430
+ else:
431
+ res.append(item)
432
+ return tuple(res)
433
+
434
+ def cast_inputs(self, inputs, dst_type):
435
+ """
436
+ Cast inputs to specified type.
437
+
438
+ Args:
439
+ inputs (tuple[Tensor]): The cell inputs.
440
+ dst_type (mindspore.dtype): The specified data type.
441
+
442
+ returns:
443
+ tuple[Tensor], the result with destination data type.
444
+ """
445
+ res = list()
446
+ for item in inputs:
447
+ if isinstance(item, tuple):
448
+ res.append(self.cast_inputs(item, dst_type))
449
+ else:
450
+ res.append(self.cast(item, dst_type))
451
+ return tuple(res)
452
+
453
+ def _do_parameter_broadcast(self):
454
+ if context.get_auto_parallel_context("parallel_mode") == ParallelMode.DATA_PARALLEL:
455
+ if not self.parameter_broadcast_done:
456
+ _pynative_executor.parameter_broadcast(self, self.phase)
457
+ self.parameter_broadcast_done = True
458
+
459
+ def run_construct(self, cast_inputs, kwargs):
460
+ """
461
+ Run the construct function.
462
+
463
+ Note:
464
+ This function will be removed in a future version. It is not recommended to call this function.
465
+
466
+ Args:
467
+ cast_inputs (tuple): The input objects of Cell.
468
+ kwargs (dict): Provide keyword arguments.
469
+
470
+ Returns:
471
+ output, the output object of Cell.
472
+ """
473
+ logger.warning(f"The 'run_construct' function of '{self.cls_name}' will be removed in a future version. "
474
+ f"Calling this function is not recommended.")
475
+ output = self._run_construct(cast_inputs, kwargs)
476
+ return output
477
+
478
+ def _run_construct(self, cast_inputs, kwargs):
479
+ """Run the construct function"""
480
+ if self._enable_forward_pre_hook:
481
+ cast_inputs = self._run_forward_pre_hook(cast_inputs)
482
+ if self._enable_backward_hook:
483
+ output = self._backward_hook_construct(*cast_inputs, **kwargs)
484
+ elif hasattr(self, "_shard_fn"):
485
+ output = self._shard_fn(*cast_inputs, **kwargs)
486
+ else:
487
+ if self.recompute_cell is not None:
488
+ output = self.recompute_cell(*cast_inputs, **kwargs)
489
+ else:
490
+ output = self.construct(*cast_inputs, **kwargs)
491
+ if self._enable_forward_hook:
492
+ output = self._run_forward_hook(cast_inputs, output)
493
+ return output
494
+
495
+ def _check_construct_args(self, *args):
496
+ """Check the args needed by the function construct"""
497
+ positional_args = 0
498
+ default_args = 0
499
+ has_var = False
500
+ for value in inspect.signature(self.construct).parameters.values():
501
+ if value.kind is inspect.Parameter.VAR_POSITIONAL or value.kind is inspect.Parameter.VAR_KEYWORD:
502
+ has_var = True
503
+ if value.kind is inspect.Parameter.POSITIONAL_OR_KEYWORD:
504
+ if value.default is inspect.Parameter.empty:
505
+ positional_args += 1
506
+ else:
507
+ default_args += 1
508
+
509
+ if has_var:
510
+ return
511
+
512
+ if len(args) < positional_args:
513
+ raise TypeError(f"For 'Cell', the function construct requires {positional_args} positional argument, "
514
+ f"but got {len(args)}. When using set_inputs, please make sure that all networks "
515
+ f"and loss functions are configured with set_inputs.")
516
+
517
+ if len(args) > positional_args + default_args:
518
+ construct_inputs_names = self.construct.__code__.co_varnames
519
+ if 'self' not in construct_inputs_names:
520
+ raise TypeError(f"For 'Cell', the method 'construct' must have parameter 'self'. ")
521
+
522
+ raise TypeError(f"For 'Cell', the function construct requires {positional_args} positional argument and "
523
+ f"{default_args} default argument, total {positional_args + default_args}, "
524
+ f"but got {len(args)}.")
525
+
526
+ def _hook_fn_registered(self):
527
+ '''Hook function in graph mode'''
528
+ # Check super().__init__() in graph mode.
529
+ try:
530
+ if self._enable_forward_pre_hook or self._enable_forward_hook or self._enable_backward_hook:
531
+ return True
532
+ except AttributeError as e:
533
+ raise AttributeError(f"The '{type(self).__name__}' object does not inherit attribute from 'cell'. "
534
+ f"Please use 'super().__init__()'.") from e
535
+ if not self._is_recursion_hook:
536
+ self._is_recursion_hook = True
537
+ for cell in self.cells():
538
+ if cell._hook_fn_registered():
539
+ return True
540
+ return False
541
+
542
+ def _get_prims_recursively(self):
543
+ all_prims = list()
544
+ for _, value in self._primitives.items():
545
+ if value:
546
+ all_prims.append(value)
547
+
548
+ for cell in self.cells():
549
+ all_prims.extend(cell._get_prims_recursively())
550
+
551
+ return all_prims
552
+
553
+ def set_data_parallel(self):
554
+ """
555
+ For all primitive ops in this cell(including ops of cells that wrapped by this cell),
556
+ if parallel strategy is not specified, then instead of auto-searching, data parallel
557
+ strategy will be generated for those primitive ops.
558
+
559
+ Note:
560
+ Only effective while using auto_parallel_context = ParallelMode.AUTO_PARALLEL under graph mode.
561
+
562
+ Examples:
563
+ >>> import mindspore.nn as nn
564
+ >>> net = nn.Dense(3, 4)
565
+ >>> net.set_data_parallel()
566
+ """
567
+ if context._get_mode() == context.PYNATIVE_MODE:
568
+ raise ValueError("set_data_parallel: does not support PyNative mode.")
569
+
570
+ all_prims = self._get_prims_recursively()
571
+ for prim in all_prims:
572
+ prim.add_prim_attr("strategy_gen_mode", "data_parallel")
573
+
574
+ def shard(self, in_strategy, out_strategy=None, parameter_plan=None, device="Ascend", level=0):
575
+ """
576
+ Defining the input and output layouts of this cell and the parallel strategies of remaining ops will be
577
+ generated by sharding propagation. In PyNative mode, use this method to specify a Cell for distributed
578
+ execution in graph mode. In Graph mode, use this method to specify distribution strategy for a Cell,
579
+ strategy for others will be set by sharding propagation.
580
+ in_strategy and out_strategy define the input and output layout respectively.
581
+ in_strategy/out_strategy should be a tuple, each element of which corresponds to the desired layout of
582
+ this input/output, and None represents data_parallel,
583
+ which can refer to the description of `mindspore.ops.Primitive.shard`.
584
+ The parallel strategies of remaining operators are derived from the strategy specified by the input and output.
585
+
586
+ Note:
587
+ If Cell.shard is called, the parallel mode in `set_auto_parallel_context` (parallel_mode) will be set to
588
+ "auto_parallel" and the search mode (search_mode) to "sharding_propagation".
589
+ If the input contain Parameter, its strategy should be set in `in_strategy`.
590
+
591
+ Args:
592
+ in_strategy (tuple): Define the layout of inputs, each element of the tuple should be a tuple or None. Tuple
593
+ defines the layout of the corresponding input and None represents a data parallel strategy.
594
+ out_strategy (Union[None, tuple]): Define the layout of outputs similar with in_strategy.
595
+ It is not in use right now. Default: ``None`` .
596
+ parameter_plan (Union[dict, None]): Define the layout for the specified parameters. Each element in dict
597
+ defines the layout of the parameter like "param_name: layout".
598
+ The key is a parameter name of type 'str'.
599
+ The value is a 1-D integer tuple, indicating the corresponding layout.
600
+ If the parameter name is incorrect or the corresponding parameter
601
+ has been set, the parameter setting will be ignored.
602
+ Default: ``None`` .
603
+ device (string): Select a certain device target. It is not in use right now.
604
+ Support [ ``"CPU"`` , ``"GPU"`` , ``"Ascend"`` ]. Default: ``"Ascend"`` .
605
+ level (int): Option for parallel strategy infer algorithm, namely the object function, maximize computation
606
+ over communication ratio, maximize speed performance, minimize memory usage etc. It is not in
607
+ use right now. Support [ ``"0"`` , ``"1"`` , ``"2"`` ]. Default: ``0`` .
608
+
609
+ Returns:
610
+ Function, return the cell construct function that will be executed under auto parallel process.
611
+
612
+ Examples:
613
+ >>> import mindspore.nn as nn
614
+ >>>
615
+ >>> class Block(nn.Cell):
616
+ ... def __init__(self):
617
+ ... self.dense1 = nn.Dense(10, 10)
618
+ ... self.relu = nn.ReLU()
619
+ ... self.dense2 = nn.Dense2(10, 10)
620
+ ... def construct(self, x):
621
+ ... x = self.relu(self.dense2(self.relu(self.dense1(x))))
622
+ ... return x
623
+ >>>
624
+ >>> class example(nn.Cell):
625
+ ... def __init__(self):
626
+ ... self.block1 = Block()
627
+ ... self.block2 = Block()
628
+ ... self.block2_shard = self.block2.shard(in_strategy=((2, 1),), out_strategy=(None,),
629
+ ... parameter_plan={'self.block2.shard.dense1.weight': (4, 1)})
630
+ ... def construct(self, x):
631
+ ... x = self.block1(x)
632
+ ... x = self.block2_shard(x)
633
+ ... return x
634
+ """
635
+ if context.get_auto_parallel_context("parallel_mode") not in ["auto_parallel", "semi_auto_parallel"]:
636
+ raise AssertionError(f"Cell shard only supports auto parallel or semi_auto_parallel "
637
+ f"Please check the parallel mode in parallel context.")
638
+
639
+ shard_fn = Shard()
640
+ fn = shard_fn(self, in_strategy, out_strategy, parameter_plan, device, level)
641
+ object.__setattr__(self, "_shard_fn", fn)
642
+ return fn
643
+
644
+ def auto_cast_inputs(self, inputs):
645
+ """
646
+ Auto cast inputs in mixed precision scenarios.
647
+
648
+ Args:
649
+ inputs (tuple): the inputs of construct.
650
+
651
+ Returns:
652
+ Tuple, the inputs after data type cast.
653
+ """
654
+ msg = f"'auto_cast_inputs' is deprecated from version 2.0 and will be removed in a future version."
655
+ logger.warning(msg)
656
+ cast_inputs = inputs
657
+ mixed_type = self.get_mixed_precision_type()
658
+ if mixed_type == MixedPrecisionType.FP16:
659
+ cast_inputs = self._cast_mixed_precision_inputs(inputs, mstype.float16)
660
+ if mixed_type == MixedPrecisionType.FP32:
661
+ cast_inputs = self._cast_mixed_precision_inputs(inputs, mstype.float32)
662
+
663
+ return cast_inputs
664
+
665
+ def _init_check(self):
666
+ for param in self.get_parameters(expand=False):
667
+ if param.has_init:
668
+ param.init_data()
669
+
670
+ def _self_check(self):
671
+ if not self._is_check_and_refresh:
672
+ self.check_names_and_refresh_name()
673
+ self._is_check_and_refresh = True
674
+
675
+ def _predict(self, *args, **kwargs):
676
+ if not hasattr(self, "phase"):
677
+ return False, None
678
+ if (self.phase == "prefill" or self.phase == 'increment') and self.phase in self.phase_cache:
679
+ new_args = _get_args_for_run_predict(self, args, kwargs, self._compile_args)
680
+ res = _cell_graph_executor._graph_executor(tuple(new_args), self.phase_cache[self.phase])
681
+ res = _convert_python_data(res)
682
+ return True, res
683
+ return False, None
684
+
685
+ def __call__(self, *args, **kwargs):
686
+ # Run in Graph mode.
687
+ if os.getenv("MS_JIT") != '0' and context._get_mode() == context.GRAPH_MODE:
688
+ if kwargs:
689
+ bound_arguments = self.sig.bind(*args, **kwargs)
690
+ bound_arguments.apply_defaults()
691
+ args = bound_arguments.args
692
+ kwargs = bound_arguments.kwargs
693
+
694
+ predict_compiled, res = self._predict(*args, **kwargs)
695
+ if predict_compiled:
696
+ return res
697
+ self._check_construct_args(*args)
698
+
699
+ if self._hook_fn_registered():
700
+ logger.warning(f"For 'Cell', it's not support hook function in graph mode. If you want to use hook "
701
+ f"function, please use context.set_context to set pynative mode.")
702
+ self._self_check()
703
+ out = self.compile_and_run(*args, **kwargs)
704
+ return out
705
+
706
+ # Run in PyNative mode.
707
+ self._self_check()
708
+ if not self._init_flag:
709
+ self._init_check()
710
+ self._init_flag = True
711
+
712
+ if self.requires_grad:
713
+ _pynative_executor.set_grad_flag(True)
714
+
715
+ try:
716
+ _pynative_executor.new_graph(self, *args, **kwargs)
717
+ output = self._run_construct(args, kwargs)
718
+ _pynative_executor.end_graph(self, output, *args, **kwargs)
719
+ except Exception as err:
720
+ _pynative_executor.clear_res()
721
+ raise err
722
+
723
+ return output
724
+
725
+ def _add_attr(self, name, value):
726
+ if name and name[:2] != '__' and name not in Cell.IGNORE_LIST:
727
+ super(Cell, self)._add_attr(name, value)
728
+
729
+ def _sync_attr_for_compile(self):
730
+ """Sync the attr to c++ object."""
731
+ if self._attr_synced:
732
+ return
733
+ cells = self.__dict__.get('_cells')
734
+ for key in cells:
735
+ cell = cells[key]
736
+ cell._sync_attr_for_compile()
737
+ self._add_attr(key, cell)
738
+ params = self.__dict__.get('_params')
739
+ for key in params:
740
+ if '.' in key:
741
+ continue
742
+ param = params[key]
743
+ self._add_attr(key, param)
744
+ params_list = self.__dict__.get('_params_list')
745
+ for key in params_list:
746
+ params_list_item = params_list[key]
747
+ self._add_attr(key, params_list_item)
748
+ for key in self.__dict__:
749
+ value = self.__dict__[key]
750
+ self._add_attr(key, value)
751
+ self._attr_synced = True
752
+
753
+ def _set_attr_for_parameter(self, name, value):
754
+ """Set attr for parameter."""
755
+ cells = self.__dict__.get('_cells')
756
+ params = self.__dict__.get('_params')
757
+ if params is None:
758
+ raise AttributeError("For 'Cell', can not assign params before Cell.__init__() is called.")
759
+ if name in self.__dict__:
760
+ if self.__dict__[name] is not None:
761
+ raise TypeError(f"For 'Cell', the {name} should not be Parameter.")
762
+ del self.__dict__[name]
763
+ if cells and name in cells:
764
+ raise TypeError(f"For 'Cell', the {name} must be Cell, but got Parameter.")
765
+ self.insert_param_to_cell(name, value)
766
+
767
+ def _set_attr_for_parameter_tuple(self, name, value):
768
+ """Set attr for parameter in ParameterTuple."""
769
+ params = self.__dict__.get('_params')
770
+ params_list = self.__dict__.get('_params_list')
771
+ if params is None:
772
+ raise AttributeError("For 'Cell', can not assign params before Cell.__init__() is called.")
773
+ exist_names = set("")
774
+ exist_objs = set()
775
+ for item in value:
776
+ if item in exist_objs:
777
+ # If there are multiple identical objects, their names only check once.
778
+ continue
779
+ exist_objs.add(item)
780
+ if item.name == PARAMETER_NAME_DEFAULT:
781
+ logger.warning("For 'Cell', the parameter definition is deprecated.\n"
782
+ "Please set a unique name for the parameter in ParameterTuple '{}'.".format(value))
783
+ item.name = item.name + "$" + str(self._id)
784
+ self._id += 1
785
+ self.insert_param_to_cell(item.name, item, check_name_contain_dot=False)
786
+ if item.name in exist_names:
787
+ raise ValueError("The value {} , its name '{}' already exists. "
788
+ "Please set a unique name for the parameter.".format(value, item.name))
789
+ exist_names.add(item.name)
790
+
791
+ if context._get_mode() == context.PYNATIVE_MODE:
792
+ if name in self.__dict__:
793
+ del self.__dict__[name]
794
+ if name in params:
795
+ del params[name]
796
+ params_list[name] = value
797
+ else:
798
+ object.__setattr__(self, name, value)
799
+
800
+ def _set_attr_for_parameter_in_list_or_tuple(self, name, value):
801
+ """Set attr for parameter in list or tuple."""
802
+ for item in value:
803
+ if item in self.exist_objs:
804
+ # If there are multiple identical objects, their names only check once.
805
+ continue
806
+ self.exist_objs.add(item)
807
+ if item.name == PARAMETER_NAME_DEFAULT:
808
+ item.name = item.name + "$" + str(self._id)
809
+ self._id += 1
810
+ if item.name in self.exist_names:
811
+ raise ValueError("The value {} , its name '{}' already exists. "
812
+ "Please set a unique name for the parameter.".format(value, item.name))
813
+ self.exist_names.add(item.name)
814
+ object.__setattr__(self, name, value)
815
+
816
+ def _set_attr_for_cell(self, name, value):
817
+ """Set attr for cell."""
818
+ cells = self.__dict__.get('_cells')
819
+ params = self.__dict__.get('_params')
820
+ if cells is None:
821
+ raise AttributeError("For 'Cell', can not assign cells before Cell.__init__() is called.")
822
+ if name in self.__dict__:
823
+ del self.__dict__[name]
824
+ if params and name in params:
825
+ raise TypeError(f"For 'Cell', the {name} must be Parameter, but got Cell.")
826
+ if self._auto_prefix:
827
+ value.update_parameters_name(name + '.')
828
+ cells[name] = value
829
+ if hasattr(self, '_cell_init_args'):
830
+ self.cell_init_args += str({name: value})
831
+
832
+ def _set_attr_for_params(self, name, value):
833
+ if isinstance(value, Tensor) and self._params[name] is not None:
834
+ self._params[name].set_data(value)
835
+ elif value is not None:
836
+ raise TypeError(f"For 'Cell', the type of {name} must be Parameter or ParameterTuple, "
837
+ f"but got {type(value).__name__}.")
838
+ else:
839
+ self.insert_param_to_cell(name, None)
840
+
841
+ def __setattr__(self, name, value):
842
+ cells = self.__dict__.get('_cells')
843
+ params = self.__dict__.get('_params')
844
+ if isinstance(value, Parameter):
845
+ self._set_attr_for_parameter(name, value)
846
+ elif isinstance(value, ParameterTuple):
847
+ self._set_attr_for_parameter_tuple(name, value)
848
+ elif isinstance(value, (list, tuple)) and value and _check_param_list_tuple(value):
849
+ self._set_attr_for_parameter_in_list_or_tuple(name, value)
850
+ elif isinstance(value, Cell):
851
+ self._set_attr_for_cell(name, value)
852
+ elif params and name in params:
853
+ self._set_attr_for_params(name, value)
854
+ elif cells and name in cells:
855
+ if value is not None:
856
+ raise TypeError(f"For 'Cell', the type of {name} must be cell, but got {type(value).__name__}.")
857
+ self._cells[name] = None
858
+ else:
859
+ if isinstance(value, Primitive):
860
+ value.set_prim_instance_name(name)
861
+ self._primitives[name] = value
862
+ object.__setattr__(self, name, value)
863
+ if name not in Cell.IGNORE_LIST:
864
+ self._attr_synced = False
865
+
866
+ def extend_repr(self):
867
+ """
868
+ Expand the description of Cell.
869
+
870
+ To print customized extended information, re-implement this method in your own cells.
871
+ """
872
+ return ''
873
+
874
+ def __str__(self):
875
+ return self.__repr__()
876
+
877
+ def __repr__(self):
878
+ extra_str = self.extend_repr()
879
+ info_str = self.__class__.__name__ + '<'
880
+ if self._cells:
881
+ sub_str = '\n'
882
+ if extra_str:
883
+ sub_str += '{}\n'.format(self.extend_repr())
884
+ for key, value in self._cells.items():
885
+ sub_str += '({}): {}\n'.format(key, repr(value))
886
+ sub_str = sub_str.replace('\n', '\n ') + '>'
887
+ info_str += sub_str
888
+ else:
889
+ info_str += extra_str + '>'
890
+ return info_str
891
+
892
+ def load_parameter_slice(self, params):
893
+ """
894
+ Replace parameters with sliced tensors by parallel strategies.
895
+
896
+ Note:
897
+ This interface is deprecated.
898
+ """
899
+ logger.warning("'load_parameter_slice' function is deprecated.")
900
+
901
+ def set_parallel_input_with_inputs(self, *inputs):
902
+ """
903
+ Slice inputs tensors by parallel strategies.
904
+
905
+ Note:
906
+ This interface is deprecated.
907
+ """
908
+ logger.warning("'set_parallel_input_with_inputs' function is deprecated.")
909
+
910
+ def set_inputs(self, *inputs, **kwargs):
911
+ """
912
+ Save set inputs for computation graph. The number of inputs should be the same with that of the datasets. When
913
+ using Model for dynamic shape, please make sure that all networks and loss functions passed to the Model are
914
+ configured with set_inputs. The shape of input Tensor can be either dynamic or static.
915
+
916
+ .. note::
917
+ There are two mode:
918
+
919
+ - Full mode: arguments will be used as all compile inputs for graph-compiling.
920
+ - Incremental mode: arguments will set to some of the Cell inputs, which will be substituted into the input
921
+ at the corresponding position for graph-compiling.
922
+
923
+ Only one of inputs or kwargs can be set. Inputs for full mode and kwargs for incremental mode.
924
+
925
+ Args:
926
+ inputs (tuple): Full mode arguments.
927
+ kwargs (dict): Incremental mode arguments. The acceptable key is the name of parameter defined
928
+ in `self.construct`.
929
+
930
+ .. warning::
931
+ This is an experimental API that is subject to change or deletion.
932
+
933
+ Examples:
934
+ >>> import numpy as np
935
+ >>> import mindspore as ms
936
+ >>> from mindspore import nn, Tensor
937
+ >>>
938
+ >>> class ReluNet(nn.Cell):
939
+ ... def __init__(self):
940
+ ... super(ReluNet, self).__init__()
941
+ ... self.relu = nn.ReLU()
942
+ ... def construct(self, x):
943
+ ... return self.relu(x)
944
+ >>>
945
+ >>> net = ReluNet()
946
+ >>> input_dyn = Tensor(shape=[3, None], dtype=ms.float32)
947
+ >>> net.set_inputs(input_dyn)
948
+ >>> input = Tensor(np.random.random([3, 10]), dtype=ms.float32)
949
+ >>> output = net(input)
950
+ >>>
951
+ >>> net2 = ReluNet()
952
+ >>> net2.set_inputs(x=input_dyn)
953
+ >>> output = net2(input)
954
+ """
955
+ if self.grad_ops_label:
956
+ logger.warning(f'For Cell, set_inputs must be set before the gradient function of the network is '
957
+ f'generated.')
958
+ if kwargs and inputs:
959
+ raise ValueError('For Cell, set_inputs should only set inputs or kwargs(inputs: %s, kwargs: %s)!'
960
+ % (inputs, kwargs))
961
+
962
+ if not kwargs:
963
+ self._dynamic_shape_inputs = inputs
964
+ self._check_construct_args(*inputs)
965
+ if context._get_mode() == context.PYNATIVE_MODE:
966
+ _pynative_executor.set_dynamic_input(self, *self._dynamic_shape_inputs)
967
+ else:
968
+ self._dynamic_shape_inputs = _process_dyn_args(self.construct, kwargs)
969
+
970
+ def get_inputs(self):
971
+ """
972
+ Returns the dynamic_inputs of a cell object in one network.
973
+
974
+ Returns:
975
+ inputs (tuple), Inputs of the Cell object.
976
+
977
+ .. warning::
978
+ This is an experimental API that is subject to change or deletion.
979
+
980
+ Examples:
981
+ >>> import numpy as np
982
+ >>> import mindspore as ms
983
+ >>> from mindspore import nn, Tensor
984
+ >>>
985
+ >>> class ReluNet(nn.Cell):
986
+ ... def __init__(self):
987
+ ... super(ReluNet, self).__init__()
988
+ ... self.relu = nn.ReLU()
989
+ ... def construct(self, x):
990
+ ... return self.relu(x)
991
+ >>>
992
+ >>> net = ReluNet()
993
+ >>> input_dyn = Tensor(shape=[3, None], dtype=ms.float32)
994
+ >>> net.set_inputs(input_dyn)
995
+ >>> get_inputs = net.get_inputs()
996
+ >>> print(get_inputs)
997
+ (Tensor(shape=[3, -1], dtype=Float32, value= ),)
998
+
999
+ """
1000
+
1001
+ return self._dynamic_shape_inputs
1002
+
1003
+ def _check_parameter_consistency(self, set_inputs, net_inputs):
1004
+ """Check consistency for parameter."""
1005
+ for index, (set_input, net_input) in enumerate(zip(set_inputs, net_inputs)):
1006
+ if isinstance(set_input, Tensor):
1007
+ if not isinstance(net_input, Tensor):
1008
+ raise TypeError(
1009
+ f"For 'set_inputs' and tuple(list) in 'set_inputs',the type of {index + 1}th input must "
1010
+ f"be Tensor, but got {type(net_input)}.")
1011
+ if isinstance(set_input, Parameter) != isinstance(net_input, Parameter):
1012
+ raise TypeError(
1013
+ f"For 'set_inputs' and tuple(list) in 'set_inputs', the {index + 1}th input must be the same "
1014
+ f"as expected, but got expected: {type(set_input)} and input: {type(net_input)}.")
1015
+ elif isinstance(set_input, (tuple, list)):
1016
+ if not isinstance(net_input, (tuple, list)):
1017
+ raise TypeError(
1018
+ f"The {index + 1}th input type of 'set_inputs' or tuple(list) in "
1019
+ f"'set_inputs' must be tuple or list, but got {type(net_input)}.")
1020
+ self._check_parameter_consistency(set_input, net_input)
1021
+
1022
+ def _get_compile_args(self, args):
1023
+ """Get compile arguments."""
1024
+ # this is used only for test
1025
+ set_by_auto_dynamic = False
1026
+ if is_auto_dynamic():
1027
+ if self._dynamic_shape_inputs is None:
1028
+ set_by_auto_dynamic = True
1029
+ else:
1030
+ if isinstance(self._dynamic_shape_inputs, (list, tuple)) and self._dynamic_shape_inputs[0] is None:
1031
+ set_by_auto_dynamic = True
1032
+ if set_by_auto_dynamic:
1033
+ self._dynamic_shape_inputs = convert_inputs_to_dynamic(*args)
1034
+
1035
+ if self._dynamic_shape_inputs is not None:
1036
+ logger.debug("Compiled Graph with dynamic shape")
1037
+ compile_args = _generate_dyn_compile_args(args, self._dynamic_shape_inputs)
1038
+ _cell_graph_executor._graph_executor.check_argument_consistency(compile_args, args, "set_inputs")
1039
+ self._check_parameter_consistency(compile_args, args)
1040
+ Validator.check_symbolic_shape(compile_args, args)
1041
+ self.saved_dynamic_shape = compile_args
1042
+ return compile_args
1043
+ return args
1044
+
1045
+ def compile(self, *args, **kwargs):
1046
+ """
1047
+ Compile Cell as a computation graph, the input must be consistent with the input defined in construct.
1048
+
1049
+ Args:
1050
+ args (tuple): Args of the Cell object.
1051
+ kwargs (dict): Kwargs of the Cell object.
1052
+ """
1053
+ self._compile_args = self._get_compile_args(args)
1054
+ _cell_graph_executor.compile(self, *self._compile_args, phase=self.phase,
1055
+ jit_config_dict=self._jit_config_dict, **kwargs)
1056
+
1057
+ def compile_and_run(self, *args, **kwargs):
1058
+ """
1059
+ Compile and run Cell, the input must be consistent with the input defined in construct.
1060
+
1061
+ Note:
1062
+ It is not recommended to call directly.
1063
+
1064
+ Args:
1065
+ args (tuple): Args of the Cell object.
1066
+ kwargs (dict): Kwargs of the Cell object.
1067
+
1068
+ Returns:
1069
+ Object, the result of executing.
1070
+ """
1071
+ self.compile(*args, **kwargs)
1072
+ self.add_flags(ge_sync_data=False)
1073
+ new_args = _get_args_for_run(self, args, kwargs, self._compile_args)
1074
+ return _cell_graph_executor(self, *new_args, phase=self.phase)
1075
+
1076
+ def auto_parallel_compile_and_run(self):
1077
+ """
1078
+ Whether or not to execute compile and run in 'AUTO_PARALLEL' or 'SEMI_AUTO_PARALLEL' mode.
1079
+
1080
+ Note:
1081
+ This interface is deprecated.
1082
+ """
1083
+ logger.warning("'auto_parallel_compile_and_run' function is deprecated.")
1084
+
1085
+ def exec_checkpoint_graph(self):
1086
+ """Executes GE saving checkpoint graph operation."""
1087
+ logger.warning("'exec_checkpoint_graph' function is deprecated.")
1088
+ self.add_flags(ge_sync_data=True)
1089
+ _cell_graph_executor(self, phase='save')
1090
+
1091
+ def insert_param_to_cell(self, param_name, param, check_name_contain_dot=True):
1092
+ """
1093
+ Adds a parameter to the current cell.
1094
+
1095
+ Inserts a parameter with given name to the cell. The method is currently used in
1096
+ `mindspore.nn.Cell.__setattr__`.
1097
+
1098
+ Args:
1099
+ param_name (str): Name of the parameter.
1100
+ param (Parameter): Parameter to be inserted to the cell.
1101
+ check_name_contain_dot (bool): Determines whether the name input is compatible. Default: ``True`` .
1102
+
1103
+ Raises:
1104
+ KeyError: If the name of parameter is null or contains dot.
1105
+ TypeError: If the type of parameter is not Parameter.
1106
+
1107
+ Examples:
1108
+ >>> import mindspore as ms
1109
+ >>> from mindspore import Tensor, nn, Parameter
1110
+ ...
1111
+ >>> class Net(nn.Cell):
1112
+ ... def __init__(self):
1113
+ ... super(Net, self).__init__()
1114
+ ... self.relu = nn.ReLU()
1115
+ ...
1116
+ ... def construct(self, x):
1117
+ ... x = self.relu(x)
1118
+ ... return x
1119
+ >>> net = Net()
1120
+ >>> net.insert_param_to_cell("bias", Parameter(Tensor([1, 2, 3])))
1121
+ >>> print(net.bias)
1122
+ Parameter(name=bias, shape=(3,), dtype=Int64, requires_grad=True)
1123
+ """
1124
+ if not param_name:
1125
+ raise KeyError(f"For 'insert_param_to_cell', the argument 'param_name' should not be None.")
1126
+ if check_name_contain_dot and '.' in param_name:
1127
+ raise KeyError(f"For 'insert_param_to_cell', the argument 'param_name' should not contain'.' ")
1128
+ if '_params' not in self.__dict__:
1129
+ raise AttributeError(f"For 'insert_param_to_cell', please call Cell.__init__() firstly.")
1130
+ if hasattr(self, param_name) and param_name not in self._params:
1131
+ raise KeyError(f"For 'insert_param_to_cell', the {param_name} parameter already exists in the network."
1132
+ f"Cannot insert another parameter with the same name.")
1133
+ if not isinstance(param, Parameter) and param is not None:
1134
+ raise TypeError(f"For 'insert_param_to_cell', the argument 'param' must be 'Parameter' if not None, "
1135
+ f"but got {type(param)}.")
1136
+ if isinstance(param, Parameter) and param.name == PARAMETER_NAME_DEFAULT:
1137
+ param.name = param_name
1138
+ self._params[param_name] = param
1139
+
1140
+ def cast_param(self, param):
1141
+ """
1142
+ Cast parameter according to auto mix precision level in pynative mode.
1143
+
1144
+ This interface is currently used in the case of auto mix precision and usually needs not to be used explicitly.
1145
+
1146
+ Args:
1147
+ param (Parameter): Parameters, the type of which should be cast.
1148
+
1149
+ Returns:
1150
+ Parameter, the input parameter with type automatically cast.
1151
+ """
1152
+ msg = f"'cast_param' is deprecated from version 2.0 and will be removed in a future version."
1153
+ logger.warning(msg)
1154
+ mixed_type = self.get_mixed_precision_type()
1155
+ if mixed_type != MixedPrecisionType.NOTSET:
1156
+ if mixed_type == MixedPrecisionType.FP32:
1157
+ param.set_cast_dtype(mstype.float32)
1158
+ elif mixed_type == MixedPrecisionType.FP16:
1159
+ param.set_cast_dtype(mstype.float16)
1160
+ elif hasattr(param, "set_cast_dtype"):
1161
+ # retest dtype
1162
+ param.set_cast_dtype()
1163
+ return param
1164
+
1165
+ def insert_child_to_cell(self, child_name, child_cell):
1166
+ """
1167
+ Adds a child cell to the current cell with a given name.
1168
+
1169
+ Args:
1170
+ child_name (str): Name of the child cell.
1171
+ child_cell (Cell): The child cell to be inserted.
1172
+
1173
+ Raises:
1174
+ KeyError: Child Cell's name is incorrect or duplicated with the other child name.
1175
+ TypeError: If type of `child_name` is not str.
1176
+ TypeError: Child Cell's type is incorrect.
1177
+
1178
+ Examples:
1179
+ >>> import mindspore as ms
1180
+ >>> from mindspore import Tensor, nn
1181
+ ...
1182
+ >>> net1 = nn.ReLU()
1183
+ >>> net2 = nn.Dense(2, 2)
1184
+ >>> net1.insert_child_to_cell("child", net2)
1185
+ >>> print(net1)
1186
+ ReLU<
1187
+ (child): Dense<input_channels=2, output_channels=2, has_bias=True>
1188
+ >
1189
+ """
1190
+ if not isinstance(child_name, str):
1191
+ raise TypeError(f"For 'insert_child_to_cell', the type of parameter 'child_name' must be str, "
1192
+ f"but got {type(child_name)}.")
1193
+ if not child_name or '.' in child_name:
1194
+ raise KeyError(f"For 'insert_child_to_cell', the parameter 'child_name' can not be None and "
1195
+ "can not contain '.' ")
1196
+ if hasattr(self, child_name) and child_name not in self._cells:
1197
+ raise KeyError(f"For 'insert_child_to_cell', the {child_name} child cell already exists in the network."
1198
+ f"Cannot insert another child cell with the same name.")
1199
+ if not isinstance(child_cell, Cell) and child_cell is not None:
1200
+ raise TypeError(f"For 'insert_child_to_cell', the argument 'child_cell' must be 'Cell' if not None, "
1201
+ f"but got type {type(child_cell)}.")
1202
+ self._cells[child_name] = child_cell
1203
+
1204
+ def construct(self, *args, **kwargs):
1205
+ """
1206
+ Defines the computation to be performed. This method must be overridden by all subclasses.
1207
+
1208
+ Note:
1209
+ It is not supported currently that inputs contain both tuple and non-tuple types at same time.
1210
+
1211
+ Args:
1212
+ args (tuple): Tuple of variable parameters.
1213
+ kwargs (dict): Dictionary of variable keyword parameters.
1214
+
1215
+ Returns:
1216
+ Tensor, returns the computed result.
1217
+ """
1218
+ raise AttributeError("For 'Cell', the method 'construct' is not defined.")
1219
+
1220
+ def remove_redundant_parameters(self):
1221
+ """
1222
+ Remove the redundant parameters.
1223
+
1224
+ This interface usually needs not to be used explicitly.
1225
+ """
1226
+ cells = self.cells_and_names()
1227
+ for _, cell in cells:
1228
+ params = cell._params.items()
1229
+ for param_name, param in list(params):
1230
+ if param.name not in self.parallel_parameter_name_list:
1231
+ cell._params.pop(param_name)
1232
+ logger.info("remove the redundant parameter: %s", param.name)
1233
+ continue
1234
+ cell_dict = cell.__dict__
1235
+ for key in cell_dict:
1236
+ if isinstance(cell_dict[key], ParameterTuple):
1237
+ param_tuple = cell_dict[key]
1238
+ new_param_tuple = []
1239
+ for param in param_tuple:
1240
+ if param.name not in self.parallel_parameter_name_list:
1241
+ logger.info("remove the redundant parameter: %s in ParameterTuple", param.name)
1242
+ continue
1243
+ new_param_tuple.append(param)
1244
+ cell.__dict__[key] = ParameterTuple(new_param_tuple)
1245
+
1246
+ def init_parameters_data(self, auto_parallel_mode=False):
1247
+ """
1248
+ Initialize all parameters and replace the original saved parameters in cell.
1249
+
1250
+ Note:
1251
+ trainable_params() and other similar interfaces may return different parameter instance after
1252
+ `init_parameters_data`, do not save these results.
1253
+
1254
+ Args:
1255
+ auto_parallel_mode (bool): If running in auto_parallel_mode. Default: ``False`` .
1256
+
1257
+ Returns:
1258
+ Dict[Parameter, Parameter], returns a dict of original parameter and replaced parameter.
1259
+
1260
+ Examples:
1261
+ >>> import mindspore as ms
1262
+ >>> from mindspore import Tensor, nn
1263
+ ...
1264
+ >>> class Net(nn.Cell):
1265
+ ... def __init__(self):
1266
+ ... super(Net, self).__init__()
1267
+ ... self.dense = nn.Dense(2, 2)
1268
+ ...
1269
+ ... def construct(self, x):
1270
+ ... x = self.dense(x)
1271
+ ... return x
1272
+ >>> net = Net()
1273
+ >>> print(net.init_parameters_data())
1274
+ {Parameter (name=dense.weight, shape=(2,2), dtype=Float32, requires_grad=True):
1275
+ Parameter (name=dense.weight, shape=(2,2), dtype=Float32, requires_grad=True),
1276
+ Parameter (name=dense.bias, shape=(2,), dtype=Float32, requires_grad=True):
1277
+ Parameter (name=dense.bias, shape=(2,), dtype=Float32, requires_grad=True)}
1278
+ """
1279
+ replace = dict()
1280
+
1281
+ def _updata(param):
1282
+ if param in replace:
1283
+ return replace.get(param)
1284
+ new_p = param.init_data(None, set_sliced=False)
1285
+ replace[param] = new_p
1286
+ return new_p
1287
+
1288
+ # replace all original usage.
1289
+ cells = self.cells_and_names()
1290
+ for _, cell in cells:
1291
+ params = cell._params.items()
1292
+ for param_name, param in params:
1293
+ if not auto_parallel_mode:
1294
+ cell._params[param_name] = _updata(param)
1295
+ continue
1296
+ if param.name in self.parallel_parameter_name_list:
1297
+ cell._params[param_name] = _updata(param)
1298
+ cell_dict = cell.__dict__
1299
+ for key in cell_dict:
1300
+ if isinstance(cell_dict[key], ParameterTuple):
1301
+ param_tuple = cell_dict[key]
1302
+ new_param_tuple = []
1303
+ for param in param_tuple:
1304
+ if not auto_parallel_mode:
1305
+ new_param_tuple.append(_updata(param))
1306
+ continue
1307
+ if param.name in self.parallel_parameter_name_list:
1308
+ new_param_tuple.append(_updata(param))
1309
+ else:
1310
+ new_param_tuple.append(param)
1311
+ cell.__dict__[key] = ParameterTuple(new_param_tuple)
1312
+ return replace
1313
+
1314
+ def parameters_dict(self, recurse=True):
1315
+ """
1316
+ Gets the parameters dictionary of this cell.
1317
+
1318
+ Args:
1319
+ recurse (bool): Whether contains the parameters of subcells. Default: ``True`` .
1320
+
1321
+ Returns:
1322
+ OrderedDict, return parameters dictionary.
1323
+
1324
+ Examples:
1325
+ >>> import mindspore as ms
1326
+ >>> from mindspore import Tensor, nn, Parameter
1327
+ ...
1328
+ >>> class Net(nn.Cell):
1329
+ ... def __init__(self):
1330
+ ... super(Net, self).__init__()
1331
+ ... self.dense = nn.Dense(2, 2)
1332
+ ...
1333
+ ... def construct(self, x):
1334
+ ... x = self.dense(x)
1335
+ ... return x
1336
+ >>> net = Net()
1337
+ >>> print(net.parameters_dict())
1338
+ OrderedDict([('dense.weight', Parameter(name=dense.weight, shape=(2, 2), dtype=Float32,
1339
+ requires_grad=True)), ('dense.bias', Parameter(name=dense.bias, shape=(2,), dtype=Float32,
1340
+ requires_grad=True))])
1341
+ """
1342
+ param_dict = OrderedDict()
1343
+ for param in self.get_parameters(expand=recurse):
1344
+ param_dict[param.name] = param
1345
+ return param_dict
1346
+
1347
+ def parameters_broadcast_dict(self, recurse=True):
1348
+ """
1349
+ Gets the parameters broadcast dictionary of this cell.
1350
+
1351
+ Args:
1352
+ recurse (bool): Whether contains the parameters of subcells. Default: ``True`` .
1353
+
1354
+ Returns:
1355
+ OrderedDict, return parameters broadcast dictionary.
1356
+ """
1357
+ param_dict = OrderedDict()
1358
+ for param in self.get_parameters(expand=recurse):
1359
+ if param.layerwise_parallel is False:
1360
+ param_dict[param.name] = param
1361
+ if not param_dict:
1362
+ return None
1363
+ return param_dict
1364
+
1365
+ def update_parameters_name(self, prefix='', recurse=True):
1366
+ """
1367
+ Adds the `prefix` string to the names of parameters.
1368
+
1369
+ Args:
1370
+ prefix (str): The prefix string. Default: ``''`` .
1371
+ recurse (bool): Whether contains the parameters of subcells. Default: ``True`` .
1372
+ """
1373
+
1374
+ Validator.check_str_and_none_by_regular(prefix)
1375
+ for name, param in self.parameters_and_names(expand=recurse):
1376
+ if prefix != '':
1377
+ param.is_init = False
1378
+ param.name = prefix + name
1379
+
1380
+ def _update_local_parameters_name(self, prefix='', recurse=True):
1381
+ """
1382
+ Updates the names of local parameters with given prefix string.
1383
+
1384
+ Adds the given prefix to the names of local parameters.
1385
+
1386
+ Local parameters means the parameters without user input.
1387
+
1388
+ Args:
1389
+ prefix (str): The prefix string. Default: ''.
1390
+ recurse (bool): Whether contains the parameters of subcells. Default: ``True``.
1391
+ """
1392
+
1393
+ Validator.check_str_by_regular(prefix)
1394
+ for name, param in self.parameters_and_names(expand=recurse):
1395
+ if name in self._user_parameters:
1396
+ continue
1397
+ if prefix != '':
1398
+ param.is_init = False
1399
+ param.name = prefix + name
1400
+
1401
+ @jit_forbidden_register
1402
+ def trainable_params(self, recurse=True):
1403
+ """
1404
+ Returns all trainable parameters.
1405
+
1406
+ Returns a list of all trainable parameters.
1407
+
1408
+ Args:
1409
+ recurse (bool): Whether contains the trainable parameters of subcells. Default: ``True`` .
1410
+
1411
+ Returns:
1412
+ List, the list of trainable parameters.
1413
+
1414
+ Tutorial Examples:
1415
+ - `Model Training - Optimizer
1416
+ <https://mindspore.cn/tutorials/en/master/beginner/train.html#optimizer>`_
1417
+ """
1418
+ return list(filter(lambda x: x.requires_grad, self.get_parameters(expand=recurse)))
1419
+
1420
+ @jit_forbidden_register
1421
+ def untrainable_params(self, recurse=True):
1422
+ """
1423
+ Returns all untrainable parameters.
1424
+
1425
+ Returns a list of all untrainable parameters.
1426
+
1427
+ Args:
1428
+ recurse (bool): Whether contains the untrainable parameters of subcells. Default: ``True`` .
1429
+
1430
+ Returns:
1431
+ List, the list of untrainable parameters.
1432
+ """
1433
+ return list(filter(lambda x: not x.requires_grad, self.get_parameters(expand=recurse)))
1434
+
1435
+ @jit_forbidden_register
1436
+ def get_parameters(self, expand=True):
1437
+ """
1438
+ Returns an iterator over cell parameters.
1439
+
1440
+ Yields parameters of this cell. If `expand` is ``true`` , yield parameters of this cell and all subcells.
1441
+ For more details about subcells, please see the example below.
1442
+
1443
+ Args:
1444
+ expand (bool): If ``true`` , yields parameters of this cell and all subcells. Otherwise, only yield
1445
+ parameters that are direct members of this cell. Default: ``True`` .
1446
+
1447
+ Returns:
1448
+ Iteration, all parameters at the cell.
1449
+
1450
+ Examples:
1451
+ >>> import mindspore as ms
1452
+ >>> from mindspore import nn, ops, Tensor
1453
+ >>> import numpy as np
1454
+ >>> class TestNet(nn.Cell):
1455
+ ... def __init__(self):
1456
+ ... super().__init__()
1457
+ ... self.my_w1 = ms.Parameter(Tensor(np.ones([4, 4]), ms.float32))
1458
+ ... self.my_w2 = ms.Parameter(Tensor(np.ones([16]), ms.float32))
1459
+ ... def construct(self, x):
1460
+ ... x += self.my_w1
1461
+ ... x = ops.reshape(x, (16,)) - self.my_w2
1462
+ ... return x
1463
+ >>> class TestNet2(nn.Cell):
1464
+ ... def __init__(self):
1465
+ ... super().__init__()
1466
+ ... self.my_t1 = ms.Parameter(Tensor(np.ones([4, 4]), ms.float32))
1467
+ ... # self.subcell is a subcell of TestNet2, when using expand=True, the parameters of TestNet will
1468
+ ... # also be gathered.
1469
+ ... self.subcell = TestNet()
1470
+ ... def construct(self, x):
1471
+ ... x += self.my_w1
1472
+ ... x = ops.reshape(x, (16,)) - self.my_w2
1473
+ ... return x
1474
+ >>> net = TestNet2()
1475
+ >>> print([p for p in net.get_parameters(expand=True)])
1476
+ [Parameter (name=my_t1, shape=(4, 4), dtype=Float32, requires_grad=True), Parameter (name=subcell.my_w1,
1477
+ shape=(4, 4), dtype=Float32, requires_grad=True), Parameter (name=subcell.my_w2, shape=(16,), dtype=Float32,
1478
+ requires_grad=True)]
1479
+ """
1480
+ for _, param in self.parameters_and_names(expand=expand):
1481
+ yield param
1482
+
1483
+ # pylint: disable=missing-docstring
1484
+ def check_names_and_refresh_name(self):
1485
+ if not hasattr(self, "_params"):
1486
+ return
1487
+ all_name = [i.name for i in dict(self.parameters_and_names()).values()]
1488
+ if len(set(all_name)) < len(all_name):
1489
+ self.update_parameters_name()
1490
+ self.check_names()
1491
+
1492
+ def check_names(self):
1493
+ """
1494
+ Check the names of cell parameters.
1495
+ """
1496
+ names = set("")
1497
+ for value, param in self.parameters_and_names():
1498
+ if param.name in names:
1499
+ raise ValueError("The value of {} is {}, its name '{}' already exists. "
1500
+ "Please set a unique name for the parameter.".format(value, param, param.name))
1501
+ names.add(param.name)
1502
+
1503
+ def parameters_and_names(self, name_prefix='', expand=True):
1504
+ """
1505
+ Returns an iterator over cell parameters.
1506
+
1507
+ Includes the parameter's name and itself.
1508
+
1509
+ Args:
1510
+ name_prefix (str): Namespace. Default: ``''`` .
1511
+ expand (bool): If true, yields parameters of this cell and all subcells. Otherwise, only yield parameters
1512
+ that are direct members of this cell. Default: ``True`` .
1513
+
1514
+ Returns:
1515
+ Iteration, all the names and corresponding parameters in the cell.
1516
+
1517
+ Examples:
1518
+ >>> from mindspore import nn
1519
+ >>> n = nn.Dense(3, 4)
1520
+ >>> names = []
1521
+ >>> for m in n.parameters_and_names():
1522
+ ... if m[0]:
1523
+ ... names.append(m[0])
1524
+
1525
+ Tutorial Examples:
1526
+ - `Building a Network - Model Parameters
1527
+ <https://mindspore.cn/tutorials/en/master/beginner/model.html#model-parameters>`_
1528
+ """
1529
+ cells = []
1530
+ if expand:
1531
+ cells = self.cells_and_names(name_prefix=name_prefix)
1532
+ else:
1533
+ cells.append((name_prefix, self))
1534
+
1535
+ params_set = set()
1536
+ for cell_name, cell in cells:
1537
+ params = cell._params.items()
1538
+ for par_name, par in params:
1539
+ if par is not None and par.inited_param is not None:
1540
+ par = par.inited_param
1541
+ if par is not None and id(par) not in params_set:
1542
+ params_set.add(id(par))
1543
+ par_new_name = par_name
1544
+ if cell_name:
1545
+ par_new_name = cell_name + '.' + par_new_name
1546
+
1547
+ yield par_new_name, par
1548
+
1549
+ def cells_and_names(self, cells=None, name_prefix=''):
1550
+ """
1551
+ Returns an iterator over all cells in the network, including the cell's name and itself.
1552
+
1553
+ Args:
1554
+ cells (str): Cells to iterate over. Default: ``None`` .
1555
+ name_prefix (str): Namespace. Default: ``''`` .
1556
+
1557
+ Returns:
1558
+ Iteration, all the child cells and corresponding names in the cell.
1559
+
1560
+ Examples:
1561
+ >>> from mindspore import nn
1562
+ >>> class Net(nn.Cell):
1563
+ ... def __init__(self):
1564
+ ... super(Net, self).__init__()
1565
+ ... self.conv = nn.Conv2d(3, 64, 3)
1566
+ ... def construct(self, x):
1567
+ ... out = self.conv(x)
1568
+ ... return out
1569
+ >>> names = []
1570
+ >>> n = Net()
1571
+ >>> for m in n.cells_and_names():
1572
+ ... if m[0]:
1573
+ ... names.append(m[0])
1574
+ """
1575
+ t_cells = cells if cells else set()
1576
+ if self in t_cells:
1577
+ return
1578
+
1579
+ t_cells.add(self)
1580
+ yield name_prefix, self
1581
+
1582
+ for name, cell in self._cells.items():
1583
+ if cell:
1584
+ cells_name_prefix = name
1585
+ if name_prefix:
1586
+ cells_name_prefix = name_prefix + '.' + cells_name_prefix
1587
+ for ele in cell.cells_and_names(t_cells, cells_name_prefix):
1588
+ yield ele
1589
+
1590
+ def cells(self):
1591
+ """
1592
+ Returns an iterator over immediate cells.
1593
+
1594
+ Returns:
1595
+ Iteration, the immediate cells in the cell.
1596
+
1597
+ Examples:
1598
+ >>> import mindspore as ms
1599
+ >>> from mindspore import Tensor, nn
1600
+ ...
1601
+ >>> class Net(nn.Cell):
1602
+ ... def __init__(self):
1603
+ ... super(Net, self).__init__()
1604
+ ... self.dense = nn.Dense(2, 2)
1605
+ ...
1606
+ ... def construct(self, x):
1607
+ ... x = self.dense(x)
1608
+ ... return x
1609
+ >>> net = Net()
1610
+ >>> print(net.cells())
1611
+ odict_values([Dense<input_channels=2, output_channels=2, has_bias=True>])
1612
+ """
1613
+ return self.name_cells().values()
1614
+
1615
+ def _set_scope(self, name):
1616
+ """Sets the name on the first time."""
1617
+ if self._scope is None:
1618
+ self._scope = name
1619
+ elif self._scope == 'recompute_':
1620
+ self._scope = self._scope + name
1621
+
1622
+ def _children_scope_recursive(self, parent_prefix='Default'):
1623
+ """Generates the scope of each layer of the network recursively."""
1624
+ reserve_class_name_in_scope = context.get_context("reserve_class_name_in_scope")
1625
+
1626
+ for name, cell in self.name_cells().items():
1627
+ class_name = ("-" + cell.__class__.__name__) if reserve_class_name_in_scope else ""
1628
+ yield parent_prefix + "/" + name + class_name, cell
1629
+
1630
+ for name, cell in self.name_cells().items():
1631
+ class_name = ("-" + cell.__class__.__name__) if reserve_class_name_in_scope else ""
1632
+ for key, value in cell._children_scope_recursive(parent_prefix + "/" + name + class_name):
1633
+ yield key, value
1634
+
1635
+ def get_scope(self):
1636
+ """
1637
+ Returns the scope of a cell object in one network.
1638
+
1639
+ Returns:
1640
+ String, scope of the cell.
1641
+ """
1642
+ return self._scope
1643
+
1644
+ def generate_scope(self):
1645
+ """Generate the scope for each cell object in the network."""
1646
+ for name, cell in self._children_scope_recursive():
1647
+ cell._set_scope(name)
1648
+
1649
+ def name_cells(self):
1650
+ """
1651
+ Returns an iterator over all immediate cells in the network.
1652
+
1653
+ Include name of the cell and cell itself.
1654
+
1655
+ Returns:
1656
+ Dict, all the child cells and corresponding names in the cell.
1657
+
1658
+ Examples:
1659
+ >>> import mindspore as ms
1660
+ >>> from mindspore import Tensor, nn
1661
+ ...
1662
+ >>> class Net(nn.Cell):
1663
+ ... def __init__(self):
1664
+ ... super(Net, self).__init__()
1665
+ ... self.dense = nn.Dense(2, 2)
1666
+ ...
1667
+ ... def construct(self, x):
1668
+ ... x = self.dense(x)
1669
+ ... return x
1670
+ >>> net = Net()
1671
+ >>> print(net.name_cells())
1672
+ OrderedDict([('dense', Dense<input_channels=2, output_channels=2, has_bias=True>)])
1673
+ """
1674
+ value_set = set()
1675
+ cells = OrderedDict()
1676
+ for name, cell in self._cells.items():
1677
+ if cell is not None and cell not in value_set:
1678
+ value_set.add(cell)
1679
+ cells[name] = cell
1680
+ return cells
1681
+
1682
+ def _add_mixed_precision_flag(self, **flags):
1683
+ """Add mixed precision flag to current cell"""
1684
+ if "fp16" in flags and flags.get("fp16", False):
1685
+ Cell_.set_mixed_precision_type(self, MixedPrecisionType.FP16)
1686
+ if "fp32" in flags and flags.get("fp32", False):
1687
+ Cell_.set_mixed_precision_type(self, MixedPrecisionType.FP32)
1688
+ if "bf16" in flags and flags.get("bf16", False):
1689
+ Cell_.set_mixed_precision_type(self, MixedPrecisionType.BF16)
1690
+
1691
+ def apply(self, fn):
1692
+ """
1693
+ Applies fn recursively to every subcell (as returned by .cells()) as well as self.
1694
+ Typical use includes initializing the parameters of a model.
1695
+
1696
+ Args:
1697
+ fn (function): function to be applied to each subcell.
1698
+
1699
+ Returns:
1700
+ Cell, self.
1701
+
1702
+ Examples:
1703
+ >>> import mindspore.nn as nn
1704
+ >>> from mindspore.common.initializer import initializer, One
1705
+ >>> net = nn.SequentialCell(nn.Dense(2, 2), nn.Dense(2, 2))
1706
+ >>> def func(cell):
1707
+ ... if isinstance(cell, nn.Dense):
1708
+ ... cell.weight.set_data(initializer(One(), cell.weight.shape, cell.weight.dtype))
1709
+ >>> net.apply(func)
1710
+ SequentialCell<
1711
+ (0): Dense<input_channels=2, output_channels=2, has_bias=True>
1712
+ (1): Dense<input_channels=2, output_channels=2, has_bias=True>
1713
+ >
1714
+ >>> print(net[0].weight.asnumpy())
1715
+ [[1. 1.]
1716
+ [1. 1.]]
1717
+ """
1718
+ for cell in self.cells():
1719
+ cell.apply(fn)
1720
+ fn(self)
1721
+ return self
1722
+
1723
+ def add_flags(self, **flags):
1724
+ """
1725
+ Add customized attributes for cell.
1726
+
1727
+ This method is also called when the cell class is instantiated and the class parameter 'flags' is set to True.
1728
+
1729
+ Args:
1730
+ flags (dict): Network configuration information, currently it is used for the binding of network and
1731
+ dataset. Users can also customize network attributes by this parameter.
1732
+
1733
+ Examples:
1734
+ >>> import mindspore as ms
1735
+ >>> from mindspore import Tensor, nn
1736
+ ...
1737
+ >>> class Net(nn.Cell):
1738
+ ... def __init__(self):
1739
+ ... super(Net, self).__init__()
1740
+ ... self.relu = nn.ReLU()
1741
+ ...
1742
+ ... def construct(self, x):
1743
+ ... x = self.relu(x)
1744
+ ... return x
1745
+ >>> net = Net()
1746
+ >>> net.add_flags(sink_mode=True)
1747
+ >>> print(net.sink_mode)
1748
+ True
1749
+ """
1750
+ if not hasattr(self, "_func_graph_flags"):
1751
+ self._func_graph_flags = {}
1752
+ self._func_graph_flags.update({**flags})
1753
+ if context._get_mode() == context.PYNATIVE_MODE and self._func_graph_flags.get("output_no_recompute"):
1754
+ raise TypeError("Recompute is not supported in PyNative mode currently, you can use "
1755
+ "'context.set_context(mode=context.GRAPH_MODE)' or @jit to set graph mode.")
1756
+ self.__dict__.update({**flags})
1757
+ self._add_mixed_precision_flag(**flags)
1758
+ return self
1759
+
1760
+ def add_flags_recursive(self, **flags):
1761
+ """
1762
+ If a cell contains child cells, this method can recursively customize attributes of all cells.
1763
+
1764
+ Args:
1765
+ flags (dict): Network configuration information, currently it is used for the binding of network and
1766
+ dataset. Users can also customize network attributes by this parameter.
1767
+
1768
+ Examples:
1769
+ >>> import mindspore as ms
1770
+ >>> from mindspore import Tensor, nn
1771
+ ...
1772
+ >>> class Net(nn.Cell):
1773
+ ... def __init__(self):
1774
+ ... super(Net, self).__init__()
1775
+ ... self.relu = nn.ReLU()
1776
+ ...
1777
+ ... def construct(self, x):
1778
+ ... x = self.relu(x)
1779
+ ... return x
1780
+ >>> net = Net()
1781
+ >>> net.add_flags_recursive(sink_mode=True)
1782
+ >>> print(net.sink_mode)
1783
+ True
1784
+ """
1785
+ self.add_flags(**flags)
1786
+ for cell in self.cells():
1787
+ cell.add_flags_recursive(**flags)
1788
+ return self
1789
+
1790
+ def _add_init_args(self, **args):
1791
+ if hasattr(self, '_cell_init_args'):
1792
+ self._cell_init_args += str({**args})
1793
+
1794
+ def get_flags(self):
1795
+ """
1796
+ Get the self_defined attributes of the cell, which can be added by `add_flags` method.
1797
+
1798
+ Examples:
1799
+ >>> import mindspore as ms
1800
+ >>> from mindspore import Tensor, nn
1801
+ ...
1802
+ >>> class Net(nn.Cell):
1803
+ ... def __init__(self):
1804
+ ... super(Net, self).__init__()
1805
+ ... self.relu = nn.ReLU()
1806
+ ...
1807
+ ... def construct(self, x):
1808
+ ... x = self.relu(x)
1809
+ ... return x
1810
+ >>> net = Net()
1811
+ >>> net.add_flags(sink_mode=True)
1812
+ >>> print(net.get_flags())
1813
+ {'sink_mode':True}
1814
+ """
1815
+ if not hasattr(self, "_func_graph_flags"):
1816
+ self._func_graph_flags = {}
1817
+ return self._func_graph_flags
1818
+
1819
+ def to_float(self, dst_type):
1820
+ """
1821
+ Add cast on all inputs of cell and child cells to run with certain float type.
1822
+
1823
+ If `dst_type` is `mindspore.dtype.float16`, all the inputs of Cell, including input, Parameter and Tensor, will
1824
+ be cast to float16. Please refer to the usage in source code of :func:`mindspore.amp.build_train_network`.
1825
+
1826
+ Note:
1827
+ Multiple calls will overwrite.
1828
+
1829
+ Args:
1830
+ dst_type (:class:`mindspore.dtype`): Transfer cell to run with dst_type.
1831
+ dst_type can be `mstype.float16` , `mstype.float32` or `mstype.bfloat16`.
1832
+
1833
+ Returns:
1834
+ Cell, the cell itself.
1835
+
1836
+ Raises:
1837
+ ValueError: If dst_type is not `mstype.float32` , `mstype.float16` or `mstype.bfloat16`.
1838
+
1839
+ Supported Platforms:
1840
+ ``Ascend`` ``GPU`` ``CPU``
1841
+
1842
+ Examples:
1843
+ >>> import mindspore.nn as nn
1844
+ >>> from mindspore import dtype as mstype
1845
+ >>>
1846
+ >>> net = nn.Conv2d(120, 240, 4, has_bias=False, weight_init='normal')
1847
+ >>> net.to_float(mstype.float16)
1848
+ Conv2d<input_channels=120, output_channels=240, kernel_size=(4, 4), stride=(1, 1), pad_mode=same,
1849
+ padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=normal, bias_init=None, format=NCHW>
1850
+ """
1851
+ if dst_type not in (mstype.float16, mstype.float32, mstype.bfloat16):
1852
+ raise ValueError("For 'to_float', the argument 'dst_type' must be mstype.float32, mstype.float16 or "
1853
+ "mstype.bfloat16, but got type: {} and value: {}.".format(type(dst_type), dst_type))
1854
+ flags = {'fp16': dst_type == mstype.float16, 'fp32': dst_type == mstype.float32,
1855
+ 'bf16': dst_type == mstype.bfloat16}
1856
+ self._add_init_args(**flags)
1857
+ self.add_flags_recursive(**flags)
1858
+ return self
1859
+
1860
+ def set_boost(self, boost_type):
1861
+ """
1862
+ In order to improve the network performance, configure the network auto enable to
1863
+ accelerate the algorithm in the algorithm library.
1864
+
1865
+ If `boost_type` is not in the algorithm library, please view the algorithm in the algorithm library through
1866
+ `algorithm library <https://gitee.com/mindspore/mindspore/tree/master/mindspore/python/mindspore/boost>`_.
1867
+
1868
+ Note:
1869
+ Some acceleration algorithms may affect the accuracy of the network, please choose carefully.
1870
+
1871
+ Args:
1872
+ boost_type (str): accelerate algorithm.
1873
+
1874
+ Returns:
1875
+ Cell, the cell itself.
1876
+
1877
+ Raises:
1878
+ ValueError: If boost_type is not in the algorithm library.
1879
+ """
1880
+ if boost_type not in ("less_bn",):
1881
+ raise ValueError("For 'set_boost', the argument 'boost_type' must be 'less_bn', "
1882
+ "but got {}.".format(boost_type))
1883
+ flags = {"less_bn": boost_type == "less_bn"}
1884
+ self.add_flags_recursive(**flags)
1885
+ return self
1886
+
1887
+ def set_grad(self, requires_grad=True):
1888
+ """
1889
+ Sets the cell flag for gradient. In pynative mode, this parameter specifies whether the network requires
1890
+ gradients. If ``true`` , the backward network needed to compute the gradients will be generated when the forward
1891
+ network is executed.
1892
+
1893
+ Args:
1894
+ requires_grad (bool): Specifies if the net need to grad, if it is
1895
+ ``true`` , the cell will construct backward network in pynative mode. Default: ``True`` .
1896
+
1897
+ Returns:
1898
+ Cell, the cell itself.
1899
+ """
1900
+ self.requires_grad = requires_grad
1901
+ return self
1902
+
1903
+ def set_train(self, mode=True):
1904
+ """
1905
+ Sets the cell to training mode.
1906
+
1907
+ The cell itself and all children cells will be set to training mode. Layers that have different constructions
1908
+ for training and predicting, such as `BatchNorm`, will distinguish between the branches by this attribute. If
1909
+ set to true, the training branch will be executed, otherwise another branch.
1910
+
1911
+ Note:
1912
+ When execute function Model.train(), framework will call Cell.set_train(True).
1913
+ When execute function Model.eval(), framework will call Cell.set_train(False).
1914
+
1915
+ Args:
1916
+ mode (bool): Specifies whether the model is training. Default: ``True`` .
1917
+
1918
+ Returns:
1919
+ Cell, the cell itself.
1920
+
1921
+ Tutorial Examples:
1922
+ - `Model Training - Implementing Training and Evaluation
1923
+ <https://mindspore.cn/tutorials/en/master/beginner/train.html#training-and-evaluation>`_
1924
+ """
1925
+ if mode:
1926
+ self._phase = 'train'
1927
+ else:
1928
+ self._phase = 'predict'
1929
+ self.add_flags_recursive(training=mode)
1930
+ return self
1931
+
1932
+ def set_broadcast_flag(self, mode=True):
1933
+ """
1934
+ Set parameter broadcast mode for this cell.
1935
+
1936
+ Args:
1937
+ mode (bool): Specifies whether the mode is parameter broadcast. Default: ``True`` .
1938
+ """
1939
+ self.add_flags_recursive(broadcast_flag=mode)
1940
+ return self
1941
+
1942
+ def set_auto_parallel(self):
1943
+ """
1944
+ Set the cell to auto parallel mode.
1945
+
1946
+ Note:
1947
+ This interface is deprecated.
1948
+ """
1949
+ logger.warning("'set_auto_parallel' function is deprecated.")
1950
+
1951
+ def set_jit_config(self, jit_config):
1952
+ """
1953
+ Set jit config for cell.
1954
+
1955
+ Args:
1956
+ jit_config (JitConfig): Jit config for compile. For details, please refer to :class:`mindspore.JitConfig`.
1957
+
1958
+ Examples:
1959
+ >>> import mindspore as ms
1960
+ >>> from mindspore import Tensor, nn
1961
+ ...
1962
+ >>> class Net(nn.Cell):
1963
+ ... def __init__(self):
1964
+ ... super(Net, self).__init__()
1965
+ ... self.relu = nn.ReLU()
1966
+ ...
1967
+ ... def construct(self, x):
1968
+ ... x = self.relu(x)
1969
+ ... return x
1970
+ >>> net = Net()
1971
+ >>> jitconfig = ms.JitConfig()
1972
+ >>> net.set_jit_config(jitconfig)
1973
+ """
1974
+ if self._jit_config_dict:
1975
+ logger.warning("For Cell, jit config can only be set once, ignore this setting.")
1976
+ else:
1977
+ self._jit_config_dict = jit_config.jit_config_dict
1978
+
1979
+ def flatten_weights(self, fusion_size=0):
1980
+ """
1981
+ Reset data for weight parameters so that they are using contiguous memory chunks grouped by data type.
1982
+
1983
+ Note:
1984
+ By default, parameters with same data type will using a single contiguous memory chunk. but for
1985
+ some models with huge number of parameters, splitting a large memory chunk into several smaller
1986
+ memory chunks has the potential for performance gains, if this is the case, we can use 'fusion_size'
1987
+ to limit the maximum memory chunk size.
1988
+
1989
+ Args:
1990
+ fusion_size (int): Maximum memory chunk size in bytes, ``0`` for unlimited. Default: ``0`` .
1991
+ """
1992
+ if fusion_size < 0:
1993
+ raise ValueError(f"Negative 'fusion_size' {fusion_size} is invalid.")
1994
+ Tensor._flatten_tensors(self.trainable_params(), fusion_size) # pylint: disable=W0212
1995
+
1996
+ def register_forward_pre_hook(self, hook_fn):
1997
+ """
1998
+ Register forward pre hook function for Cell object.
1999
+
2000
+ Note:
2001
+ - The `register_forward_pre_hook(hook_fn)` does not work in graph mode or functions decorated with 'jit'.
2002
+ - 'hook_fn' must be defined as the following code.
2003
+ `cell` is the object of registered Cell. `inputs` is the forward
2004
+ input objects passed to the Cell. The 'hook_fn' can modify the forward input objects by returning new
2005
+ forward input objects.
2006
+ - It should have the following signature:
2007
+ hook_fn(cell, inputs) -> new input objects or none.
2008
+ - In order to prevent running failed when switching to graph mode, it is not recommended to write it in the
2009
+ `construct` function of Cell object. In the pynative mode, if the `register_forward_pre_hook` function is
2010
+ called in the `construct` function of the Cell object, a hook function will be added at each run time of
2011
+ Cell object.
2012
+
2013
+ Args:
2014
+ hook_fn (function): Python function. Forward pre hook function.
2015
+
2016
+ Returns:
2017
+ A handle corresponding to the `hook_fn` . The handle can be used to remove the added `hook_fn` by calling
2018
+ `handle.remove()` .
2019
+
2020
+ Raises:
2021
+ TypeError: If the `hook_fn` is not a function of python.
2022
+
2023
+ Supported Platforms:
2024
+ ``Ascend`` ``GPU`` ``CPU``
2025
+
2026
+ Examples:
2027
+ >>> import numpy as np
2028
+ >>> import mindspore as ms
2029
+ >>> from mindspore import Tensor, nn, ops
2030
+ >>> ms.set_context(mode=ms.PYNATIVE_MODE)
2031
+ >>> def forward_pre_hook_fn(cell, inputs):
2032
+ ... print("forward inputs: ", inputs)
2033
+ ...
2034
+ >>> class Net(nn.Cell):
2035
+ ... def __init__(self):
2036
+ ... super(Net, self).__init__()
2037
+ ... self.mul = nn.MatMul()
2038
+ ... self.handle = self.mul.register_forward_pre_hook(forward_pre_hook_fn)
2039
+ ...
2040
+ ... def construct(self, x, y):
2041
+ ... x = x + x
2042
+ ... x = self.mul(x, y)
2043
+ ... return x
2044
+ >>> grad = ops.GradOperation(get_all=True)
2045
+ >>> net = Net()
2046
+ >>> output = grad(net)(Tensor(np.ones([1]).astype(np.float32)), Tensor(np.ones([1]).astype(np.float32)))
2047
+ forward inputs: (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]), Tensor(shape=[1],
2048
+ dtype=Float32, value= [ 1.00000000e+00]))
2049
+ >>> print(output)
2050
+ (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]), Tensor(shape=[1], dtype=Float32,
2051
+ value= [ 2.00000000e+00]))
2052
+ """
2053
+ if not check_hook_fn("register_forward_pre_hook", hook_fn):
2054
+ return HookHandle()
2055
+ self._enable_forward_pre_hook = True
2056
+ _pynative_executor.set_hook_changed(self)
2057
+ if not hasattr(self, '_forward_pre_hook_key'):
2058
+ self._forward_pre_hook_key = -1
2059
+ self._forward_pre_hook_key += 1
2060
+ self._forward_pre_hook[self._forward_pre_hook_key] = hook_fn
2061
+ handle = HookHandle(self, self._forward_pre_hook_key, "_forward_pre_hook")
2062
+ return handle
2063
+
2064
+ def _run_forward_pre_hook(self, inputs):
2065
+ """
2066
+ Running forward pre hook function registered on Cell object.
2067
+
2068
+ Args:
2069
+ inputs: The input objects of cell object.
2070
+
2071
+ Returns:
2072
+ - **outputs** - New input objects or none.
2073
+
2074
+ Supported Platforms:
2075
+ ``Ascend`` ``GPU`` ``CPU``
2076
+ """
2077
+ for fn in self._forward_pre_hook.values():
2078
+ ret = fn(self, inputs)
2079
+ if ret is not None:
2080
+ if not isinstance(ret, tuple):
2081
+ inputs = (ret,)
2082
+ else:
2083
+ inputs = ret
2084
+ return inputs
2085
+
2086
+ def register_forward_hook(self, hook_fn):
2087
+ """
2088
+ Set the Cell forward hook function.
2089
+
2090
+ Note:
2091
+ - The `register_forward_hook(hook_fn)` does not work in graph mode or functions decorated with 'jit'.
2092
+ - 'hook_fn' must be defined as the following code.
2093
+ `cell` is the object of registered Cell. `inputs` is the forward
2094
+ input objects passed to the Cell. `output` is the forward output object of the Cell. The 'hook_fn' can
2095
+ modify the forward output object by returning new forward output object.
2096
+ - It should have the following signature:
2097
+ hook_fn(cell, inputs, output) -> new output object or none.
2098
+ - In order to prevent running failed when switching to graph mode, it is not recommended to write it in the
2099
+ `construct` function of Cell object. In the pynative mode, if the `register_forward_hook` function is
2100
+ called in the `construct` function of the Cell object, a hook function will be added at each run time of
2101
+ Cell object.
2102
+
2103
+ Args:
2104
+ hook_fn (function): Python function. Forward hook function.
2105
+
2106
+ Returns:
2107
+ A handle corresponding to the `hook_fn` . The handle can be used to remove the added `hook_fn` by calling
2108
+ `handle.remove()` .
2109
+
2110
+ Raises:
2111
+ TypeError: If the `hook_fn` is not a function of python.
2112
+
2113
+ Supported Platforms:
2114
+ ``Ascend`` ``GPU`` ``CPU``
2115
+
2116
+ Examples:
2117
+ >>> import numpy as np
2118
+ >>> import mindspore as ms
2119
+ >>> from mindspore import Tensor, nn, ops
2120
+ >>> ms.set_context(mode=ms.PYNATIVE_MODE)
2121
+ >>> def forward_hook_fn(cell, inputs, output):
2122
+ ... print("forward inputs: ", inputs)
2123
+ ... print("forward output: ", output)
2124
+ ...
2125
+ >>> class Net(nn.Cell):
2126
+ ... def __init__(self):
2127
+ ... super(Net, self).__init__()
2128
+ ... self.mul = nn.MatMul()
2129
+ ... self.handle = self.mul.register_forward_hook(forward_hook_fn)
2130
+ ...
2131
+ ... def construct(self, x, y):
2132
+ ... x = x + x
2133
+ ... x = self.mul(x, y)
2134
+ ... return x
2135
+ >>> grad = ops.GradOperation(get_all=True)
2136
+ >>> net = Net()
2137
+ >>> output = grad(net)(Tensor(np.ones([1]).astype(np.float32)), Tensor(np.ones([1]).astype(np.float32)))
2138
+ forward inputs: (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]), Tensor(shape=[1],
2139
+ dtype=Float32, value= [ 1.00000000e+00]))
2140
+ forward output: 2.0
2141
+ >>> print(output)
2142
+ (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]), Tensor(shape=[1], dtype=Float32,
2143
+ value= [ 2.00000000e+00]))
2144
+ """
2145
+ if not check_hook_fn("register_forward_hook", hook_fn):
2146
+ return HookHandle()
2147
+ self._enable_forward_hook = True
2148
+ _pynative_executor.set_hook_changed(self)
2149
+ if not hasattr(self, '_forward_hook_key'):
2150
+ self._forward_hook_key = -1
2151
+ self._forward_hook_key += 1
2152
+ self._forward_hook[self._forward_hook_key] = hook_fn
2153
+ handle = HookHandle(self, self._forward_hook_key, "_forward_hook")
2154
+ return handle
2155
+
2156
+ def _run_forward_hook(self, inputs, output):
2157
+ """
2158
+ Running forward hook function registered on Cell object.
2159
+
2160
+ Args:
2161
+ inputs: The input objects of Cell object.
2162
+ output: The output object of Cell object.
2163
+
2164
+ Returns:
2165
+ - **output** - New output object or none.
2166
+
2167
+ Supported Platforms:
2168
+ ``Ascend`` ``GPU`` ``CPU``
2169
+ """
2170
+ for fn in self._forward_hook.values():
2171
+ ret = fn(self, inputs, output)
2172
+ if ret is not None:
2173
+ output = ret
2174
+ return output
2175
+
2176
+ def register_backward_hook(self, hook_fn):
2177
+ """
2178
+ Register the backward hook function.
2179
+
2180
+ Note:
2181
+ - The `register_backward_hook(hook_fn)` does not work in graph mode or functions decorated with 'jit'.
2182
+ - The 'hook_fn' must be defined as the following code.
2183
+ `cell_id` is the information of registered Cell object, including name and ID. `grad_input` is the
2184
+ gradient passed to the Cell. `grad_output` is the gradient computed and passed to the next Cell or
2185
+ primitive, which may be modified by returning a new output gradient.
2186
+ - The 'hook_fn' should have the following signature:
2187
+ hook_fn(cell_id, grad_input, grad_output) -> New output gradient or none.
2188
+ - The 'hook_fn' is executed in the python environment. In order to prevent running failed when switching to
2189
+ graph mode, it is not recommended to write it in the `construct` function of Cell object. In the pynative
2190
+ mode, if the `register_backward_hook` function is called in the `construct` function of the Cell object,
2191
+ a hook function will be added at each run time of Cell object.
2192
+
2193
+ Args:
2194
+ hook_fn (function): Python function. Backward hook function.
2195
+
2196
+ Returns:
2197
+ A handle corresponding to the `hook_fn` . The handle can be used to remove the added `hook_fn` by calling
2198
+ `handle.remove()` .
2199
+
2200
+ Raises:
2201
+ TypeError: If the `hook_fn` is not a function of python.
2202
+
2203
+ Supported Platforms:
2204
+ ``Ascend`` ``GPU`` ``CPU``
2205
+
2206
+ Examples:
2207
+ >>> import numpy as np
2208
+ >>> import mindspore as ms
2209
+ >>> from mindspore import Tensor, nn, ops
2210
+ >>> ms.set_context(mode=ms.PYNATIVE_MODE)
2211
+ >>> def backward_hook_fn(cell_id, grad_input, grad_output):
2212
+ ... print("backward input: ", grad_input)
2213
+ ... print("backward output: ", grad_output)
2214
+ ...
2215
+ >>> class Net(nn.Cell):
2216
+ ... def __init__(self):
2217
+ ... super(Net, self).__init__()
2218
+ ... self.relu = nn.ReLU()
2219
+ ... self.handle = self.relu.register_backward_hook(backward_hook_fn)
2220
+ ...
2221
+ ... def construct(self, x):
2222
+ ... x = x + x
2223
+ ... x = self.relu(x)
2224
+ ... return x
2225
+ >>> grad = ops.GradOperation(get_all=True)
2226
+ >>> net = Net()
2227
+ >>> output = grad(net)(Tensor(np.ones([1]).astype(np.float32)))
2228
+ backward input: (Tensor(shape=[1], dtype=Float32, value= [ 1.00000000e+00]),)
2229
+ backward output: (Tensor(shape=[1], dtype=Float32, value= [ 1.00000000e+00]),)
2230
+ >>> print(output)
2231
+ (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]),)
2232
+ """
2233
+ if not check_hook_fn("register_backward_hook", hook_fn):
2234
+ return HookHandle()
2235
+ if self._cell_backward_hook is None:
2236
+ self._enable_backward_hook = True
2237
+ self._cell_backward_hook = inner.CellBackwardHook(self.cls_name + "(" + str(id(self)) + ")")
2238
+ backward_hook_key = self._cell_backward_hook.register_backward_hook(hook_fn)
2239
+ handle = HookHandle(self, backward_hook_key, "_cell_backward_hook")
2240
+ else:
2241
+ backward_hook_key = self._cell_backward_hook.register_backward_hook(hook_fn)
2242
+ handle = HookHandle(self, backward_hook_key, "_cell_backward_hook")
2243
+ return handle
2244
+
2245
+ def _backward_hook_construct(self, *inputs, **kwargs):
2246
+ """
2247
+ Backward hook construct method to replace original construct method.
2248
+
2249
+ Args:
2250
+ inputs: The input objects of Cell object.
2251
+ kwargs (dict): Dictionary of variable keyword parameters.
2252
+
2253
+ Returns:
2254
+ - **outputs** - The output objects of Cell object.
2255
+
2256
+ Supported Platforms:
2257
+ ``Ascend`` ``GPU`` ``CPU``
2258
+ """
2259
+ if len(inputs) > 1:
2260
+ inputs = self._cell_backward_hook(inputs)
2261
+ else:
2262
+ inputs = self._cell_backward_hook(*inputs)
2263
+ inputs = (inputs,)
2264
+ if self.recompute_cell is not None:
2265
+ if isinstance(inputs, tuple):
2266
+ outputs = self.recompute_cell(*inputs, **kwargs)
2267
+ else:
2268
+ outputs = self.recompute_cell(inputs, **kwargs)
2269
+ else:
2270
+ if isinstance(inputs, tuple):
2271
+ outputs = self.construct(*inputs, **kwargs)
2272
+ else:
2273
+ outputs = self.construct(inputs, **kwargs)
2274
+ outputs = self._cell_backward_hook(outputs)
2275
+ return outputs
2276
+
2277
+ def set_param_ps(self, recurse=True, init_in_server=False):
2278
+ """
2279
+ Set whether the trainable parameters are updated by parameter server and whether the
2280
+ trainable parameters are initialized on server.
2281
+
2282
+ Note:
2283
+ It only works when a running task is in the parameter server mode.
2284
+ It is only supported in graph mode.
2285
+
2286
+ Args:
2287
+ recurse (bool): Whether sets the trainable parameters of subcells. Default: ``True`` .
2288
+ init_in_server (bool): Whether trainable parameters updated by parameter server are
2289
+ initialized on server. Default: ``False`` .
2290
+ """
2291
+ params = self.trainable_params(recurse)
2292
+ for param in params:
2293
+ param.set_param_ps(init_in_server)
2294
+
2295
+ @deprecated("1.8", "set_param_fl")
2296
+ def set_param_fl(self, push_to_server=False, pull_from_server=False, requires_aggr=True):
2297
+ params = self.parameters_and_names()
2298
+ for param in params:
2299
+ param[1].set_param_fl(push_to_server, pull_from_server, requires_aggr)
2300
+
2301
+ def set_comm_fusion(self, fusion_type, recurse=True):
2302
+ """
2303
+ Set `comm_fusion` for all the parameters in this cell. Please refer to the description of
2304
+ :class:`mindspore.Parameter.comm_fusion`.
2305
+
2306
+ Note:
2307
+ The value of attribute will be overwritten when the function is called multiply.
2308
+
2309
+ Args:
2310
+ fusion_type (int): The value of `comm_fusion`.
2311
+ recurse (bool): Whether sets the trainable parameters of subcells. Default: ``True`` .
2312
+ """
2313
+ Validator.check_non_negative_int(fusion_type)
2314
+ for param in self.trainable_params(recurse):
2315
+ param.comm_fusion = fusion_type
2316
+ return self
2317
+
2318
+ def _set_recompute_scope(self, mode):
2319
+ prefix = 'recompute_'
2320
+ if mode:
2321
+ if self._scope is None:
2322
+ self._scope = prefix
2323
+ elif not self._scope.startswith(prefix):
2324
+ self._scope = prefix + self._scope
2325
+ elif self._scope is not None and self._scope.startswith(prefix):
2326
+ self._scope = self._scope[len(prefix):]
2327
+
2328
+ def _mp_comm_recompute(self, mp_comm_recompute=True):
2329
+ """
2330
+ Set the model parallel communication in cell recomputed.
2331
+ """
2332
+ for _, value in self._primitives.items():
2333
+ if value:
2334
+ value.add_prim_attr("recompute_comm_op", mp_comm_recompute)
2335
+ for cell in self.cells():
2336
+ cell._mp_comm_recompute(mp_comm_recompute)
2337
+
2338
+ def _parallel_optimizer_comm_recompute(self, parallel_optimizer_comm_recompute=False):
2339
+ """
2340
+ Set the parallel optimizer communication in cell recomputed.
2341
+ """
2342
+ for param in self.trainable_params():
2343
+ param.parallel_optimizer_comm_recompute = parallel_optimizer_comm_recompute
2344
+
2345
+ def _recompute_slice_activation(self, slice_activation=False):
2346
+ """
2347
+ Slice the cell output which would remains in memory.
2348
+ """
2349
+ for _, value in self._primitives.items():
2350
+ if value:
2351
+ value.add_prim_attr("slice_activation", slice_activation)
2352
+ for cell in self.cells():
2353
+ cell._recompute_slice_activation(slice_activation)
2354
+
2355
+ def _recompute(self, mode=True, output_recompute=False):
2356
+ """
2357
+ Set the cell recomputed.
2358
+ """
2359
+ Validator.check_bool(mode)
2360
+ Validator.check_bool(output_recompute)
2361
+ if not self._has_config_recompute:
2362
+ self._has_config_recompute = True
2363
+ else:
2364
+ raise RuntimeError("The recompute interface can be configured only once."
2365
+ " When the parent cell is configured, the child cell should not be configured")
2366
+ self._set_recompute_scope(mode)
2367
+ if mode and not output_recompute:
2368
+ self.add_flags(output_no_recompute=True)
2369
+ for cell in self.cells():
2370
+ cell._recompute(mode, True)
2371
+
2372
+ @args_type_check(mp_comm_recompute=bool, parallel_optimizer_comm_recompute=bool)
2373
+ def recompute(self, **kwargs):
2374
+ """
2375
+ Set the cell recomputed. All the primitive in the cell except the outputs will be set recomputed.
2376
+ If a primitive set recomputed feeds into some backward nodes for computing gradient, rather than
2377
+ storing the intermediate activation computed in forward pass, we will recompute it in backward pass.
2378
+
2379
+ Note:
2380
+
2381
+ - If the computation involves something like randomization or global variable, the equivalence
2382
+ is not guaranteed currently.
2383
+ - If the recompute api of a primitive in this cell is also called, the recompute mode of this
2384
+ primitive is subject to the recompute api of the primitive.
2385
+ - The interface can be configured only once.
2386
+ Therefore, when the parent cell is configured, the child cell should not be configured.
2387
+ - The outputs of cell are excluded from recomputation by default, which is based on our configuration
2388
+ experience to reduce memory footprint. If a cell has only one primitive and the primitive is wanted
2389
+ to be set recomputed, use the recompute api of the primtive.
2390
+ - When the memory remains after applying the recomputation, configuring 'mp_comm_recompute=False'
2391
+ to improve performance if necessary.
2392
+ - When the memory still not enough after applying the recompute, configuring
2393
+ 'parallel_optimizer_comm_recompute=True' to save more memory if necessary.
2394
+ Cells in the same fusion group should have the same parallel_optimizer_comm_recompute configures.
2395
+
2396
+ Args:
2397
+ mp_comm_recompute (bool): Specifies whether the model parallel communication operators
2398
+ in the cell are recomputed in auto parallel or semi auto parallel mode. Default: ``True`` .
2399
+ parallel_optimizer_comm_recompute (bool): Specifies whether the communication operator allgathers
2400
+ introduced by optimizer shard are recomputed in auto parallel or semi auto parallel mode.
2401
+ Default: ``False`` .
2402
+ """
2403
+ if context.get_context("mode") == context.PYNATIVE_MODE:
2404
+ self.recompute_cell = recompute_registry.get()(self.construct)
2405
+ return
2406
+ self._recompute()
2407
+ if 'mp_comm_recompute' in kwargs.keys():
2408
+ self._mp_comm_recompute(kwargs.get('mp_comm_recompute', False))
2409
+ if 'parallel_optimizer_comm_recompute' in kwargs.keys():
2410
+ if (kwargs.get('parallel_optimizer_comm_recompute', False) and
2411
+ context.get_auto_parallel_context("pipeline_stages") > 1):
2412
+ logger.warning("Currently, the communication operator allgathers introduced by optimizer shard "
2413
+ "are not support recomputation in pipeline parallel.")
2414
+ elif context.get_auto_parallel_context("pipeline_stages") == 1:
2415
+ self._parallel_optimizer_comm_recompute(kwargs.get('parallel_optimizer_comm_recompute', False))
2416
+ if 'recompute_slice_activation' in kwargs:
2417
+ self._recompute_slice_activation(kwargs.get('recompute_slice_activation', False))
2418
+
2419
+ for key, _ in kwargs.items():
2420
+ if key not in ('mp_comm_recompute', 'parallel_optimizer_comm_recompute', 'recompute_slice_activation'):
2421
+ raise ValueError("For 'recompute', keyword '%s' is not recognized! "
2422
+ "the key kwargs must be 'mp_comm_recompute', "
2423
+ "'parallel_optimizer_comm_recompute', 'recompute_slice_activation'" % key)
2424
+
2425
+ @deprecated("2.3", "infer_param_pipeline_stage")
2426
+ def infer_param_pipeline_stage(self):
2427
+ """
2428
+ Infer pipeline stages of all parameters in the cell.
2429
+
2430
+ Note:
2431
+ - The interface is deprecated from version 2.3 and will be removed in a future version.
2432
+
2433
+ Returns:
2434
+ The params belong to current stage in pipeline parallel.
2435
+
2436
+ Raises:
2437
+ RuntimeError: If there is a parameter does not belong to any stage.
2438
+ """
2439
+ from mindspore.parallel._utils import _get_global_rank, _get_device_num
2440
+ logger.warning(f"This interface may be deleted in the future.")
2441
+ stage_num = context.get_auto_parallel_context("pipeline_stages")
2442
+ device_num = _get_device_num()
2443
+ rank_id = _get_global_rank()
2444
+ per_stage_devices = device_num // stage_num
2445
+ current_stage = rank_id // per_stage_devices
2446
+ params = []
2447
+ for param in self.trainable_params():
2448
+ if not param._pipeline_stage_list: # pylint: disable=W0212
2449
+ raise RuntimeError("For 'infer_param_pipeline_stage', the parameter {} does not belong to any stage, "
2450
+ "please check whether the cell where the param locates has been set "
2451
+ "'pipeline_stage'. Otherwise, the parameter should use 'add_pipeline_stage' "
2452
+ "to add its stage information".format(param.name))
2453
+ if current_stage in param._pipeline_stage_list:
2454
+ params.append(param)
2455
+ return params
2456
+
2457
+ def place(self, role, rank_id):
2458
+ """
2459
+ Set the label for all operators in this cell.
2460
+ This label tells MindSpore compiler on which process this cell should be launched.
2461
+ And each process's identical label consists of input `role` and `rank_id`.
2462
+ So by setting different cells with different labels, which will be launched on different processes,
2463
+ users can launch a distributed training or predicting job.
2464
+
2465
+ Note:
2466
+ - This method is effective only after
2467
+ `mindspore.communication.init()` is called for dynamic cluster building.
2468
+
2469
+ Args:
2470
+ role (str): The role of the process on which this cell will be launched.
2471
+ Only 'MS_WORKER' is supported for now.
2472
+ rank_id (int): The rank id of the process on which this cell will be launched.
2473
+ The rank is unique in processes with the same role.
2474
+
2475
+ Examples:
2476
+ >>> from mindspore import context
2477
+ >>> import mindspore.nn as nn
2478
+ >>> context.set_context(mode=context.GRAPH_MODE)
2479
+ >>> fc = nn.Dense(2, 3)
2480
+ >>> fc.place('MS_WORKER', 0)
2481
+ """
2482
+ all_ops = self._get_prims_recursively()
2483
+ for op in all_ops:
2484
+ op.place(role, rank_id)
2485
+
2486
+ def _mixed_precision_cast(self, inputs):
2487
+ mixed_type = self.get_mixed_precision_type()
2488
+ if mixed_type == MixedPrecisionType.NOTSET:
2489
+ return inputs
2490
+ if mixed_type == MixedPrecisionType.FP16:
2491
+ cast_type = mstype.float16
2492
+ elif mixed_type == MixedPrecisionType.BF16:
2493
+ cast_type = mstype.bfloat16
2494
+ else:
2495
+ cast_type = mstype.float32
2496
+ cast_inputs = self._cast_mixed_precision_inputs(inputs, cast_type)
2497
+ return cast_inputs
2498
+
2499
+ def _get_attr_from_cell(self, network):
2500
+ if not isinstance(network, Cell):
2501
+ return
2502
+ if hasattr(network, "jit_config_dict"):
2503
+ self._jit_config_dict = network.jit_config_dict
2504
+ if hasattr(network, "_amp_level"):
2505
+ self._amp_level = getattr(network, "_amp_level")
2506
+
2507
+
2508
+ class GraphCell(Cell):
2509
+ """
2510
+ Base class for running the graph loaded from a MindIR.
2511
+
2512
+ This feature is still under development. Currently `GraphCell` do not support modifying the structure of the
2513
+ diagram, and can only use data that shape and type are the same as the input when exporting the MindIR.
2514
+
2515
+ Args:
2516
+ graph (FuncGraph): A compiled graph loaded from MindIR.
2517
+ params_init (dict): Parameters need to be inited in the graph.
2518
+ The key is the parameter name whose type is str, and the value is a Tensor or Parameter.
2519
+ If the parameter exists in the graph according to the name, update it's value.
2520
+ If the parameter does not exist, ignore it. Default: ``None`` .
2521
+ obf_random_seed (Union[int, None]): The random seed used for dynamic obfuscation. "dynamic obfuscation" is
2522
+ used for model protection, which can refer to :func:`mindspore.obfuscate_model`. If the input `graph` is
2523
+ a func_graph loaded from a mindir file obfuscated with `obf_random_seed` , then `obf_random_seed` should be
2524
+ provided. `obf_random_seed` should be in (0, 9223372036854775807]. default: ``None`` .
2525
+
2526
+ Raises:
2527
+ TypeError: If the `graph` is not a FuncGraph.
2528
+ TypeError: If the `params_init` is not a dict.
2529
+ TypeError: If the key of the `params_init` is not a str.
2530
+ TypeError: If the value of the `params_init` is neither a Tensor nor a Parameter.
2531
+
2532
+ Supported Platforms:
2533
+ ``Ascend`` ``GPU`` ``CPU``
2534
+
2535
+ Examples:
2536
+ >>> import numpy as np
2537
+ >>> import mindspore as ms
2538
+ >>> import mindspore.nn as nn
2539
+ >>> from mindspore import Tensor
2540
+ >>> from mindspore import context
2541
+ >>> context.set_context(mode=context.GRAPH_MODE)
2542
+ >>> net = nn.Conv2d(1, 1, kernel_size=3, weight_init="ones")
2543
+ >>> input = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
2544
+ >>> ms.export(net, input, file_name="net", file_format="MINDIR")
2545
+ >>> graph = ms.load("net.mindir")
2546
+ >>> net = nn.GraphCell(graph)
2547
+ >>> output = net(input)
2548
+ >>> print(output)
2549
+ [[[[4. 6. 4.]
2550
+ [6. 9. 6.]
2551
+ [4. 6. 4.]]]]
2552
+ """
2553
+
2554
+ def __init__(self, graph, params_init=None, obf_random_seed=None):
2555
+ super(GraphCell, self).__init__(auto_prefix=True)
2556
+ if not isinstance(graph, FuncGraph):
2557
+ raise TypeError(f"For 'GraphCell', the argument 'graph' must be a FuncGraph loaded from MindIR, "
2558
+ f"but got type {type(graph)}.")
2559
+ self.graph = graph
2560
+ self.obf_random_seed = obf_random_seed
2561
+ if obf_random_seed is not None:
2562
+ if not isinstance(obf_random_seed, int):
2563
+ raise TypeError("'obf_random_seed' must be int, but got {}.".format(type(obf_random_seed)))
2564
+ int_64_max = 9223372036854775807
2565
+ if obf_random_seed <= 0 or obf_random_seed > int_64_max:
2566
+ raise ValueError(
2567
+ "'obf_random_seed' must be larger than 0, and less or equal than int64 ({}),"
2568
+ "but got {}.".format(int_64_max, obf_random_seed))
2569
+ self._branch_control_input = _generate_branch_control_input(self.obf_random_seed)
2570
+ params_init = {} if params_init is None else params_init
2571
+ if not isinstance(params_init, dict):
2572
+ raise TypeError(f"For 'GraphCell', the argument 'params_init' must be a dict, but got {type(params_init)}.")
2573
+ for name, value in params_init.items():
2574
+ if not isinstance(name, str) or not isinstance(value, Tensor):
2575
+ raise TypeError("For 'GraphCell', the key of the 'params_init' must be str, "
2576
+ "and the value must be Tensor or Parameter, "
2577
+ f"but got the key type: {type(name)}, and the value type: {type(value)}")
2578
+
2579
+ params_dict = update_func_graph_hyper_params(self.graph, params_init)
2580
+ for name, param in params_dict.items():
2581
+ self._params[name] = param
2582
+ _cell_graph_executor.inc_graph_cell_count()
2583
+
2584
+ def construct(self, *inputs):
2585
+ return self.graph(*inputs)
2586
+
2587
+ def __call__(self, *args, **kwargs):
2588
+ self.phase = "graph_load_from_mindir"
2589
+ self._add_attr("graph_load_from_mindir", self.graph)
2590
+ if not self.obf_random_seed:
2591
+ return self.compile_and_run(*args, **kwargs)
2592
+ append_input = Tensor((numpy.ones((1,)) * self._branch_control_input).astype(numpy.int32))
2593
+ return self.compile_and_run(*args, append_input, **kwargs)
2594
+
2595
+
2596
+ def _check_param_list_tuple(value):
2597
+ """
2598
+ Check the type of input in list or tuple is Parameter.
2599
+ :param value: list or tuple.
2600
+ :return: The types of all inputs are parameter.
2601
+ """
2602
+ for item in value:
2603
+ if not isinstance(item, Parameter):
2604
+ return False
2605
+ return True