mindspore 2.3.0__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (1400) hide show
  1. mindspore/.commit_id +1 -0
  2. mindspore/ConcurrencyCheck.dll +0 -0
  3. mindspore/CppBuildInsights.dll +0 -0
  4. mindspore/CppCoreCheck.dll +0 -0
  5. mindspore/EnumIndex.dll +0 -0
  6. mindspore/EspXEngine.dll +0 -0
  7. mindspore/HResultCheck.dll +0 -0
  8. mindspore/KernelTraceControl.dll +0 -0
  9. mindspore/LocalESPC.dll +0 -0
  10. mindspore/Microsoft.Diagnostics.Tracing.EventSource.dll +0 -0
  11. mindspore/Microsoft.VisualStudio.RemoteControl.dll +0 -0
  12. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  13. mindspore/Microsoft.VisualStudio.Utilities.Internal.dll +0 -0
  14. mindspore/Newtonsoft.Json.dll +0 -0
  15. mindspore/System.Runtime.CompilerServices.Unsafe.dll +0 -0
  16. mindspore/VariantClear.dll +0 -0
  17. mindspore/__init__.py +51 -0
  18. mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
  19. mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
  20. mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
  21. mindspore/_check_jit_forbidden_api.py +106 -0
  22. mindspore/_checkparam.py +1378 -0
  23. mindspore/_extends/__init__.py +23 -0
  24. mindspore/_extends/builtin_operations.py +224 -0
  25. mindspore/_extends/graph_kernel/__init__.py +17 -0
  26. mindspore/_extends/graph_kernel/model/__init__.py +19 -0
  27. mindspore/_extends/graph_kernel/model/graph_parallel.py +311 -0
  28. mindspore/_extends/graph_kernel/model/graph_split.py +1348 -0
  29. mindspore/_extends/graph_kernel/model/model.py +553 -0
  30. mindspore/_extends/graph_kernel/model/model_builder.py +216 -0
  31. mindspore/_extends/graph_kernel/parallel_estimate.py +60 -0
  32. mindspore/_extends/graph_kernel/splitter.py +140 -0
  33. mindspore/_extends/graph_kernel/utils.py +28 -0
  34. mindspore/_extends/parallel_compile/__init__.py +19 -0
  35. mindspore/_extends/parallel_compile/akg_compiler/__init__.py +19 -0
  36. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +269 -0
  37. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +529 -0
  38. mindspore/_extends/parallel_compile/akg_compiler/compiler.py +56 -0
  39. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +96 -0
  40. mindspore/_extends/parallel_compile/akg_compiler/get_file_path.py +36 -0
  41. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +556 -0
  42. mindspore/_extends/parallel_compile/akg_compiler/util.py +159 -0
  43. mindspore/_extends/parse/__init__.py +49 -0
  44. mindspore/_extends/parse/compile_config.py +258 -0
  45. mindspore/_extends/parse/namespace.py +136 -0
  46. mindspore/_extends/parse/parser.py +1446 -0
  47. mindspore/_extends/parse/resources.py +213 -0
  48. mindspore/_extends/parse/standard_method.py +4437 -0
  49. mindspore/_extends/parse/trope.py +97 -0
  50. mindspore/_extends/pijit/__init__.py +23 -0
  51. mindspore/_extends/pijit/pijit_func_white_list.py +343 -0
  52. mindspore/_extends/remote/__init__.py +19 -0
  53. mindspore/_extends/remote/kernel_build_server.py +199 -0
  54. mindspore/_extends/remote/kernel_build_server_akg.py +55 -0
  55. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  56. mindspore/_extends/remote/kernel_build_server_ascend.py +75 -0
  57. mindspore/_extends/utils.py +68 -0
  58. mindspore/_install_custom.py +43 -0
  59. mindspore/_profiler.py +30 -0
  60. mindspore/amp.py +419 -0
  61. mindspore/atlprov.dll +0 -0
  62. mindspore/avcodec-59.dll +0 -0
  63. mindspore/avdevice-59.dll +0 -0
  64. mindspore/avfilter-8.dll +0 -0
  65. mindspore/avformat-59.dll +0 -0
  66. mindspore/avutil-57.dll +0 -0
  67. mindspore/boost/__init__.py +42 -0
  68. mindspore/boost/adasum.py +319 -0
  69. mindspore/boost/base.py +535 -0
  70. mindspore/boost/boost.py +400 -0
  71. mindspore/boost/boost_cell_wrapper.py +790 -0
  72. mindspore/boost/dim_reduce.py +323 -0
  73. mindspore/boost/grad_accumulation.py +79 -0
  74. mindspore/boost/grad_freeze.py +382 -0
  75. mindspore/boost/group_loss_scale_manager.py +166 -0
  76. mindspore/boost/less_batch_normalization.py +174 -0
  77. mindspore/c1.dll +0 -0
  78. mindspore/c1xx.dll +0 -0
  79. mindspore/c2.dll +0 -0
  80. mindspore/cfgpersist.dll +0 -0
  81. mindspore/clang_rt.asan_dbg_dynamic-x86_64.dll +0 -0
  82. mindspore/clang_rt.asan_dynamic-x86_64.dll +0 -0
  83. mindspore/common/__init__.py +84 -0
  84. mindspore/common/_auto_dynamic.py +68 -0
  85. mindspore/common/_decorator.py +50 -0
  86. mindspore/common/_jit_fallback_utils.py +110 -0
  87. mindspore/common/_monad.py +25 -0
  88. mindspore/common/_register_for_adapter.py +74 -0
  89. mindspore/common/_register_for_recompute.py +48 -0
  90. mindspore/common/_register_for_tensor.py +45 -0
  91. mindspore/common/_stub_tensor.py +210 -0
  92. mindspore/common/_utils.py +122 -0
  93. mindspore/common/api.py +2049 -0
  94. mindspore/common/auto_dynamic_shape.py +507 -0
  95. mindspore/common/dtype.py +422 -0
  96. mindspore/common/dump.py +131 -0
  97. mindspore/common/file_system.py +48 -0
  98. mindspore/common/generator.py +260 -0
  99. mindspore/common/hook_handle.py +155 -0
  100. mindspore/common/initializer.py +880 -0
  101. mindspore/common/jit_config.py +98 -0
  102. mindspore/common/lazy_inline.py +240 -0
  103. mindspore/common/mindir_util.py +111 -0
  104. mindspore/common/mutable.py +234 -0
  105. mindspore/common/no_inline.py +54 -0
  106. mindspore/common/np_dtype.py +25 -0
  107. mindspore/common/parameter.py +1048 -0
  108. mindspore/common/recompute.py +262 -0
  109. mindspore/common/seed.py +260 -0
  110. mindspore/common/sparse_tensor.py +1171 -0
  111. mindspore/common/symbol.py +122 -0
  112. mindspore/common/tensor.py +4859 -0
  113. mindspore/communication/__init__.py +37 -0
  114. mindspore/communication/_comm_helper.py +466 -0
  115. mindspore/communication/_hccl_management.py +297 -0
  116. mindspore/communication/comm_func.py +1140 -0
  117. mindspore/communication/management.py +673 -0
  118. mindspore/config/op_info.config +533 -0
  119. mindspore/context.py +1976 -0
  120. mindspore/d3dcompiler_47.dll +0 -0
  121. mindspore/dataset/__init__.py +90 -0
  122. mindspore/dataset/audio/__init__.py +61 -0
  123. mindspore/dataset/audio/transforms.py +3690 -0
  124. mindspore/dataset/audio/utils.py +386 -0
  125. mindspore/dataset/audio/validators.py +1172 -0
  126. mindspore/dataset/callback/__init__.py +20 -0
  127. mindspore/dataset/callback/ds_callback.py +368 -0
  128. mindspore/dataset/callback/validators.py +32 -0
  129. mindspore/dataset/core/__init__.py +13 -0
  130. mindspore/dataset/core/config.py +1088 -0
  131. mindspore/dataset/core/datatypes.py +101 -0
  132. mindspore/dataset/core/py_util_helpers.py +65 -0
  133. mindspore/dataset/core/validator_helpers.py +774 -0
  134. mindspore/dataset/debug/__init__.py +21 -0
  135. mindspore/dataset/debug/debug_hook.py +97 -0
  136. mindspore/dataset/debug/pre_defined_hook.py +67 -0
  137. mindspore/dataset/engine/__init__.py +124 -0
  138. mindspore/dataset/engine/cache_admin.py +47 -0
  139. mindspore/dataset/engine/cache_client.py +129 -0
  140. mindspore/dataset/engine/datasets.py +4554 -0
  141. mindspore/dataset/engine/datasets_audio.py +911 -0
  142. mindspore/dataset/engine/datasets_standard_format.py +493 -0
  143. mindspore/dataset/engine/datasets_text.py +2161 -0
  144. mindspore/dataset/engine/datasets_user_defined.py +1114 -0
  145. mindspore/dataset/engine/datasets_vision.py +4816 -0
  146. mindspore/dataset/engine/iterators.py +342 -0
  147. mindspore/dataset/engine/obs/__init__.py +23 -0
  148. mindspore/dataset/engine/obs/config_loader.py +68 -0
  149. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +508 -0
  150. mindspore/dataset/engine/obs/util.py +475 -0
  151. mindspore/dataset/engine/offload.py +596 -0
  152. mindspore/dataset/engine/queue.py +250 -0
  153. mindspore/dataset/engine/samplers.py +895 -0
  154. mindspore/dataset/engine/serializer_deserializer.py +159 -0
  155. mindspore/dataset/engine/validators.py +2875 -0
  156. mindspore/dataset/text/__init__.py +54 -0
  157. mindspore/dataset/text/transforms.py +1703 -0
  158. mindspore/dataset/text/utils.py +715 -0
  159. mindspore/dataset/text/validators.py +642 -0
  160. mindspore/dataset/transforms/__init__.py +48 -0
  161. mindspore/dataset/transforms/c_transforms.py +638 -0
  162. mindspore/dataset/transforms/py_transforms.py +393 -0
  163. mindspore/dataset/transforms/py_transforms_util.py +255 -0
  164. mindspore/dataset/transforms/transforms.py +1260 -0
  165. mindspore/dataset/transforms/validators.py +410 -0
  166. mindspore/dataset/utils/__init__.py +19 -0
  167. mindspore/dataset/utils/browse_dataset.py +190 -0
  168. mindspore/dataset/utils/line_reader.py +124 -0
  169. mindspore/dataset/vision/__init__.py +68 -0
  170. mindspore/dataset/vision/c_transforms.py +2641 -0
  171. mindspore/dataset/vision/py_transforms.py +2120 -0
  172. mindspore/dataset/vision/py_transforms_util.py +1660 -0
  173. mindspore/dataset/vision/transforms.py +7295 -0
  174. mindspore/dataset/vision/utils.py +863 -0
  175. mindspore/dataset/vision/validators.py +1482 -0
  176. mindspore/default_config.py +2 -0
  177. mindspore/dnnl.dll +0 -0
  178. mindspore/dpcmi.dll +0 -0
  179. mindspore/experimental/__init__.py +20 -0
  180. mindspore/experimental/map_parameter.py +309 -0
  181. mindspore/experimental/optim/__init__.py +40 -0
  182. mindspore/experimental/optim/adadelta.py +161 -0
  183. mindspore/experimental/optim/adagrad.py +168 -0
  184. mindspore/experimental/optim/adam.py +193 -0
  185. mindspore/experimental/optim/adamax.py +170 -0
  186. mindspore/experimental/optim/adamw.py +205 -0
  187. mindspore/experimental/optim/asgd.py +153 -0
  188. mindspore/experimental/optim/lr_scheduler.py +1371 -0
  189. mindspore/experimental/optim/nadam.py +157 -0
  190. mindspore/experimental/optim/optimizer.py +259 -0
  191. mindspore/experimental/optim/radam.py +194 -0
  192. mindspore/experimental/optim/rmsprop.py +154 -0
  193. mindspore/experimental/optim/rprop.py +164 -0
  194. mindspore/experimental/optim/sgd.py +156 -0
  195. mindspore/hal/__init__.py +40 -0
  196. mindspore/hal/_ascend.py +57 -0
  197. mindspore/hal/_base.py +57 -0
  198. mindspore/hal/_cpu.py +56 -0
  199. mindspore/hal/_gpu.py +57 -0
  200. mindspore/hal/device.py +356 -0
  201. mindspore/hal/event.py +179 -0
  202. mindspore/hal/memory.py +326 -0
  203. mindspore/hal/stream.py +339 -0
  204. mindspore/include/OWNERS +7 -0
  205. mindspore/include/api/allocator.h +97 -0
  206. mindspore/include/api/callback/callback.h +93 -0
  207. mindspore/include/api/callback/ckpt_saver.h +41 -0
  208. mindspore/include/api/callback/loss_monitor.h +33 -0
  209. mindspore/include/api/callback/lr_scheduler.h +51 -0
  210. mindspore/include/api/callback/time_monitor.h +34 -0
  211. mindspore/include/api/callback/train_accuracy.h +37 -0
  212. mindspore/include/api/cell.h +90 -0
  213. mindspore/include/api/cfg.h +82 -0
  214. mindspore/include/api/context.h +602 -0
  215. mindspore/include/api/data_type.h +47 -0
  216. mindspore/include/api/delegate.h +178 -0
  217. mindspore/include/api/delegate_api.h +75 -0
  218. mindspore/include/api/dual_abi_helper.h +208 -0
  219. mindspore/include/api/format.h +28 -0
  220. mindspore/include/api/graph.h +46 -0
  221. mindspore/include/api/kernel.h +58 -0
  222. mindspore/include/api/kernel_api.h +168 -0
  223. mindspore/include/api/metrics/accuracy.h +36 -0
  224. mindspore/include/api/metrics/metrics.h +41 -0
  225. mindspore/include/api/model.h +438 -0
  226. mindspore/include/api/model_group.h +79 -0
  227. mindspore/include/api/model_parallel_runner.h +168 -0
  228. mindspore/include/api/serialization.h +185 -0
  229. mindspore/include/api/status.h +192 -0
  230. mindspore/include/api/types.h +431 -0
  231. mindspore/include/api/visible.h +41 -0
  232. mindspore/include/c_api/context_c.h +179 -0
  233. mindspore/include/c_api/data_type_c.h +52 -0
  234. mindspore/include/c_api/format_c.h +46 -0
  235. mindspore/include/c_api/model_c.h +347 -0
  236. mindspore/include/c_api/ms/abstract.h +67 -0
  237. mindspore/include/c_api/ms/attribute.h +197 -0
  238. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  239. mindspore/include/c_api/ms/base/macros.h +32 -0
  240. mindspore/include/c_api/ms/base/status.h +33 -0
  241. mindspore/include/c_api/ms/base/types.h +283 -0
  242. mindspore/include/c_api/ms/context.h +102 -0
  243. mindspore/include/c_api/ms/graph.h +160 -0
  244. mindspore/include/c_api/ms/node.h +606 -0
  245. mindspore/include/c_api/ms/tensor.h +161 -0
  246. mindspore/include/c_api/ms/value.h +84 -0
  247. mindspore/include/c_api/status_c.h +79 -0
  248. mindspore/include/c_api/tensor_c.h +146 -0
  249. mindspore/include/c_api/types_c.h +67 -0
  250. mindspore/include/dataset/config.h +163 -0
  251. mindspore/include/dataset/constants.h +363 -0
  252. mindspore/include/dataset/execute.h +196 -0
  253. mindspore/include/dataset/text.h +1092 -0
  254. mindspore/include/dataset/transforms.h +638 -0
  255. mindspore/include/dataset/vision.h +2125 -0
  256. mindspore/include/dataset/vision_ascend.h +206 -0
  257. mindspore/include/dataset/vision_lite.h +625 -0
  258. mindspore/jpeg62.dll +0 -0
  259. mindspore/log.py +633 -0
  260. mindspore/mindrecord/__init__.py +43 -0
  261. mindspore/mindrecord/common/__init__.py +17 -0
  262. mindspore/mindrecord/common/constant.py +20 -0
  263. mindspore/mindrecord/common/enums.py +44 -0
  264. mindspore/mindrecord/common/exceptions.py +311 -0
  265. mindspore/mindrecord/config.py +809 -0
  266. mindspore/mindrecord/filereader.py +174 -0
  267. mindspore/mindrecord/filewriter.py +705 -0
  268. mindspore/mindrecord/mindpage.py +210 -0
  269. mindspore/mindrecord/shardheader.py +141 -0
  270. mindspore/mindrecord/shardindexgenerator.py +74 -0
  271. mindspore/mindrecord/shardreader.py +117 -0
  272. mindspore/mindrecord/shardsegment.py +128 -0
  273. mindspore/mindrecord/shardutils.py +185 -0
  274. mindspore/mindrecord/shardwriter.py +237 -0
  275. mindspore/mindrecord/tools/__init__.py +17 -0
  276. mindspore/mindrecord/tools/cifar10.py +140 -0
  277. mindspore/mindrecord/tools/cifar100.py +153 -0
  278. mindspore/mindrecord/tools/cifar100_to_mr.py +185 -0
  279. mindspore/mindrecord/tools/cifar10_to_mr.py +177 -0
  280. mindspore/mindrecord/tools/csv_to_mr.py +200 -0
  281. mindspore/mindrecord/tools/imagenet_to_mr.py +206 -0
  282. mindspore/mindrecord/tools/mnist_to_mr.py +259 -0
  283. mindspore/mindrecord/tools/tfrecord_to_mr.py +360 -0
  284. mindspore/mindspore_backend.dll +0 -0
  285. mindspore/mindspore_common.dll +0 -0
  286. mindspore/mindspore_core.dll +0 -0
  287. mindspore/mindspore_glog.dll +0 -0
  288. mindspore/mindspore_np_dtype.dll +0 -0
  289. mindspore/mindspore_shared_lib.dll +0 -0
  290. mindspore/mint/__init__.py +1137 -0
  291. mindspore/mint/linalg/__init__.py +22 -0
  292. mindspore/mint/nn/__init__.py +512 -0
  293. mindspore/mint/nn/functional.py +573 -0
  294. mindspore/mint/optim/__init__.py +24 -0
  295. mindspore/mint/optim/adamw.py +185 -0
  296. mindspore/msobj140.dll +0 -0
  297. mindspore/mspdb140.dll +0 -0
  298. mindspore/mspdbcore.dll +0 -0
  299. mindspore/mspdbst.dll +0 -0
  300. mindspore/mspft140.dll +0 -0
  301. mindspore/msvcdis140.dll +0 -0
  302. mindspore/msvcp140.dll +0 -0
  303. mindspore/msvcp140_1.dll +0 -0
  304. mindspore/msvcp140_2.dll +0 -0
  305. mindspore/msvcp140_atomic_wait.dll +0 -0
  306. mindspore/msvcp140_codecvt_ids.dll +0 -0
  307. mindspore/multiprocessing/__init__.py +72 -0
  308. mindspore/nn/__init__.py +48 -0
  309. mindspore/nn/cell.py +2605 -0
  310. mindspore/nn/dynamic_lr.py +482 -0
  311. mindspore/nn/extend/__init__.py +29 -0
  312. mindspore/nn/extend/basic.py +140 -0
  313. mindspore/nn/extend/embedding.py +143 -0
  314. mindspore/nn/extend/layer/__init__.py +27 -0
  315. mindspore/nn/extend/layer/normalization.py +109 -0
  316. mindspore/nn/extend/pooling.py +117 -0
  317. mindspore/nn/grad/__init__.py +21 -0
  318. mindspore/nn/grad/cell_grad.py +196 -0
  319. mindspore/nn/layer/__init__.py +63 -0
  320. mindspore/nn/layer/activation.py +1655 -0
  321. mindspore/nn/layer/basic.py +1519 -0
  322. mindspore/nn/layer/channel_shuffle.py +90 -0
  323. mindspore/nn/layer/combined.py +248 -0
  324. mindspore/nn/layer/container.py +734 -0
  325. mindspore/nn/layer/conv.py +1505 -0
  326. mindspore/nn/layer/dense.py +204 -0
  327. mindspore/nn/layer/embedding.py +751 -0
  328. mindspore/nn/layer/embedding_service.py +531 -0
  329. mindspore/nn/layer/embedding_service_layer.py +393 -0
  330. mindspore/nn/layer/image.py +661 -0
  331. mindspore/nn/layer/math.py +1069 -0
  332. mindspore/nn/layer/normalization.py +1177 -0
  333. mindspore/nn/layer/padding.py +894 -0
  334. mindspore/nn/layer/pooling.py +2148 -0
  335. mindspore/nn/layer/rnn_cells.py +388 -0
  336. mindspore/nn/layer/rnns.py +849 -0
  337. mindspore/nn/layer/thor_layer.py +963 -0
  338. mindspore/nn/layer/timedistributed.py +155 -0
  339. mindspore/nn/layer/transformer.py +823 -0
  340. mindspore/nn/learning_rate_schedule.py +512 -0
  341. mindspore/nn/loss/__init__.py +36 -0
  342. mindspore/nn/loss/loss.py +2846 -0
  343. mindspore/nn/metrics.py +53 -0
  344. mindspore/nn/optim/__init__.py +44 -0
  345. mindspore/nn/optim/_dist_optimizer_registry.py +111 -0
  346. mindspore/nn/optim/ada_grad.py +217 -0
  347. mindspore/nn/optim/adadelta.py +206 -0
  348. mindspore/nn/optim/adafactor.py +448 -0
  349. mindspore/nn/optim/adam.py +1297 -0
  350. mindspore/nn/optim/adamax.py +220 -0
  351. mindspore/nn/optim/adasum.py +548 -0
  352. mindspore/nn/optim/asgd.py +216 -0
  353. mindspore/nn/optim/ftrl.py +401 -0
  354. mindspore/nn/optim/lamb.py +296 -0
  355. mindspore/nn/optim/lars.py +202 -0
  356. mindspore/nn/optim/lazyadam.py +533 -0
  357. mindspore/nn/optim/momentum.py +239 -0
  358. mindspore/nn/optim/optimizer.py +1034 -0
  359. mindspore/nn/optim/proximal_ada_grad.py +242 -0
  360. mindspore/nn/optim/rmsprop.py +264 -0
  361. mindspore/nn/optim/rprop.py +251 -0
  362. mindspore/nn/optim/sgd.py +237 -0
  363. mindspore/nn/optim/thor.py +1310 -0
  364. mindspore/nn/probability/__init__.py +22 -0
  365. mindspore/nn/probability/bijector/__init__.py +35 -0
  366. mindspore/nn/probability/bijector/bijector.py +337 -0
  367. mindspore/nn/probability/bijector/exp.py +65 -0
  368. mindspore/nn/probability/bijector/gumbel_cdf.py +144 -0
  369. mindspore/nn/probability/bijector/invert.py +126 -0
  370. mindspore/nn/probability/bijector/power_transform.py +196 -0
  371. mindspore/nn/probability/bijector/scalar_affine.py +167 -0
  372. mindspore/nn/probability/bijector/softplus.py +189 -0
  373. mindspore/nn/probability/bnn_layers/__init__.py +29 -0
  374. mindspore/nn/probability/bnn_layers/_util.py +46 -0
  375. mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +112 -0
  376. mindspore/nn/probability/bnn_layers/conv_variational.py +267 -0
  377. mindspore/nn/probability/bnn_layers/dense_variational.py +302 -0
  378. mindspore/nn/probability/bnn_layers/layer_distribution.py +123 -0
  379. mindspore/nn/probability/distribution/__init__.py +56 -0
  380. mindspore/nn/probability/distribution/_utils/__init__.py +34 -0
  381. mindspore/nn/probability/distribution/_utils/custom_ops.py +96 -0
  382. mindspore/nn/probability/distribution/_utils/utils.py +362 -0
  383. mindspore/nn/probability/distribution/bernoulli.py +334 -0
  384. mindspore/nn/probability/distribution/beta.py +391 -0
  385. mindspore/nn/probability/distribution/categorical.py +435 -0
  386. mindspore/nn/probability/distribution/cauchy.py +383 -0
  387. mindspore/nn/probability/distribution/distribution.py +827 -0
  388. mindspore/nn/probability/distribution/exponential.py +350 -0
  389. mindspore/nn/probability/distribution/gamma.py +391 -0
  390. mindspore/nn/probability/distribution/geometric.py +335 -0
  391. mindspore/nn/probability/distribution/gumbel.py +257 -0
  392. mindspore/nn/probability/distribution/half_normal.py +133 -0
  393. mindspore/nn/probability/distribution/laplace.py +128 -0
  394. mindspore/nn/probability/distribution/log_normal.py +272 -0
  395. mindspore/nn/probability/distribution/logistic.py +379 -0
  396. mindspore/nn/probability/distribution/normal.py +336 -0
  397. mindspore/nn/probability/distribution/poisson.py +288 -0
  398. mindspore/nn/probability/distribution/student_t.py +149 -0
  399. mindspore/nn/probability/distribution/transformed_distribution.py +235 -0
  400. mindspore/nn/probability/distribution/uniform.py +375 -0
  401. mindspore/nn/reinforcement/__init__.py +24 -0
  402. mindspore/nn/reinforcement/_batch_read_write.py +142 -0
  403. mindspore/nn/reinforcement/_tensors_queue.py +152 -0
  404. mindspore/nn/reinforcement/tensor_array.py +145 -0
  405. mindspore/nn/sparse/__init__.py +23 -0
  406. mindspore/nn/sparse/sparse.py +147 -0
  407. mindspore/nn/wrap/__init__.py +49 -0
  408. mindspore/nn/wrap/cell_wrapper.py +979 -0
  409. mindspore/nn/wrap/grad_reducer.py +608 -0
  410. mindspore/nn/wrap/loss_scale.py +680 -0
  411. mindspore/numpy/__init__.py +121 -0
  412. mindspore/numpy/array_creations.py +2734 -0
  413. mindspore/numpy/array_ops.py +2625 -0
  414. mindspore/numpy/dtypes.py +185 -0
  415. mindspore/numpy/fft.py +431 -0
  416. mindspore/numpy/logic_ops.py +935 -0
  417. mindspore/numpy/math_ops.py +5910 -0
  418. mindspore/numpy/utils.py +214 -0
  419. mindspore/numpy/utils_const.py +565 -0
  420. mindspore/opencv_core452.dll +0 -0
  421. mindspore/opencv_imgcodecs452.dll +0 -0
  422. mindspore/opencv_imgproc452.dll +0 -0
  423. mindspore/ops/__init__.py +54 -0
  424. mindspore/ops/_constants.py +30 -0
  425. mindspore/ops/_grad_experimental/__init__.py +31 -0
  426. mindspore/ops/_grad_experimental/grad_array_ops.py +830 -0
  427. mindspore/ops/_grad_experimental/grad_base.py +143 -0
  428. mindspore/ops/_grad_experimental/grad_comm_ops.py +670 -0
  429. mindspore/ops/_grad_experimental/grad_debug_ops.py +31 -0
  430. mindspore/ops/_grad_experimental/grad_implementations.py +203 -0
  431. mindspore/ops/_grad_experimental/grad_inner_ops.py +79 -0
  432. mindspore/ops/_grad_experimental/grad_math_ops.py +824 -0
  433. mindspore/ops/_grad_experimental/grad_nn_ops.py +231 -0
  434. mindspore/ops/_grad_experimental/grad_quant_ops.py +238 -0
  435. mindspore/ops/_grad_experimental/grad_sparse.py +342 -0
  436. mindspore/ops/_grad_experimental/grad_sparse_ops.py +399 -0
  437. mindspore/ops/_grad_experimental/taylor_rule.py +220 -0
  438. mindspore/ops/_op_impl/__init__.py +23 -0
  439. mindspore/ops/_op_impl/_custom_op/__init__.py +39 -0
  440. mindspore/ops/_op_impl/_custom_op/_basic.py +158 -0
  441. mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +279 -0
  442. mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +156 -0
  443. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +109 -0
  444. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +125 -0
  445. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +105 -0
  446. mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +124 -0
  447. mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +116 -0
  448. mindspore/ops/_op_impl/_custom_op/correction_mul.py +89 -0
  449. mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +196 -0
  450. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +366 -0
  451. mindspore/ops/_op_impl/_custom_op/dsd_impl.py +162 -0
  452. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +136 -0
  453. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +206 -0
  454. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +88 -0
  455. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +128 -0
  456. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +199 -0
  457. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +88 -0
  458. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +156 -0
  459. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +184 -0
  460. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +143 -0
  461. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +169 -0
  462. mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +548 -0
  463. mindspore/ops/_op_impl/_custom_op/img2col_impl.py +881 -0
  464. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +278 -0
  465. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +200 -0
  466. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +334 -0
  467. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +255 -0
  468. mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +222 -0
  469. mindspore/ops/_op_impl/_custom_op/matmul_dds_grad_impl.py +644 -0
  470. mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +488 -0
  471. mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +87 -0
  472. mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +129 -0
  473. mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +121 -0
  474. mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +352 -0
  475. mindspore/ops/_op_impl/aicpu/__init__.py +441 -0
  476. mindspore/ops/_op_impl/aicpu/abs.py +36 -0
  477. mindspore/ops/_op_impl/aicpu/acos.py +32 -0
  478. mindspore/ops/_op_impl/aicpu/acos_grad.py +33 -0
  479. mindspore/ops/_op_impl/aicpu/acosh.py +34 -0
  480. mindspore/ops/_op_impl/aicpu/acosh_grad.py +35 -0
  481. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d.py +34 -0
  482. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
  483. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d.py +39 -0
  484. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d_grad.py +39 -0
  485. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
  486. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d_grad.py +37 -0
  487. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d.py +42 -0
  488. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d_grad.py +152 -0
  489. mindspore/ops/_op_impl/aicpu/add.py +43 -0
  490. mindspore/ops/_op_impl/aicpu/add_n.py +41 -0
  491. mindspore/ops/_op_impl/aicpu/add_v2.py +40 -0
  492. mindspore/ops/_op_impl/aicpu/addcdiv.py +41 -0
  493. mindspore/ops/_op_impl/aicpu/addcmul.py +47 -0
  494. mindspore/ops/_op_impl/aicpu/adjust_contrastv2.py +32 -0
  495. mindspore/ops/_op_impl/aicpu/adjust_hue.py +31 -0
  496. mindspore/ops/_op_impl/aicpu/adjust_saturation.py +32 -0
  497. mindspore/ops/_op_impl/aicpu/affine_grid.py +33 -0
  498. mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
  499. mindspore/ops/_op_impl/aicpu/angle.py +31 -0
  500. mindspore/ops/_op_impl/aicpu/arg_max.py +75 -0
  501. mindspore/ops/_op_impl/aicpu/arg_min.py +75 -0
  502. mindspore/ops/_op_impl/aicpu/argmax_with_value.py +43 -0
  503. mindspore/ops/_op_impl/aicpu/argmin_with_value.py +43 -0
  504. mindspore/ops/_op_impl/aicpu/asin.py +32 -0
  505. mindspore/ops/_op_impl/aicpu/asin_grad.py +33 -0
  506. mindspore/ops/_op_impl/aicpu/asinh.py +34 -0
  507. mindspore/ops/_op_impl/aicpu/asinh_grad.py +35 -0
  508. mindspore/ops/_op_impl/aicpu/atanh.py +34 -0
  509. mindspore/ops/_op_impl/aicpu/avgpool_grad_v1.py +37 -0
  510. mindspore/ops/_op_impl/aicpu/avgpool_v1.py +36 -0
  511. mindspore/ops/_op_impl/aicpu/bartlett_window.py +36 -0
  512. mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -0
  513. mindspore/ops/_op_impl/aicpu/batch_norm_grad_grad.py +49 -0
  514. mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
  515. mindspore/ops/_op_impl/aicpu/bessel_i0.py +31 -0
  516. mindspore/ops/_op_impl/aicpu/betainc.py +31 -0
  517. mindspore/ops/_op_impl/aicpu/bias_add.py +44 -0
  518. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +42 -0
  519. mindspore/ops/_op_impl/aicpu/bincount.py +33 -0
  520. mindspore/ops/_op_impl/aicpu/blackman_window.py +36 -0
  521. mindspore/ops/_op_impl/aicpu/broadcast_to.py +58 -0
  522. mindspore/ops/_op_impl/aicpu/bucketize.py +34 -0
  523. mindspore/ops/_op_impl/aicpu/cache_swap_table.py +102 -0
  524. mindspore/ops/_op_impl/aicpu/cast.py +225 -0
  525. mindspore/ops/_op_impl/aicpu/cauchy.py +33 -0
  526. mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
  527. mindspore/ops/_op_impl/aicpu/check_numerics.py +33 -0
  528. mindspore/ops/_op_impl/aicpu/cholesky.py +32 -0
  529. mindspore/ops/_op_impl/aicpu/cholesky_inverse.py +31 -0
  530. mindspore/ops/_op_impl/aicpu/cholesky_solve.py +33 -0
  531. mindspore/ops/_op_impl/aicpu/choleskygrad.py +32 -0
  532. mindspore/ops/_op_impl/aicpu/coalesce.py +37 -0
  533. mindspore/ops/_op_impl/aicpu/col2im.py +38 -0
  534. mindspore/ops/_op_impl/aicpu/combined_non_max_suppression.py +42 -0
  535. mindspore/ops/_op_impl/aicpu/compare_and_bitpack.py +37 -0
  536. mindspore/ops/_op_impl/aicpu/complex.py +32 -0
  537. mindspore/ops/_op_impl/aicpu/complex_abs.py +31 -0
  538. mindspore/ops/_op_impl/aicpu/compute_accidental_hits.py +44 -0
  539. mindspore/ops/_op_impl/aicpu/concat.py +57 -0
  540. mindspore/ops/_op_impl/aicpu/concat_offset.py +42 -0
  541. mindspore/ops/_op_impl/aicpu/concat_offset_v1.py +31 -0
  542. mindspore/ops/_op_impl/aicpu/conj.py +42 -0
  543. mindspore/ops/_op_impl/aicpu/conjugate_transpose.py +58 -0
  544. mindspore/ops/_op_impl/aicpu/cos.py +34 -0
  545. mindspore/ops/_op_impl/aicpu/cosh.py +34 -0
  546. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  547. mindspore/ops/_op_impl/aicpu/crop_and_resize.py +69 -0
  548. mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_boxes.py +68 -0
  549. mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_image.py +38 -0
  550. mindspore/ops/_op_impl/aicpu/cross.py +42 -0
  551. mindspore/ops/_op_impl/aicpu/csr_sparse_matrix_to_dense.py +48 -0
  552. mindspore/ops/_op_impl/aicpu/csr_sparse_matrix_to_sparse_tensor.py +51 -0
  553. mindspore/ops/_op_impl/aicpu/ctc_greedy_decoder.py +35 -0
  554. mindspore/ops/_op_impl/aicpu/ctc_loss_v2.py +43 -0
  555. mindspore/ops/_op_impl/aicpu/ctc_loss_v2_grad.py +45 -0
  556. mindspore/ops/_op_impl/aicpu/ctcloss.py +38 -0
  557. mindspore/ops/_op_impl/aicpu/cummax.py +41 -0
  558. mindspore/ops/_op_impl/aicpu/cumprod.py +58 -0
  559. mindspore/ops/_op_impl/aicpu/cumsum.py +58 -0
  560. mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +36 -0
  561. mindspore/ops/_op_impl/aicpu/data_format_vec_permute.py +32 -0
  562. mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
  563. mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +43 -0
  564. mindspore/ops/_op_impl/aicpu/dense_to_csr_sparse_matrix.py +49 -0
  565. mindspore/ops/_op_impl/aicpu/dense_to_dense_set_operation.py +45 -0
  566. mindspore/ops/_op_impl/aicpu/dense_to_sparse_set_operation.py +48 -0
  567. mindspore/ops/_op_impl/aicpu/depth_to_space.py +44 -0
  568. mindspore/ops/_op_impl/aicpu/diag.py +36 -0
  569. mindspore/ops/_op_impl/aicpu/diag_part.py +36 -0
  570. mindspore/ops/_op_impl/aicpu/diagonal.py +35 -0
  571. mindspore/ops/_op_impl/aicpu/digamma.py +31 -0
  572. mindspore/ops/_op_impl/aicpu/div.py +41 -0
  573. mindspore/ops/_op_impl/aicpu/div_no_nan.py +35 -0
  574. mindspore/ops/_op_impl/aicpu/dropout2d.py +42 -0
  575. mindspore/ops/_op_impl/aicpu/dropout3d.py +42 -0
  576. mindspore/ops/_op_impl/aicpu/dropout_genmask.py +41 -0
  577. mindspore/ops/_op_impl/aicpu/dropout_genmask_v3.py +32 -0
  578. mindspore/ops/_op_impl/aicpu/dynamic_stitch.py +42 -0
  579. mindspore/ops/_op_impl/aicpu/edit_distance.py +56 -0
  580. mindspore/ops/_op_impl/aicpu/eig.py +35 -0
  581. mindspore/ops/_op_impl/aicpu/embedding_lookup.py +102 -0
  582. mindspore/ops/_op_impl/aicpu/end_of_sequence.py +30 -0
  583. mindspore/ops/_op_impl/aicpu/environ_create.py +28 -0
  584. mindspore/ops/_op_impl/aicpu/environ_destroy_all.py +28 -0
  585. mindspore/ops/_op_impl/aicpu/environ_get.py +41 -0
  586. mindspore/ops/_op_impl/aicpu/environ_set.py +40 -0
  587. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  588. mindspore/ops/_op_impl/aicpu/equal.py +41 -0
  589. mindspore/ops/_op_impl/aicpu/exp.py +37 -0
  590. mindspore/ops/_op_impl/aicpu/expand.py +45 -0
  591. mindspore/ops/_op_impl/aicpu/expand_dims.py +42 -0
  592. mindspore/ops/_op_impl/aicpu/expm1.py +34 -0
  593. mindspore/ops/_op_impl/aicpu/extract_glimpse.py +35 -0
  594. mindspore/ops/_op_impl/aicpu/eye.py +44 -0
  595. mindspore/ops/_op_impl/aicpu/fft_with_size.py +47 -0
  596. mindspore/ops/_op_impl/aicpu/fill_diagonal.py +39 -0
  597. mindspore/ops/_op_impl/aicpu/fill_v2.py +58 -0
  598. mindspore/ops/_op_impl/aicpu/flatten.py +43 -0
  599. mindspore/ops/_op_impl/aicpu/floor_div.py +38 -0
  600. mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
  601. mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
  602. mindspore/ops/_op_impl/aicpu/fractional_avg_pool.py +41 -0
  603. mindspore/ops/_op_impl/aicpu/fractional_avg_pool_grad.py +41 -0
  604. mindspore/ops/_op_impl/aicpu/fractional_max_pool.py +41 -0
  605. mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_grad_with_fixed_ksize.py +43 -0
  606. mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +65 -0
  607. mindspore/ops/_op_impl/aicpu/fractional_max_pool_grad.py +42 -0
  608. mindspore/ops/_op_impl/aicpu/fractional_max_pool_grad_with_fixed_ksize.py +42 -0
  609. mindspore/ops/_op_impl/aicpu/fractional_max_pool_with_fixed_ksize.py +49 -0
  610. mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
  611. mindspore/ops/_op_impl/aicpu/fused_sparse_adam.py +46 -0
  612. mindspore/ops/_op_impl/aicpu/fused_sparse_ftrl.py +41 -0
  613. mindspore/ops/_op_impl/aicpu/fused_sparse_lazy_adam.py +46 -0
  614. mindspore/ops/_op_impl/aicpu/fused_sparse_proximal_adagrad.py +39 -0
  615. mindspore/ops/_op_impl/aicpu/gamma.py +38 -0
  616. mindspore/ops/_op_impl/aicpu/gather.py +46 -0
  617. mindspore/ops/_op_impl/aicpu/gather_d.py +79 -0
  618. mindspore/ops/_op_impl/aicpu/gather_d_grad_v2.py +79 -0
  619. mindspore/ops/_op_impl/aicpu/gather_grad.py +54 -0
  620. mindspore/ops/_op_impl/aicpu/gather_nd.py +56 -0
  621. mindspore/ops/_op_impl/aicpu/gcd.py +32 -0
  622. mindspore/ops/_op_impl/aicpu/generate_eod_mask.py +38 -0
  623. mindspore/ops/_op_impl/aicpu/geqrf.py +32 -0
  624. mindspore/ops/_op_impl/aicpu/get_next.py +39 -0
  625. mindspore/ops/_op_impl/aicpu/glu.py +33 -0
  626. mindspore/ops/_op_impl/aicpu/glu_grad.py +34 -0
  627. mindspore/ops/_op_impl/aicpu/greater.py +41 -0
  628. mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
  629. mindspore/ops/_op_impl/aicpu/grid_sampler_2d.py +35 -0
  630. mindspore/ops/_op_impl/aicpu/grid_sampler_2d_grad.py +38 -0
  631. mindspore/ops/_op_impl/aicpu/grid_sampler_3d.py +34 -0
  632. mindspore/ops/_op_impl/aicpu/grid_sampler_3d_grad.py +38 -0
  633. mindspore/ops/_op_impl/aicpu/hamming_window.py +57 -0
  634. mindspore/ops/_op_impl/aicpu/hard_sigmoid.py +32 -0
  635. mindspore/ops/_op_impl/aicpu/hard_sigmoid_grad.py +33 -0
  636. mindspore/ops/_op_impl/aicpu/heaviside.py +40 -0
  637. mindspore/ops/_op_impl/aicpu/histogram.py +35 -0
  638. mindspore/ops/_op_impl/aicpu/hsv_to_rgb.py +32 -0
  639. mindspore/ops/_op_impl/aicpu/hypot.py +32 -0
  640. mindspore/ops/_op_impl/aicpu/identity.py +42 -0
  641. mindspore/ops/_op_impl/aicpu/identity_n.py +41 -0
  642. mindspore/ops/_op_impl/aicpu/igamma.py +30 -0
  643. mindspore/ops/_op_impl/aicpu/igammac.py +30 -0
  644. mindspore/ops/_op_impl/aicpu/igammagrada.py +30 -0
  645. mindspore/ops/_op_impl/aicpu/im2col.py +43 -0
  646. mindspore/ops/_op_impl/aicpu/imag.py +31 -0
  647. mindspore/ops/_op_impl/aicpu/index_fill.py +54 -0
  648. mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
  649. mindspore/ops/_op_impl/aicpu/init_data_set_queue.py +27 -0
  650. mindspore/ops/_op_impl/aicpu/inplace_index_add.py +39 -0
  651. mindspore/ops/_op_impl/aicpu/instance_norm_v2.py +41 -0
  652. mindspore/ops/_op_impl/aicpu/instance_norm_v2_grad.py +44 -0
  653. mindspore/ops/_op_impl/aicpu/is_finite.py +40 -0
  654. mindspore/ops/_op_impl/aicpu/is_inf.py +31 -0
  655. mindspore/ops/_op_impl/aicpu/is_nan.py +31 -0
  656. mindspore/ops/_op_impl/aicpu/kldivloss.py +34 -0
  657. mindspore/ops/_op_impl/aicpu/kldivlossgrad.py +35 -0
  658. mindspore/ops/_op_impl/aicpu/layer_norm_grad_grad.py +47 -0
  659. mindspore/ops/_op_impl/aicpu/lcm.py +32 -0
  660. mindspore/ops/_op_impl/aicpu/left_shift.py +38 -0
  661. mindspore/ops/_op_impl/aicpu/less.py +41 -0
  662. mindspore/ops/_op_impl/aicpu/less_equal.py +41 -0
  663. mindspore/ops/_op_impl/aicpu/lgamma.py +33 -0
  664. mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +57 -0
  665. mindspore/ops/_op_impl/aicpu/linspace.py +33 -0
  666. mindspore/ops/_op_impl/aicpu/list_diff.py +50 -0
  667. mindspore/ops/_op_impl/aicpu/log.py +37 -0
  668. mindspore/ops/_op_impl/aicpu/log1p.py +34 -0
  669. mindspore/ops/_op_impl/aicpu/log_matrix_determinant.py +31 -0
  670. mindspore/ops/_op_impl/aicpu/log_normal_reverse.py +33 -0
  671. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +37 -0
  672. mindspore/ops/_op_impl/aicpu/logical_xor.py +30 -0
  673. mindspore/ops/_op_impl/aicpu/logit.py +33 -0
  674. mindspore/ops/_op_impl/aicpu/logit_grad.py +34 -0
  675. mindspore/ops/_op_impl/aicpu/logspace.py +36 -0
  676. mindspore/ops/_op_impl/aicpu/lower_bound.py +47 -0
  677. mindspore/ops/_op_impl/aicpu/lstsq.py +34 -0
  678. mindspore/ops/_op_impl/aicpu/lu.py +39 -0
  679. mindspore/ops/_op_impl/aicpu/lu_solve.py +32 -0
  680. mindspore/ops/_op_impl/aicpu/lu_unpack.py +114 -0
  681. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +49 -0
  682. mindspore/ops/_op_impl/aicpu/masked_fill.py +42 -0
  683. mindspore/ops/_op_impl/aicpu/masked_scatter.py +40 -0
  684. mindspore/ops/_op_impl/aicpu/masked_select.py +31 -0
  685. mindspore/ops/_op_impl/aicpu/masked_select_grad.py +35 -0
  686. mindspore/ops/_op_impl/aicpu/matmul.py +39 -0
  687. mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
  688. mindspore/ops/_op_impl/aicpu/matrix_determinant.py +30 -0
  689. mindspore/ops/_op_impl/aicpu/matrix_diag_part_v3.py +54 -0
  690. mindspore/ops/_op_impl/aicpu/matrix_diag_v3.py +56 -0
  691. mindspore/ops/_op_impl/aicpu/matrix_exp.py +34 -0
  692. mindspore/ops/_op_impl/aicpu/matrix_inverse.py +31 -0
  693. mindspore/ops/_op_impl/aicpu/matrix_logarithm.py +31 -0
  694. mindspore/ops/_op_impl/aicpu/matrix_power.py +37 -0
  695. mindspore/ops/_op_impl/aicpu/matrix_set_diag_v3.py +54 -0
  696. mindspore/ops/_op_impl/aicpu/matrix_solve.py +35 -0
  697. mindspore/ops/_op_impl/aicpu/matrix_solve_ls.py +36 -0
  698. mindspore/ops/_op_impl/aicpu/matrix_triangular_solve.py +36 -0
  699. mindspore/ops/_op_impl/aicpu/max_pool3d_grad_with_argmax.py +60 -0
  700. mindspore/ops/_op_impl/aicpu/max_pool3d_with_argmax.py +59 -0
  701. mindspore/ops/_op_impl/aicpu/max_unpool2d.py +57 -0
  702. mindspore/ops/_op_impl/aicpu/max_unpool2d_grad.py +58 -0
  703. mindspore/ops/_op_impl/aicpu/max_unpool3d.py +57 -0
  704. mindspore/ops/_op_impl/aicpu/max_unpool3d_grad.py +58 -0
  705. mindspore/ops/_op_impl/aicpu/maximum_grad_grad.py +40 -0
  706. mindspore/ops/_op_impl/aicpu/maxpool_grad_v1.py +46 -0
  707. mindspore/ops/_op_impl/aicpu/maxpool_v1.py +42 -0
  708. mindspore/ops/_op_impl/aicpu/median.py +39 -0
  709. mindspore/ops/_op_impl/aicpu/median_grad.py +45 -0
  710. mindspore/ops/_op_impl/aicpu/meshgrid.py +41 -0
  711. mindspore/ops/_op_impl/aicpu/minimum_grad_grad.py +40 -0
  712. mindspore/ops/_op_impl/aicpu/mirror_pad.py +50 -0
  713. mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +48 -0
  714. mindspore/ops/_op_impl/aicpu/mul.py +43 -0
  715. mindspore/ops/_op_impl/aicpu/mul_no_nan.py +42 -0
  716. mindspore/ops/_op_impl/aicpu/multi_margin_loss.py +37 -0
  717. mindspore/ops/_op_impl/aicpu/multi_margin_loss_grad.py +41 -0
  718. mindspore/ops/_op_impl/aicpu/multilabel_margin_loss_grad.py +37 -0
  719. mindspore/ops/_op_impl/aicpu/multinomial.py +47 -0
  720. mindspore/ops/_op_impl/aicpu/multinomial_with_replacement.py +35 -0
  721. mindspore/ops/_op_impl/aicpu/mvlgamma.py +32 -0
  722. mindspore/ops/_op_impl/aicpu/mvlgamma_grad.py +33 -0
  723. mindspore/ops/_op_impl/aicpu/nan_to_num.py +34 -0
  724. mindspore/ops/_op_impl/aicpu/neg.py +36 -0
  725. mindspore/ops/_op_impl/aicpu/nextafter.py +32 -0
  726. mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
  727. mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
  728. mindspore/ops/_op_impl/aicpu/no_repeat_ngram.py +34 -0
  729. mindspore/ops/_op_impl/aicpu/non_deterministic_ints.py +33 -0
  730. mindspore/ops/_op_impl/aicpu/non_max_suppression.py +36 -0
  731. mindspore/ops/_op_impl/aicpu/non_max_suppression_with_overlaps.py +35 -0
  732. mindspore/ops/_op_impl/aicpu/non_zero.py +43 -0
  733. mindspore/ops/_op_impl/aicpu/not_equal.py +39 -0
  734. mindspore/ops/_op_impl/aicpu/nth_element.py +39 -0
  735. mindspore/ops/_op_impl/aicpu/nuclear_norm.py +33 -0
  736. mindspore/ops/_op_impl/aicpu/one_hot.py +116 -0
  737. mindspore/ops/_op_impl/aicpu/ones_like.py +39 -0
  738. mindspore/ops/_op_impl/aicpu/orgqr.py +34 -0
  739. mindspore/ops/_op_impl/aicpu/pad_and_shift.py +33 -0
  740. mindspore/ops/_op_impl/aicpu/pad_v3.py +61 -0
  741. mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +59 -0
  742. mindspore/ops/_op_impl/aicpu/padding.py +41 -0
  743. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +54 -0
  744. mindspore/ops/_op_impl/aicpu/pdist_grad.py +33 -0
  745. mindspore/ops/_op_impl/aicpu/poisson.py +37 -0
  746. mindspore/ops/_op_impl/aicpu/polar.py +32 -0
  747. mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
  748. mindspore/ops/_op_impl/aicpu/pow.py +39 -0
  749. mindspore/ops/_op_impl/aicpu/print_tensor.py +39 -0
  750. mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +113 -0
  751. mindspore/ops/_op_impl/aicpu/qr.py +36 -0
  752. mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
  753. mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
  754. mindspore/ops/_op_impl/aicpu/ragged_range.py +49 -0
  755. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
  756. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_tensor.py +74 -0
  757. mindspore/ops/_op_impl/aicpu/random_categorical.py +68 -0
  758. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +36 -0
  759. mindspore/ops/_op_impl/aicpu/random_gamma.py +38 -0
  760. mindspore/ops/_op_impl/aicpu/random_poisson.py +134 -0
  761. mindspore/ops/_op_impl/aicpu/random_shuffle.py +47 -0
  762. mindspore/ops/_op_impl/aicpu/randperm.py +38 -0
  763. mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
  764. mindspore/ops/_op_impl/aicpu/range.py +36 -0
  765. mindspore/ops/_op_impl/aicpu/range_v2.py +35 -0
  766. mindspore/ops/_op_impl/aicpu/real.py +31 -0
  767. mindspore/ops/_op_impl/aicpu/real_div.py +40 -0
  768. mindspore/ops/_op_impl/aicpu/reciprocal.py +34 -0
  769. mindspore/ops/_op_impl/aicpu/reciprocal_grad.py +35 -0
  770. mindspore/ops/_op_impl/aicpu/reduce_mean.py +57 -0
  771. mindspore/ops/_op_impl/aicpu/reduce_prod.py +57 -0
  772. mindspore/ops/_op_impl/aicpu/reduce_sum.py +57 -0
  773. mindspore/ops/_op_impl/aicpu/relu_grad_v3.py +41 -0
  774. mindspore/ops/_op_impl/aicpu/relu_v3.py +38 -0
  775. mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +96 -0
  776. mindspore/ops/_op_impl/aicpu/reshape.py +42 -0
  777. mindspore/ops/_op_impl/aicpu/resize_area.py +40 -0
  778. mindspore/ops/_op_impl/aicpu/resize_bicubic.py +20 -0
  779. mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +19 -0
  780. mindspore/ops/_op_impl/aicpu/resize_bilinear.py +32 -0
  781. mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +32 -0
  782. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +36 -0
  783. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +35 -0
  784. mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
  785. mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
  786. mindspore/ops/_op_impl/aicpu/reverse_sequence.py +55 -0
  787. mindspore/ops/_op_impl/aicpu/reversev2.py +54 -0
  788. mindspore/ops/_op_impl/aicpu/rgb_to_hsv.py +32 -0
  789. mindspore/ops/_op_impl/aicpu/right_shift.py +38 -0
  790. mindspore/ops/_op_impl/aicpu/rnnt_loss.py +35 -0
  791. mindspore/ops/_op_impl/aicpu/round.py +34 -0
  792. mindspore/ops/_op_impl/aicpu/rsqrt.py +33 -0
  793. mindspore/ops/_op_impl/aicpu/rsqrt_grad.py +36 -0
  794. mindspore/ops/_op_impl/aicpu/sample_distorted_bounding_box_v2.py +49 -0
  795. mindspore/ops/_op_impl/aicpu/scale_and_translate.py +52 -0
  796. mindspore/ops/_op_impl/aicpu/scale_and_translate_grad.py +36 -0
  797. mindspore/ops/_op_impl/aicpu/scatter.py +79 -0
  798. mindspore/ops/_op_impl/aicpu/scatter_add_with_axis.py +53 -0
  799. mindspore/ops/_op_impl/aicpu/scatter_elements.py +39 -0
  800. mindspore/ops/_op_impl/aicpu/scatter_nd.py +59 -0
  801. mindspore/ops/_op_impl/aicpu/scatter_nd_max.py +54 -0
  802. mindspore/ops/_op_impl/aicpu/scatter_nd_min.py +54 -0
  803. mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +59 -0
  804. mindspore/ops/_op_impl/aicpu/search_sorted.py +44 -0
  805. mindspore/ops/_op_impl/aicpu/segment_max.py +52 -0
  806. mindspore/ops/_op_impl/aicpu/segment_mean.py +56 -0
  807. mindspore/ops/_op_impl/aicpu/segment_min.py +52 -0
  808. mindspore/ops/_op_impl/aicpu/segment_prod.py +56 -0
  809. mindspore/ops/_op_impl/aicpu/segment_sum.py +56 -0
  810. mindspore/ops/_op_impl/aicpu/select.py +45 -0
  811. mindspore/ops/_op_impl/aicpu/self_adjoint_eig.py +34 -0
  812. mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
  813. mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
  814. mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
  815. mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
  816. mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
  817. mindspore/ops/_op_impl/aicpu/set_size.py +38 -0
  818. mindspore/ops/_op_impl/aicpu/sign.py +36 -0
  819. mindspore/ops/_op_impl/aicpu/sin.py +34 -0
  820. mindspore/ops/_op_impl/aicpu/sinc.py +43 -0
  821. mindspore/ops/_op_impl/aicpu/sinh.py +34 -0
  822. mindspore/ops/_op_impl/aicpu/slice.py +59 -0
  823. mindspore/ops/_op_impl/aicpu/slice_grad.py +76 -0
  824. mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
  825. mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
  826. mindspore/ops/_op_impl/aicpu/sort.py +39 -0
  827. mindspore/ops/_op_impl/aicpu/space_to_depth.py +44 -0
  828. mindspore/ops/_op_impl/aicpu/sparse_addmm.py +87 -0
  829. mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +80 -0
  830. mindspore/ops/_op_impl/aicpu/sparse_apply_centered_rms_prop.py +105 -0
  831. mindspore/ops/_op_impl/aicpu/sparse_apply_momentum.py +80 -0
  832. mindspore/ops/_op_impl/aicpu/sparse_apply_proximal_gradient_descent.py +79 -0
  833. mindspore/ops/_op_impl/aicpu/sparse_concat.py +59 -0
  834. mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
  835. mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_add.py +58 -0
  836. mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_div.py +58 -0
  837. mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_mul.py +58 -0
  838. mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows.py +63 -0
  839. mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows_grad.py +45 -0
  840. mindspore/ops/_op_impl/aicpu/sparse_matrix_mat_mul.py +56 -0
  841. mindspore/ops/_op_impl/aicpu/sparse_matrix_nnz.py +81 -0
  842. mindspore/ops/_op_impl/aicpu/sparse_matrix_transpose.py +116 -0
  843. mindspore/ops/_op_impl/aicpu/sparse_reorder.py +56 -0
  844. mindspore/ops/_op_impl/aicpu/sparse_reshape.py +34 -0
  845. mindspore/ops/_op_impl/aicpu/sparse_segment_mean_grad.py +36 -0
  846. mindspore/ops/_op_impl/aicpu/sparse_segment_mean_with_num_segments.py +44 -0
  847. mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n.py +43 -0
  848. mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n_grad.py +38 -0
  849. mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n_with_num_segments.py +44 -0
  850. mindspore/ops/_op_impl/aicpu/sparse_segment_sum.py +49 -0
  851. mindspore/ops/_op_impl/aicpu/sparse_segment_sum_with_num_segments.py +68 -0
  852. mindspore/ops/_op_impl/aicpu/sparse_slice.py +63 -0
  853. mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +61 -0
  854. mindspore/ops/_op_impl/aicpu/sparse_softmax.py +33 -0
  855. mindspore/ops/_op_impl/aicpu/sparse_softmax_cross_entropy_with_logits_v2.py +35 -0
  856. mindspore/ops/_op_impl/aicpu/sparse_sparse_maximum.py +53 -0
  857. mindspore/ops/_op_impl/aicpu/sparse_sparse_minimum.py +53 -0
  858. mindspore/ops/_op_impl/aicpu/sparse_tensor_dense_add.py +84 -0
  859. mindspore/ops/_op_impl/aicpu/sparse_tensor_dense_mat_mul.py +190 -0
  860. mindspore/ops/_op_impl/aicpu/sparse_tensor_to_csr_sparse_matrix.py +51 -0
  861. mindspore/ops/_op_impl/aicpu/sparse_to_dense_v2.py +73 -0
  862. mindspore/ops/_op_impl/aicpu/split.py +45 -0
  863. mindspore/ops/_op_impl/aicpu/sqrt.py +34 -0
  864. mindspore/ops/_op_impl/aicpu/sqrt_grad.py +35 -0
  865. mindspore/ops/_op_impl/aicpu/square.py +35 -0
  866. mindspore/ops/_op_impl/aicpu/squared_difference.py +37 -0
  867. mindspore/ops/_op_impl/aicpu/squeeze.py +42 -0
  868. mindspore/ops/_op_impl/aicpu/sspaddmm.py +97 -0
  869. mindspore/ops/_op_impl/aicpu/stack.py +45 -0
  870. mindspore/ops/_op_impl/aicpu/stack_push_pop.py +87 -0
  871. mindspore/ops/_op_impl/aicpu/standard_laplace.py +34 -0
  872. mindspore/ops/_op_impl/aicpu/standard_normal.py +34 -0
  873. mindspore/ops/_op_impl/aicpu/stateless_dropout_genmask.py +37 -0
  874. mindspore/ops/_op_impl/aicpu/stft.py +70 -0
  875. mindspore/ops/_op_impl/aicpu/strided_slice.py +43 -0
  876. mindspore/ops/_op_impl/aicpu/strided_slice_grad.py +50 -0
  877. mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +93 -0
  878. mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +66 -0
  879. mindspore/ops/_op_impl/aicpu/sub.py +41 -0
  880. mindspore/ops/_op_impl/aicpu/sub_and_filter.py +36 -0
  881. mindspore/ops/_op_impl/aicpu/tan.py +34 -0
  882. mindspore/ops/_op_impl/aicpu/tanh.py +34 -0
  883. mindspore/ops/_op_impl/aicpu/tanh_grad.py +35 -0
  884. mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
  885. mindspore/ops/_op_impl/aicpu/tile.py +56 -0
  886. mindspore/ops/_op_impl/aicpu/topk.py +34 -0
  887. mindspore/ops/_op_impl/aicpu/trace.py +40 -0
  888. mindspore/ops/_op_impl/aicpu/tracegrad.py +41 -0
  889. mindspore/ops/_op_impl/aicpu/trans_data.py +35 -0
  890. mindspore/ops/_op_impl/aicpu/transpose.py +58 -0
  891. mindspore/ops/_op_impl/aicpu/tridiagonal_matmul.py +42 -0
  892. mindspore/ops/_op_impl/aicpu/tridiagonal_solve.py +35 -0
  893. mindspore/ops/_op_impl/aicpu/tril.py +42 -0
  894. mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
  895. mindspore/ops/_op_impl/aicpu/triplet_margin_loss.py +62 -0
  896. mindspore/ops/_op_impl/aicpu/triu.py +43 -0
  897. mindspore/ops/_op_impl/aicpu/triu_indices.py +34 -0
  898. mindspore/ops/_op_impl/aicpu/truncated_normal.py +39 -0
  899. mindspore/ops/_op_impl/aicpu/uniform.py +36 -0
  900. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +41 -0
  901. mindspore/ops/_op_impl/aicpu/uniform_int.py +36 -0
  902. mindspore/ops/_op_impl/aicpu/uniform_real.py +33 -0
  903. mindspore/ops/_op_impl/aicpu/unique.py +31 -0
  904. mindspore/ops/_op_impl/aicpu/unique_consecutive.py +47 -0
  905. mindspore/ops/_op_impl/aicpu/unique_with_pad.py +32 -0
  906. mindspore/ops/_op_impl/aicpu/unravel_index.py +32 -0
  907. mindspore/ops/_op_impl/aicpu/unsorted_segment_prod.py +53 -0
  908. mindspore/ops/_op_impl/aicpu/unsorted_segment_sum.py +57 -0
  909. mindspore/ops/_op_impl/aicpu/unstack.py +45 -0
  910. mindspore/ops/_op_impl/aicpu/update_cache.py +44 -0
  911. mindspore/ops/_op_impl/aicpu/upper_bound.py +47 -0
  912. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +42 -0
  913. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +49 -0
  914. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +40 -0
  915. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +50 -0
  916. mindspore/ops/_op_impl/aicpu/xdivy.py +35 -0
  917. mindspore/ops/_op_impl/aicpu/xlogy.py +33 -0
  918. mindspore/ops/_op_impl/aicpu/zeros_like.py +42 -0
  919. mindspore/ops/_op_impl/aicpu/zeta.py +31 -0
  920. mindspore/ops/_op_impl/akg/__init__.py +19 -0
  921. mindspore/ops/_op_impl/akg/ascend/__init__.py +48 -0
  922. mindspore/ops/_op_impl/akg/ascend/abs.py +35 -0
  923. mindspore/ops/_op_impl/akg/ascend/add.py +42 -0
  924. mindspore/ops/_op_impl/akg/ascend/add_n.py +37 -0
  925. mindspore/ops/_op_impl/akg/ascend/batchmatmul.py +33 -0
  926. mindspore/ops/_op_impl/akg/ascend/cast.py +46 -0
  927. mindspore/ops/_op_impl/akg/ascend/equal.py +35 -0
  928. mindspore/ops/_op_impl/akg/ascend/exp.py +35 -0
  929. mindspore/ops/_op_impl/akg/ascend/expand_dims.py +33 -0
  930. mindspore/ops/_op_impl/akg/ascend/greater.py +34 -0
  931. mindspore/ops/_op_impl/akg/ascend/greater_equal.py +35 -0
  932. mindspore/ops/_op_impl/akg/ascend/less.py +31 -0
  933. mindspore/ops/_op_impl/akg/ascend/less_equal.py +35 -0
  934. mindspore/ops/_op_impl/akg/ascend/load_im2col.py +33 -0
  935. mindspore/ops/_op_impl/akg/ascend/log.py +34 -0
  936. mindspore/ops/_op_impl/akg/ascend/maximum.py +36 -0
  937. mindspore/ops/_op_impl/akg/ascend/minimum.py +39 -0
  938. mindspore/ops/_op_impl/akg/ascend/mul.py +41 -0
  939. mindspore/ops/_op_impl/akg/ascend/neg.py +37 -0
  940. mindspore/ops/_op_impl/akg/ascend/pow.py +35 -0
  941. mindspore/ops/_op_impl/akg/ascend/prod_force_se_a.py +33 -0
  942. mindspore/ops/_op_impl/akg/ascend/real_div.py +36 -0
  943. mindspore/ops/_op_impl/akg/ascend/reciprocal.py +32 -0
  944. mindspore/ops/_op_impl/akg/ascend/reduce_max.py +32 -0
  945. mindspore/ops/_op_impl/akg/ascend/reduce_min.py +32 -0
  946. mindspore/ops/_op_impl/akg/ascend/reduce_sum.py +37 -0
  947. mindspore/ops/_op_impl/akg/ascend/rsqrt.py +35 -0
  948. mindspore/ops/_op_impl/akg/ascend/select.py +37 -0
  949. mindspore/ops/_op_impl/akg/ascend/sqrt.py +35 -0
  950. mindspore/ops/_op_impl/akg/ascend/square.py +35 -0
  951. mindspore/ops/_op_impl/akg/ascend/sub.py +42 -0
  952. mindspore/ops/_op_impl/akg/cpu/__init__.py +23 -0
  953. mindspore/ops/_op_impl/akg/cpu/coo2csr.py +29 -0
  954. mindspore/ops/_op_impl/akg/cpu/csr2coo.py +29 -0
  955. mindspore/ops/_op_impl/akg/cpu/csr_gather.py +33 -0
  956. mindspore/ops/_op_impl/akg/cpu/csr_mm.py +34 -0
  957. mindspore/ops/_op_impl/akg/cpu/csr_mul.py +33 -0
  958. mindspore/ops/_op_impl/akg/cpu/csr_mv.py +33 -0
  959. mindspore/ops/_op_impl/akg/cpu/csr_reduce_sum.py +31 -0
  960. mindspore/ops/_op_impl/akg/gpu/__init__.py +24 -0
  961. mindspore/ops/_op_impl/akg/gpu/coo2csr.py +29 -0
  962. mindspore/ops/_op_impl/akg/gpu/csr2coo.py +29 -0
  963. mindspore/ops/_op_impl/akg/gpu/csr_div.py +36 -0
  964. mindspore/ops/_op_impl/akg/gpu/csr_gather.py +33 -0
  965. mindspore/ops/_op_impl/akg/gpu/csr_mm.py +37 -0
  966. mindspore/ops/_op_impl/akg/gpu/csr_mul.py +36 -0
  967. mindspore/ops/_op_impl/akg/gpu/csr_mv.py +36 -0
  968. mindspore/ops/_op_impl/akg/gpu/csr_reduce_sum.py +33 -0
  969. mindspore/ops/_op_impl/cpu/__init__.py +78 -0
  970. mindspore/ops/_op_impl/cpu/adam.py +49 -0
  971. mindspore/ops/_op_impl/cpu/adam_weight_decay.py +47 -0
  972. mindspore/ops/_op_impl/cpu/arg_max.py +30 -0
  973. mindspore/ops/_op_impl/cpu/arg_max_with_value.py +31 -0
  974. mindspore/ops/_op_impl/cpu/arg_min_with_value.py +31 -0
  975. mindspore/ops/_op_impl/cpu/buffer_append.py +28 -0
  976. mindspore/ops/_op_impl/cpu/buffer_get.py +28 -0
  977. mindspore/ops/_op_impl/cpu/buffer_sample.py +28 -0
  978. mindspore/ops/_op_impl/cpu/cast.py +171 -0
  979. mindspore/ops/_op_impl/cpu/concat_offset.py +38 -0
  980. mindspore/ops/_op_impl/cpu/conv2d.py +30 -0
  981. mindspore/ops/_op_impl/cpu/conv3d.py +30 -0
  982. mindspore/ops/_op_impl/cpu/div.py +32 -0
  983. mindspore/ops/_op_impl/cpu/dropout.py +31 -0
  984. mindspore/ops/_op_impl/cpu/dropout_grad.py +30 -0
  985. mindspore/ops/_op_impl/cpu/dynamic_shape.py +42 -0
  986. mindspore/ops/_op_impl/cpu/dynamic_stitch.py +41 -0
  987. mindspore/ops/_op_impl/cpu/equal_count.py +30 -0
  988. mindspore/ops/_op_impl/cpu/gather_d.py +49 -0
  989. mindspore/ops/_op_impl/cpu/gather_d_grad.py +38 -0
  990. mindspore/ops/_op_impl/cpu/gather_d_grad_v2.py +40 -0
  991. mindspore/ops/_op_impl/cpu/gather_v2.py +40 -0
  992. mindspore/ops/_op_impl/cpu/hsigmoid.py +33 -0
  993. mindspore/ops/_op_impl/cpu/hsigmoid_grad.py +34 -0
  994. mindspore/ops/_op_impl/cpu/hswish.py +32 -0
  995. mindspore/ops/_op_impl/cpu/hswish_grad.py +33 -0
  996. mindspore/ops/_op_impl/cpu/identity_n.py +40 -0
  997. mindspore/ops/_op_impl/cpu/is_finite.py +39 -0
  998. mindspore/ops/_op_impl/cpu/l2loss.py +30 -0
  999. mindspore/ops/_op_impl/cpu/layer_norm.py +36 -0
  1000. mindspore/ops/_op_impl/cpu/layer_norm_grad.py +38 -0
  1001. mindspore/ops/_op_impl/cpu/maximum.py +35 -0
  1002. mindspore/ops/_op_impl/cpu/maximum_grad.py +47 -0
  1003. mindspore/ops/_op_impl/cpu/minimum.py +40 -0
  1004. mindspore/ops/_op_impl/cpu/minimum_grad.py +51 -0
  1005. mindspore/ops/_op_impl/cpu/mirror_pad.py +36 -0
  1006. mindspore/ops/_op_impl/cpu/mirror_pad_grad.py +36 -0
  1007. mindspore/ops/_op_impl/cpu/mul.py +32 -0
  1008. mindspore/ops/_op_impl/cpu/one_hot.py +31 -0
  1009. mindspore/ops/_op_impl/cpu/pad.py +32 -0
  1010. mindspore/ops/_op_impl/cpu/pow.py +32 -0
  1011. mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +42 -0
  1012. mindspore/ops/_op_impl/cpu/pyexecute.py +29 -0
  1013. mindspore/ops/_op_impl/cpu/pyfunc.py +29 -0
  1014. mindspore/ops/_op_impl/cpu/range.py +34 -0
  1015. mindspore/ops/_op_impl/cpu/real_div.py +33 -0
  1016. mindspore/ops/_op_impl/cpu/reduce_all.py +29 -0
  1017. mindspore/ops/_op_impl/cpu/reduce_any.py +29 -0
  1018. mindspore/ops/_op_impl/cpu/reduce_max.py +32 -0
  1019. mindspore/ops/_op_impl/cpu/reduce_mean.py +40 -0
  1020. mindspore/ops/_op_impl/cpu/reduce_min.py +32 -0
  1021. mindspore/ops/_op_impl/cpu/reduce_prod.py +40 -0
  1022. mindspore/ops/_op_impl/cpu/reduce_std.py +31 -0
  1023. mindspore/ops/_op_impl/cpu/reduce_sum.py +41 -0
  1024. mindspore/ops/_op_impl/cpu/space_to_batch_nd.py +38 -0
  1025. mindspore/ops/_op_impl/cpu/sparse_slice.py +62 -0
  1026. mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +60 -0
  1027. mindspore/ops/_op_impl/cpu/split.py +34 -0
  1028. mindspore/ops/_op_impl/cpu/sspaddmm.py +95 -0
  1029. mindspore/ops/_op_impl/cpu/stack.py +38 -0
  1030. mindspore/ops/_op_impl/cpu/sub.py +32 -0
  1031. mindspore/ops/_op_impl/cpu/tensor_copy_slices.py +41 -0
  1032. mindspore/ops/_op_impl/cpu/tile.py +37 -0
  1033. mindspore/ops/_op_impl/cpu/top_k.py +31 -0
  1034. mindspore/ops/_op_impl/cpu/transpose.py +39 -0
  1035. mindspore/ops/_primitive_cache.py +90 -0
  1036. mindspore/ops/_register_for_op.py +73 -0
  1037. mindspore/ops/_utils/__init__.py +20 -0
  1038. mindspore/ops/_utils/utils.py +147 -0
  1039. mindspore/ops/_vmap/__init__.py +25 -0
  1040. mindspore/ops/_vmap/vmap_array_ops.py +2151 -0
  1041. mindspore/ops/_vmap/vmap_base.py +533 -0
  1042. mindspore/ops/_vmap/vmap_convolution_ops.py +441 -0
  1043. mindspore/ops/_vmap/vmap_debug_ops.py +50 -0
  1044. mindspore/ops/_vmap/vmap_grad_math_ops.py +274 -0
  1045. mindspore/ops/_vmap/vmap_grad_nn_ops.py +806 -0
  1046. mindspore/ops/_vmap/vmap_image_ops.py +194 -0
  1047. mindspore/ops/_vmap/vmap_math_ops.py +977 -0
  1048. mindspore/ops/_vmap/vmap_nn_ops.py +2209 -0
  1049. mindspore/ops/_vmap/vmap_other_ops.py +105 -0
  1050. mindspore/ops/_vmap/vmap_random_ops.py +122 -0
  1051. mindspore/ops/_vmap/vmap_sparse_ops.py +89 -0
  1052. mindspore/ops/auto_generate/__init__.py +31 -0
  1053. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +231 -0
  1054. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +250 -0
  1055. mindspore/ops/auto_generate/gen_arg_handler.py +197 -0
  1056. mindspore/ops/auto_generate/gen_extend_func.py +980 -0
  1057. mindspore/ops/auto_generate/gen_ops_def.py +6443 -0
  1058. mindspore/ops/auto_generate/gen_ops_prim.py +13167 -0
  1059. mindspore/ops/auto_generate/pyboost_inner_prim.py +429 -0
  1060. mindspore/ops/composite/__init__.py +71 -0
  1061. mindspore/ops/composite/base.py +1281 -0
  1062. mindspore/ops/composite/env_ops.py +41 -0
  1063. mindspore/ops/composite/math_ops.py +125 -0
  1064. mindspore/ops/composite/multitype_ops/__init__.py +77 -0
  1065. mindspore/ops/composite/multitype_ops/_compile_utils.py +1458 -0
  1066. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +897 -0
  1067. mindspore/ops/composite/multitype_ops/add_impl.py +606 -0
  1068. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +56 -0
  1069. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +56 -0
  1070. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +56 -0
  1071. mindspore/ops/composite/multitype_ops/div_impl.py +189 -0
  1072. mindspore/ops/composite/multitype_ops/equal_impl.py +335 -0
  1073. mindspore/ops/composite/multitype_ops/floordiv_impl.py +88 -0
  1074. mindspore/ops/composite/multitype_ops/getitem_impl.py +400 -0
  1075. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +109 -0
  1076. mindspore/ops/composite/multitype_ops/greater_impl.py +110 -0
  1077. mindspore/ops/composite/multitype_ops/in_impl.py +196 -0
  1078. mindspore/ops/composite/multitype_ops/left_shift_impl.py +37 -0
  1079. mindspore/ops/composite/multitype_ops/less_equal_impl.py +111 -0
  1080. mindspore/ops/composite/multitype_ops/less_impl.py +112 -0
  1081. mindspore/ops/composite/multitype_ops/logic_not_impl.py +113 -0
  1082. mindspore/ops/composite/multitype_ops/logical_and_impl.py +60 -0
  1083. mindspore/ops/composite/multitype_ops/logical_or_impl.py +61 -0
  1084. mindspore/ops/composite/multitype_ops/mod_impl.py +86 -0
  1085. mindspore/ops/composite/multitype_ops/mul_impl.py +294 -0
  1086. mindspore/ops/composite/multitype_ops/negative_impl.py +79 -0
  1087. mindspore/ops/composite/multitype_ops/not_equal_impl.py +290 -0
  1088. mindspore/ops/composite/multitype_ops/not_in_impl.py +196 -0
  1089. mindspore/ops/composite/multitype_ops/ones_like_impl.py +96 -0
  1090. mindspore/ops/composite/multitype_ops/pow_impl.py +87 -0
  1091. mindspore/ops/composite/multitype_ops/right_shift_impl.py +37 -0
  1092. mindspore/ops/composite/multitype_ops/setitem_impl.py +884 -0
  1093. mindspore/ops/composite/multitype_ops/sub_impl.py +116 -0
  1094. mindspore/ops/composite/multitype_ops/uadd_impl.py +29 -0
  1095. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +228 -0
  1096. mindspore/ops/deprecated.py +315 -0
  1097. mindspore/ops/extend/__init__.py +53 -0
  1098. mindspore/ops/extend/array_func.py +218 -0
  1099. mindspore/ops/extend/math_func.py +76 -0
  1100. mindspore/ops/extend/nn_func.py +308 -0
  1101. mindspore/ops/function/__init__.py +760 -0
  1102. mindspore/ops/function/array_func.py +6889 -0
  1103. mindspore/ops/function/clip_func.py +384 -0
  1104. mindspore/ops/function/debug_func.py +69 -0
  1105. mindspore/ops/function/fft_func.py +31 -0
  1106. mindspore/ops/function/grad/__init__.py +34 -0
  1107. mindspore/ops/function/grad/grad_func.py +1424 -0
  1108. mindspore/ops/function/image_func.py +292 -0
  1109. mindspore/ops/function/linalg_func.py +416 -0
  1110. mindspore/ops/function/math_func.py +11877 -0
  1111. mindspore/ops/function/nn_func.py +8175 -0
  1112. mindspore/ops/function/other_func.py +114 -0
  1113. mindspore/ops/function/parameter_func.py +134 -0
  1114. mindspore/ops/function/random_func.py +1539 -0
  1115. mindspore/ops/function/reshard_func.py +102 -0
  1116. mindspore/ops/function/sparse_func.py +884 -0
  1117. mindspore/ops/function/sparse_unary_func.py +2422 -0
  1118. mindspore/ops/function/spectral_func.py +150 -0
  1119. mindspore/ops/function/vmap_func.py +116 -0
  1120. mindspore/ops/functional.py +454 -0
  1121. mindspore/ops/op_info_register.py +1572 -0
  1122. mindspore/ops/operations/__init__.py +717 -0
  1123. mindspore/ops/operations/_csr_ops.py +403 -0
  1124. mindspore/ops/operations/_custom_grad.py +181 -0
  1125. mindspore/ops/operations/_embedding_cache_ops.py +307 -0
  1126. mindspore/ops/operations/_grad_ops.py +3052 -0
  1127. mindspore/ops/operations/_infer_ops.py +19 -0
  1128. mindspore/ops/operations/_inner_ops.py +2567 -0
  1129. mindspore/ops/operations/_map_tensor_ops.py +112 -0
  1130. mindspore/ops/operations/_ms_kernel.py +601 -0
  1131. mindspore/ops/operations/_ocr_ops.py +379 -0
  1132. mindspore/ops/operations/_opaque_predicate_registry.py +41 -0
  1133. mindspore/ops/operations/_pyfunc_registry.py +58 -0
  1134. mindspore/ops/operations/_quant_ops.py +1844 -0
  1135. mindspore/ops/operations/_rl_inner_ops.py +1231 -0
  1136. mindspore/ops/operations/_scalar_ops.py +106 -0
  1137. mindspore/ops/operations/_sequence_ops.py +1155 -0
  1138. mindspore/ops/operations/_sparse_grad_ops.py +56 -0
  1139. mindspore/ops/operations/_tensor_array.py +359 -0
  1140. mindspore/ops/operations/_thor_ops.py +807 -0
  1141. mindspore/ops/operations/array_ops.py +6258 -0
  1142. mindspore/ops/operations/comm_ops.py +1996 -0
  1143. mindspore/ops/operations/control_ops.py +127 -0
  1144. mindspore/ops/operations/custom_ops.py +1065 -0
  1145. mindspore/ops/operations/debug_ops.py +646 -0
  1146. mindspore/ops/operations/image_ops.py +1041 -0
  1147. mindspore/ops/operations/inner_ops.py +697 -0
  1148. mindspore/ops/operations/linalg_ops.py +95 -0
  1149. mindspore/ops/operations/manually_defined/__init__.py +24 -0
  1150. mindspore/ops/operations/manually_defined/_inner.py +61 -0
  1151. mindspore/ops/operations/manually_defined/ops_def.py +2016 -0
  1152. mindspore/ops/operations/math_ops.py +5306 -0
  1153. mindspore/ops/operations/nn_ops.py +9669 -0
  1154. mindspore/ops/operations/other_ops.py +871 -0
  1155. mindspore/ops/operations/random_ops.py +1243 -0
  1156. mindspore/ops/operations/reshard_ops.py +53 -0
  1157. mindspore/ops/operations/rl_ops.py +288 -0
  1158. mindspore/ops/operations/sparse_ops.py +2753 -0
  1159. mindspore/ops/operations/spectral_ops.py +111 -0
  1160. mindspore/ops/primitive.py +1034 -0
  1161. mindspore/ops/signature.py +54 -0
  1162. mindspore/ops/silent_check.py +162 -0
  1163. mindspore/ops/vm_impl_registry.py +91 -0
  1164. mindspore/ops_generate/__init__.py +27 -0
  1165. mindspore/ops_generate/arg_dtype_cast.py +250 -0
  1166. mindspore/ops_generate/arg_handler.py +197 -0
  1167. mindspore/ops_generate/gen_aclnn_implement.py +263 -0
  1168. mindspore/ops_generate/gen_ops.py +1084 -0
  1169. mindspore/ops_generate/gen_ops_inner_prim.py +131 -0
  1170. mindspore/ops_generate/gen_pyboost_func.py +968 -0
  1171. mindspore/ops_generate/gen_utils.py +209 -0
  1172. mindspore/ops_generate/op_proto.py +138 -0
  1173. mindspore/ops_generate/pyboost_utils.py +354 -0
  1174. mindspore/ops_generate/template.py +239 -0
  1175. mindspore/parallel/__init__.py +28 -0
  1176. mindspore/parallel/_auto_parallel_context.py +1466 -0
  1177. mindspore/parallel/_cell_wrapper.py +91 -0
  1178. mindspore/parallel/_cost_model_context.py +700 -0
  1179. mindspore/parallel/_dp_allreduce_fusion.py +159 -0
  1180. mindspore/parallel/_offload_context.py +275 -0
  1181. mindspore/parallel/_parallel_serialization.py +533 -0
  1182. mindspore/parallel/_ps_context.py +242 -0
  1183. mindspore/parallel/_recovery_context.py +110 -0
  1184. mindspore/parallel/_tensor.py +660 -0
  1185. mindspore/parallel/_transformer/__init__.py +35 -0
  1186. mindspore/parallel/_transformer/layers.py +765 -0
  1187. mindspore/parallel/_transformer/loss.py +251 -0
  1188. mindspore/parallel/_transformer/moe.py +693 -0
  1189. mindspore/parallel/_transformer/op_parallel_config.py +222 -0
  1190. mindspore/parallel/_transformer/transformer.py +3119 -0
  1191. mindspore/parallel/_utils.py +600 -0
  1192. mindspore/parallel/algo_parameter_config.py +400 -0
  1193. mindspore/parallel/checkpoint_transform.py +643 -0
  1194. mindspore/parallel/cluster/__init__.py +15 -0
  1195. mindspore/parallel/cluster/process_entity/__init__.py +18 -0
  1196. mindspore/parallel/cluster/process_entity/_api.py +344 -0
  1197. mindspore/parallel/cluster/process_entity/_utils.py +126 -0
  1198. mindspore/parallel/cluster/run.py +136 -0
  1199. mindspore/parallel/mpi/__init__.py +14 -0
  1200. mindspore/parallel/mpi/_mpi_config.py +116 -0
  1201. mindspore/parallel/parameter_broadcast.py +152 -0
  1202. mindspore/parallel/shard.py +350 -0
  1203. mindspore/perf_msvcbuildinsights.dll +0 -0
  1204. mindspore/pgodb140.dll +0 -0
  1205. mindspore/pgort140.dll +0 -0
  1206. mindspore/profiler/__init__.py +27 -0
  1207. mindspore/profiler/common/__init__.py +14 -0
  1208. mindspore/profiler/common/exceptions/__init__.py +14 -0
  1209. mindspore/profiler/common/exceptions/error_code.py +83 -0
  1210. mindspore/profiler/common/exceptions/exceptions.py +286 -0
  1211. mindspore/profiler/common/process_pool.py +41 -0
  1212. mindspore/profiler/common/singleton.py +28 -0
  1213. mindspore/profiler/common/struct_type.py +118 -0
  1214. mindspore/profiler/common/util.py +444 -0
  1215. mindspore/profiler/common/validator/__init__.py +14 -0
  1216. mindspore/profiler/common/validator/validate_path.py +84 -0
  1217. mindspore/profiler/envprofiling.py +256 -0
  1218. mindspore/profiler/parser/__init__.py +14 -0
  1219. mindspore/profiler/parser/aicpu_data_parser.py +272 -0
  1220. mindspore/profiler/parser/ascend_analysis/__init__.py +14 -0
  1221. mindspore/profiler/parser/ascend_analysis/constant.py +53 -0
  1222. mindspore/profiler/parser/ascend_analysis/file_manager.py +159 -0
  1223. mindspore/profiler/parser/ascend_analysis/function_event.py +161 -0
  1224. mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +131 -0
  1225. mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +85 -0
  1226. mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +57 -0
  1227. mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +116 -0
  1228. mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +86 -0
  1229. mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +68 -0
  1230. mindspore/profiler/parser/ascend_cluster_generator.py +116 -0
  1231. mindspore/profiler/parser/ascend_communicate_generator.py +314 -0
  1232. mindspore/profiler/parser/ascend_flops_generator.py +116 -0
  1233. mindspore/profiler/parser/ascend_fpbp_generator.py +82 -0
  1234. mindspore/profiler/parser/ascend_hccl_generator.py +271 -0
  1235. mindspore/profiler/parser/ascend_integrate_generator.py +42 -0
  1236. mindspore/profiler/parser/ascend_memory_generator.py +185 -0
  1237. mindspore/profiler/parser/ascend_msprof_exporter.py +281 -0
  1238. mindspore/profiler/parser/ascend_msprof_generator.py +187 -0
  1239. mindspore/profiler/parser/ascend_op_generator.py +334 -0
  1240. mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
  1241. mindspore/profiler/parser/ascend_timeline_generator.py +543 -0
  1242. mindspore/profiler/parser/base_timeline_generator.py +489 -0
  1243. mindspore/profiler/parser/container.py +229 -0
  1244. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +684 -0
  1245. mindspore/profiler/parser/flops_parser.py +531 -0
  1246. mindspore/profiler/parser/framework_enum.py +111 -0
  1247. mindspore/profiler/parser/framework_parser.py +854 -0
  1248. mindspore/profiler/parser/framework_struct.py +61 -0
  1249. mindspore/profiler/parser/hccl_parser.py +573 -0
  1250. mindspore/profiler/parser/hwts_log_parser.py +122 -0
  1251. mindspore/profiler/parser/integrator.py +526 -0
  1252. mindspore/profiler/parser/memory_usage_parser.py +431 -0
  1253. mindspore/profiler/parser/minddata_analyzer.py +800 -0
  1254. mindspore/profiler/parser/minddata_parser.py +186 -0
  1255. mindspore/profiler/parser/minddata_pipeline_parser.py +299 -0
  1256. mindspore/profiler/parser/msadvisor_analyzer.py +82 -0
  1257. mindspore/profiler/parser/msadvisor_parser.py +240 -0
  1258. mindspore/profiler/parser/op_intermediate_parser.py +149 -0
  1259. mindspore/profiler/parser/optime_parser.py +250 -0
  1260. mindspore/profiler/parser/profiler_info.py +141 -0
  1261. mindspore/profiler/parser/step_trace_parser.py +666 -0
  1262. mindspore/profiler/profiling.py +2054 -0
  1263. mindspore/rewrite/__init__.py +29 -0
  1264. mindspore/rewrite/api/__init__.py +17 -0
  1265. mindspore/rewrite/api/node.py +519 -0
  1266. mindspore/rewrite/api/node_type.py +53 -0
  1267. mindspore/rewrite/api/pattern_engine.py +490 -0
  1268. mindspore/rewrite/api/scoped_value.py +181 -0
  1269. mindspore/rewrite/api/symbol_tree.py +497 -0
  1270. mindspore/rewrite/ast_helpers/__init__.py +25 -0
  1271. mindspore/rewrite/ast_helpers/ast_converter.py +143 -0
  1272. mindspore/rewrite/ast_helpers/ast_finder.py +404 -0
  1273. mindspore/rewrite/ast_helpers/ast_flattener.py +268 -0
  1274. mindspore/rewrite/ast_helpers/ast_modifier.py +605 -0
  1275. mindspore/rewrite/ast_helpers/ast_replacer.py +79 -0
  1276. mindspore/rewrite/common/__init__.py +19 -0
  1277. mindspore/rewrite/common/config.py +24 -0
  1278. mindspore/rewrite/common/error_log.py +39 -0
  1279. mindspore/rewrite/common/event.py +28 -0
  1280. mindspore/rewrite/common/namer.py +271 -0
  1281. mindspore/rewrite/common/namespace.py +118 -0
  1282. mindspore/rewrite/common/observable.py +44 -0
  1283. mindspore/rewrite/common/observer.py +54 -0
  1284. mindspore/rewrite/node/__init__.py +22 -0
  1285. mindspore/rewrite/node/call_function.py +95 -0
  1286. mindspore/rewrite/node/cell_container.py +139 -0
  1287. mindspore/rewrite/node/control_flow.py +113 -0
  1288. mindspore/rewrite/node/node.py +1428 -0
  1289. mindspore/rewrite/node/node_manager.py +283 -0
  1290. mindspore/rewrite/node/node_topological_manager.py +223 -0
  1291. mindspore/rewrite/parsers/__init__.py +29 -0
  1292. mindspore/rewrite/parsers/arguments_parser.py +63 -0
  1293. mindspore/rewrite/parsers/assign_parser.py +852 -0
  1294. mindspore/rewrite/parsers/attribute_parser.py +57 -0
  1295. mindspore/rewrite/parsers/class_def_parser.py +289 -0
  1296. mindspore/rewrite/parsers/constant_parser.py +104 -0
  1297. mindspore/rewrite/parsers/container_parser.py +88 -0
  1298. mindspore/rewrite/parsers/expr_parser.py +55 -0
  1299. mindspore/rewrite/parsers/for_parser.py +61 -0
  1300. mindspore/rewrite/parsers/function_def_parser.py +84 -0
  1301. mindspore/rewrite/parsers/if_parser.py +85 -0
  1302. mindspore/rewrite/parsers/module_parser.py +117 -0
  1303. mindspore/rewrite/parsers/parser.py +43 -0
  1304. mindspore/rewrite/parsers/parser_register.py +86 -0
  1305. mindspore/rewrite/parsers/return_parser.py +37 -0
  1306. mindspore/rewrite/parsers/while_parser.py +59 -0
  1307. mindspore/rewrite/sparsify/__init__.py +0 -0
  1308. mindspore/rewrite/sparsify/sparse_transformer.py +457 -0
  1309. mindspore/rewrite/sparsify/sparsify.py +112 -0
  1310. mindspore/rewrite/sparsify/utils.py +179 -0
  1311. mindspore/rewrite/symbol_tree/__init__.py +20 -0
  1312. mindspore/rewrite/symbol_tree/symbol_tree.py +1819 -0
  1313. mindspore/rewrite/symbol_tree/symbol_tree_builder.py +76 -0
  1314. mindspore/rewrite/symbol_tree/symbol_tree_dumper.py +142 -0
  1315. mindspore/run_check/__init__.py +20 -0
  1316. mindspore/run_check/_check_version.py +574 -0
  1317. mindspore/run_check/run_check.py +66 -0
  1318. mindspore/safeguard/__init__.py +18 -0
  1319. mindspore/safeguard/rewrite_obfuscation.py +531 -0
  1320. mindspore/swresample-4.dll +0 -0
  1321. mindspore/swscale-6.dll +0 -0
  1322. mindspore/tbbmalloc.dll +0 -0
  1323. mindspore/tinyxml2.dll +0 -0
  1324. mindspore/train/__init__.py +47 -0
  1325. mindspore/train/_utils.py +439 -0
  1326. mindspore/train/amp.py +817 -0
  1327. mindspore/train/anf_ir_pb2.py +1517 -0
  1328. mindspore/train/callback/__init__.py +44 -0
  1329. mindspore/train/callback/_backup_and_restore.py +117 -0
  1330. mindspore/train/callback/_callback.py +613 -0
  1331. mindspore/train/callback/_checkpoint.py +751 -0
  1332. mindspore/train/callback/_cluster_monitor.py +201 -0
  1333. mindspore/train/callback/_dataset_graph.py +150 -0
  1334. mindspore/train/callback/_early_stop.py +239 -0
  1335. mindspore/train/callback/_flops_collector.py +238 -0
  1336. mindspore/train/callback/_history.py +92 -0
  1337. mindspore/train/callback/_lambda_callback.py +80 -0
  1338. mindspore/train/callback/_landscape.py +1049 -0
  1339. mindspore/train/callback/_loss_monitor.py +107 -0
  1340. mindspore/train/callback/_lr_scheduler_callback.py +76 -0
  1341. mindspore/train/callback/_mindio_ttp.py +443 -0
  1342. mindspore/train/callback/_on_request_exit.py +195 -0
  1343. mindspore/train/callback/_reduce_lr_on_plateau.py +226 -0
  1344. mindspore/train/callback/_summary_collector.py +1184 -0
  1345. mindspore/train/callback/_time_monitor.py +141 -0
  1346. mindspore/train/checkpoint_pb2.py +233 -0
  1347. mindspore/train/data_sink.py +219 -0
  1348. mindspore/train/dataset_helper.py +688 -0
  1349. mindspore/train/lineage_pb2.py +1260 -0
  1350. mindspore/train/loss_scale_manager.py +213 -0
  1351. mindspore/train/memory_profiling_pb2.py +298 -0
  1352. mindspore/train/metrics/__init__.py +175 -0
  1353. mindspore/train/metrics/accuracy.py +133 -0
  1354. mindspore/train/metrics/auc.py +129 -0
  1355. mindspore/train/metrics/bleu_score.py +170 -0
  1356. mindspore/train/metrics/confusion_matrix.py +700 -0
  1357. mindspore/train/metrics/cosine_similarity.py +109 -0
  1358. mindspore/train/metrics/dice.py +116 -0
  1359. mindspore/train/metrics/error.py +175 -0
  1360. mindspore/train/metrics/fbeta.py +167 -0
  1361. mindspore/train/metrics/hausdorff_distance.py +333 -0
  1362. mindspore/train/metrics/loss.py +97 -0
  1363. mindspore/train/metrics/mean_surface_distance.py +189 -0
  1364. mindspore/train/metrics/metric.py +373 -0
  1365. mindspore/train/metrics/occlusion_sensitivity.py +225 -0
  1366. mindspore/train/metrics/perplexity.py +133 -0
  1367. mindspore/train/metrics/precision.py +160 -0
  1368. mindspore/train/metrics/recall.py +159 -0
  1369. mindspore/train/metrics/roc.py +223 -0
  1370. mindspore/train/metrics/root_mean_square_surface_distance.py +191 -0
  1371. mindspore/train/metrics/topk.py +167 -0
  1372. mindspore/train/mind_ir_pb2.py +1903 -0
  1373. mindspore/train/model.py +2176 -0
  1374. mindspore/train/node_strategy_pb2.py +653 -0
  1375. mindspore/train/print_pb2.py +184 -0
  1376. mindspore/train/profiling_parallel_pb2.py +151 -0
  1377. mindspore/train/serialization.py +3101 -0
  1378. mindspore/train/summary/__init__.py +23 -0
  1379. mindspore/train/summary/_lineage_adapter.py +41 -0
  1380. mindspore/train/summary/_summary_adapter.py +496 -0
  1381. mindspore/train/summary/_writer_pool.py +207 -0
  1382. mindspore/train/summary/enums.py +56 -0
  1383. mindspore/train/summary/summary_record.py +581 -0
  1384. mindspore/train/summary/writer.py +167 -0
  1385. mindspore/train/summary_pb2.py +1165 -0
  1386. mindspore/train/train_thor/__init__.py +20 -0
  1387. mindspore/train/train_thor/convert_utils.py +268 -0
  1388. mindspore/train/train_thor/dataset_helper.py +192 -0
  1389. mindspore/train/train_thor/model_thor.py +257 -0
  1390. mindspore/turbojpeg.dll +0 -0
  1391. mindspore/vcmeta.dll +0 -0
  1392. mindspore/vcomp140.dll +0 -0
  1393. mindspore/vcruntime140.dll +0 -0
  1394. mindspore/vcruntime140_1.dll +0 -0
  1395. mindspore/version.py +1 -0
  1396. mindspore-2.3.0.dist-info/METADATA +351 -0
  1397. mindspore-2.3.0.dist-info/RECORD +1400 -0
  1398. mindspore-2.3.0.dist-info/WHEEL +5 -0
  1399. mindspore-2.3.0.dist-info/entry_points.txt +4 -0
  1400. mindspore-2.3.0.dist-info/top_level.txt +1 -0
mindspore/context.py ADDED
@@ -0,0 +1,1976 @@
1
+ # Copyright 2020-2024 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+ """
16
+ The context of mindspore, used to configure the current execution environment,
17
+ includes the execution mode, execution backend and other feature switches.
18
+ """
19
+ from __future__ import absolute_import
20
+
21
+ import json
22
+ import os
23
+ import time
24
+ import threading
25
+ from collections import namedtuple
26
+ from types import FunctionType
27
+
28
+ from mindspore import log as logger
29
+ from mindspore._c_expression import MSContext, ms_ctx_param
30
+ from mindspore import _checkparam as Validator
31
+ from mindspore._checkparam import args_type_check
32
+ from mindspore.parallel._auto_parallel_context import _set_auto_parallel_context, _get_auto_parallel_context, \
33
+ _reset_auto_parallel_context
34
+ from mindspore.parallel._ps_context import _set_ps_context, _get_ps_context, _reset_ps_context, \
35
+ _need_reset_device_target_for_ps
36
+ from mindspore.parallel._offload_context import _set_offload_context, _get_offload_context
37
+ from mindspore.hal.device import is_initialized
38
+
39
+ __all__ = ['GRAPH_MODE', 'PYNATIVE_MODE', 'STRICT', 'COMPATIBLE', 'LAX', 'set_context', 'get_context',
40
+ 'set_auto_parallel_context', 'get_auto_parallel_context', 'reset_auto_parallel_context', 'ParallelMode',
41
+ 'set_ps_context', 'get_ps_context', 'reset_ps_context', 'set_offload_context', 'get_offload_context']
42
+
43
+ GRAPH_MODE = 0
44
+ PYNATIVE_MODE = 1
45
+ _DEVICE_APP_MEMORY_SIZE = 31 # The max memory size of graph plus variable.
46
+ _RE_PATTERN = r'[1-9][0-9]*(\.)?[0-9]*GB|0\.[0-9]*GB'
47
+ K_CONTEXT = None
48
+
49
+ # Enumerate for the property 'jit_syntax_level'.
50
+ STRICT = 0
51
+ COMPATIBLE = 1
52
+ LAX = 2
53
+
54
+ # Enumerate for the property 'debug_level'.
55
+ RELEASE = 0
56
+ DEBUG = 1
57
+
58
+
59
+ def _make_directory(path):
60
+ """Make directory."""
61
+ if path is None or not isinstance(path, str) or path.strip() == "":
62
+ raise ValueError(f"For 'context.set_context', the 'save_graphs_path' or the 'print_file_path' is invalid "
63
+ f"type, it should be Non-empty string, but got '{path}'.")
64
+
65
+ path = os.path.realpath(path)
66
+ logger.debug("The absolute path is %r", path)
67
+
68
+ if not os.path.exists(path):
69
+ logger.debug("The directory(%s) doesn't exist, will create it", path)
70
+ try:
71
+ os.makedirs(path)
72
+ except FileExistsError:
73
+ logger.debug("The directory(%s) already exist.", path)
74
+ except PermissionError as e:
75
+ logger.critical(f"No write permission on the directory '{path}'', error = {e}")
76
+ raise ValueError(e.__str__() + f"\nNo write permission on the directory '{path}'.")
77
+ return path
78
+
79
+
80
+ def _get_print_file_name(file_name):
81
+ """Add timestamp suffix to file name. Rename the file name: file_name + "." + time(seconds)."""
82
+ time_second = str(int(time.time()))
83
+ file_name = file_name + "." + time_second
84
+ if os.path.exists(file_name):
85
+ raise ValueError("For 'context.set_context', the argument 'print_file_path' {} already exists, "
86
+ "please check it".format(file_name))
87
+ return file_name
88
+
89
+
90
+ class _ThreadLocalInfo(threading.local):
91
+ """
92
+ Thread local Info used for store thread local attributes.
93
+ """
94
+
95
+ def __init__(self):
96
+ super(_ThreadLocalInfo, self).__init__()
97
+ self._reserve_class_name_in_scope = True
98
+ self.debug_runtime = False
99
+
100
+ @property
101
+ def reserve_class_name_in_scope(self):
102
+ """Get whether to save the network class name in the scope."""
103
+ return self._reserve_class_name_in_scope
104
+
105
+ @reserve_class_name_in_scope.setter
106
+ def reserve_class_name_in_scope(self, reserve_class_name_in_scope):
107
+ """Set whether to save the network class name in the scope."""
108
+ self._reserve_class_name_in_scope = reserve_class_name_in_scope
109
+
110
+
111
+ _ContextRecord = namedtuple(
112
+ "_ContextRecord", ["is_pynative_mode", "switch_context_fn"])
113
+
114
+
115
+ class _ContextSwitchInfo(threading.local):
116
+ """
117
+ Record of context switch information.
118
+
119
+ Args:
120
+ is_pynative (bool): Whether to adopt the PyNative mode.
121
+ """
122
+
123
+ def __init__(self, is_pynative):
124
+ super(_ContextSwitchInfo, self).__init__()
125
+ self.context_stack = []
126
+ if is_pynative:
127
+ self.push(True, None)
128
+
129
+ def push(self, is_pynative, switch_context_fn):
130
+ """
131
+ Push a context switch record onto the stack.
132
+
133
+ Args:
134
+ is_pynative (bool): Whether context switch to PyNative mode.
135
+ switch_context_fn (Function): A callable that executes the context switch.
136
+ """
137
+ if isinstance(switch_context_fn, FunctionType):
138
+ switch_context_fn()
139
+ self.context_stack.append(
140
+ _ContextRecord(is_pynative, switch_context_fn))
141
+
142
+ def pop(self):
143
+ self.context_stack.pop()
144
+
145
+
146
+ class _Context:
147
+ """
148
+ _Context is the environment in which operations are executed
149
+
150
+ Note:
151
+ Create a context through instantiating Context object is not recommended.
152
+ should use context() to get the context since Context is a singleton.
153
+ """
154
+ _instance = None
155
+ _instance_lock = threading.Lock()
156
+
157
+ def __new__(cls, *args, **kwargs):
158
+ if cls._instance is None:
159
+ cls._instance_lock.acquire()
160
+ cls._instance = object.__new__(cls)
161
+ cls._instance_lock.release()
162
+ return cls._instance
163
+
164
+ def __init__(self):
165
+ self._thread_local_info = _ThreadLocalInfo()
166
+ self._context_switches = _ContextSwitchInfo(False)
167
+ self._context_handle = MSContext.get_instance()
168
+ self._support_binary = False
169
+ self.enable_compile_cache = None
170
+ self._mode = PYNATIVE_MODE
171
+ self._jit_config = {}
172
+
173
+ def __getattribute__(self, attr):
174
+ value = object.__getattribute__(self, attr)
175
+ if attr == "_context_handle" and value is None:
176
+ raise ValueError("Get {} failed, please check whether 'env_config_path' is correct.".format(attr))
177
+ return value
178
+
179
+ def get_param(self, param):
180
+ return self._context_handle.get_param(param)
181
+
182
+ def set_param(self, param, value):
183
+ self._context_handle.set_param(param, value)
184
+
185
+ def get_mode(self):
186
+ """Get current mode."""
187
+ return self._mode
188
+
189
+ def get_jit_config(self):
190
+ """Get current jit_config."""
191
+ return self._jit_config
192
+
193
+ def set_mode(self, mode):
194
+ """
195
+ Switch between Graph mode and PyNative mode.
196
+
197
+ Args:
198
+ mode (int): GRAPH_MODE or PYNATIVE_MODE.
199
+ """
200
+ if mode == PYNATIVE_MODE:
201
+ if self.enable_debug_runtime:
202
+ self.set_backend_policy("vm")
203
+ parallel_mode = _get_auto_parallel_context("parallel_mode")
204
+ if parallel_mode not in (ParallelMode.DATA_PARALLEL, ParallelMode.STAND_ALONE, ParallelMode.AUTO_PARALLEL):
205
+ raise ValueError(f"Got {parallel_mode}, when the user enabled SEMI_AUTO_PARALELL, "
206
+ f"pynative mode dose not support, you should set either "
207
+ f"context.set_auto_parallel_context(parallel_mode='data_parallel'), "
208
+ f"context.set_auto_parallel_context(parallel_mode='stand_alone') "
209
+ f"or context.set_auto_parallel_context(parallel_mode='auto_parallel').")
210
+ self._context_switches.push(True, None)
211
+ elif mode == GRAPH_MODE:
212
+ if self.enable_debug_runtime:
213
+ self.set_backend_policy("ge")
214
+ self._context_switches.push(False, None)
215
+ else:
216
+ raise ValueError(f"For 'context.set_context', the argument 'mode' should be context.GRAPH_MODE (0) "
217
+ f"or context.PYNATIVE_MODE (1), but got {mode}.")
218
+ self.set_param(ms_ctx_param.mode, mode)
219
+ self._mode = mode
220
+
221
+ def set_jit_syntax_level(self, level):
222
+ """"Set the JIT syntax level for graph compiling"""
223
+ if level != STRICT and level != COMPATIBLE and level != LAX:
224
+ raise ValueError(f"For 'context.set_jit_syntax_level', the argument 'level' should be context.STRICT "
225
+ f"or context.LAX, but got {level}.")
226
+ self.set_param(ms_ctx_param.jit_syntax_level, level)
227
+
228
+ def set_debug_level(self, level):
229
+ """"Set the debug level for graph compiling"""
230
+ if level != RELEASE and level != DEBUG:
231
+ raise ValueError(f"For 'context.set_debug_level', the argument 'level' should be context.RELEASE "
232
+ f"or context.DEBUG, but got {level}.")
233
+ self.set_param(ms_ctx_param.debug_level, level)
234
+
235
+ def set_memory_optimize_level(self, memory_optimize_level):
236
+ """
237
+ The memory optimize level, support "O0", "O1".
238
+
239
+ Args:
240
+ target (str): "O0", "O1"
241
+ """
242
+ memory_optimize_levels = ["O0", "O1"]
243
+ if memory_optimize_level not in memory_optimize_levels:
244
+ raise ValueError(f"For 'context.set_context', the argument 'memory_optimize_level' must be one of "
245
+ f"{memory_optimize_levels}, but got {memory_optimize_level}.")
246
+ if memory_optimize_level == "O0":
247
+ self.set_param(ms_ctx_param.memory_optimize_level, 0)
248
+ else:
249
+ self.set_param(ms_ctx_param.memory_optimize_level, 1)
250
+
251
+ def set_memory_offload(self, memory_offload):
252
+ """
253
+ Enable memory offload or not, support "ON", "OFF".
254
+
255
+ Args:
256
+ memory_offload (str): "ON", "OFF"
257
+ """
258
+ memory_offload_options = ["ON", "OFF"]
259
+ if memory_offload not in memory_offload_options:
260
+ raise ValueError(f"For 'context.set_context', the argument 'memory_offload' must be one of "
261
+ f"{memory_offload_options}, but got {memory_offload}.")
262
+ if memory_offload == "ON":
263
+ self.set_param(ms_ctx_param.memory_offload, True)
264
+ else:
265
+ self.set_param(ms_ctx_param.memory_offload, False)
266
+
267
+ def set_deterministic(self, deterministic):
268
+ """
269
+ Enable model run in deterministic, and support the values "ON" and "OFF".
270
+
271
+ Args:
272
+ deterministic (str): "ON", "OFF"
273
+ """
274
+ deterministic_options = ["ON", "OFF"]
275
+ if deterministic not in deterministic_options:
276
+ raise ValueError(f"For 'context.set_context', the argument 'deterministic' must be one of "
277
+ f"{deterministic_options}, but got {deterministic}.")
278
+ self.set_param(ms_ctx_param.deterministic, deterministic)
279
+
280
+ def set_ascend_config(self, ascend_config):
281
+ """
282
+ Enable ascend config.
283
+
284
+ Args:
285
+ ascend_config (dict):
286
+ - precision_mode (str): "force_fp16", "allow_fp32_to_fp16", "allow_mix_precision",
287
+ "must_keep_origin_dtype", "force_fp32", "allow_fp32_to_bf16",
288
+ "allow_mix_precision_fp16" and "allow_mix_precision_bf16".
289
+ - jit_compile (bool): ``False`` and ``True``.
290
+ - atomic_clean_policy (int): ``0`` and ``1``. Default: ``1`` .
291
+ - op_precision_mode (str): precision mode config file path.
292
+ - op_debug_option (str): Enable debugging options for Ascend operators,
293
+ default not enabled, only supports ``"oom"`` currently.
294
+ ``"oom"``: Detect memory out of bounds.
295
+ - ge_options (dict): Global or session CANN options.
296
+ - exception_dump (str): Enable exception dump for Ascend operators. ``"0"`` , ``"1"`` and ``"2"``.
297
+ Default: ``"2"`` .
298
+ - parallel_speed_up_json_path(Union[str, None]): The path to the parallel speed up json file.
299
+ If its value is None or '', it does not take effect. Default None.
300
+ - host_scheduling_max_threshold(int): The host scheduling max threshold.
301
+ """
302
+ ascend_cfg_modes = {
303
+ 'precision_mode': ["force_fp16", "allow_fp32_to_fp16", "allow_mix_precision", "must_keep_origin_dtype",
304
+ "force_fp32", "allow_fp32_to_bf16", "allow_mix_precision_fp16",
305
+ "allow_mix_precision_bf16"],
306
+ 'jit_compile': [True, False],
307
+ 'atomic_clean_policy': [0, 1],
308
+ 'matmul_allow_hf32': [True, False],
309
+ 'conv_allow_hf32': [True, False],
310
+ 'exception_dump': ["0", "1", "2"],
311
+ 'op_precision_mode': (str,),
312
+ 'ge_options': (dict,),
313
+ 'parallel_speed_up_json_path': (str, None),
314
+ 'host_scheduling_max_threshold': (int,),
315
+ 'cur_step_num': (int,),
316
+ 'save_checkpoint_steps': (int,),
317
+ 'need_ckpt': (bool,),
318
+ 'last_triggered_step': (int,),
319
+ 'topo_order': (dict,),
320
+ 'op_debug_option': (str, None),
321
+ }
322
+ ascend_cfg_setters = {
323
+ 'precision_mode': self._get_ascend_config_setter('precision_mode'),
324
+ 'jit_compile': self._get_ascend_config_setter('jit_compile', lambda v: "1" if v else "0"),
325
+ 'atomic_clean_policy': self._get_ascend_config_setter('atomic_clean_policy', str),
326
+ 'matmul_allow_hf32': self._get_ascend_config_setter('matmul_allow_hf32', lambda v: "1" if v else "0"),
327
+ 'conv_allow_hf32': self._get_ascend_config_setter('conv_allow_hf32', lambda v: "1" if v else "0"),
328
+ 'exception_dump': self._get_ascend_config_setter('exception_dump'),
329
+ 'op_debug_option': self._set_op_debug_option,
330
+ 'op_precision_mode': self._set_op_precision_mode,
331
+ 'ge_options': self._set_ge_options,
332
+ 'parallel_speed_up_json_path': self._set_speedup_config_path,
333
+ 'host_scheduling_max_threshold': self._get_ascend_config_setter('host_scheduling_max_threshold', str),
334
+ 'cur_step_num': self._set_cur_step_num,
335
+ 'save_checkpoint_steps': self._set_save_checkpoint_steps,
336
+ 'need_ckpt': self._set_need_ckpt,
337
+ 'last_triggered_step': self._set_last_triggered_step,
338
+ 'topo_order': self._set_topo_order
339
+ }
340
+ ascend_cfg_set = tuple(ascend_cfg_modes.keys())
341
+ for ascend_key, ascend_value in ascend_config.items():
342
+ if ascend_key not in ascend_cfg_set:
343
+ raise ValueError(f"For 'context.set_context', the key of argument 'ascend_config' must be one of "
344
+ f"{ascend_cfg_set}, but got {ascend_key}.")
345
+ supported_modes = ascend_cfg_modes.get(ascend_key)
346
+ if isinstance(supported_modes, list) and ascend_value not in supported_modes:
347
+ raise ValueError(f"For 'ascend_config', the value of argument {ascend_key} must be one of "
348
+ f"{supported_modes}, but got {ascend_value}.")
349
+ if isinstance(supported_modes, tuple) and not isinstance(ascend_value, supported_modes):
350
+ raise TypeError(f"For 'ascend_config', the type of argument {ascend_key} must be one of "
351
+ f"{supported_modes}, but got {type(ascend_value)}.")
352
+ cfg_setter = ascend_cfg_setters.get(ascend_key)
353
+ cfg_setter(ascend_value)
354
+
355
+ def set_gpu_config(self, gpu_config):
356
+ """
357
+ Enable gpu config.
358
+
359
+ Args:
360
+ gpu_config (dict):
361
+
362
+ - conv_fprop_algo (str): "normal", "performance" or user specifies conv forward algorithm directly.
363
+ - conv_dgrad_algo (str): "normal", "performance" or user specifies conv data grad algorithm directly.
364
+ - conv_wgrad_algo (str): "normal", "performance" or user specifies conv weight grad algorithm directly.
365
+ - conv_allow_tf32 (bool): ``False`` and ``True``.
366
+ - matmul_allow_tf32 (bool): ``False`` and ``True``.
367
+ """
368
+
369
+ gpu_cfgs = {'conv_fprop_algo': ["normal", "performance", "implicit_gemm", "precomp_gemm", "gemm", "direct",
370
+ "fft", "fft_tiling", "winograd", "winograd_nonfused"],
371
+ 'conv_dgrad_algo': ["normal", "performance", "algo_0", "algo_1", "fft", "fft_tiling", "winograd",
372
+ "winograd_nonfused"],
373
+ 'conv_wgrad_algo': ["normal", "performance", "algo_0", "algo_1", "fft", "algo_3", "fft_tiling",
374
+ "winograd_nonfused"],
375
+ 'conv_allow_tf32': [True, False],
376
+ 'matmul_allow_tf32': [True, False]}
377
+ for gpu_key in gpu_config:
378
+ if gpu_key not in gpu_cfgs:
379
+ raise ValueError(f"For 'context.set_context', the key of argument 'gpu_config' must be one of "
380
+ f"{gpu_cfgs}, but got {gpu_key}.")
381
+ supported_value = gpu_cfgs.get(gpu_key)
382
+ if gpu_config[gpu_key] not in supported_value:
383
+ raise ValueError(f"For 'gpu_config', the value of argument {gpu_key} must be one of "
384
+ f"{supported_value}, but got {gpu_config[gpu_key]}.")
385
+ if gpu_key == 'conv_fprop_algo':
386
+ self.set_param(ms_ctx_param.conv_fprop_algo, gpu_config[gpu_key])
387
+ if gpu_key == 'conv_dgrad_algo':
388
+ self.set_param(ms_ctx_param.conv_dgrad_algo, gpu_config[gpu_key])
389
+ if gpu_key == 'conv_wgrad_algo':
390
+ self.set_param(ms_ctx_param.conv_wgrad_algo, gpu_config[gpu_key])
391
+ if gpu_key == 'conv_allow_tf32':
392
+ self.set_param(ms_ctx_param.conv_allow_tf32, gpu_config[gpu_key])
393
+ if gpu_key == 'matmul_allow_tf32':
394
+ self.set_param(ms_ctx_param.matmul_allow_tf32, gpu_config[gpu_key])
395
+
396
+ def set_jit_config(self, jit_config):
397
+ """
398
+ Enable jit config.
399
+
400
+ Args:
401
+ jit_config (dict):
402
+
403
+ - jit_level (str): "O0", "O1" or "O2" to control the compilation optimization level.
404
+ """
405
+ jit_cfgs = {'jit_level': ["O0", "O1", "O2"], 'infer_boost': ["on", "off"]}
406
+ key_args_map = {'jit_level': ms_ctx_param.jit_level, 'infer_boost': ms_ctx_param.infer_boost}
407
+ for jit_key in jit_config:
408
+ if jit_key not in jit_cfgs:
409
+ raise ValueError(f"For 'context.set_context', the key of argument 'jit_config' must be one of "
410
+ f"{jit_cfgs}, but got {jit_key}.")
411
+ supported_value = jit_cfgs.get(jit_key)
412
+ if jit_config[jit_key] not in supported_value:
413
+ raise ValueError(f"For 'jit_cfgs', the value of argument {jit_key} must be one of "
414
+ f"{supported_value}, but got {jit_config[jit_key]}.")
415
+ self._jit_config = jit_config
416
+ self.set_param(key_args_map[jit_key], jit_config[jit_key])
417
+
418
+ if 'infer_boost' in jit_config and jit_config['infer_boost'] == "on" and jit_config['jit_level'] != "O0":
419
+ raise ValueError(f"Only jit_level set O0 can set infer_boost to on.")
420
+
421
+ def set_backend_policy(self, policy):
422
+ success = self._context_handle.set_backend_policy(policy)
423
+ if not success:
424
+ raise RuntimeError("Backend policy must be one of values in ['ge', 'vm', 'ms']. "
425
+ "But got {}.".format(policy))
426
+
427
+ def set_save_graphs_path(self, save_graphs_path):
428
+ self.set_param(ms_ctx_param.save_graphs_path, _make_directory(save_graphs_path))
429
+
430
+ def set_device_target(self, target):
431
+ """
432
+ The target device to run, support "Ascend", "GPU", and "CPU".
433
+
434
+ Args:
435
+ target (str): "Ascend", "GPU", and "CPU".
436
+ """
437
+ valid_targets = ["CPU", "GPU", "Ascend", "Davinci"]
438
+ if target not in valid_targets:
439
+ raise ValueError(f"For 'context.set_context', the argument 'device_target' must be one of "
440
+ f"{valid_targets}, but got {target}.")
441
+ if target == "Davinci":
442
+ target = "Ascend"
443
+ logger.warning("The device 'Davinci' is deprecated and will be removed in the next version. "
444
+ "For 'context.set_context', please set the argument 'device_target' "
445
+ "to 'CPU', 'GPU' or 'Ascend',if you set it to 'Davinci', it will be automatically "
446
+ "changed to 'Ascend'.")
447
+ # If in Parameter Server mode, Ascend card should not be used by server and scheduler.
448
+ if _need_reset_device_target_for_ps(target):
449
+ logger.info("Reset device target to CPU when set_device_target.")
450
+ target = "CPU"
451
+ self.set_param(ms_ctx_param.device_target, target)
452
+ if self.enable_debug_runtime and target == "CPU":
453
+ self.set_backend_policy("vm")
454
+
455
+ def set_aoe_tune_mode(self, tune_mode):
456
+ """
457
+ Set aoe tune mode, support "online" and "offline".
458
+
459
+ Args:
460
+ tune_mode (str): "online" and "offline".
461
+ """
462
+ candidate = ["online", "offline"]
463
+ if tune_mode in candidate:
464
+ self.set_param(ms_ctx_param.aoe_tune_mode, tune_mode)
465
+ else:
466
+ raise ValueError(f"For 'context.set_context', the argument 'aoe_tune_mode' must be in "
467
+ f"['online', 'offline'], but got {tune_mode}.")
468
+
469
+ def set_aoe_config(self, aoe_config):
470
+ """
471
+ Enable aoe config.
472
+
473
+ Args:
474
+ aoe_config (dict):
475
+ - job_type (str): ``"1"``, ``"2"``. Default: ``"2"`` .
476
+ - ``"1"``: subgraph tuning.
477
+ - ``"2"``: operator tuning.
478
+ """
479
+
480
+ aoe_cfgs = {'job_type': ["1", "2"]}
481
+ for aoe_config_key in aoe_config:
482
+ if aoe_config_key not in aoe_cfgs:
483
+ raise ValueError(f"For 'context.set_context', the key of argument 'aoe_config' must be one of "
484
+ f"{aoe_cfgs}, but got {aoe_config_key}.")
485
+ supported_value = aoe_cfgs.get(aoe_config_key)
486
+ if aoe_config[aoe_config_key] not in supported_value:
487
+ raise ValueError(f"For 'aoe_config', the value of argument {aoe_config_key} must be one of "
488
+ f"{supported_value}, but got {aoe_config[aoe_config_key]}.")
489
+ if aoe_config_key == 'job_type':
490
+ self.set_param(ms_ctx_param.aoe_job_type, aoe_config[aoe_config_key])
491
+
492
+ def set_device_id(self, device_id):
493
+ if device_id < 0 or device_id > 4095:
494
+ raise ValueError(f"For 'context.set_context', the argument 'device_id' must be in range [0, 4095], "
495
+ f"but got {device_id}.")
496
+ self.set_param(ms_ctx_param.device_id, device_id)
497
+
498
+ def set_max_call_depth(self, max_call_depth):
499
+ if max_call_depth <= 0:
500
+ raise ValueError(f"For 'context.set_context', the argument 'max_call_depth' must be greater than 0, "
501
+ f"but got {max_call_depth}.")
502
+ self.set_param(ms_ctx_param.max_call_depth, max_call_depth)
503
+
504
+ def set_profiling_options(self, option):
505
+ if not isinstance(option, str):
506
+ raise TypeError("For 'context.set_context', the argument 'profiling_option' must be string, "
507
+ "but got {}.".format(type(option)))
508
+ self.set_param(ms_ctx_param.profiling_options, option)
509
+
510
+ def set_variable_memory_max_size(self, variable_memory_max_size):
511
+ """set values of variable_memory_max_size and graph_memory_max_size"""
512
+ logger.warning("For 'context.set_context', the parameter 'variable_memory_max_size' is deprecated, "
513
+ "and will be removed in a future "
514
+ "version. Please use parameter 'max_device_memory' instead.")
515
+ if not Validator.check_str_by_regular(variable_memory_max_size, _RE_PATTERN):
516
+ raise ValueError("For 'context.set_context', the argument 'variable_memory_max_size' should be in correct"
517
+ " format! It must be a string ending with 'GB', in addition to that, it must contain "
518
+ "only numbers or decimal points, such as \"5GB\" or \"3.5GB\", but got {}GB."
519
+ .format(variable_memory_max_size))
520
+ if float(variable_memory_max_size[:-2]) > _DEVICE_APP_MEMORY_SIZE:
521
+ raise ValueError("For 'context.set_context', the argument 'variable_memory_max_size' should not be "
522
+ "greater than 31GB, but got {}GB.".format(variable_memory_max_size))
523
+ variable_memory_max_size_ = variable_memory_max_size[:-2] + " * 1024 * 1024 * 1024"
524
+ graph_memory_max_size = _DEVICE_APP_MEMORY_SIZE - int(variable_memory_max_size[:-2])
525
+ graph_memory_max_size_ = str(graph_memory_max_size) + " * 1024 * 1024 * 1024"
526
+ self.set_param(ms_ctx_param.variable_memory_max_size, variable_memory_max_size_)
527
+ self.set_param(ms_ctx_param._graph_memory_max_size, graph_memory_max_size_)
528
+
529
+ def set_max_device_memory(self, max_device_memory):
530
+ if not Validator.check_str_by_regular(max_device_memory, _RE_PATTERN):
531
+ raise ValueError("For 'context.set_context', the argument 'max_device_memory' should be in correct "
532
+ " format! It must be a string ending with 'GB', in addition to that, it must contain "
533
+ "only numbers or decimal points, such as \"5GB\" or \"3.5GB\", but got {}."
534
+ .format(max_device_memory))
535
+ max_device_memory_value = float(max_device_memory[:-2])
536
+ if max_device_memory_value == 0:
537
+ raise ValueError("For 'context.set_context', the argument 'max_device_memory' should not be \"0GB\".")
538
+ self.set_param(ms_ctx_param.max_device_memory, max_device_memory_value)
539
+
540
+ def set_mempool_block_size(self, mempool_block_size):
541
+ """Set the block size of memory pool."""
542
+ global_jit_config = get_jit_config()
543
+ is_force_kbk = False
544
+ if global_jit_config:
545
+ is_force_kbk = global_jit_config.get('jit_level') == "O0" or global_jit_config.get('jit_level') == "O1"
546
+ if _get_mode() == GRAPH_MODE and not is_force_kbk:
547
+ logger.warning("Graph mode doesn't support to set parameter 'mempool_block_size' of context currently, "
548
+ "you can use context.set_context to set pynative mode or set jit_level=O0/O1.")
549
+ return
550
+ if not Validator.check_str_by_regular(mempool_block_size, _RE_PATTERN):
551
+ raise ValueError("For 'context.set_context', the argument 'mempool_block_size' should be in "
552
+ "correct format! Such as \"10GB\", "
553
+ "but got {}".format(mempool_block_size))
554
+ mempool_block_size_value = float(mempool_block_size[:-2])
555
+ if mempool_block_size_value < 1.0:
556
+ raise ValueError("For 'context.set_context', the argument 'mempool_block_size' should be "
557
+ "greater or equal to \"1GB\", "
558
+ "but got {}GB".format(float(mempool_block_size[:-2])))
559
+ self.set_param(ms_ctx_param.mempool_block_size, mempool_block_size_value)
560
+
561
+ def set_print_file_path(self, file_path):
562
+ """Add timestamp suffix to file name. Sets print file path."""
563
+ print_file_path = os.path.realpath(file_path)
564
+ if os.path.isdir(print_file_path):
565
+ raise IOError("For 'context.set_context', the argument 'print_file_path' should be file path, "
566
+ "but got directory {}.".format(file_path))
567
+
568
+ if os.path.exists(print_file_path):
569
+ _path, _file_name = os.path.split(print_file_path)
570
+ path = _make_directory(_path)
571
+ file_name = _get_print_file_name(_file_name)
572
+ full_file_name = os.path.join(path, file_name)
573
+ else:
574
+ full_file_name = print_file_path
575
+ self.set_param(ms_ctx_param.print_file_path, full_file_name)
576
+
577
+ def set_env_config_path(self, env_config_path):
578
+ """Check and set env_config_path."""
579
+ if not self._context_handle.enable_dump_ir():
580
+ raise ValueError("For 'context.set_context', the argument 'env_config_path' is not supported, please "
581
+ "enable ENABLE_DUMP_IR with '-D on' and recompile source firstly.")
582
+ env_config_path = os.path.realpath(env_config_path)
583
+ if not os.path.isfile(env_config_path):
584
+ raise ValueError("For 'context.set_context', the 'env_config_path' file %r is not exists, "
585
+ "please check whether 'env_config_path' is correct." % env_config_path)
586
+ try:
587
+ with open(env_config_path, 'r') as f:
588
+ json.load(f)
589
+ except (TypeError, ValueError) as exo:
590
+ raise ValueError(str(exo) + "\nFor 'context.set_context', open or load the 'env_config_path' file {} "
591
+ "failed, please check whether 'env_config_path' is json file and correct, "
592
+ "or may not have permission to read it.".format(env_config_path)) from exo
593
+ self.set_param(ms_ctx_param.env_config_path, env_config_path)
594
+
595
+ def set_runtime_num_threads(self, runtime_num_threads):
596
+ """Check and set runtime_num_threads."""
597
+ if runtime_num_threads < 0:
598
+ raise ValueError("The num of thread must bigger than or equal to 0.")
599
+ self.set_param(ms_ctx_param.runtime_num_threads, runtime_num_threads)
600
+
601
+ def set_op_timeout(self, op_timeout):
602
+ """Set the maximum duration of executing an operator in seconds."""
603
+ if op_timeout < 0:
604
+ raise ValueError("The num of op exe timeout must bigger than or equal to 0.")
605
+ self.set_param(ms_ctx_param.op_timeout, op_timeout)
606
+
607
+ def set_inter_op_parallel_num(self, inter_op_parallel_num):
608
+ """Check and set inter_op_parallel_num."""
609
+ if inter_op_parallel_num < 0:
610
+ raise ValueError("The num of parallel thread must bigger than or equal to 0.")
611
+ self.set_param(ms_ctx_param.inter_op_parallel_num, inter_op_parallel_num)
612
+
613
+ setters = {
614
+ 'mode': set_mode,
615
+ 'save_graphs_path': set_save_graphs_path,
616
+ 'device_target': set_device_target,
617
+ 'aoe_tune_mode': set_aoe_tune_mode,
618
+ 'device_id': set_device_id,
619
+ 'max_call_depth': set_max_call_depth,
620
+ 'profiling_options': set_profiling_options,
621
+ 'variable_memory_max_size': set_variable_memory_max_size,
622
+ 'max_device_memory': set_max_device_memory,
623
+ 'mempool_block_size': set_mempool_block_size,
624
+ 'print_file_path': set_print_file_path,
625
+ 'env_config_path': set_env_config_path,
626
+ 'inter_op_parallel_num': set_inter_op_parallel_num,
627
+ 'runtime_num_threads': set_runtime_num_threads,
628
+ 'memory_optimize_level': set_memory_optimize_level,
629
+ 'op_timeout': set_op_timeout,
630
+ 'memory_offload': set_memory_offload,
631
+ 'deterministic': set_deterministic,
632
+ 'ascend_config': set_ascend_config,
633
+ 'jit_syntax_level': set_jit_syntax_level,
634
+ 'debug_level': set_debug_level,
635
+ 'gpu_config': set_gpu_config,
636
+ 'aoe_config': set_aoe_config,
637
+ 'jit_config': set_jit_config,
638
+ }
639
+
640
+ @property
641
+ def reserve_class_name_in_scope(self):
642
+ """Get whether to save the network class name in the scope."""
643
+ return self._thread_local_info.reserve_class_name_in_scope
644
+
645
+ @reserve_class_name_in_scope.setter
646
+ def reserve_class_name_in_scope(self, reserve_class_name_in_scope):
647
+ """Set whether to save the network class name in the scope."""
648
+ if not isinstance(reserve_class_name_in_scope, bool):
649
+ raise ValueError("For 'context.set_context', the type of the property 'reserve_class_name_in_scope' must "
650
+ "be bool, but got {}.".format(type(reserve_class_name_in_scope)))
651
+ self._thread_local_info.reserve_class_name_in_scope = reserve_class_name_in_scope
652
+
653
+ @property
654
+ def enable_ge(self):
655
+ return self._context_handle.get_backend_policy() == 'ge'
656
+
657
+ @property
658
+ def enable_debug_runtime(self):
659
+ return self._thread_local_info.debug_runtime
660
+
661
+ @enable_debug_runtime.setter
662
+ def enable_debug_runtime(self, enable):
663
+ thread_info = self._thread_local_info
664
+ thread_info.debug_runtime = enable
665
+
666
+ @property
667
+ def support_binary(self):
668
+ """Whether support run .pyc or .so in graph mode."""
669
+ return self._support_binary
670
+
671
+ @support_binary.setter
672
+ def support_binary(self, support: bool):
673
+ if not isinstance(support, bool):
674
+ raise TypeError(f"The attribute 'support_binary' should be a bool, but got {type(support)}.")
675
+ self._support_binary = support
676
+
677
+ def _get_ascend_config_setter(self, ascend_key, trans_fn=None):
678
+ def _config_setter(ascend_value):
679
+ self.set_param(ms_ctx_param.__members__[ascend_key], trans_fn(ascend_value))
680
+
681
+ if trans_fn is None:
682
+ trans_fn = lambda x: x
683
+ return _config_setter
684
+
685
+ def _set_op_debug_option(self, option_value):
686
+ valid_order = {'oom'}
687
+ if not isinstance(option_value, str):
688
+ raise TypeError(f"For 'ascend_config', the type of 'op_debug_option' must be str, "
689
+ f"but got {type(option_value)}.")
690
+ if option_value not in valid_order:
691
+ raise ValueError(f"For 'ascend_config', the 'op_debug_option' supports being set to 'oom' currently, "
692
+ f"but got {option_value}.")
693
+ self.set_param(ms_ctx_param.op_debug_option, option_value)
694
+
695
+ def _set_op_precision_mode(self, ascend_value):
696
+ op_precision_path = ascend_value
697
+ real_path = os.path.realpath(op_precision_path)
698
+ if not os.path.exists(real_path):
699
+ raise ValueError(f"For 'ascend_config', the 'op_precision_mode' is invalid path, "
700
+ f"got '{op_precision_path}'.")
701
+ self.set_param(ms_ctx_param.op_precision_mode, ascend_value)
702
+
703
+ def _set_ge_options(self, ge_options):
704
+ """Set ge options."""
705
+ for level, options in ge_options.items():
706
+ if level not in ['global', 'session']:
707
+ raise ValueError(f"For 'ascend_config', the key of ge_options must be one of "
708
+ f"('global', 'session'), but got {level}.")
709
+
710
+ if not isinstance(options, dict):
711
+ raise TypeError(f"For 'ge_options', the type of {level} options must be dict, "
712
+ f"but got {type(options)}. The error options: {options}.")
713
+
714
+ for key, value in options.items():
715
+ if not isinstance(key, str):
716
+ raise TypeError(f"For 'ge_options', the type of key and value must be str, "
717
+ f"but got {type(key)}. The error key is {key}.")
718
+ if not isinstance(value, str):
719
+ raise TypeError(f"For 'ge_options', the type of key and value must be str, "
720
+ f"but got {type(value)}. The error value is {value}")
721
+
722
+ options_str = json.dumps(ge_options)
723
+ self.set_param(ms_ctx_param.ge_options, options_str)
724
+
725
+ def _set_topo_order(self, topo_order):
726
+ """
727
+ Set topo order.
728
+
729
+ Args:
730
+ topo_order (dict):
731
+ key: str, the name of the graph.
732
+ value: str, the topo order of the graph, should be one of 'dfs', 'bfs', 'rdfs'.
733
+ """
734
+ valid_order = {'dfs', 'bfs', 'rdfs'}
735
+ if not isinstance(topo_order, dict):
736
+ raise TypeError(f"For 'ascend_config', the 'topo_order' should be a dict, "
737
+ f"got '{type(topo_order)}'.")
738
+ for k, v in topo_order.items():
739
+ if not isinstance(k, str):
740
+ raise TypeError("key {} is not a str".format(k))
741
+ if v not in valid_order:
742
+ raise ValueError("value {} should be one of {}.".format(v, valid_order))
743
+
744
+ options_str = json.dumps(topo_order)
745
+ self.set_param(ms_ctx_param.topo_order, options_str)
746
+
747
+ def _set_need_ckpt(self, need_ckpt):
748
+ """Set need ckpt flag"""
749
+ if not isinstance(need_ckpt, bool):
750
+ raise TypeError(f"For step num, the value type should be int, but got {type(need_ckpt)}, {need_ckpt}")
751
+ self.set_param(ms_ctx_param.need_ckpt, need_ckpt)
752
+
753
+ def _set_cur_step_num(self, step_num):
754
+ """set current step num at every step begin"""
755
+ if not isinstance(step_num, int):
756
+ raise TypeError(f"For step num, the value type should be int, but got {type(step_num)}, {step_num}")
757
+ self.set_param(ms_ctx_param.cur_step_num, step_num)
758
+
759
+ def _set_save_checkpoint_steps(self, steps):
760
+ """set save checkpoint steps before run"""
761
+ if not isinstance(steps, int):
762
+ raise TypeError(f"For step num, the value type should be int, but got {type(steps)}, {steps}")
763
+ self.set_param(ms_ctx_param.save_checkpoint_steps, steps)
764
+
765
+ def _set_last_triggered_step(self, step):
766
+ """set last triggered save ckpt steps before run"""
767
+ if not isinstance(step, int):
768
+ raise TypeError(f"For step num, the value type should be int, but got {type(step)}, {step}")
769
+ self.set_param(ms_ctx_param.last_triggered_step, step)
770
+
771
+ def _set_speedup_config_path(self, speedup_config_path):
772
+ """"Check and set speedup config for auto parallel."""
773
+ if speedup_config_path is None or speedup_config_path == "":
774
+ return
775
+ speedup_config_real_path = os.path.abspath(speedup_config_path)
776
+ if not os.path.exists(speedup_config_real_path):
777
+ raise ValueError(f"For 'ascend_config', the path to parallel_speed_up_json: "
778
+ f"{speedup_config_real_path} does not exist, please check whether the "
779
+ f"'parallel_speed_up_json_path' is correct.")
780
+ try:
781
+ valid_option = {"recompute_comm_overlap": (ms_ctx_param.recompute_comm_overlap, bool),
782
+ "matmul_grad_comm_overlap": (ms_ctx_param.matmul_grad_comm_overlap, bool),
783
+ "enable_task_opt": (ms_ctx_param.enable_task_opt, bool),
784
+ "enable_grad_comm_opt": (ms_ctx_param.enable_grad_comm_opt, bool),
785
+ "recompute_allgather_overlap_fagrad":
786
+ (ms_ctx_param.recompute_allgather_overlap_fagrad, bool),
787
+ "interleaved_matmul_comm": (ms_ctx_param.interleaved_matmul_comm, bool),
788
+ "bias_add_comm_swap": (ms_ctx_param.bias_add_comm_swap, bool),
789
+ "enable_opt_shard_comm_opt": (ms_ctx_param.enable_opt_shard_comm_opt, bool),
790
+ "enable_begin_end_inline_opt": (ms_ctx_param.enable_begin_end_inline_opt, bool),
791
+ "enable_concat_eliminate_opt": (ms_ctx_param.enable_concat_eliminate_opt, bool),
792
+ "interleaved_layernorm_comm": (ms_ctx_param.interleaved_layernorm_comm, bool),
793
+ "compute_communicate_fusion_level":
794
+ (ms_ctx_param.compute_communicate_fusion_level, int),
795
+ "enable_flash_attention_load_balance":
796
+ (ms_ctx_param.enable_flash_attention_load_balance, bool)}
797
+ with open(speedup_config_real_path, 'r') as f:
798
+ speedup_config = json.load(f)
799
+ for key, value in speedup_config.items():
800
+ if not isinstance(key, str):
801
+ raise TypeError("key {} is not a str".format(key))
802
+ if key not in valid_option:
803
+ raise ValueError("key {} should be one of {}.".format(key, valid_option.keys()))
804
+ set_func, valid_type = valid_option.get(key)
805
+ if not isinstance(value, valid_type):
806
+ raise TypeError(f"The value type of {key} must be {valid_type}, "
807
+ f"but got value is {value} and type is {type(value)}.")
808
+ self.set_param(set_func, value)
809
+ except (TypeError, ValueError) as exo:
810
+ raise ValueError(str(exo) + "\nFor 'context.set_context', "
811
+ "open or load the 'speedup_config_path' file {} "
812
+ "failed, please check whether 'speedup_config_path' is json file and correct, "
813
+ "or may not have permission to read it.".format(speedup_config_real_path)) \
814
+ from exo
815
+
816
+
817
+ def _context():
818
+ """
819
+ Get the global _context, if context is not created, create a new one.
820
+
821
+ Returns:
822
+ _Context, the global context in PyNative mode.
823
+ """
824
+ global K_CONTEXT
825
+ if K_CONTEXT is None:
826
+ default_backend = 'debug'
827
+ try:
828
+ from mindspore import default_config
829
+ default_backend = default_config.__backend__
830
+ except ImportError:
831
+ logger.error("import default config fail")
832
+ K_CONTEXT = _Context()
833
+ K_CONTEXT.enable_debug_runtime = False
834
+ if default_backend == 'debug':
835
+ K_CONTEXT.enable_debug_runtime = True
836
+ default_backend = 'vm'
837
+ K_CONTEXT.set_backend_policy(default_backend)
838
+ return K_CONTEXT
839
+
840
+
841
+ @args_type_check(device_num=int, global_rank=int, gradients_mean=bool, gradient_fp32_sync=bool, parallel_mode=str,
842
+ auto_parallel_search_mode=str, search_mode=str, parameter_broadcast=bool, strategy_ckpt_load_file=str,
843
+ strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool, enable_alltoall=bool,
844
+ all_reduce_fusion_config=list, pipeline_stages=int, pipeline_segments=int,
845
+ pipeline_result_broadcast=bool, parallel_optimizer_config=dict,
846
+ pipeline_config=dict,
847
+ comm_fusion=dict, strategy_ckpt_config=dict, force_fp32_communication=bool)
848
+ def set_auto_parallel_context(**kwargs):
849
+ r"""
850
+ Set auto parallel context, only data parallel supported on CPU.
851
+
852
+ Note:
853
+ Attribute name is required for setting attributes.
854
+ If a program has tasks on different parallel modes, before setting a new parallel mode for the
855
+ next task, interface :func:`mindspore.reset_auto_parallel_context` should be called to reset
856
+ the configuration.
857
+ Setting or changing parallel modes must be called before creating any Initializer, otherwise,
858
+ it may have RuntimeError when compiling the network.
859
+
860
+ Some configurations are parallel mode specific, see the below table for details:
861
+
862
+ =========================== ===========================
863
+ Common AUTO_PARALLEL
864
+ =========================== ===========================
865
+ device_num gradient_fp32_sync
866
+ global_rank loss_repeated_mean
867
+ gradients_mean search_mode
868
+ parallel_mode parameter_broadcast
869
+ all_reduce_fusion_config strategy_ckpt_load_file
870
+ enable_parallel_optimizer strategy_ckpt_save_file
871
+ parallel_optimizer_config dataset_strategy
872
+ enable_alltoall pipeline_stages
873
+ pipeline_config auto_parallel_search_mode
874
+ force_fp32_communication pipeline_result_broadcast
875
+ \ comm_fusion
876
+ \ strategy_ckpt_config
877
+ \ group_ckpt_save_file
878
+ \ auto_pipeline
879
+ =========================== ===========================
880
+
881
+ Args:
882
+ device_num (int): Available device number, the value must be in [1, 4096]. Default: ``1`` .
883
+ global_rank (int): Global rank id, the value must be in [0, 4095]. Default: ``0`` .
884
+ gradients_mean (bool): Whether to perform mean operator after allreduce of gradients.
885
+ "stand_alone" do not support gradients_mean. Default: ``False`` .
886
+ gradient_fp32_sync (bool): Run allreduce of gradients in fp32. "stand_alone", "data_parallel"
887
+ and "hybrid_parallel" do not support gradient_fp32_sync. Default: ``True`` .
888
+ loss_repeated_mean (bool) - Indicates whether the mean operator is executed backwards when the
889
+ calculation is repeated. Default: ``True`` .
890
+ parallel_mode (str): There are five kinds of parallel modes, ``"stand_alone"`` , ``"data_parallel"`` ,
891
+ ``"hybrid_parallel"`` , ``"semi_auto_parallel"`` and ``"auto_parallel"`` . Note the pynative mode
892
+ only supports the ``"stand_alone"`` and ``"data_parallel"`` mode. Default: ``"stand_alone"`` .
893
+
894
+ - stand_alone: Only one processor is working.
895
+
896
+ - data_parallel: Distributes the data across different processors.
897
+
898
+ - hybrid_parallel: Achieves data parallelism and model parallelism manually.
899
+
900
+ - semi_auto_parallel: Achieves data and model parallelism by setting parallel strategies.
901
+
902
+ - auto_parallel: Achieving parallelism automatically.
903
+ search_mode (str): There are three kinds of shard strategy search modes: ``"recursive_programming"`` ,
904
+ ``"sharding_propagation"`` and ``"dynamic_programming"`` (Not recommended).
905
+ Default: ``"recursive_programming"`` .
906
+
907
+ - recursive_programming: Recursive programming search mode. In order to obtain optimal performance,
908
+ it is recommended that users set the batch size to be greater than or equal to the product of
909
+ the number of devices and the number of multi-copy parallelism.
910
+
911
+ - sharding_propagation: Propagate shardings from configured ops to non-configured ops.
912
+
913
+ - dynamic_programming: Dynamic programming search mode.
914
+ auto_parallel_search_mode (str): This is the old version of 'search_mode'. Here, remaining this attribute is
915
+ for forward compatibility, and this attribute will be deleted in a future MindSpore version.
916
+ parameter_broadcast (bool): Whether to broadcast parameters before training. Before training, in order to have
917
+ the same network initialization parameter values for all devices, broadcast the parameters
918
+ on device 0 to other devices. Parameter broadcasting in different parallel modes is different,
919
+ ``data_parallel`` mode, all parameters are broadcast except for the parameter whose attribute
920
+ layerwise_parallel is ``True`` . ``Hybrid_parallel`` , ``semi_auto_parallel`` and
921
+ ``auto_parallel mode`` , the segmented parameters do not participate in broadcasting.
922
+ Default: ``False`` .
923
+ strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. The parameter is not to be
924
+ recommended currently, it is better using 'strategy_ckpt_config' to replace it. Default: ``''``
925
+ strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. The parameter is not to be
926
+ recommended currently, it is better using 'strategy_ckpt_config' to replace it. Default: ``''``
927
+ full_batch (bool): If you load whole batch datasets in ``auto_parallel`` mode, this parameter
928
+ should be set as ``True`` . Default: ``False`` . The interface is not to be recommended
929
+ currently, it is better using 'dataset_strategy' to replace it.
930
+ dataset_strategy (Union[str, tuple]): Dataset sharding strategy. Default: ``"data_parallel"`` .
931
+ dataset_strategy="data_parallel" is equal to full_batch=False, dataset_strategy="full_batch" is
932
+ equal to full_batch=True. For execution mode is 'GRAPH_MODE' and dataset load into net by model
933
+ parallel strategy likes ds_stra ((1, 8), (1, 8)), it requires using
934
+ set_auto_parallel_context(dataset_strategy=ds_stra).
935
+ enable_parallel_optimizer (bool): This is a developing feature, which shards the weight update computation for
936
+ data parallel training in the benefit of time and memory saving. Currently, auto and semi auto
937
+ parallel mode support all optimizers in both Ascend and GPU. Data parallel mode only supports
938
+ `Lamb` and `AdamWeightDecay` in Ascend . Default: ``False`` .
939
+ force_fp32_communication (bool): A switch that determines whether reduce operators (AllReduce, ReduceScatter)
940
+ are forced to use the fp32 data type for communication during communication. True is the enable
941
+ switch. Default: ``False`` .
942
+ enable_alltoall (bool): A switch that allows AllToAll operators to be generated during communication. If its
943
+ value is ``False`` , there will be a combination of operators such as AllGather, Split and
944
+ Concat instead of AllToAll. Default: ``False`` .
945
+ all_reduce_fusion_config (list): Set allreduce fusion strategy by parameters indices. Only support ReduceOp.SUM
946
+ and HCCL_WORLD_GROUP/NCCL_WORLD_GROUP. No Default, if it is not set, the fusion is closed.
947
+ pipeline_stages (int): Set the stage information for pipeline parallel. This indicates how the devices are
948
+ distributed alone in the pipeline. The total devices will be divided into 'pipeline_stags'
949
+ stages.
950
+ Default: ``1`` .
951
+ pipeline_result_broadcast (bool): A switch that broadcast the last stage result to all other stage in pipeline
952
+ parallel inference. Default: ``False`` .
953
+ pipeline_config (dict): A dict contains the keys and values for setting the pipeline parallelism configuration.
954
+ It supports the following keys:
955
+
956
+ - pipeline_interleave(bool): Indicates whether to enable the interleaved execution mode.
957
+ - pipeline_scheduler(str): Indicates the scheduling mode for pipeline parallelism. Only support
958
+ ``gpipe/1f1b``.
959
+ parallel_optimizer_config (dict): A dict contains the keys and values for setting the parallel optimizer
960
+ configure. The configure provides more detailed behavior control about parallel training
961
+ when parallel optimizer is enabled. The configure will be effective when we use
962
+ mindspore.set_auto_parallel_context(enable_parallel_optimizer=True).
963
+ It supports the following keys.
964
+
965
+ - gradient_accumulation_shard(bool): If ``true`` , the accumulation gradient parameters will be
966
+ sharded across the data parallel devices. This will
967
+ introduce additional communication(ReduceScatter) at
968
+ each step when accumulate the gradients, but saves a
969
+ lot of device memories, thus can make model be trained
970
+ with larger batch size. This configure is effective only
971
+ when the model runs on pipeline training or gradient
972
+ accumulation with data parallel. Default ``False`` .
973
+
974
+ - parallel_optimizer_threshold(int): Set the threshold of parallel optimizer. When parallel
975
+ optimizer is enabled, parameters with size smaller than this threshold will not be sharded
976
+ across the devices. Parameter size = shape[0] \* ... \* shape[n] \* size(dtype). Non-negative.
977
+ Unit: KB. Default: ``64`` .
978
+
979
+ - optimizer_weight_shard_size(int): Set the optimizer weight shard group size, if you want to
980
+ specific the maximum group size across devices when the parallel optimizer is enabled.
981
+ The numerical range can be (0, device_num]. If pipeline parallel is enabled, the numerical
982
+ range is (0, device_num/stage]. If the size of data parallel communication domain
983
+ of the parameter cannot be divided by `optimizer_weight_shard_size`, then the specified
984
+ communication group size will not take effect. Default value is ``-1`` , which means the
985
+ optimizer weight shard group size will be the size of data parallel group of each parameter.
986
+
987
+ comm_fusion (dict): A dict contains the types and configurations for setting the communication fusion. each
988
+ communication fusion config has two keys: "mode" and "config".
989
+ It supports following communication fusion types and configurations:
990
+
991
+ - openstate: Whether turn on the communication fusion or not. If `openstate` is ``True`` ,
992
+ turn on the communication fusion, otherwise, turn off the communication fusion.
993
+ Default: ``True`` .
994
+
995
+ - allreduce: If communication fusion type is `allreduce`. The `mode` contains: `auto`, `size`
996
+ and `index`. In `auto` mode, AllReduce fusion is configured by gradients size and the default
997
+ fusion threshold is `64` MB. In 'size' mode, AllReduce fusion is configured by gradients size
998
+ manually, and the fusion threshold must be larger than `0` MB. In `index` mode, it is same as
999
+ `all_reduce_fusion_config`.
1000
+
1001
+ - allgather: If communication fusion type is `allgather`. The `mode` contains: `auto`, `size`.
1002
+ In `auto` mode, AllGather fusion is configured by gradients size, and the default fusion
1003
+ threshold is `64` MB. In 'size' mode, AllGather fusion is configured by gradients size
1004
+ manually, and the fusion threshold must be larger than `0` MB.
1005
+
1006
+ - reducescatter: If communication fusion type is `reducescatter`. The `mode` contains: `auto`
1007
+ and `size`. Config is same as `allgather`.
1008
+
1009
+ strategy_ckpt_config (dict): A dict contains the configurations for setting the parallel strategy file. This
1010
+ interface contains the functions of parameter `strategy_ckpt_load_file` and
1011
+ `strategy_ckpt_save_file`, it is recommonded to use this parameter to replace those two
1012
+ parameters.
1013
+ It contains following configurations:
1014
+
1015
+ - load_file (str): The path to load parallel strategy checkpoint. If the file name extension is
1016
+ `.json`, the file is loaded in JSON format. Otherwise, the file is loaded in ProtoBuf
1017
+ format.
1018
+ Default: ``''``
1019
+
1020
+ - save_file (str): The path to save parallel strategy checkpoint. If the file name extension is
1021
+ `.json`, the file is saved in JSON format. Otherwise, the file is saved in ProtoBuf format.
1022
+ Default: ``''``
1023
+
1024
+ - only_trainable_params (bool): Only save/load the strategy information for trainable parameter.
1025
+ Default: ``True`` .
1026
+ group_ckpt_save_file (str): The path to save parallel group checkpoint.
1027
+ auto_pipeline (bool): Set the pipeline stage number to automatic. Its value will be selected between 1 and the
1028
+ parameter `pipeline_stages`. This option requires the `parallel_mode` to be ``auto_parallel``
1029
+ and the `search_mode` to be ``recursive_programming``. Default: ``False`` .
1030
+
1031
+ Raises:
1032
+ ValueError: If input key is not attribute in auto parallel context.
1033
+
1034
+ Examples:
1035
+ >>> import mindspore as ms
1036
+ >>> ms.set_auto_parallel_context(device_num=8)
1037
+ >>> ms.set_auto_parallel_context(global_rank=0)
1038
+ >>> ms.set_auto_parallel_context(gradients_mean=True)
1039
+ >>> ms.set_auto_parallel_context(gradient_fp32_sync=False)
1040
+ >>> ms.set_auto_parallel_context(parallel_mode="auto_parallel")
1041
+ >>> ms.set_auto_parallel_context(search_mode="recursive_programming")
1042
+ >>> ms.set_auto_parallel_context(auto_parallel_search_mode="recursive_programming")
1043
+ >>> ms.set_auto_parallel_context(parameter_broadcast=False)
1044
+ >>> ms.set_auto_parallel_context(strategy_ckpt_load_file="./strategy_stage1.ckpt")
1045
+ >>> ms.set_auto_parallel_context(strategy_ckpt_save_file="./strategy_stage1.ckpt")
1046
+ >>> ms.set_auto_parallel_context(dataset_strategy=((1, 8), (1, 8)))
1047
+ >>> ms.set_auto_parallel_context(enable_parallel_optimizer=False)
1048
+ >>> ms.set_auto_parallel_context(enable_alltoall=False)
1049
+ >>> ms.set_auto_parallel_context(all_reduce_fusion_config=[8, 160])
1050
+ >>> ms.set_auto_parallel_context(pipeline_stages=2)
1051
+ >>> ms.set_auto_parallel_context(pipeline_stages=2, pipeline_result_broadcast=True)
1052
+ >>> parallel_config = {"gradient_accumulation_shard": True, "parallel_optimizer_threshold": 24,
1053
+ ... "optimizer_weight_shard_size": 2}
1054
+ >>> ms.set_auto_parallel_context(parallel_optimizer_config=parallel_config, enable_parallel_optimizer=True)
1055
+ >>> config = {"allreduce": {"mode": "size", "config": 32}, "allgather": {"mode": "size", "config": 32}}
1056
+ >>> ms.set_auto_parallel_context(comm_fusion=config)
1057
+ >>> stra_ckpt_dict = {"load_file": "./stra0.ckpt", "save_file": "./stra1.ckpt", "only_trainable_params": False}
1058
+ >>> ms.set_auto_parallel_context(strategy_ckpt_config=stra_ckpt_dict)
1059
+ """
1060
+ _set_auto_parallel_context(**kwargs)
1061
+
1062
+
1063
+ def get_auto_parallel_context(attr_key):
1064
+ """
1065
+ Get auto parallel context attribute value according to the key.
1066
+
1067
+ Args:
1068
+ attr_key (str): The key of the attribute.
1069
+
1070
+ Returns:
1071
+ Returns attribute value according to the key.
1072
+
1073
+ Raises:
1074
+ ValueError: If input key is not attribute in auto parallel context.
1075
+
1076
+ Examples:
1077
+ >>> import mindspore as ms
1078
+ >>> parallel_mode = ms.get_auto_parallel_context("parallel_mode")
1079
+ >>> dataset_strategy = ms.get_auto_parallel_context("dataset_strategy")
1080
+ """
1081
+ return _get_auto_parallel_context(attr_key)
1082
+
1083
+
1084
+ def reset_auto_parallel_context():
1085
+ """
1086
+ Reset auto parallel context attributes to the default values.
1087
+
1088
+ - device_num: 1.
1089
+ - global_rank: 0.
1090
+ - gradients_mean: False.
1091
+ - gradient_fp32_sync: True.
1092
+ - parallel_mode: 'stand_alone'.
1093
+ - search_mode: 'recursive_programming'.
1094
+ - auto_parallel_search_mode: 'recursive_programming'.
1095
+ - parameter_broadcast: False.
1096
+ - strategy_ckpt_load_file: ''.
1097
+ - strategy_ckpt_save_file: ''.
1098
+ - full_batch: False.
1099
+ - enable_parallel_optimizer: False.
1100
+ - force_fp32_communication: False
1101
+ - enable_alltoall: False.
1102
+ - pipeline_stages: 1.
1103
+ - pipeline_result_broadcast: False.
1104
+ - fusion_threshold: 64.
1105
+ - auto_pipeline: False.
1106
+
1107
+ Examples:
1108
+ >>> import mindspore as ms
1109
+ >>> ms.reset_auto_parallel_context()
1110
+ """
1111
+ _reset_auto_parallel_context()
1112
+
1113
+
1114
+ @args_type_check(offload_config=dict)
1115
+ def set_offload_context(offload_config):
1116
+ r"""
1117
+ Configure heterogeneous training detailed parameters to adjust the offload strategy.
1118
+
1119
+ Note:
1120
+ The offload configuration is only used if the memory offload feature is enabled
1121
+ via mindspore.set_context(memory_offload="ON").
1122
+
1123
+ Args:
1124
+ offload_config (dict): A dict contains the keys and values for setting the offload context
1125
+ configure.It supports the following keys.
1126
+
1127
+ - offload_path (str): The path of offload, relative paths are supported. Default: ``"./offload"``.
1128
+ - offload_cpu_size (str): The cpu memory size for offload. The format is "xxGB".
1129
+ - offload_disk_size (str): The disk size for offload. The format is "xxGB"
1130
+ - hbm_ratio (float): The ratio that can be used based on the maximum device memory.
1131
+ The range is (0,1], Default: ``1.0``.
1132
+ - cpu_ratio (float): The ratio that can be used based on the maximum host memory.
1133
+ The range is (0,1], Default: ``1.0``.
1134
+ - enable_pinned_mem (bool): The flag of whether enabling Pinned Memory. Default: ``True``.
1135
+ - enable_aio (bool): The flag of whether enabling aio. Default: ``True``.
1136
+ - aio_block_size (str): The size of aio block. The format is "xxGB".
1137
+ - aio_queue_depth (int): The depth of aio queue.
1138
+ - offload_param (str): The param for offload destination, cpu or disk, Default: ``""``.
1139
+ - offload_checkpoint (str): The checkpoint for offload destination, only valid if recompute is turned on,
1140
+ cpu or disk, Default: ``""``.
1141
+ - auto_offload (bool): The flag of whether auto offload. Default: ``True``.
1142
+ - host_mem_block_size (str): The memory block size of host memory pool. The format is "xxGB"
1143
+
1144
+ Raises:
1145
+ ValueError: If input key is not attribute in auto parallel context.
1146
+
1147
+ Examples:
1148
+ >>> from mindspore import context
1149
+ >>> context.set_offload_context(offload_config={"offload_param":"cpu"})
1150
+ """
1151
+ _set_offload_context(offload_config)
1152
+
1153
+
1154
+ def get_offload_context():
1155
+ """
1156
+ Gets the offload configuration parameters. Configure through interface mindspore.set_offload_context().
1157
+ If the user is not set, the default configuration is obtained.
1158
+
1159
+ Returns:
1160
+ Dict, heterogeneous training offload detailed configuration parameters.
1161
+
1162
+ Examples:
1163
+ >>> from mindspore import context
1164
+ >>> offload_config = context.get_offload_context()
1165
+ """
1166
+ return _get_offload_context()
1167
+
1168
+
1169
+ def _check_target_specific_cfgs(device, arg_key):
1170
+ """Checking whether a config is suitable for a specified device"""
1171
+ device_cfgs = {
1172
+ 'enable_graph_kernel': ['Ascend', 'GPU', 'CPU'],
1173
+ 'graph_kernel_flags': ['Ascend', 'GPU', 'CPU'],
1174
+ 'enable_reduce_precision': ['Ascend'],
1175
+ 'print_file_path': ['Ascend'],
1176
+ 'variable_memory_max_size': ['Ascend'],
1177
+ 'max_device_memory': ['Ascend', 'GPU'],
1178
+ 'mempool_block_size': ['GPU', 'Ascend'],
1179
+ 'disable_format_transform': ['GPU'],
1180
+ 'ascend_config': ['Ascend'],
1181
+ 'gpu_config': ['GPU'],
1182
+ }
1183
+ # configs not in map device_cfgs are supposed to be suitable for all devices
1184
+ if arg_key not in device_cfgs:
1185
+ return True
1186
+ supported_devices = device_cfgs[arg_key]
1187
+ if device in supported_devices:
1188
+ return True
1189
+ logger.warning(f"For 'context.set_context', when set the argument '{arg_key}', "
1190
+ f"the argument 'device_target' only supports devices in '{supported_devices}', "
1191
+ f"but got '{device}', ignore it.")
1192
+ return False
1193
+
1194
+
1195
+ def _check_ascend_device_context_initialized(device_target, settings):
1196
+ if device_target == 'Ascend' and is_initialized(device_target):
1197
+ for key, _ in settings.items():
1198
+ if key in ('ascend_config', 'deterministic', 'jit_compile', 'exception_dump', 'device_id'):
1199
+ logger.warning(f"For 'context.set_context' in Ascend backend, the backend is already initialized, "
1200
+ "please set it before the definition of any Tensor and Parameter, and the "
1201
+ "instantiation and execution of any operation and net, otherwise the settings may not "
1202
+ "take effect. ")
1203
+ break
1204
+
1205
+
1206
+ def _check_key(key):
1207
+ if key in ('precision_mode', 'jit_compile', 'atomic_clean_policy', 'matmul_allow_hf32', 'conv_allow_hf32',
1208
+ 'op_precision_mode', 'host_scheduling_max_threshold', 'ge_options', 'op_debug_option'):
1209
+ raise ValueError(f"Please set '{key}' through parameter ascend_config")
1210
+
1211
+
1212
+ @args_type_check(mode=int, precompile_only=bool, device_target=str, device_id=int, save_graphs=(bool, int),
1213
+ save_graphs_path=str, enable_dump=bool, aoe_tune_mode=str, aoe_config=dict,
1214
+ save_dump_path=str, enable_reduce_precision=bool, variable_memory_max_size=str,
1215
+ enable_auto_mixed_precision=bool, inter_op_parallel_num=int,
1216
+ enable_graph_kernel=bool, reserve_class_name_in_scope=bool, check_bprop=bool,
1217
+ max_device_memory=str, print_file_path=str, max_call_depth=int, env_config_path=str,
1218
+ graph_kernel_flags=str, save_compile_cache=bool, runtime_num_threads=int, load_compile_cache=bool,
1219
+ grad_for_scalar=bool, pynative_synchronize=bool, mempool_block_size=str, disable_format_transform=bool,
1220
+ op_timeout=int, deterministic=str, ascend_config=dict, jit_syntax_level=int, debug_level=int,
1221
+ jit_enable_inplace_ops=bool, gpu_config=dict, jit_config=dict, enable_compile_cache=bool)
1222
+ def set_context(**kwargs):
1223
+ """
1224
+ Set context for running environment.
1225
+
1226
+ Context should be configured before running your program. If there is no configuration,
1227
+ it will be automatically set according to the device target by default.
1228
+
1229
+ Note:
1230
+ Attribute name is required for setting attributes.
1231
+ The mode is not recommended to be changed after net was initialized because the implementations of some
1232
+ operations are different in graph mode and pynative mode. Default: ``PYNATIVE_MODE`` .
1233
+
1234
+ Some configurations are device specific, see the below table for details:
1235
+
1236
+ +-------------------------+------------------------------+----------------------------+
1237
+ | Function Classification | Configuration Parameters | Hardware Platform Support|
1238
+ +=========================+==============================+============================+
1239
+ | System Configuration | device_id | CPU/GPU/Ascend |
1240
+ | +------------------------------+----------------------------+
1241
+ | | device_target | CPU/GPU/Ascend |
1242
+ | +------------------------------+----------------------------+
1243
+ | | max_device_memory | GPU/Ascend |
1244
+ | +------------------------------+----------------------------+
1245
+ | | variable_memory_max_size | Ascend |
1246
+ | +------------------------------+----------------------------+
1247
+ | | mempool_block_size | GPU/Ascend |
1248
+ | +------------------------------+----------------------------+
1249
+ | | op_timeout | Ascend |
1250
+ +-------------------------+------------------------------+----------------------------+
1251
+ | Debug Configuration | save_graphs | CPU/GPU/Ascend |
1252
+ | +------------------------------+----------------------------+
1253
+ | | save_graphs_path | CPU/GPU/Ascend |
1254
+ | +------------------------------+----------------------------+
1255
+ | | enable_dump | Ascend |
1256
+ | +------------------------------+----------------------------+
1257
+ | | save_dump_path | Ascend |
1258
+ | +------------------------------+----------------------------+
1259
+ | | deterministic | Ascend |
1260
+ | +------------------------------+----------------------------+
1261
+ | | print_file_path | Ascend |
1262
+ | +------------------------------+----------------------------+
1263
+ | | env_config_path | CPU/GPU/Ascend |
1264
+ | +------------------------------+----------------------------+
1265
+ | | precompile_only | CPU/GPU/Ascend |
1266
+ | +------------------------------+----------------------------+
1267
+ | | reserve_class_name_in_scope | CPU/GPU/Ascend |
1268
+ | +------------------------------+----------------------------+
1269
+ | | pynative_synchronize | CPU/GPU/Ascend |
1270
+ | +------------------------------+----------------------------+
1271
+ | | debug_level | CPU/GPU/Ascend |
1272
+ +-------------------------+------------------------------+----------------------------+
1273
+ | Executive Control | mode | CPU/GPU/Ascend |
1274
+ | +------------------------------+----------------------------+
1275
+ | | enable_graph_kernel | Ascend/GPU |
1276
+ | +------------------------------+----------------------------+
1277
+ | | graph_kernel_flags | Ascend/GPU |
1278
+ | +------------------------------+----------------------------+
1279
+ | | enable_reduce_precision | Ascend |
1280
+ | +------------------------------+----------------------------+
1281
+ | | aoe_tune_mode | Ascend |
1282
+ | +------------------------------+----------------------------+
1283
+ | | aoe_config | Ascend |
1284
+ | +------------------------------+----------------------------+
1285
+ | | check_bprop | CPU/GPU/Ascend |
1286
+ | +------------------------------+----------------------------+
1287
+ | | max_call_depth | CPU/GPU/Ascend |
1288
+ | +------------------------------+----------------------------+
1289
+ | | grad_for_scalar | CPU/GPU/Ascend |
1290
+ | +------------------------------+----------------------------+
1291
+ | | enable_compile_cache | CPU/GPU/Ascend |
1292
+ | +------------------------------+----------------------------+
1293
+ | | inter_op_parallel_num | CPU/GPU/Ascend |
1294
+ | +------------------------------+----------------------------+
1295
+ | | runtime_num_threads | CPU/GPU/Ascend |
1296
+ | +------------------------------+----------------------------+
1297
+ | | compile_cache_path | CPU/GPU/Ascend |
1298
+ | +------------------------------+----------------------------+
1299
+ | | disable_format_transform | GPU |
1300
+ | +------------------------------+----------------------------+
1301
+ | | support_binary | CPU/GPU/Ascend |
1302
+ | +------------------------------+----------------------------+
1303
+ | | memory_optimize_level | CPU/GPU/Ascend |
1304
+ | +------------------------------+----------------------------+
1305
+ | | memory_offload | GPU/Ascend |
1306
+ | +------------------------------+----------------------------+
1307
+ | | ascend_config | Ascend |
1308
+ | +------------------------------+----------------------------+
1309
+ | | jit_syntax_level | CPU/GPU/Ascend |
1310
+ | +------------------------------+----------------------------+
1311
+ | | gpu_config | GPU |
1312
+ | +------------------------------+----------------------------+
1313
+ | | jit_config | CPU/GPU/Ascend |
1314
+ +-------------------------+------------------------------+----------------------------+
1315
+
1316
+ Args:
1317
+ device_id (int): ID of the target device, the value must be in [0, device_num_per_host-1],
1318
+ while device_num_per_host should be no more than 4096. Default: ``0`` .
1319
+ device_target (str): The target device to run, support "Ascend", "GPU", and "CPU".
1320
+ If device target is not set, the version of MindSpore package is used.
1321
+ max_device_memory (str): Set the maximum memory available for devices. The format is "xxGB".
1322
+ Default: ``" 1024GB"`` . The actual used memory size is the minimum of the available memory of the device
1323
+ and max_device_memory. 'max_device_memory' should be set before the program runs.
1324
+ variable_memory_max_size (str): This parameter is deprecated, and will be removed in a future version.
1325
+ Please use parameter 'max_device_memory' instead.
1326
+ mempool_block_size (str): Set the size of the memory pool block in PyNative mode or jit level is 'O0'/'O1'
1327
+ for devices. The format is "xxGB". Default: ``"1GB"`` . Minimum size is "1G". The actual used memory block
1328
+ size is the minimum of the available memory of the device and mempool_block_size.
1329
+ op_timeout (int): Set the maximum duration of executing an operator in seconds.
1330
+ If the execution time exceeds this value, system will terminate the task.
1331
+ 0 means endless wait. The defaults for AI Core and AICPU operators vary on different hardware.
1332
+ For more information,
1333
+ please refer to `Ascend Community document about aclrtSetOpExecuteTimeOut
1334
+ <https://www.hiascend.com/document/detail/en/CANNCommunityEdition/600alphaX/infacldevg/aclcppdevg/aclcppdevg_03_0069.html>`_.
1335
+ Default: ``900`` .
1336
+ save_graphs (bool or int): Whether to save intermediate compilation graphs. Default: ``0`` .
1337
+ Available values are:
1338
+
1339
+ - False or 0: disable saving of intermediate compilation graphs.
1340
+ - 1: some intermediate files will be generated during graph compilation.
1341
+ - True or 2: Generate more ir files related to backend process.
1342
+ - 3: Generate visualization computing graphs and detailed frontend ir graphs.
1343
+
1344
+ When the network structure is complex, setting `save_graphs` attribute to ``2`` or ``3`` may take too long.
1345
+ If you need quick problem locating, you can switch to ``1`` first.
1346
+
1347
+ When the `save_graphs` attribute is set as ``True`` , ``1`` , ``2`` or ``3`` , attribute of
1348
+ `save_graphs_path` is used to set the intermediate compilation graph storage path. By default, the graphs
1349
+ are saved in the current directory.
1350
+ save_graphs_path (str): Path to save graphs. Default: ``"."``.
1351
+ If the specified directory does not exist, the system will automatically create the directory.
1352
+ During distributed training, graphs will be saved to the directory of
1353
+ `save_graphs_path/rank_${rank_id}/`. `rank_id` is the ID of the current device in the cluster.
1354
+ deterministic (str): Whether to enable op run in deterministic mode. The value must be in the
1355
+ range of ['ON', 'OFF'], and the default value is ``'OFF'`` .
1356
+
1357
+ - "ON": Enable operator deterministic running mode.
1358
+ - "OFF": Disable operator deterministic running mode.
1359
+
1360
+ When deterministic mode is on, model ops will be deterministic in Ascend. This means that if op run
1361
+ multiple times with the same inputs on the same hardware, it will have the exact same outputs each time.
1362
+ This is useful for debugging models.
1363
+ enable_dump (bool): This parameters is deprecated, and will be deleted in the next version.
1364
+ save_dump_path (str): This parameters is deprecated, and will be deleted in the next version.
1365
+ print_file_path (str): The path of saving print data. If this parameter is set, print data is saved to
1366
+ a file by default, and print_file_path is not set, the screen will be displayed.
1367
+ If the saved file already exists, the timestamp suffix will be added to the file. Saving data to a file
1368
+ solves the problem of data loss in screen printing when a large amount of data is generated.
1369
+ If it is not set, an error will be reported: prompt to set the upper absolute path.
1370
+ When print data to file, the total output bytes of single print must be less then 2GB(limited by
1371
+ protobuf).
1372
+ env_config_path (str): Config path for DFX.
1373
+ Through mindspore.set_context(env_config_path="./mindspore_config.json")
1374
+
1375
+ configure RDR:
1376
+
1377
+ - enable: controls whether the RDR is enabled to collect the key data during training and
1378
+ save key data in the fault scenario. When set to ``true`` , the RDR will be turned on.
1379
+ When set to ``false`` , the RDR will be turned off.
1380
+ - mode: sets the mode of RDR on exporting data. When set to ``1`` , the RDR only exports data
1381
+ in the fault scenario. When set to ``2`` , the RDR exports data in the fault scenario and the
1382
+ normal end scenario. Default: ``1`` .
1383
+ - path: sets the path where RDR saves data. The current path must be absolute.
1384
+
1385
+ Memory reuse:
1386
+
1387
+ - mem_Reuse: controls whether the memory reuse function is turned on. When set to ``True`` ,
1388
+ the memory reuse function is turned on. When set to ``False`` , the memory reuse function is turned off.
1389
+
1390
+ precompile_only (bool): Whether to only precompile the network. Default: ``False`` .
1391
+ If set to ``True`` , the network will only be compiled, not executed.
1392
+ reserve_class_name_in_scope (bool) : Whether to save the network class name in the scope. Default: ``True`` .
1393
+ Each node has a scope. A scope of a subnode is the name of its parent node. If reserve_class_name_in_scope
1394
+ is set to ``True`` , the class name will be saved after keyword 'net-' in the scope.
1395
+ For example:
1396
+
1397
+ Default/net-Net1/net-Net2 (reserve_class_name_in_scope=True)
1398
+
1399
+ Default/net/net (reserve_class_name_in_scope=False)
1400
+
1401
+ pynative_synchronize (bool): Whether to enable synchronous execution of the device in PyNative mode.
1402
+ Default: ``False`` . When the value is set to ``False`` , the operator is executed asynchronously on the
1403
+ device. When an error occurs in the execution of the operator, the specific error script code location
1404
+ cannot be located, when the value is set to ``True`` , the operator is executed synchronously on the
1405
+ device. It will reduce the execution performance of the program. At this time, when an error occurs in the
1406
+ execution of the operator, the location of the error script code can be located according to the call stack
1407
+ of the error.
1408
+ mode (int): Running in GRAPH_MODE(0) or PYNATIVE_MODE(1).
1409
+ Both modes support all backends. Default: ``PYNATIVE_MODE`` .
1410
+ enable_graph_kernel (bool): Whether to enable graph kernel fusion to optimize network execution performance.
1411
+ Default: ``False`` .
1412
+ Indicates whether to enable image-computing convergence to optimize network execution performance.
1413
+ If enable_graph_kernel is set to ``True`` , acceleration can be enabled.
1414
+ For details of graph kernel fusion, please check
1415
+ `Enabling Graph Kernel Fusion
1416
+ <https://www.mindspore.cn/tutorials/experts/en/master/optimize/graph_fusion_engine.html>`_.
1417
+ graph_kernel_flags (str):
1418
+ Optimization options of graph kernel fusion, and the priority is higher when it conflicts
1419
+ with enable_graph_kernel. Only for experienced users.
1420
+ For example,
1421
+
1422
+ .. code-block::
1423
+
1424
+ mindspore.set_context(graph_kernel_flags="--opt_level=2 --dump_as_text")
1425
+
1426
+ Some general options:
1427
+
1428
+ - opt_level: Set the optimization level.
1429
+ Default: ``2`` . Graph kernel fusion can be enabled equivalently by setting opt_level greater than 0.
1430
+ Available values are:
1431
+
1432
+ - 0: disables graph kernel fusion;
1433
+ - 1: enables the basic fusion of operators;
1434
+ - 2: includes all optimizations of level 1,
1435
+ and turns on more optimizations such as CSE, arithmetic simplification and so on;
1436
+ - 3: includes all optimizations of level 2, and turns on more optimizations such as SitchingFusion,
1437
+ ParallelFusion and so on. Optimizations of this level are radical and unstable in some scenarios.
1438
+ Be caution when using this level.
1439
+
1440
+ - dump_as_text: dumps detail info as text files. Default: ``False`` .
1441
+
1442
+ enable_reduce_precision (bool): Whether to enable precision reduction.
1443
+ If the operator does not support the user-specified precision, the precision will
1444
+ be changed automatically. Default: ``True`` .
1445
+ aoe_tune_mode (str): AOE tuning mode setting, which is not set by default.
1446
+ When set to ``"online"`` , the tuning in online function is turned on.
1447
+ When set to ``"offline"`` , ge graph will be save for offline tuning.
1448
+ aoe_config (dict): Set the parameters specific to Ascend Optimization Engine. It is not set by default.
1449
+
1450
+ - job_type (str): Mode type setting, default value is ``"2"``.
1451
+
1452
+ - ``"1"``: subgraph tuning;
1453
+ - ``"2"``: operator tuning.
1454
+
1455
+ check_bprop (bool): Whether to check back propagation nodes. The checking ensures that the shape and dtype
1456
+ of back propagation node outputs is the same as input parameters. Default: ``False`` .
1457
+ max_call_depth (int): Specify the maximum depth of function call. Must be positive integer. Default: ``1000`` .
1458
+ The max_call_depth parameter needs to be set when the nested call is too deep or the number
1459
+ of subgraphs is too large. If max_call_depth is set larger than before, the system max stack depth should be
1460
+ set larger too, otherwise a `core dumped` exception may be raised because of system stack overflow.
1461
+ grad_for_scalar (bool): Whether to get gradient for scalar. Default: ``False`` .
1462
+ When grad_for_scalar is set to ``True`` , the function's scalar input can be derived.
1463
+ The default value is ``False`` . Because the back-end does not support scaling operations currently,
1464
+ this interface only supports simple operations that can be deduced by the front-end.
1465
+ enable_compile_cache (bool): Whether to save or load the cache of the graph compiled by front-end.
1466
+ After enable_compile_cache is set to ``True`` , during the first execution, a hardware-independent
1467
+ compilation cache is generated and exported to a MINDIR file. When the network is executed again,
1468
+ if enable_compile_cache is still set to ``True`` and the network scripts are not changed,
1469
+ the compile cache is loaded. Note that only limited automatic detection for the changes of
1470
+ python scripts is supported by now, which means that there is a correctness risk. Default: ``False`` .
1471
+ This is an experimental prototype that is subject to change and/or deletion.
1472
+ compile_cache_path (str): Path to save the compile cache. Default: ``"."``.
1473
+ If the specified directory does not exist, the system will automatically create the directory.
1474
+ The cache will be saved to the directory of `compile_cache_path/rank_${rank_id}/`. The `rank_id` is
1475
+ the ID of the current device in the cluster.
1476
+ inter_op_parallel_num(int): The thread number of op parallel at the same time. Default value is ``0`` ,
1477
+ which means use the default num.
1478
+ runtime_num_threads(int): The thread pool number of cpu kernel used in runtime,
1479
+ which must bigger than or equal to 0. Default value is ``30`` , if you run many processes at
1480
+ the same time, you should set the value smaller to avoid thread contention.
1481
+ disable_format_transform (bool): Whether to disable the automatic format transform function from NCHW to NHWC.
1482
+ When the network training performance of fp16 is worse than fp32, `disable_format_transform` can be set to
1483
+ ``True`` to try to improve training performance. Default: ``False`` .
1484
+ support_binary (bool): Whether to support run .pyc or .so in graph mode. If want to support run .so or .pyc
1485
+ in graph mode, coulde set 'support_binary' to be ``True`` , and run once .py file. It would save the source
1486
+ of the interfaces would be compiled by MindSpore to the interfaces definition .py file that should be
1487
+ guaranteed to be writable. Then compile the .py file to the .pyc or .so file, and could run in Graph mode.
1488
+ memory_optimize_level (str): The memory optimize level.
1489
+ On Ascend hardware platform, default: ``O1``, on other hardware platforms, default: ``O0``.
1490
+ The value must be in ['O0', 'O1'].
1491
+
1492
+ - O0: priority performance option, disable SOMAS (Safe Optimized Memory Allocation Solver)
1493
+ and some other memory optimizations.
1494
+ - O1: priority memory option, enable SOMAS and some other memory optimizations.
1495
+ memory_offload (str): Whether to enable the memory offload function. When it is enabled, the idle data will be
1496
+ temporarily copied to the host side in the case of insufficient device memory. The value must be in the
1497
+ range of ['ON', 'OFF'], and the default value is ``'OFF'`` .
1498
+
1499
+ - ON: Enable the memory Offload function. On Ascend hardware platform, this parameter does not take effect
1500
+ when the graph compilation level is not 'O0'; This parameter does not take effect when
1501
+ memory_optimize_level is set 'O1'.
1502
+ - OFF: Turn off the memory Offload function.
1503
+ ascend_config (dict): Set the parameters specific to Ascend hardware platform. It is not set by default.
1504
+ The default value of `precision_mode`, `jit_compile` and
1505
+ `atomic_clean_policy` are experimental parameters, may change in the future.
1506
+
1507
+ - precision_mode (str): Mixed precision mode setting, and the default value of inference network
1508
+ is ``force_fp16`` . The value range is as follows:
1509
+
1510
+ - force_fp16: When the operator supports both float16 and float32, select float16 directly.
1511
+ - allow_fp32_to_fp16: For cube operators, use the float16. For vector operators,
1512
+ prefer to keep the origin dtype, if the operator in model can support float32,
1513
+ it will keep original dtype, otherwise it will reduce to float16.
1514
+ - allow_mix_precision: Automatic mixing precision, facing the whole network operator, according
1515
+ to the built-in optimization strategy, automatically reduces the precision of some operators
1516
+ to float16 or bfloat16.
1517
+ - must_keep_origin_dtype: Keep the accuracy of the original drawing.
1518
+ - force_fp32: When the input of the matrix calculation operator is float16 and the output supports
1519
+ float16 and float32, output is forced to float32.
1520
+ - allow_fp32_to_bf16: For cube operators, use the bfloat16. For vector operators,
1521
+ prefer to keep the origin dtype, if the operator in model can support float32,
1522
+ it will keep original dtype, otherwise it will reduce to bfloat16.
1523
+ - allow_mix_precision_fp16: Automatic mixing precision, facing the whole network operator, automatically
1524
+ reduces the precision of some operators to float16 according to the built-in optimization strategy.
1525
+ - allow_mix_precision_bf16: Automatic mixing precision, facing the whole network operator, according to
1526
+ the built-in optimization strategy, automatically reduces the precision of some operators to bfloat16.
1527
+
1528
+ - jit_compile (bool): Whether to select online compilation. When set to 'True', online compilation is
1529
+ prioritized. When set to 'False', compiled operator binary files are prioritized to improve compilation
1530
+ performance. The default settings are online compilation for static shape, and compiled operator binary
1531
+ files for dynamic shape.
1532
+ - atomic_clean_policy (int): The policy for cleaning memory occupied by atomic operators in the network.
1533
+ Default: ``1`` .
1534
+
1535
+ - 0: The memory occupied by all atomic operators in the network is cleaned centrally.
1536
+ - 1: Memory is not cleaned centrally and each atomic operator in the network is cleaned separately.
1537
+ When the memory of the network exceeds the limit, you may try this cleaning policy, but it may cause
1538
+ performance loss.
1539
+ - matmul_allow_hf32 (bool): Whether to convert FP32 to HF32 for Matmul operators. Default value: ``False``.
1540
+ This is an experimental prototype that is subject to change and/or deletion.
1541
+ For detailed information, please refer to `Ascend community <https://www.hiascend.com/>`_ .
1542
+ - conv_allow_hf32 (bool): Whether to convert FP32 to HF32 for Conv operators. Default value: ``True``.
1543
+ This is an experimental prototype that is subject to change and/or deletion.
1544
+ For detailed information, please refer to `Ascend community <https://www.hiascend.com/>`_ .
1545
+ - exception_dump (str): Enable exception dump for Ascend operators, providing the input and output data for
1546
+ failing Ascend operators. The value can be ``"0"`` , ``"1"`` and ``"2"``. For ``"0"`` , exception dump is
1547
+ turned off; for ``"1"``, all inputs and outputs will be dumped for AICore exception operators;
1548
+ for ``"2"``, inputs will be dumped for AICore exception operators, reducing the saved information
1549
+ but improving performance. Default: ``"2"`` .
1550
+ - op_precision_mode (str): Path to config file of op precision mode. For detailed information, please refer
1551
+ to `Ascend community <https://www.hiascend.com/>`_ .
1552
+ - op_debug_option (str): Enable debugging options for Ascend operators, default not enabled.
1553
+ The value currently only supports being set to ``"oom"``.
1554
+
1555
+ - ``"oom"``: When there is a memory out of bounds during the execution of an operator,
1556
+ AscendCL will return an error code of ``EZ9999``.
1557
+
1558
+ - ge_options (dict): Set options for CANN. The options are divided into two categories: global and session.
1559
+ This is an experimental prototype that is subject to change and/or deletion.
1560
+ For detailed information, please refer to `Ascend community <https://www.hiascend.com/document/detail/zh/canncommercial/70RC1/inferapplicationdev/graphdevg/atlasgeapi_07_0119.html>`_ .
1561
+ The configuration options in `ge_options` may be duplicated with the options in `ascend_config`. If the
1562
+ same configuration options are set in both `ascend_config` and `ge_options`, the one set in `ge_options`
1563
+ shall prevail.
1564
+
1565
+ - global (dict): Set global options.
1566
+ - session (dict): Set session options.
1567
+
1568
+ - parallel_speed_up_json_path(Union[str, None]): The path to the parallel speed up json file, configuration
1569
+ can refer to `parallel_speed_up.json
1570
+ <https://gitee.com/mindspore/mindspore/blob/master/config/parallel_speed_up.json>`_ .
1571
+ If its value is None or '', it does not take effect. Default None.
1572
+
1573
+ - recompute_comm_overlap (bool): Enable overlap between recompute ops and communication ops if True.
1574
+ Default: False.
1575
+ - matmul_grad_comm_overlap (bool): Enable overlap between dw matmul and
1576
+ tensor parallel communication ops if True. Default: False.
1577
+ - recompute_allgather_overlap_fagrad (bool): Enable overlap between duplicated allgather by recomputing
1578
+ in sequence parallel and flashattentionscoregrad ops if True. Default: False.
1579
+ - enable_task_opt (bool): Enable communication fusion to optimize the number of communication operator
1580
+ tasks if True.
1581
+ Default: False.
1582
+ - enable_grad_comm_opt (bool): Enable overlap between dx ops and data parallel communication ops if True.
1583
+ Currently, do not support
1584
+ `LazyInline <https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.lazy_inline.html>`
1585
+ Default: False.
1586
+ - enable_opt_shard_comm_opt (bool): Enable overlap between forward ops
1587
+ and optimizer parallel allgather communication if True. Currently, do not support
1588
+ `LazyInline <https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.lazy_inline.html>`
1589
+ Default: False.
1590
+ - compute_communicate_fusion_level (int): Enable the fusion between compute and communicate.
1591
+ Default: ``0``.
1592
+
1593
+ - 0: Disable fusion.
1594
+
1595
+ - 1: Apply fusion to forward nodes.
1596
+
1597
+ - 2: Apply fusion to backward nodes.
1598
+
1599
+ - 3: Apply fusion to all nodes.
1600
+ - bias_add_comm_swap (bool): Enable node execution order swap communication operators and add operators
1601
+ if ``True``. Only 1-dimension bias node is supported. Default: ``False``.
1602
+ - host_scheduling_max_threshold(int): The max threshold to control whether the dynamic shape process is
1603
+ used when run the static graph, the default value is 0. When the number of operations in the static graph
1604
+ is less than the max threshold, this graph will be executed in dynamic shape process. In large model
1605
+ scenarios, this approach can save stream resources. If the number of operations in the static graph is
1606
+ greater than the maximum threshold, this graph will be executed in original static process.
1607
+
1608
+ jit_syntax_level (int): Set JIT syntax level for graph compiling, triggered by GRAPH_MODE and @jit decorator.
1609
+ The value must be ``STRICT`` or ``LAX`` . Default: ``LAX`` . All levels support all backends.
1610
+
1611
+ - ``STRICT`` : Only basic syntax is supported, and execution performance is optimal. Can be used for MindIR
1612
+ load and export.
1613
+ - ``LAX`` : Compatible with all Python syntax as much as possible. However, execution performance may be
1614
+ affected and not optimal. Cannot be used for MindIR load and export due to some syntax that may not be
1615
+ able to be exported.
1616
+
1617
+ debug_level (int): Set config for debugging. Default value: ``RELEASE``.
1618
+
1619
+ - ``RELEASE``: Used for normally running, and some debug information will be discard to get a better
1620
+ compiling performance.
1621
+ - ``DEBUG``: Used for debugging when errors occur, more information will be record in compiling process.
1622
+
1623
+ gpu_config (dict): Set the parameters specific to gpu hardware platform. It is not set by default.
1624
+ Currently, only setting `conv_fprop_algo` and `conv_dgrad_algo` and `conv_wgrad_algo` and `conv_allow_tf32`
1625
+ and `matmul_allow_tf32` are supported on GPU hardware platform.
1626
+
1627
+ - conv_fprop_algo (str): Specifies convolution forward algorithm and the default value is 'normal',
1628
+ The value range is as follows:
1629
+
1630
+ - normal: Use the heuristic search algorithm.
1631
+ - performance: Use the trial search algorithm.
1632
+ - implicit_gemm: This algorithm expresses the convolution as a matrix product without actually explicitly
1633
+ forming the matrix that holds the input tensor data.
1634
+ - implicit_precomp_gemm: This algorithm expresses convolution as a matrix product without actually
1635
+ explicitly forming the matrix that holds the input tensor data, but still needs some memory workspace to
1636
+ precompute some indices in order to facilitate the implicit construction of the matrix that holds the
1637
+ input tensor data.
1638
+ - gemm: This algorithm expresses the convolution as an explicit matrix product. A significant memory
1639
+ workspace is needed to store the matrix that holds the input tensor data.
1640
+ - direct: This algorithm expresses the convolution as a direct convolution (for example, without
1641
+ implicitly or explicitly doing a matrix multiplication).
1642
+ - fft: This algorithm uses the Fast-Fourier Transform approach to compute the convolution. A significant
1643
+ memory workspace is needed to store intermediate results.
1644
+ - fft_tiling: This algorithm uses the Fast-Fourier Transform approach but splits the inputs into tiles.
1645
+ A significant memory workspace is needed to store intermediate results but less than fft algorithm for
1646
+ large size images.
1647
+ - winograd: This algorithm uses the Winograd Transform approach to compute the convolution. A reasonably
1648
+ sized workspace is needed to store intermediate results.
1649
+ - winograd_nonfused: This algorithm uses the Winograd Transform approach to compute the convolution. A
1650
+ significant workspace may be needed to store intermediate results.
1651
+ - conv_dgrad_algo (str): Specifies convolution data grad algorithm and the default value is 'normal',
1652
+ The value range is as follows:
1653
+
1654
+ - normal: Use the heuristic search algorithm.
1655
+ - performance: Use the trial search algorithm.
1656
+ - algo_0: This algorithm expresses the convolution as a sum of matrix products without actually explicitly
1657
+ forming the matrix that holds the input tensor data. The sum is done using the atomic add operation,
1658
+ thus the results are non-deterministic.
1659
+ - algo_1: This algorithm expresses the convolution as a matrix product without actually explicitly forming
1660
+ the matrix that holds the input tensor data. The results are deterministic.
1661
+ - fft: This algorithm uses a Fast-Fourier Transform approach to compute the convolution. A significant
1662
+ memory workspace is needed to store intermediate results. The results are deterministic.
1663
+ - fft_tiling: This algorithm uses the Fast-Fourier Transform approach but splits the inputs into tiles.
1664
+ A significant memory workspace is needed to store intermediate results but less than fft for large size
1665
+ images. The results are deterministic.
1666
+ - winograd: This algorithm uses the Winograd Transform approach to compute the convolution. A reasonably
1667
+ sized workspace is needed to store intermediate results. The results are deterministic.
1668
+ - winograd_nonfused: This algorithm uses the Winograd Transform approach to compute the convolution.
1669
+ A significant workspace may be needed to store intermediate results. The results are deterministic.
1670
+ - conv_wgrad_algo (str): Specifies convolution filter grad algorithm and the default value is 'normal',
1671
+ The value range is as follows:
1672
+
1673
+ - normal: Use the heuristic search algorithm.
1674
+ - performance: Use the trial search algorithm.
1675
+ - algo_0: This algorithm expresses the convolution as a sum of matrix products without actually explicitly
1676
+ forming the matrix that holds the input tensor data. The sum is done using the atomic add operation,
1677
+ thus the results are non-deterministic.
1678
+ - algo_1: This algorithm expresses the convolution as a matrix product without actually explicitly forming
1679
+ the matrix that holds the input tensor data. The results are deterministic.
1680
+ - fft: This algorithm uses a Fast-Fourier Transform approach to compute the convolution. A significant
1681
+ memory workspace is needed to store intermediate results. The results are deterministic.
1682
+ - algo_3: This algorithm is similar to algo_0 but uses some small workspace to precompute some indices.
1683
+ The results are also non-deterministic.
1684
+ - winograd_nonfused: This algorithm uses the Winograd Transform approach to compute the convolution.
1685
+ A significant workspace may be needed to store intermediate results. The results are deterministic.
1686
+ - fft_tiling: This algorithm uses the Fast-Fourier Transform approach but splits the inputs into tiles.
1687
+ A significant memory workspace is needed to store intermediate results but less than fft for large size
1688
+ images. The results are deterministic.
1689
+ - conv_allow_tf32 (bool): The flag below controls to allow Tensor core TF32 computation on CUDNN and the
1690
+ default value is ``True``.
1691
+ - matmul_allow_tf32 (bool): The flag below controls to allow Tensor core TF32 computation on CUBLAS and the
1692
+ default value is ``False``.
1693
+
1694
+ jit_config (dict): Set the global jit config for compile, take effect in network defined in Cell or jit
1695
+ decorators. It is not set by default.
1696
+ The setting in context is the global jit config, while JitConfig is the local network's jit config.
1697
+ When both exist simultaneously, the global jit config will not overwrite the local network's jit config.
1698
+
1699
+ - jit_level (str): Used to control the compilation optimization level. Default: ``""`` , The framework
1700
+ automatically selects the execution method based on product, Altas training product is O2, and all other
1701
+ products are O0. The value range is as follows:
1702
+
1703
+ - ``"O0"``: Except for optimizations that may affect functionality, all other optimizations are turned
1704
+ off, adopt KernelByKernel execution mode.
1705
+ - ``"O1"``: Using commonly used optimizations and automatic operator fusion optimizations,
1706
+ adopt KernelByKernel execution mode.
1707
+ - ``"O2"``: Ultimate performance optimization, adopt Sink execution mode.
1708
+
1709
+ - infer_boost (str): Used to control the infer mode. Default: ``"off"`` . The value range is as follows:
1710
+
1711
+ - ``"on"``: Enable infer mode, get better infer performance.
1712
+ - ``"off"``: Disable infer mode, use forward to infer, performance is not good.
1713
+
1714
+ Raises:
1715
+ ValueError: If input key is not an attribute in context.
1716
+
1717
+ Examples:
1718
+ >>> import mindspore as ms
1719
+ >>> ms.set_context(mode=ms.PYNATIVE_MODE)
1720
+ >>> ms.set_context(precompile_only=True)
1721
+ >>> ms.set_context(device_target="Ascend")
1722
+ >>> ms.set_context(device_id=0)
1723
+ >>> ms.set_context(save_graphs=True, save_graphs_path="./model.ms")
1724
+ >>> ms.set_context(enable_reduce_precision=True)
1725
+ >>> ms.set_context(enable_graph_kernel=True)
1726
+ >>> ms.set_context(graph_kernel_flags="--opt_level=2 --dump_as_text")
1727
+ >>> ms.set_context(reserve_class_name_in_scope=True)
1728
+ >>> ms.set_context(variable_memory_max_size="6GB")
1729
+ >>> ms.set_context(aoe_tune_mode="online")
1730
+ >>> ms.set_context(aoe_config={"job_type": "2"})
1731
+ >>> ms.set_context(check_bprop=True)
1732
+ >>> ms.set_context(max_device_memory="3.5GB")
1733
+ >>> ms.set_context(mempool_block_size="1GB")
1734
+ >>> ms.set_context(print_file_path="print.pb")
1735
+ >>> ms.set_context(max_call_depth=80)
1736
+ >>> ms.set_context(env_config_path="./env_config.json")
1737
+ >>> ms.set_context(grad_for_scalar=True)
1738
+ >>> ms.set_context(enable_compile_cache=True, compile_cache_path="./cache.ms")
1739
+ >>> ms.set_context(pynative_synchronize=True)
1740
+ >>> ms.set_context(runtime_num_threads=10)
1741
+ >>> ms.set_context(inter_op_parallel_num=4)
1742
+ >>> ms.set_context(disable_format_transform=True)
1743
+ >>> ms.set_context(memory_optimize_level='O0')
1744
+ >>> ms.set_context(memory_offload='ON')
1745
+ >>> ms.set_context(deterministic='ON')
1746
+ >>> ms.set_context(ascend_config={"precision_mode": "force_fp16", "jit_compile": True,
1747
+ ... "atomic_clean_policy": 1, "op_precision_mode": "./op_precision_config_file",
1748
+ ... "op_debug_option": "oom",
1749
+ ... "ge_options": {"global": {"ge.opSelectImplmode": "high_precision"},
1750
+ ... "session": {"ge.exec.atomicCleanPolicy": "0"}}})
1751
+ >>> ms.set_context(jit_syntax_level=ms.STRICT)
1752
+ >>> ms.set_context(debug_level=ms.context.DEBUG)
1753
+ >>> ms.set_context(gpu_config={"conv_fprop_algo": "performance", "conv_allow_tf32": True,
1754
+ ... "matmul_allow_tf32": True})
1755
+ >>> ms.set_context(jit_config={"jit_level": "O0"})
1756
+ """
1757
+ ctx = _context()
1758
+ # set device target first
1759
+ if 'device_target' in kwargs:
1760
+ ctx.set_device_target(kwargs['device_target'])
1761
+ device = ctx.get_param(ms_ctx_param.device_target)
1762
+ _check_ascend_device_context_initialized(device, kwargs)
1763
+
1764
+ for key, value in kwargs.items():
1765
+ if key in ('enable_sparse', 'auto_tune_mode'):
1766
+ logger.warning(f"For 'context.set_context', '{key}' parameter is deprecated, "
1767
+ "and will be removed in the next version.")
1768
+ continue
1769
+ if key in ('enable_auto_mixed_precision', 'enable_dump', 'save_dump_path'):
1770
+ logger.warning(f"For 'context.set_context', '{key}' parameter is deprecated. "
1771
+ "For details, please see the interface parameter API comments")
1772
+ continue
1773
+ _check_key(key)
1774
+ if key == 'save_graphs':
1775
+ if value is True:
1776
+ value = 2
1777
+ if value is False:
1778
+ value = 0
1779
+ if value > 3:
1780
+ raise ValueError(f"value for save_graphs should be 0-3 but got '{value}'")
1781
+ if key == 'jit_syntax_level' and value not in (STRICT, COMPATIBLE, LAX):
1782
+ raise ValueError(f"For 'jit_syntax_level', the value should be context.STRICT"
1783
+ f" or context.LAX, but got {value}.")
1784
+ if key == 'debug_level' and value not in (RELEASE, DEBUG):
1785
+ raise ValueError(f"For 'debug_level', the value should be context.DEBUG"
1786
+ f" or context.RELEASE, but got {value}.")
1787
+ if key == 'enable_compile_cache':
1788
+ setattr(ctx, key, value)
1789
+ ctx.set_param(ms_ctx_param.__members__[key], int(value))
1790
+ continue
1791
+ if not _check_target_specific_cfgs(device, key):
1792
+ continue
1793
+ if hasattr(ctx, key):
1794
+ setattr(ctx, key, value)
1795
+ continue
1796
+ if key in ctx.setters:
1797
+ ctx.setters[key](ctx, value)
1798
+ continue
1799
+ # enum variables beginning with '_' are for internal use
1800
+ if key in ms_ctx_param.__members__ and key[0] != '_':
1801
+ ctx.set_param(ms_ctx_param.__members__[key], value)
1802
+ continue
1803
+ raise ValueError(f"For 'context.set_context', the keyword argument {key} is not recognized! For detailed "
1804
+ f"usage of 'set_context', please refer to the Mindspore official website.")
1805
+
1806
+
1807
+ def get_context(attr_key):
1808
+ """
1809
+ Get context attribute value according to the input key.
1810
+ If some attributes are not set, they will be automatically obtained.
1811
+
1812
+ Args:
1813
+ attr_key (str): The key of the attribute.
1814
+
1815
+ Returns:
1816
+ Object, The value of given attribute key.
1817
+
1818
+ Raises:
1819
+ ValueError: If input key is not an attribute in context.
1820
+ Examples:
1821
+ >>> import mindspore as ms
1822
+ >>> ms.get_context("device_target")
1823
+ >>> ms.get_context("device_id")
1824
+ """
1825
+ ctx = _context()
1826
+ device = ctx.get_param(ms_ctx_param.device_target)
1827
+ _ = _check_target_specific_cfgs(device, attr_key)
1828
+ if hasattr(ctx, attr_key):
1829
+ return getattr(ctx, attr_key)
1830
+ # enum variables beginning with '_' are for internal use
1831
+ if attr_key in ms_ctx_param.__members__ and attr_key[0] != '_':
1832
+ return ctx.get_param(ms_ctx_param.__members__[attr_key])
1833
+ raise ValueError(f"For 'context.get_context', the argument {attr_key} is not recognized! For detailed "
1834
+ f"usage of 'get_context', please refer to the Mindspore official website.")
1835
+
1836
+
1837
+ def _get_mode():
1838
+ """
1839
+ Get execution mode. Only for internal using.
1840
+
1841
+ Returns:
1842
+ Object: The Value of execution mode.
1843
+ """
1844
+ ctx = _context()
1845
+ return ctx.get_mode()
1846
+
1847
+
1848
+ def get_jit_config():
1849
+ """
1850
+ Get global jit config.
1851
+
1852
+ Returns:
1853
+ Object: The Value of jit config.
1854
+ """
1855
+ ctx = _context()
1856
+ return ctx.get_jit_config()
1857
+
1858
+
1859
+ class ParallelMode:
1860
+ """
1861
+ Parallel mode options.
1862
+
1863
+ There are five kinds of parallel modes, ``STAND_ALONE``, ``DATA_PARALLEL``,
1864
+ ``HYBRID_PARALLEL``, ``SEMI_AUTO_PARALLEL`` and ``AUTO_PARALLEL``. Default: ``STAND_ALONE``.
1865
+
1866
+ - ``STAND_ALONE``: Only one processor is working.
1867
+ - ``DATA_PARALLEL``: Distributes the data across different processors.
1868
+ - ``HYBRID_PARALLEL``: Achieves data parallelism and model parallelism manually.
1869
+ - ``SEMI_AUTO_PARALLEL``: Achieves data parallelism and model parallelism by setting parallel strategies.
1870
+ - ``AUTO_PARALLEL``: Achieves parallelism automatically.
1871
+
1872
+ ``MODE_LIST``: The list of all supported parallel modes.
1873
+ """
1874
+
1875
+ STAND_ALONE = "stand_alone"
1876
+ DATA_PARALLEL = "data_parallel"
1877
+ HYBRID_PARALLEL = "hybrid_parallel"
1878
+ SEMI_AUTO_PARALLEL = "semi_auto_parallel"
1879
+ AUTO_PARALLEL = "auto_parallel"
1880
+ MODE_LIST = [STAND_ALONE, DATA_PARALLEL, HYBRID_PARALLEL, SEMI_AUTO_PARALLEL, AUTO_PARALLEL]
1881
+
1882
+
1883
+ @args_type_check(enable_ps=bool)
1884
+ def set_ps_context(**kwargs):
1885
+ """
1886
+ Set parameter server training mode context.
1887
+
1888
+ Note:
1889
+ Parameter server mode is only supported in graph mode.
1890
+ Some other environment variables should also be set for parameter server training mode.
1891
+ These environment variables are listed below:
1892
+
1893
+ - MS_SERVER_NUM: Server number
1894
+ - MS_WORKER_NUM: Worker number
1895
+ - MS_SCHED_HOST: Scheduler IP address
1896
+ - MS_SCHED_PORT: Scheduler port
1897
+ - MS_ROLE: The role of this process:
1898
+
1899
+ - MS_SCHED: represents the scheduler,
1900
+ - MS_WORKER: represents the worker,
1901
+ - MS_PSERVER/MS_SERVER: represents the Server
1902
+
1903
+ Args:
1904
+ enable_ps (bool): Whether to enable parameter server training mode.
1905
+ Only after enable_ps is set True, the environment variables will be effective.
1906
+ Default: ``False`` .
1907
+ config_file_path (string): Configuration file path used by recovery, parameter server training mode only
1908
+ supports Server disaster recovery currently. Default: ``''`` .
1909
+ scheduler_manage_port (int): Scheduler manage port used to scale out/in. Default: ``11202`` .
1910
+ enable_ssl (bool): Set PS SSL mode enabled or disabled. Default: ``False`` .
1911
+ client_password (str): Password to decrypt the secret key stored in the client certificate. Default: ``''`` .
1912
+ server_password (str): Password to decrypt the secret key stored in the server certificate. Default: ``''`` .
1913
+
1914
+ Raises:
1915
+ ValueError: If input key is not the attribute in parameter server training mode context.
1916
+
1917
+ Examples:
1918
+ >>> import mindspore as ms
1919
+ >>> ms.set_ps_context(enable_ps=True, enable_ssl=True, client_password='123456', server_password='123456')
1920
+ """
1921
+ _set_ps_context(**kwargs)
1922
+
1923
+
1924
+ def get_ps_context(attr_key):
1925
+ """
1926
+ Get parameter server training mode context attribute value according to the key.
1927
+
1928
+ Args:
1929
+ attr_key (str): The key of the attribute:
1930
+
1931
+ - enable_ps (bool): Whether to enable parameter server training mode. Default: ``False`` .
1932
+ - config_file_path (string): Configuration file path used by recovery, parameter server training mode only
1933
+ supports Server disaster recovery currently. Default: ``''`` .
1934
+ - scheduler_manage_port (int): Scheduler manage port used to scale out/in. Default: ``11202`` .
1935
+ - enable_ssl (bool): Set PS SSL mode enabled or disabled. Default: ``False`` .
1936
+ - client_password (str): Password to decrypt the secret key stored in the client certificate.
1937
+ Default: ``''`` .
1938
+ - server_password (str): Password to decrypt the secret key stored in the server certificate.
1939
+ Default: ``''`` .
1940
+
1941
+ Returns:
1942
+ Returns attribute value according to the key.
1943
+
1944
+ Raises:
1945
+ ValueError: If input key is not attribute in auto parallel context.
1946
+
1947
+ Examples:
1948
+ >>> import mindspore as ms
1949
+ >>> ms.get_ps_context("enable_ps")
1950
+ """
1951
+ return _get_ps_context(attr_key)
1952
+
1953
+
1954
+ def reset_ps_context():
1955
+ """
1956
+ Reset parameter server training mode context attributes to the default values.
1957
+
1958
+ Meaning of each field and its default value refer to :func:`mindspore.set_ps_context`.
1959
+
1960
+ Examples:
1961
+ >>> import mindspore as ms
1962
+ >>> ms.reset_ps_context()
1963
+ """
1964
+ _reset_ps_context()
1965
+
1966
+
1967
+ _hccl_connect_timeout = '600'
1968
+
1969
+
1970
+ def _init_parallel_env():
1971
+ """Set hccl connect timeout."""
1972
+ if 'HCCL_CONNECT_TIMEOUT' not in os.environ:
1973
+ os.environ['HCCL_CONNECT_TIMEOUT'] = _hccl_connect_timeout
1974
+
1975
+
1976
+ _init_parallel_env()