mindspore 2.4.0__cp311-cp311-macosx_10_15_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (1387) hide show
  1. mindspore/.commit_id +1 -0
  2. mindspore/__init__.py +53 -0
  3. mindspore/_c_dataengine.cpython-311-darwin.so +0 -0
  4. mindspore/_c_expression.cpython-311-darwin.so +0 -0
  5. mindspore/_c_mindrecord.cpython-311-darwin.so +0 -0
  6. mindspore/_check_jit_forbidden_api.py +106 -0
  7. mindspore/_checkparam.py +1419 -0
  8. mindspore/_extends/__init__.py +23 -0
  9. mindspore/_extends/builtin_operations.py +224 -0
  10. mindspore/_extends/graph_kernel/__init__.py +17 -0
  11. mindspore/_extends/graph_kernel/model/__init__.py +19 -0
  12. mindspore/_extends/graph_kernel/model/graph_parallel.py +311 -0
  13. mindspore/_extends/graph_kernel/model/graph_split.py +1348 -0
  14. mindspore/_extends/graph_kernel/model/model.py +553 -0
  15. mindspore/_extends/graph_kernel/model/model_builder.py +216 -0
  16. mindspore/_extends/graph_kernel/parallel_estimate.py +60 -0
  17. mindspore/_extends/graph_kernel/splitter.py +140 -0
  18. mindspore/_extends/graph_kernel/utils.py +28 -0
  19. mindspore/_extends/parallel_compile/__init__.py +19 -0
  20. mindspore/_extends/parallel_compile/akg_compiler/__init__.py +19 -0
  21. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +269 -0
  22. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +529 -0
  23. mindspore/_extends/parallel_compile/akg_compiler/compiler.py +56 -0
  24. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +96 -0
  25. mindspore/_extends/parallel_compile/akg_compiler/get_file_path.py +36 -0
  26. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +556 -0
  27. mindspore/_extends/parallel_compile/akg_compiler/util.py +159 -0
  28. mindspore/_extends/parse/__init__.py +49 -0
  29. mindspore/_extends/parse/compile_config.py +299 -0
  30. mindspore/_extends/parse/namespace.py +136 -0
  31. mindspore/_extends/parse/parser.py +1448 -0
  32. mindspore/_extends/parse/resources.py +213 -0
  33. mindspore/_extends/parse/standard_method.py +4475 -0
  34. mindspore/_extends/parse/trope.py +97 -0
  35. mindspore/_extends/pijit/__init__.py +23 -0
  36. mindspore/_extends/pijit/pijit_func_white_list.py +669 -0
  37. mindspore/_extends/remote/__init__.py +19 -0
  38. mindspore/_extends/remote/kernel_build_server.py +199 -0
  39. mindspore/_extends/remote/kernel_build_server_akg.py +55 -0
  40. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  41. mindspore/_extends/remote/kernel_build_server_ascend.py +75 -0
  42. mindspore/_extends/utils.py +68 -0
  43. mindspore/_install_custom.py +43 -0
  44. mindspore/_profiler.py +30 -0
  45. mindspore/amp.py +433 -0
  46. mindspore/boost/__init__.py +42 -0
  47. mindspore/boost/adasum.py +319 -0
  48. mindspore/boost/base.py +535 -0
  49. mindspore/boost/boost.py +400 -0
  50. mindspore/boost/boost_cell_wrapper.py +790 -0
  51. mindspore/boost/dim_reduce.py +323 -0
  52. mindspore/boost/grad_accumulation.py +79 -0
  53. mindspore/boost/grad_freeze.py +382 -0
  54. mindspore/boost/group_loss_scale_manager.py +166 -0
  55. mindspore/boost/less_batch_normalization.py +174 -0
  56. mindspore/common/__init__.py +86 -0
  57. mindspore/common/_auto_dynamic.py +68 -0
  58. mindspore/common/_decorator.py +50 -0
  59. mindspore/common/_jit_fallback_utils.py +110 -0
  60. mindspore/common/_monad.py +25 -0
  61. mindspore/common/_pijit_context.py +190 -0
  62. mindspore/common/_register_for_adapter.py +74 -0
  63. mindspore/common/_register_for_recompute.py +48 -0
  64. mindspore/common/_register_for_tensor.py +46 -0
  65. mindspore/common/_stub_tensor.py +210 -0
  66. mindspore/common/_tensor_overload.py +139 -0
  67. mindspore/common/_utils.py +122 -0
  68. mindspore/common/api.py +2064 -0
  69. mindspore/common/auto_dynamic_shape.py +507 -0
  70. mindspore/common/dtype.py +422 -0
  71. mindspore/common/dump.py +130 -0
  72. mindspore/common/file_system.py +48 -0
  73. mindspore/common/generator.py +254 -0
  74. mindspore/common/hook_handle.py +143 -0
  75. mindspore/common/initializer.py +880 -0
  76. mindspore/common/jit_config.py +98 -0
  77. mindspore/common/lazy_inline.py +240 -0
  78. mindspore/common/mindir_util.py +111 -0
  79. mindspore/common/mutable.py +234 -0
  80. mindspore/common/no_inline.py +54 -0
  81. mindspore/common/np_dtype.py +25 -0
  82. mindspore/common/parameter.py +1081 -0
  83. mindspore/common/recompute.py +292 -0
  84. mindspore/common/seed.py +260 -0
  85. mindspore/common/sparse_tensor.py +1175 -0
  86. mindspore/common/symbol.py +122 -0
  87. mindspore/common/tensor.py +5039 -0
  88. mindspore/communication/__init__.py +37 -0
  89. mindspore/communication/_comm_helper.py +501 -0
  90. mindspore/communication/_hccl_management.py +297 -0
  91. mindspore/communication/comm_func.py +1395 -0
  92. mindspore/communication/management.py +673 -0
  93. mindspore/config/op_info.config +533 -0
  94. mindspore/context.py +2077 -0
  95. mindspore/dataset/__init__.py +90 -0
  96. mindspore/dataset/audio/__init__.py +61 -0
  97. mindspore/dataset/audio/transforms.py +3690 -0
  98. mindspore/dataset/audio/utils.py +386 -0
  99. mindspore/dataset/audio/validators.py +1172 -0
  100. mindspore/dataset/callback/__init__.py +20 -0
  101. mindspore/dataset/callback/ds_callback.py +368 -0
  102. mindspore/dataset/callback/validators.py +32 -0
  103. mindspore/dataset/core/__init__.py +13 -0
  104. mindspore/dataset/core/config.py +1095 -0
  105. mindspore/dataset/core/datatypes.py +101 -0
  106. mindspore/dataset/core/py_util_helpers.py +65 -0
  107. mindspore/dataset/core/validator_helpers.py +781 -0
  108. mindspore/dataset/debug/__init__.py +21 -0
  109. mindspore/dataset/debug/debug_hook.py +97 -0
  110. mindspore/dataset/debug/pre_defined_hook.py +67 -0
  111. mindspore/dataset/engine/__init__.py +124 -0
  112. mindspore/dataset/engine/cache_admin.py +47 -0
  113. mindspore/dataset/engine/cache_client.py +129 -0
  114. mindspore/dataset/engine/datasets.py +4582 -0
  115. mindspore/dataset/engine/datasets_audio.py +911 -0
  116. mindspore/dataset/engine/datasets_standard_format.py +543 -0
  117. mindspore/dataset/engine/datasets_text.py +2161 -0
  118. mindspore/dataset/engine/datasets_user_defined.py +1184 -0
  119. mindspore/dataset/engine/datasets_vision.py +4816 -0
  120. mindspore/dataset/engine/iterators.py +371 -0
  121. mindspore/dataset/engine/obs/__init__.py +23 -0
  122. mindspore/dataset/engine/obs/config_loader.py +68 -0
  123. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +508 -0
  124. mindspore/dataset/engine/obs/util.py +482 -0
  125. mindspore/dataset/engine/offload.py +596 -0
  126. mindspore/dataset/engine/queue.py +304 -0
  127. mindspore/dataset/engine/samplers.py +895 -0
  128. mindspore/dataset/engine/serializer_deserializer.py +159 -0
  129. mindspore/dataset/engine/validators.py +2895 -0
  130. mindspore/dataset/text/__init__.py +51 -0
  131. mindspore/dataset/text/transforms.py +1703 -0
  132. mindspore/dataset/text/utils.py +715 -0
  133. mindspore/dataset/text/validators.py +642 -0
  134. mindspore/dataset/transforms/__init__.py +45 -0
  135. mindspore/dataset/transforms/c_transforms.py +638 -0
  136. mindspore/dataset/transforms/py_transforms.py +393 -0
  137. mindspore/dataset/transforms/py_transforms_util.py +255 -0
  138. mindspore/dataset/transforms/transforms.py +1260 -0
  139. mindspore/dataset/transforms/validators.py +410 -0
  140. mindspore/dataset/utils/__init__.py +19 -0
  141. mindspore/dataset/utils/browse_dataset.py +190 -0
  142. mindspore/dataset/utils/line_reader.py +126 -0
  143. mindspore/dataset/vision/__init__.py +65 -0
  144. mindspore/dataset/vision/c_transforms.py +2641 -0
  145. mindspore/dataset/vision/py_transforms.py +2120 -0
  146. mindspore/dataset/vision/py_transforms_util.py +1660 -0
  147. mindspore/dataset/vision/transforms.py +7295 -0
  148. mindspore/dataset/vision/utils.py +863 -0
  149. mindspore/dataset/vision/validators.py +1483 -0
  150. mindspore/default_config.py +2 -0
  151. mindspore/experimental/__init__.py +20 -0
  152. mindspore/experimental/es/__init__.py +22 -0
  153. mindspore/experimental/es/embedding_service.py +883 -0
  154. mindspore/experimental/es/embedding_service_layer.py +581 -0
  155. mindspore/experimental/llm_boost/__init__.py +21 -0
  156. mindspore/experimental/llm_boost/atb/__init__.py +23 -0
  157. mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
  158. mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
  159. mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
  160. mindspore/experimental/llm_boost/register.py +129 -0
  161. mindspore/experimental/llm_boost/utils.py +31 -0
  162. mindspore/experimental/map_parameter.py +309 -0
  163. mindspore/experimental/optim/__init__.py +40 -0
  164. mindspore/experimental/optim/adadelta.py +161 -0
  165. mindspore/experimental/optim/adagrad.py +168 -0
  166. mindspore/experimental/optim/adam.py +193 -0
  167. mindspore/experimental/optim/adamax.py +170 -0
  168. mindspore/experimental/optim/adamw.py +290 -0
  169. mindspore/experimental/optim/asgd.py +153 -0
  170. mindspore/experimental/optim/lr_scheduler.py +1371 -0
  171. mindspore/experimental/optim/nadam.py +157 -0
  172. mindspore/experimental/optim/optimizer.py +262 -0
  173. mindspore/experimental/optim/radam.py +194 -0
  174. mindspore/experimental/optim/rmsprop.py +154 -0
  175. mindspore/experimental/optim/rprop.py +164 -0
  176. mindspore/experimental/optim/sgd.py +156 -0
  177. mindspore/hal/__init__.py +40 -0
  178. mindspore/hal/_ascend.py +57 -0
  179. mindspore/hal/_base.py +57 -0
  180. mindspore/hal/_cpu.py +56 -0
  181. mindspore/hal/_gpu.py +57 -0
  182. mindspore/hal/contiguous_tensors_handle.py +175 -0
  183. mindspore/hal/device.py +356 -0
  184. mindspore/hal/event.py +179 -0
  185. mindspore/hal/memory.py +326 -0
  186. mindspore/hal/stream.py +357 -0
  187. mindspore/include/OWNERS +7 -0
  188. mindspore/include/api/allocator.h +97 -0
  189. mindspore/include/api/callback/callback.h +93 -0
  190. mindspore/include/api/callback/ckpt_saver.h +41 -0
  191. mindspore/include/api/callback/loss_monitor.h +33 -0
  192. mindspore/include/api/callback/lr_scheduler.h +51 -0
  193. mindspore/include/api/callback/time_monitor.h +34 -0
  194. mindspore/include/api/callback/train_accuracy.h +37 -0
  195. mindspore/include/api/cell.h +90 -0
  196. mindspore/include/api/cfg.h +82 -0
  197. mindspore/include/api/context.h +602 -0
  198. mindspore/include/api/data_type.h +47 -0
  199. mindspore/include/api/delegate.h +178 -0
  200. mindspore/include/api/delegate_api.h +75 -0
  201. mindspore/include/api/dual_abi_helper.h +208 -0
  202. mindspore/include/api/format.h +28 -0
  203. mindspore/include/api/graph.h +46 -0
  204. mindspore/include/api/kernel.h +58 -0
  205. mindspore/include/api/kernel_api.h +168 -0
  206. mindspore/include/api/metrics/accuracy.h +36 -0
  207. mindspore/include/api/metrics/metrics.h +41 -0
  208. mindspore/include/api/model.h +438 -0
  209. mindspore/include/api/model_group.h +91 -0
  210. mindspore/include/api/model_parallel_runner.h +168 -0
  211. mindspore/include/api/serialization.h +185 -0
  212. mindspore/include/api/status.h +192 -0
  213. mindspore/include/api/types.h +431 -0
  214. mindspore/include/api/visible.h +41 -0
  215. mindspore/include/c_api/context_c.h +179 -0
  216. mindspore/include/c_api/data_type_c.h +52 -0
  217. mindspore/include/c_api/format_c.h +46 -0
  218. mindspore/include/c_api/model_c.h +347 -0
  219. mindspore/include/c_api/status_c.h +79 -0
  220. mindspore/include/c_api/tensor_c.h +146 -0
  221. mindspore/include/c_api/types_c.h +67 -0
  222. mindspore/include/dataset/config.h +163 -0
  223. mindspore/include/dataset/constants.h +363 -0
  224. mindspore/include/dataset/execute.h +196 -0
  225. mindspore/include/dataset/text.h +1092 -0
  226. mindspore/include/dataset/transforms.h +638 -0
  227. mindspore/include/dataset/vision.h +2129 -0
  228. mindspore/include/dataset/vision_ascend.h +206 -0
  229. mindspore/include/dataset/vision_lite.h +625 -0
  230. mindspore/lib/libavcodec.59.dylib +0 -0
  231. mindspore/lib/libavdevice.59.dylib +0 -0
  232. mindspore/lib/libavfilter.8.dylib +0 -0
  233. mindspore/lib/libavformat.59.dylib +0 -0
  234. mindspore/lib/libavutil.57.dylib +0 -0
  235. mindspore/lib/libdnnl.2.dylib +0 -0
  236. mindspore/lib/libicudata.69.dylib +0 -0
  237. mindspore/lib/libicui18n.69.dylib +0 -0
  238. mindspore/lib/libicuuc.69.dylib +0 -0
  239. mindspore/lib/libmindspore_address_sorting.15.dylib +0 -0
  240. mindspore/lib/libmindspore_backend.dylib +0 -0
  241. mindspore/lib/libmindspore_common.dylib +0 -0
  242. mindspore/lib/libmindspore_core.dylib +0 -0
  243. mindspore/lib/libmindspore_glog.0.dylib +0 -0
  244. mindspore/lib/libmindspore_gpr.15.dylib +0 -0
  245. mindspore/lib/libmindspore_grpc++.1.dylib +0 -0
  246. mindspore/lib/libmindspore_grpc.15.dylib +0 -0
  247. mindspore/lib/libmindspore_np_dtype.dylib +0 -0
  248. mindspore/lib/libmindspore_ops.dylib +0 -0
  249. mindspore/lib/libmindspore_upb.15.dylib +0 -0
  250. mindspore/lib/libnnacl.dylib +0 -0
  251. mindspore/lib/libopencv_core.4.5.dylib +0 -0
  252. mindspore/lib/libopencv_imgcodecs.4.5.dylib +0 -0
  253. mindspore/lib/libopencv_imgproc.4.5.dylib +0 -0
  254. mindspore/lib/libps_cache.dylib +0 -0
  255. mindspore/lib/libswresample.4.dylib +0 -0
  256. mindspore/lib/libswscale.6.dylib +0 -0
  257. mindspore/lib/libtinyxml2.8.dylib +0 -0
  258. mindspore/log.py +633 -0
  259. mindspore/mindrecord/__init__.py +43 -0
  260. mindspore/mindrecord/common/__init__.py +17 -0
  261. mindspore/mindrecord/common/constant.py +20 -0
  262. mindspore/mindrecord/common/enums.py +44 -0
  263. mindspore/mindrecord/common/exceptions.py +311 -0
  264. mindspore/mindrecord/config.py +809 -0
  265. mindspore/mindrecord/filereader.py +174 -0
  266. mindspore/mindrecord/filewriter.py +722 -0
  267. mindspore/mindrecord/mindpage.py +210 -0
  268. mindspore/mindrecord/shardheader.py +141 -0
  269. mindspore/mindrecord/shardindexgenerator.py +74 -0
  270. mindspore/mindrecord/shardreader.py +117 -0
  271. mindspore/mindrecord/shardsegment.py +128 -0
  272. mindspore/mindrecord/shardutils.py +185 -0
  273. mindspore/mindrecord/shardwriter.py +237 -0
  274. mindspore/mindrecord/tools/__init__.py +17 -0
  275. mindspore/mindrecord/tools/cifar10.py +140 -0
  276. mindspore/mindrecord/tools/cifar100.py +153 -0
  277. mindspore/mindrecord/tools/cifar100_to_mr.py +185 -0
  278. mindspore/mindrecord/tools/cifar10_to_mr.py +177 -0
  279. mindspore/mindrecord/tools/csv_to_mr.py +200 -0
  280. mindspore/mindrecord/tools/imagenet_to_mr.py +206 -0
  281. mindspore/mindrecord/tools/mnist_to_mr.py +259 -0
  282. mindspore/mindrecord/tools/tfrecord_to_mr.py +360 -0
  283. mindspore/mint/__init__.py +1586 -0
  284. mindspore/mint/distributed/__init__.py +31 -0
  285. mindspore/mint/distributed/distributed.py +254 -0
  286. mindspore/mint/linalg/__init__.py +22 -0
  287. mindspore/mint/nn/__init__.py +757 -0
  288. mindspore/mint/nn/functional.py +679 -0
  289. mindspore/mint/nn/layer/__init__.py +39 -0
  290. mindspore/mint/nn/layer/activation.py +133 -0
  291. mindspore/mint/nn/layer/normalization.py +477 -0
  292. mindspore/mint/nn/layer/pooling.py +110 -0
  293. mindspore/mint/optim/__init__.py +24 -0
  294. mindspore/mint/optim/adamw.py +206 -0
  295. mindspore/mint/special/__init__.py +63 -0
  296. mindspore/multiprocessing/__init__.py +73 -0
  297. mindspore/nn/__init__.py +47 -0
  298. mindspore/nn/cell.py +2787 -0
  299. mindspore/nn/dynamic_lr.py +482 -0
  300. mindspore/nn/grad/__init__.py +21 -0
  301. mindspore/nn/grad/cell_grad.py +196 -0
  302. mindspore/nn/layer/__init__.py +63 -0
  303. mindspore/nn/layer/activation.py +1822 -0
  304. mindspore/nn/layer/basic.py +1629 -0
  305. mindspore/nn/layer/channel_shuffle.py +90 -0
  306. mindspore/nn/layer/combined.py +248 -0
  307. mindspore/nn/layer/container.py +734 -0
  308. mindspore/nn/layer/conv.py +1505 -0
  309. mindspore/nn/layer/dense.py +204 -0
  310. mindspore/nn/layer/embedding.py +869 -0
  311. mindspore/nn/layer/image.py +661 -0
  312. mindspore/nn/layer/math.py +1069 -0
  313. mindspore/nn/layer/normalization.py +1273 -0
  314. mindspore/nn/layer/padding.py +880 -0
  315. mindspore/nn/layer/pooling.py +2302 -0
  316. mindspore/nn/layer/rnn_cells.py +388 -0
  317. mindspore/nn/layer/rnns.py +849 -0
  318. mindspore/nn/layer/thor_layer.py +963 -0
  319. mindspore/nn/layer/timedistributed.py +155 -0
  320. mindspore/nn/layer/transformer.py +823 -0
  321. mindspore/nn/learning_rate_schedule.py +512 -0
  322. mindspore/nn/loss/__init__.py +36 -0
  323. mindspore/nn/loss/loss.py +2924 -0
  324. mindspore/nn/metrics.py +53 -0
  325. mindspore/nn/optim/__init__.py +45 -0
  326. mindspore/nn/optim/_dist_optimizer_registry.py +111 -0
  327. mindspore/nn/optim/ada_grad.py +217 -0
  328. mindspore/nn/optim/adadelta.py +206 -0
  329. mindspore/nn/optim/adafactor.py +448 -0
  330. mindspore/nn/optim/adam.py +1297 -0
  331. mindspore/nn/optim/adamax.py +220 -0
  332. mindspore/nn/optim/adasum.py +548 -0
  333. mindspore/nn/optim/asgd.py +216 -0
  334. mindspore/nn/optim/ftrl.py +401 -0
  335. mindspore/nn/optim/lamb.py +296 -0
  336. mindspore/nn/optim/lars.py +202 -0
  337. mindspore/nn/optim/lazyadam.py +533 -0
  338. mindspore/nn/optim/momentum.py +239 -0
  339. mindspore/nn/optim/optimizer.py +1034 -0
  340. mindspore/nn/optim/proximal_ada_grad.py +242 -0
  341. mindspore/nn/optim/rmsprop.py +264 -0
  342. mindspore/nn/optim/rprop.py +251 -0
  343. mindspore/nn/optim/sgd.py +237 -0
  344. mindspore/nn/optim/tft_wrapper.py +127 -0
  345. mindspore/nn/optim/thor.py +1310 -0
  346. mindspore/nn/probability/__init__.py +22 -0
  347. mindspore/nn/probability/bijector/__init__.py +35 -0
  348. mindspore/nn/probability/bijector/bijector.py +337 -0
  349. mindspore/nn/probability/bijector/exp.py +65 -0
  350. mindspore/nn/probability/bijector/gumbel_cdf.py +144 -0
  351. mindspore/nn/probability/bijector/invert.py +126 -0
  352. mindspore/nn/probability/bijector/power_transform.py +196 -0
  353. mindspore/nn/probability/bijector/scalar_affine.py +167 -0
  354. mindspore/nn/probability/bijector/softplus.py +189 -0
  355. mindspore/nn/probability/bnn_layers/__init__.py +29 -0
  356. mindspore/nn/probability/bnn_layers/_util.py +46 -0
  357. mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +112 -0
  358. mindspore/nn/probability/bnn_layers/conv_variational.py +267 -0
  359. mindspore/nn/probability/bnn_layers/dense_variational.py +302 -0
  360. mindspore/nn/probability/bnn_layers/layer_distribution.py +123 -0
  361. mindspore/nn/probability/distribution/__init__.py +56 -0
  362. mindspore/nn/probability/distribution/_utils/__init__.py +34 -0
  363. mindspore/nn/probability/distribution/_utils/custom_ops.py +96 -0
  364. mindspore/nn/probability/distribution/_utils/utils.py +362 -0
  365. mindspore/nn/probability/distribution/bernoulli.py +334 -0
  366. mindspore/nn/probability/distribution/beta.py +391 -0
  367. mindspore/nn/probability/distribution/categorical.py +435 -0
  368. mindspore/nn/probability/distribution/cauchy.py +383 -0
  369. mindspore/nn/probability/distribution/distribution.py +827 -0
  370. mindspore/nn/probability/distribution/exponential.py +350 -0
  371. mindspore/nn/probability/distribution/gamma.py +391 -0
  372. mindspore/nn/probability/distribution/geometric.py +335 -0
  373. mindspore/nn/probability/distribution/gumbel.py +257 -0
  374. mindspore/nn/probability/distribution/half_normal.py +133 -0
  375. mindspore/nn/probability/distribution/laplace.py +128 -0
  376. mindspore/nn/probability/distribution/log_normal.py +272 -0
  377. mindspore/nn/probability/distribution/logistic.py +379 -0
  378. mindspore/nn/probability/distribution/normal.py +336 -0
  379. mindspore/nn/probability/distribution/poisson.py +288 -0
  380. mindspore/nn/probability/distribution/student_t.py +149 -0
  381. mindspore/nn/probability/distribution/transformed_distribution.py +235 -0
  382. mindspore/nn/probability/distribution/uniform.py +375 -0
  383. mindspore/nn/reinforcement/__init__.py +24 -0
  384. mindspore/nn/reinforcement/_batch_read_write.py +142 -0
  385. mindspore/nn/reinforcement/_tensors_queue.py +152 -0
  386. mindspore/nn/reinforcement/tensor_array.py +145 -0
  387. mindspore/nn/sparse/__init__.py +23 -0
  388. mindspore/nn/sparse/sparse.py +147 -0
  389. mindspore/nn/wrap/__init__.py +49 -0
  390. mindspore/nn/wrap/cell_wrapper.py +968 -0
  391. mindspore/nn/wrap/grad_reducer.py +608 -0
  392. mindspore/nn/wrap/loss_scale.py +694 -0
  393. mindspore/numpy/__init__.py +121 -0
  394. mindspore/numpy/array_creations.py +2731 -0
  395. mindspore/numpy/array_ops.py +2629 -0
  396. mindspore/numpy/dtypes.py +185 -0
  397. mindspore/numpy/fft.py +966 -0
  398. mindspore/numpy/logic_ops.py +936 -0
  399. mindspore/numpy/math_ops.py +5911 -0
  400. mindspore/numpy/utils.py +214 -0
  401. mindspore/numpy/utils_const.py +565 -0
  402. mindspore/ops/__init__.py +56 -0
  403. mindspore/ops/_constants.py +30 -0
  404. mindspore/ops/_grad_experimental/__init__.py +31 -0
  405. mindspore/ops/_grad_experimental/grad_array_ops.py +830 -0
  406. mindspore/ops/_grad_experimental/grad_base.py +143 -0
  407. mindspore/ops/_grad_experimental/grad_comm_ops.py +714 -0
  408. mindspore/ops/_grad_experimental/grad_debug_ops.py +31 -0
  409. mindspore/ops/_grad_experimental/grad_implementations.py +203 -0
  410. mindspore/ops/_grad_experimental/grad_inner_ops.py +79 -0
  411. mindspore/ops/_grad_experimental/grad_math_ops.py +802 -0
  412. mindspore/ops/_grad_experimental/grad_nn_ops.py +231 -0
  413. mindspore/ops/_grad_experimental/grad_quant_ops.py +238 -0
  414. mindspore/ops/_grad_experimental/grad_sparse.py +342 -0
  415. mindspore/ops/_grad_experimental/grad_sparse_ops.py +399 -0
  416. mindspore/ops/_grad_experimental/taylor_rule.py +220 -0
  417. mindspore/ops/_op_impl/__init__.py +23 -0
  418. mindspore/ops/_op_impl/_custom_op/__init__.py +39 -0
  419. mindspore/ops/_op_impl/_custom_op/_basic.py +158 -0
  420. mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +279 -0
  421. mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +156 -0
  422. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +109 -0
  423. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +125 -0
  424. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +105 -0
  425. mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +124 -0
  426. mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +116 -0
  427. mindspore/ops/_op_impl/_custom_op/correction_mul.py +89 -0
  428. mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +196 -0
  429. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +366 -0
  430. mindspore/ops/_op_impl/_custom_op/dsd_impl.py +162 -0
  431. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +136 -0
  432. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +206 -0
  433. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +88 -0
  434. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +128 -0
  435. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +199 -0
  436. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +88 -0
  437. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +156 -0
  438. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +184 -0
  439. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +143 -0
  440. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +169 -0
  441. mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +548 -0
  442. mindspore/ops/_op_impl/_custom_op/img2col_impl.py +881 -0
  443. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +278 -0
  444. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +200 -0
  445. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +334 -0
  446. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +255 -0
  447. mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +222 -0
  448. mindspore/ops/_op_impl/_custom_op/matmul_dds_grad_impl.py +644 -0
  449. mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +488 -0
  450. mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +87 -0
  451. mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +129 -0
  452. mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +121 -0
  453. mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +352 -0
  454. mindspore/ops/_op_impl/aicpu/__init__.py +441 -0
  455. mindspore/ops/_op_impl/aicpu/abs.py +36 -0
  456. mindspore/ops/_op_impl/aicpu/acos.py +32 -0
  457. mindspore/ops/_op_impl/aicpu/acos_grad.py +33 -0
  458. mindspore/ops/_op_impl/aicpu/acosh.py +34 -0
  459. mindspore/ops/_op_impl/aicpu/acosh_grad.py +35 -0
  460. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d.py +34 -0
  461. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
  462. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d.py +39 -0
  463. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d_grad.py +39 -0
  464. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
  465. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d_grad.py +37 -0
  466. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d.py +42 -0
  467. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d_grad.py +152 -0
  468. mindspore/ops/_op_impl/aicpu/add.py +43 -0
  469. mindspore/ops/_op_impl/aicpu/add_n.py +41 -0
  470. mindspore/ops/_op_impl/aicpu/add_v2.py +40 -0
  471. mindspore/ops/_op_impl/aicpu/addcdiv.py +41 -0
  472. mindspore/ops/_op_impl/aicpu/addcmul.py +47 -0
  473. mindspore/ops/_op_impl/aicpu/adjust_contrastv2.py +32 -0
  474. mindspore/ops/_op_impl/aicpu/adjust_hue.py +31 -0
  475. mindspore/ops/_op_impl/aicpu/adjust_saturation.py +32 -0
  476. mindspore/ops/_op_impl/aicpu/affine_grid.py +33 -0
  477. mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
  478. mindspore/ops/_op_impl/aicpu/angle.py +31 -0
  479. mindspore/ops/_op_impl/aicpu/arg_max.py +75 -0
  480. mindspore/ops/_op_impl/aicpu/arg_min.py +75 -0
  481. mindspore/ops/_op_impl/aicpu/argmax_with_value.py +43 -0
  482. mindspore/ops/_op_impl/aicpu/argmin_with_value.py +43 -0
  483. mindspore/ops/_op_impl/aicpu/asin.py +32 -0
  484. mindspore/ops/_op_impl/aicpu/asin_grad.py +33 -0
  485. mindspore/ops/_op_impl/aicpu/asinh.py +34 -0
  486. mindspore/ops/_op_impl/aicpu/asinh_grad.py +35 -0
  487. mindspore/ops/_op_impl/aicpu/atanh.py +34 -0
  488. mindspore/ops/_op_impl/aicpu/avgpool_grad_v1.py +37 -0
  489. mindspore/ops/_op_impl/aicpu/avgpool_v1.py +36 -0
  490. mindspore/ops/_op_impl/aicpu/bartlett_window.py +36 -0
  491. mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -0
  492. mindspore/ops/_op_impl/aicpu/batch_norm_grad_grad.py +49 -0
  493. mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
  494. mindspore/ops/_op_impl/aicpu/bessel_i0.py +31 -0
  495. mindspore/ops/_op_impl/aicpu/betainc.py +31 -0
  496. mindspore/ops/_op_impl/aicpu/bias_add.py +44 -0
  497. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +42 -0
  498. mindspore/ops/_op_impl/aicpu/bincount.py +33 -0
  499. mindspore/ops/_op_impl/aicpu/blackman_window.py +36 -0
  500. mindspore/ops/_op_impl/aicpu/broadcast_to.py +58 -0
  501. mindspore/ops/_op_impl/aicpu/bucketize.py +34 -0
  502. mindspore/ops/_op_impl/aicpu/cache_swap_table.py +102 -0
  503. mindspore/ops/_op_impl/aicpu/cast.py +225 -0
  504. mindspore/ops/_op_impl/aicpu/cauchy.py +33 -0
  505. mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
  506. mindspore/ops/_op_impl/aicpu/check_numerics.py +33 -0
  507. mindspore/ops/_op_impl/aicpu/cholesky.py +32 -0
  508. mindspore/ops/_op_impl/aicpu/cholesky_inverse.py +31 -0
  509. mindspore/ops/_op_impl/aicpu/cholesky_solve.py +33 -0
  510. mindspore/ops/_op_impl/aicpu/choleskygrad.py +32 -0
  511. mindspore/ops/_op_impl/aicpu/coalesce.py +37 -0
  512. mindspore/ops/_op_impl/aicpu/col2im.py +38 -0
  513. mindspore/ops/_op_impl/aicpu/combined_non_max_suppression.py +42 -0
  514. mindspore/ops/_op_impl/aicpu/compare_and_bitpack.py +37 -0
  515. mindspore/ops/_op_impl/aicpu/complex.py +32 -0
  516. mindspore/ops/_op_impl/aicpu/complex_abs.py +31 -0
  517. mindspore/ops/_op_impl/aicpu/compute_accidental_hits.py +44 -0
  518. mindspore/ops/_op_impl/aicpu/concat.py +57 -0
  519. mindspore/ops/_op_impl/aicpu/concat_offset.py +42 -0
  520. mindspore/ops/_op_impl/aicpu/concat_offset_v1.py +31 -0
  521. mindspore/ops/_op_impl/aicpu/conj.py +42 -0
  522. mindspore/ops/_op_impl/aicpu/conjugate_transpose.py +58 -0
  523. mindspore/ops/_op_impl/aicpu/cos.py +34 -0
  524. mindspore/ops/_op_impl/aicpu/cosh.py +34 -0
  525. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  526. mindspore/ops/_op_impl/aicpu/crop_and_resize.py +69 -0
  527. mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_boxes.py +68 -0
  528. mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_image.py +38 -0
  529. mindspore/ops/_op_impl/aicpu/cross.py +42 -0
  530. mindspore/ops/_op_impl/aicpu/csr_sparse_matrix_to_dense.py +48 -0
  531. mindspore/ops/_op_impl/aicpu/csr_sparse_matrix_to_sparse_tensor.py +51 -0
  532. mindspore/ops/_op_impl/aicpu/ctc_greedy_decoder.py +35 -0
  533. mindspore/ops/_op_impl/aicpu/ctc_loss_v2.py +43 -0
  534. mindspore/ops/_op_impl/aicpu/ctc_loss_v2_grad.py +45 -0
  535. mindspore/ops/_op_impl/aicpu/ctcloss.py +38 -0
  536. mindspore/ops/_op_impl/aicpu/cummax.py +41 -0
  537. mindspore/ops/_op_impl/aicpu/cumprod.py +58 -0
  538. mindspore/ops/_op_impl/aicpu/cumsum.py +58 -0
  539. mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +36 -0
  540. mindspore/ops/_op_impl/aicpu/data_format_vec_permute.py +32 -0
  541. mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
  542. mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +43 -0
  543. mindspore/ops/_op_impl/aicpu/dense_to_csr_sparse_matrix.py +49 -0
  544. mindspore/ops/_op_impl/aicpu/dense_to_dense_set_operation.py +45 -0
  545. mindspore/ops/_op_impl/aicpu/dense_to_sparse_set_operation.py +48 -0
  546. mindspore/ops/_op_impl/aicpu/depth_to_space.py +44 -0
  547. mindspore/ops/_op_impl/aicpu/diag.py +36 -0
  548. mindspore/ops/_op_impl/aicpu/diag_part.py +36 -0
  549. mindspore/ops/_op_impl/aicpu/diagonal.py +35 -0
  550. mindspore/ops/_op_impl/aicpu/digamma.py +31 -0
  551. mindspore/ops/_op_impl/aicpu/div.py +41 -0
  552. mindspore/ops/_op_impl/aicpu/div_no_nan.py +35 -0
  553. mindspore/ops/_op_impl/aicpu/dropout2d.py +42 -0
  554. mindspore/ops/_op_impl/aicpu/dropout3d.py +42 -0
  555. mindspore/ops/_op_impl/aicpu/dropout_genmask.py +41 -0
  556. mindspore/ops/_op_impl/aicpu/dropout_genmask_v3.py +32 -0
  557. mindspore/ops/_op_impl/aicpu/dynamic_stitch.py +42 -0
  558. mindspore/ops/_op_impl/aicpu/edit_distance.py +56 -0
  559. mindspore/ops/_op_impl/aicpu/eig.py +35 -0
  560. mindspore/ops/_op_impl/aicpu/embedding_lookup.py +102 -0
  561. mindspore/ops/_op_impl/aicpu/end_of_sequence.py +30 -0
  562. mindspore/ops/_op_impl/aicpu/environ_create.py +28 -0
  563. mindspore/ops/_op_impl/aicpu/environ_destroy_all.py +28 -0
  564. mindspore/ops/_op_impl/aicpu/environ_get.py +41 -0
  565. mindspore/ops/_op_impl/aicpu/environ_set.py +40 -0
  566. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  567. mindspore/ops/_op_impl/aicpu/equal.py +41 -0
  568. mindspore/ops/_op_impl/aicpu/exp.py +37 -0
  569. mindspore/ops/_op_impl/aicpu/expand.py +45 -0
  570. mindspore/ops/_op_impl/aicpu/expand_dims.py +42 -0
  571. mindspore/ops/_op_impl/aicpu/expm1.py +34 -0
  572. mindspore/ops/_op_impl/aicpu/extract_glimpse.py +35 -0
  573. mindspore/ops/_op_impl/aicpu/eye.py +44 -0
  574. mindspore/ops/_op_impl/aicpu/fft_with_size.py +47 -0
  575. mindspore/ops/_op_impl/aicpu/fill_diagonal.py +39 -0
  576. mindspore/ops/_op_impl/aicpu/fill_v2.py +58 -0
  577. mindspore/ops/_op_impl/aicpu/flatten.py +43 -0
  578. mindspore/ops/_op_impl/aicpu/floor_div.py +38 -0
  579. mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
  580. mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
  581. mindspore/ops/_op_impl/aicpu/fractional_avg_pool.py +41 -0
  582. mindspore/ops/_op_impl/aicpu/fractional_avg_pool_grad.py +41 -0
  583. mindspore/ops/_op_impl/aicpu/fractional_max_pool.py +41 -0
  584. mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_grad_with_fixed_ksize.py +43 -0
  585. mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +65 -0
  586. mindspore/ops/_op_impl/aicpu/fractional_max_pool_grad.py +42 -0
  587. mindspore/ops/_op_impl/aicpu/fractional_max_pool_grad_with_fixed_ksize.py +42 -0
  588. mindspore/ops/_op_impl/aicpu/fractional_max_pool_with_fixed_ksize.py +49 -0
  589. mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
  590. mindspore/ops/_op_impl/aicpu/fused_sparse_adam.py +46 -0
  591. mindspore/ops/_op_impl/aicpu/fused_sparse_ftrl.py +41 -0
  592. mindspore/ops/_op_impl/aicpu/fused_sparse_lazy_adam.py +46 -0
  593. mindspore/ops/_op_impl/aicpu/fused_sparse_proximal_adagrad.py +39 -0
  594. mindspore/ops/_op_impl/aicpu/gamma.py +38 -0
  595. mindspore/ops/_op_impl/aicpu/gather.py +46 -0
  596. mindspore/ops/_op_impl/aicpu/gather_d.py +79 -0
  597. mindspore/ops/_op_impl/aicpu/gather_d_grad_v2.py +79 -0
  598. mindspore/ops/_op_impl/aicpu/gather_grad.py +54 -0
  599. mindspore/ops/_op_impl/aicpu/gather_nd.py +56 -0
  600. mindspore/ops/_op_impl/aicpu/gcd.py +32 -0
  601. mindspore/ops/_op_impl/aicpu/generate_eod_mask.py +38 -0
  602. mindspore/ops/_op_impl/aicpu/geqrf.py +32 -0
  603. mindspore/ops/_op_impl/aicpu/get_next.py +39 -0
  604. mindspore/ops/_op_impl/aicpu/glu.py +33 -0
  605. mindspore/ops/_op_impl/aicpu/glu_grad.py +34 -0
  606. mindspore/ops/_op_impl/aicpu/greater.py +41 -0
  607. mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
  608. mindspore/ops/_op_impl/aicpu/grid_sampler_2d.py +35 -0
  609. mindspore/ops/_op_impl/aicpu/grid_sampler_2d_grad.py +38 -0
  610. mindspore/ops/_op_impl/aicpu/grid_sampler_3d.py +34 -0
  611. mindspore/ops/_op_impl/aicpu/grid_sampler_3d_grad.py +38 -0
  612. mindspore/ops/_op_impl/aicpu/hamming_window.py +57 -0
  613. mindspore/ops/_op_impl/aicpu/hard_sigmoid.py +32 -0
  614. mindspore/ops/_op_impl/aicpu/hard_sigmoid_grad.py +33 -0
  615. mindspore/ops/_op_impl/aicpu/heaviside.py +40 -0
  616. mindspore/ops/_op_impl/aicpu/histogram.py +35 -0
  617. mindspore/ops/_op_impl/aicpu/hsv_to_rgb.py +32 -0
  618. mindspore/ops/_op_impl/aicpu/hypot.py +32 -0
  619. mindspore/ops/_op_impl/aicpu/identity.py +42 -0
  620. mindspore/ops/_op_impl/aicpu/identity_n.py +41 -0
  621. mindspore/ops/_op_impl/aicpu/igamma.py +30 -0
  622. mindspore/ops/_op_impl/aicpu/igammac.py +30 -0
  623. mindspore/ops/_op_impl/aicpu/igammagrada.py +30 -0
  624. mindspore/ops/_op_impl/aicpu/im2col.py +43 -0
  625. mindspore/ops/_op_impl/aicpu/imag.py +31 -0
  626. mindspore/ops/_op_impl/aicpu/index_fill.py +54 -0
  627. mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
  628. mindspore/ops/_op_impl/aicpu/init_data_set_queue.py +27 -0
  629. mindspore/ops/_op_impl/aicpu/inplace_index_add.py +39 -0
  630. mindspore/ops/_op_impl/aicpu/instance_norm_v2.py +41 -0
  631. mindspore/ops/_op_impl/aicpu/instance_norm_v2_grad.py +44 -0
  632. mindspore/ops/_op_impl/aicpu/is_finite.py +40 -0
  633. mindspore/ops/_op_impl/aicpu/is_inf.py +31 -0
  634. mindspore/ops/_op_impl/aicpu/is_nan.py +31 -0
  635. mindspore/ops/_op_impl/aicpu/kldivloss.py +34 -0
  636. mindspore/ops/_op_impl/aicpu/kldivlossgrad.py +35 -0
  637. mindspore/ops/_op_impl/aicpu/layer_norm_grad_grad.py +47 -0
  638. mindspore/ops/_op_impl/aicpu/lcm.py +32 -0
  639. mindspore/ops/_op_impl/aicpu/left_shift.py +38 -0
  640. mindspore/ops/_op_impl/aicpu/less.py +41 -0
  641. mindspore/ops/_op_impl/aicpu/less_equal.py +41 -0
  642. mindspore/ops/_op_impl/aicpu/lgamma.py +33 -0
  643. mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +57 -0
  644. mindspore/ops/_op_impl/aicpu/linspace.py +33 -0
  645. mindspore/ops/_op_impl/aicpu/list_diff.py +50 -0
  646. mindspore/ops/_op_impl/aicpu/log.py +37 -0
  647. mindspore/ops/_op_impl/aicpu/log1p.py +34 -0
  648. mindspore/ops/_op_impl/aicpu/log_matrix_determinant.py +31 -0
  649. mindspore/ops/_op_impl/aicpu/log_normal_reverse.py +33 -0
  650. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +37 -0
  651. mindspore/ops/_op_impl/aicpu/logical_xor.py +30 -0
  652. mindspore/ops/_op_impl/aicpu/logit.py +33 -0
  653. mindspore/ops/_op_impl/aicpu/logit_grad.py +34 -0
  654. mindspore/ops/_op_impl/aicpu/logspace.py +36 -0
  655. mindspore/ops/_op_impl/aicpu/lower_bound.py +47 -0
  656. mindspore/ops/_op_impl/aicpu/lstsq.py +34 -0
  657. mindspore/ops/_op_impl/aicpu/lu.py +39 -0
  658. mindspore/ops/_op_impl/aicpu/lu_solve.py +32 -0
  659. mindspore/ops/_op_impl/aicpu/lu_unpack.py +114 -0
  660. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +49 -0
  661. mindspore/ops/_op_impl/aicpu/masked_fill.py +42 -0
  662. mindspore/ops/_op_impl/aicpu/masked_scatter.py +40 -0
  663. mindspore/ops/_op_impl/aicpu/masked_select.py +31 -0
  664. mindspore/ops/_op_impl/aicpu/masked_select_grad.py +35 -0
  665. mindspore/ops/_op_impl/aicpu/matmul.py +39 -0
  666. mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
  667. mindspore/ops/_op_impl/aicpu/matrix_determinant.py +30 -0
  668. mindspore/ops/_op_impl/aicpu/matrix_diag_part_v3.py +54 -0
  669. mindspore/ops/_op_impl/aicpu/matrix_diag_v3.py +56 -0
  670. mindspore/ops/_op_impl/aicpu/matrix_exp.py +34 -0
  671. mindspore/ops/_op_impl/aicpu/matrix_inverse.py +31 -0
  672. mindspore/ops/_op_impl/aicpu/matrix_logarithm.py +31 -0
  673. mindspore/ops/_op_impl/aicpu/matrix_power.py +37 -0
  674. mindspore/ops/_op_impl/aicpu/matrix_set_diag_v3.py +54 -0
  675. mindspore/ops/_op_impl/aicpu/matrix_solve.py +35 -0
  676. mindspore/ops/_op_impl/aicpu/matrix_solve_ls.py +36 -0
  677. mindspore/ops/_op_impl/aicpu/matrix_triangular_solve.py +36 -0
  678. mindspore/ops/_op_impl/aicpu/max_pool3d_grad_with_argmax.py +60 -0
  679. mindspore/ops/_op_impl/aicpu/max_pool3d_with_argmax.py +59 -0
  680. mindspore/ops/_op_impl/aicpu/max_unpool2d.py +57 -0
  681. mindspore/ops/_op_impl/aicpu/max_unpool2d_grad.py +58 -0
  682. mindspore/ops/_op_impl/aicpu/max_unpool3d.py +57 -0
  683. mindspore/ops/_op_impl/aicpu/max_unpool3d_grad.py +58 -0
  684. mindspore/ops/_op_impl/aicpu/maximum_grad_grad.py +40 -0
  685. mindspore/ops/_op_impl/aicpu/maxpool_grad_v1.py +46 -0
  686. mindspore/ops/_op_impl/aicpu/maxpool_v1.py +42 -0
  687. mindspore/ops/_op_impl/aicpu/median.py +39 -0
  688. mindspore/ops/_op_impl/aicpu/median_grad.py +45 -0
  689. mindspore/ops/_op_impl/aicpu/meshgrid.py +41 -0
  690. mindspore/ops/_op_impl/aicpu/minimum_grad_grad.py +40 -0
  691. mindspore/ops/_op_impl/aicpu/mirror_pad.py +50 -0
  692. mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +48 -0
  693. mindspore/ops/_op_impl/aicpu/mul.py +43 -0
  694. mindspore/ops/_op_impl/aicpu/mul_no_nan.py +42 -0
  695. mindspore/ops/_op_impl/aicpu/multi_margin_loss.py +37 -0
  696. mindspore/ops/_op_impl/aicpu/multi_margin_loss_grad.py +41 -0
  697. mindspore/ops/_op_impl/aicpu/multilabel_margin_loss_grad.py +37 -0
  698. mindspore/ops/_op_impl/aicpu/multinomial.py +47 -0
  699. mindspore/ops/_op_impl/aicpu/multinomial_with_replacement.py +35 -0
  700. mindspore/ops/_op_impl/aicpu/mvlgamma.py +32 -0
  701. mindspore/ops/_op_impl/aicpu/mvlgamma_grad.py +33 -0
  702. mindspore/ops/_op_impl/aicpu/nan_to_num.py +34 -0
  703. mindspore/ops/_op_impl/aicpu/neg.py +36 -0
  704. mindspore/ops/_op_impl/aicpu/nextafter.py +32 -0
  705. mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
  706. mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
  707. mindspore/ops/_op_impl/aicpu/no_repeat_ngram.py +34 -0
  708. mindspore/ops/_op_impl/aicpu/non_deterministic_ints.py +33 -0
  709. mindspore/ops/_op_impl/aicpu/non_max_suppression.py +36 -0
  710. mindspore/ops/_op_impl/aicpu/non_max_suppression_with_overlaps.py +35 -0
  711. mindspore/ops/_op_impl/aicpu/non_zero.py +43 -0
  712. mindspore/ops/_op_impl/aicpu/not_equal.py +39 -0
  713. mindspore/ops/_op_impl/aicpu/nth_element.py +39 -0
  714. mindspore/ops/_op_impl/aicpu/nuclear_norm.py +33 -0
  715. mindspore/ops/_op_impl/aicpu/one_hot.py +116 -0
  716. mindspore/ops/_op_impl/aicpu/ones_like.py +39 -0
  717. mindspore/ops/_op_impl/aicpu/orgqr.py +34 -0
  718. mindspore/ops/_op_impl/aicpu/pad_and_shift.py +33 -0
  719. mindspore/ops/_op_impl/aicpu/pad_v3.py +61 -0
  720. mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +59 -0
  721. mindspore/ops/_op_impl/aicpu/padding.py +41 -0
  722. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +54 -0
  723. mindspore/ops/_op_impl/aicpu/pdist_grad.py +33 -0
  724. mindspore/ops/_op_impl/aicpu/poisson.py +37 -0
  725. mindspore/ops/_op_impl/aicpu/polar.py +32 -0
  726. mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
  727. mindspore/ops/_op_impl/aicpu/pow.py +39 -0
  728. mindspore/ops/_op_impl/aicpu/print_tensor.py +39 -0
  729. mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +113 -0
  730. mindspore/ops/_op_impl/aicpu/qr.py +36 -0
  731. mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
  732. mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
  733. mindspore/ops/_op_impl/aicpu/ragged_range.py +49 -0
  734. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
  735. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_tensor.py +74 -0
  736. mindspore/ops/_op_impl/aicpu/random_categorical.py +68 -0
  737. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +36 -0
  738. mindspore/ops/_op_impl/aicpu/random_gamma.py +38 -0
  739. mindspore/ops/_op_impl/aicpu/random_poisson.py +134 -0
  740. mindspore/ops/_op_impl/aicpu/random_shuffle.py +47 -0
  741. mindspore/ops/_op_impl/aicpu/randperm.py +38 -0
  742. mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
  743. mindspore/ops/_op_impl/aicpu/range.py +36 -0
  744. mindspore/ops/_op_impl/aicpu/range_v2.py +35 -0
  745. mindspore/ops/_op_impl/aicpu/real.py +31 -0
  746. mindspore/ops/_op_impl/aicpu/real_div.py +40 -0
  747. mindspore/ops/_op_impl/aicpu/reciprocal.py +34 -0
  748. mindspore/ops/_op_impl/aicpu/reciprocal_grad.py +35 -0
  749. mindspore/ops/_op_impl/aicpu/reduce_mean.py +57 -0
  750. mindspore/ops/_op_impl/aicpu/reduce_prod.py +57 -0
  751. mindspore/ops/_op_impl/aicpu/reduce_sum.py +57 -0
  752. mindspore/ops/_op_impl/aicpu/relu_grad_v3.py +41 -0
  753. mindspore/ops/_op_impl/aicpu/relu_v3.py +38 -0
  754. mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +96 -0
  755. mindspore/ops/_op_impl/aicpu/reshape.py +42 -0
  756. mindspore/ops/_op_impl/aicpu/resize_area.py +40 -0
  757. mindspore/ops/_op_impl/aicpu/resize_bicubic.py +20 -0
  758. mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +19 -0
  759. mindspore/ops/_op_impl/aicpu/resize_bilinear.py +32 -0
  760. mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +32 -0
  761. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +36 -0
  762. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +35 -0
  763. mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
  764. mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
  765. mindspore/ops/_op_impl/aicpu/reverse_sequence.py +55 -0
  766. mindspore/ops/_op_impl/aicpu/reversev2.py +54 -0
  767. mindspore/ops/_op_impl/aicpu/rgb_to_hsv.py +32 -0
  768. mindspore/ops/_op_impl/aicpu/right_shift.py +38 -0
  769. mindspore/ops/_op_impl/aicpu/rnnt_loss.py +35 -0
  770. mindspore/ops/_op_impl/aicpu/round.py +34 -0
  771. mindspore/ops/_op_impl/aicpu/rsqrt.py +33 -0
  772. mindspore/ops/_op_impl/aicpu/rsqrt_grad.py +36 -0
  773. mindspore/ops/_op_impl/aicpu/sample_distorted_bounding_box_v2.py +49 -0
  774. mindspore/ops/_op_impl/aicpu/scale_and_translate.py +52 -0
  775. mindspore/ops/_op_impl/aicpu/scale_and_translate_grad.py +36 -0
  776. mindspore/ops/_op_impl/aicpu/scatter.py +79 -0
  777. mindspore/ops/_op_impl/aicpu/scatter_add_with_axis.py +53 -0
  778. mindspore/ops/_op_impl/aicpu/scatter_elements.py +39 -0
  779. mindspore/ops/_op_impl/aicpu/scatter_nd.py +59 -0
  780. mindspore/ops/_op_impl/aicpu/scatter_nd_max.py +54 -0
  781. mindspore/ops/_op_impl/aicpu/scatter_nd_min.py +54 -0
  782. mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +59 -0
  783. mindspore/ops/_op_impl/aicpu/search_sorted.py +44 -0
  784. mindspore/ops/_op_impl/aicpu/segment_max.py +52 -0
  785. mindspore/ops/_op_impl/aicpu/segment_mean.py +56 -0
  786. mindspore/ops/_op_impl/aicpu/segment_min.py +52 -0
  787. mindspore/ops/_op_impl/aicpu/segment_prod.py +56 -0
  788. mindspore/ops/_op_impl/aicpu/segment_sum.py +56 -0
  789. mindspore/ops/_op_impl/aicpu/select.py +45 -0
  790. mindspore/ops/_op_impl/aicpu/self_adjoint_eig.py +34 -0
  791. mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
  792. mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
  793. mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
  794. mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
  795. mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
  796. mindspore/ops/_op_impl/aicpu/set_size.py +38 -0
  797. mindspore/ops/_op_impl/aicpu/sign.py +36 -0
  798. mindspore/ops/_op_impl/aicpu/sin.py +34 -0
  799. mindspore/ops/_op_impl/aicpu/sinc.py +43 -0
  800. mindspore/ops/_op_impl/aicpu/sinh.py +34 -0
  801. mindspore/ops/_op_impl/aicpu/slice.py +59 -0
  802. mindspore/ops/_op_impl/aicpu/slice_grad.py +76 -0
  803. mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
  804. mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
  805. mindspore/ops/_op_impl/aicpu/sort.py +39 -0
  806. mindspore/ops/_op_impl/aicpu/space_to_depth.py +44 -0
  807. mindspore/ops/_op_impl/aicpu/sparse_addmm.py +87 -0
  808. mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +80 -0
  809. mindspore/ops/_op_impl/aicpu/sparse_apply_centered_rms_prop.py +105 -0
  810. mindspore/ops/_op_impl/aicpu/sparse_apply_momentum.py +80 -0
  811. mindspore/ops/_op_impl/aicpu/sparse_apply_proximal_gradient_descent.py +79 -0
  812. mindspore/ops/_op_impl/aicpu/sparse_concat.py +59 -0
  813. mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
  814. mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_add.py +58 -0
  815. mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_div.py +58 -0
  816. mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_mul.py +58 -0
  817. mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows.py +63 -0
  818. mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows_grad.py +45 -0
  819. mindspore/ops/_op_impl/aicpu/sparse_matrix_mat_mul.py +56 -0
  820. mindspore/ops/_op_impl/aicpu/sparse_matrix_nnz.py +81 -0
  821. mindspore/ops/_op_impl/aicpu/sparse_matrix_transpose.py +116 -0
  822. mindspore/ops/_op_impl/aicpu/sparse_reorder.py +56 -0
  823. mindspore/ops/_op_impl/aicpu/sparse_reshape.py +34 -0
  824. mindspore/ops/_op_impl/aicpu/sparse_segment_mean_grad.py +36 -0
  825. mindspore/ops/_op_impl/aicpu/sparse_segment_mean_with_num_segments.py +44 -0
  826. mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n.py +43 -0
  827. mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n_grad.py +38 -0
  828. mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n_with_num_segments.py +44 -0
  829. mindspore/ops/_op_impl/aicpu/sparse_segment_sum.py +49 -0
  830. mindspore/ops/_op_impl/aicpu/sparse_segment_sum_with_num_segments.py +68 -0
  831. mindspore/ops/_op_impl/aicpu/sparse_slice.py +63 -0
  832. mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +61 -0
  833. mindspore/ops/_op_impl/aicpu/sparse_softmax.py +33 -0
  834. mindspore/ops/_op_impl/aicpu/sparse_softmax_cross_entropy_with_logits_v2.py +35 -0
  835. mindspore/ops/_op_impl/aicpu/sparse_sparse_maximum.py +53 -0
  836. mindspore/ops/_op_impl/aicpu/sparse_sparse_minimum.py +53 -0
  837. mindspore/ops/_op_impl/aicpu/sparse_tensor_dense_add.py +84 -0
  838. mindspore/ops/_op_impl/aicpu/sparse_tensor_dense_mat_mul.py +190 -0
  839. mindspore/ops/_op_impl/aicpu/sparse_tensor_to_csr_sparse_matrix.py +51 -0
  840. mindspore/ops/_op_impl/aicpu/sparse_to_dense_v2.py +73 -0
  841. mindspore/ops/_op_impl/aicpu/split.py +45 -0
  842. mindspore/ops/_op_impl/aicpu/sqrt.py +34 -0
  843. mindspore/ops/_op_impl/aicpu/sqrt_grad.py +35 -0
  844. mindspore/ops/_op_impl/aicpu/square.py +35 -0
  845. mindspore/ops/_op_impl/aicpu/squared_difference.py +37 -0
  846. mindspore/ops/_op_impl/aicpu/squeeze.py +42 -0
  847. mindspore/ops/_op_impl/aicpu/sspaddmm.py +97 -0
  848. mindspore/ops/_op_impl/aicpu/stack.py +45 -0
  849. mindspore/ops/_op_impl/aicpu/stack_push_pop.py +87 -0
  850. mindspore/ops/_op_impl/aicpu/standard_laplace.py +34 -0
  851. mindspore/ops/_op_impl/aicpu/standard_normal.py +34 -0
  852. mindspore/ops/_op_impl/aicpu/stateless_dropout_genmask.py +37 -0
  853. mindspore/ops/_op_impl/aicpu/stft.py +70 -0
  854. mindspore/ops/_op_impl/aicpu/strided_slice.py +43 -0
  855. mindspore/ops/_op_impl/aicpu/strided_slice_grad.py +50 -0
  856. mindspore/ops/_op_impl/aicpu/sub.py +41 -0
  857. mindspore/ops/_op_impl/aicpu/sub_and_filter.py +36 -0
  858. mindspore/ops/_op_impl/aicpu/tan.py +34 -0
  859. mindspore/ops/_op_impl/aicpu/tanh.py +34 -0
  860. mindspore/ops/_op_impl/aicpu/tanh_grad.py +35 -0
  861. mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
  862. mindspore/ops/_op_impl/aicpu/tile.py +56 -0
  863. mindspore/ops/_op_impl/aicpu/topk.py +34 -0
  864. mindspore/ops/_op_impl/aicpu/trace.py +40 -0
  865. mindspore/ops/_op_impl/aicpu/tracegrad.py +41 -0
  866. mindspore/ops/_op_impl/aicpu/trans_data.py +35 -0
  867. mindspore/ops/_op_impl/aicpu/transpose.py +58 -0
  868. mindspore/ops/_op_impl/aicpu/tridiagonal_matmul.py +42 -0
  869. mindspore/ops/_op_impl/aicpu/tridiagonal_solve.py +35 -0
  870. mindspore/ops/_op_impl/aicpu/tril.py +42 -0
  871. mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
  872. mindspore/ops/_op_impl/aicpu/triplet_margin_loss.py +62 -0
  873. mindspore/ops/_op_impl/aicpu/triu.py +43 -0
  874. mindspore/ops/_op_impl/aicpu/triu_indices.py +34 -0
  875. mindspore/ops/_op_impl/aicpu/truncated_normal.py +39 -0
  876. mindspore/ops/_op_impl/aicpu/uniform.py +36 -0
  877. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +41 -0
  878. mindspore/ops/_op_impl/aicpu/uniform_int.py +36 -0
  879. mindspore/ops/_op_impl/aicpu/uniform_real.py +33 -0
  880. mindspore/ops/_op_impl/aicpu/unique.py +31 -0
  881. mindspore/ops/_op_impl/aicpu/unique_consecutive.py +47 -0
  882. mindspore/ops/_op_impl/aicpu/unique_with_pad.py +32 -0
  883. mindspore/ops/_op_impl/aicpu/unravel_index.py +32 -0
  884. mindspore/ops/_op_impl/aicpu/unsorted_segment_prod.py +53 -0
  885. mindspore/ops/_op_impl/aicpu/unsorted_segment_sum.py +57 -0
  886. mindspore/ops/_op_impl/aicpu/unstack.py +45 -0
  887. mindspore/ops/_op_impl/aicpu/update_cache.py +44 -0
  888. mindspore/ops/_op_impl/aicpu/upper_bound.py +47 -0
  889. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +42 -0
  890. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +49 -0
  891. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +40 -0
  892. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +50 -0
  893. mindspore/ops/_op_impl/aicpu/xdivy.py +35 -0
  894. mindspore/ops/_op_impl/aicpu/xlogy.py +33 -0
  895. mindspore/ops/_op_impl/aicpu/zeros_like.py +42 -0
  896. mindspore/ops/_op_impl/aicpu/zeta.py +31 -0
  897. mindspore/ops/_op_impl/akg/__init__.py +19 -0
  898. mindspore/ops/_op_impl/akg/ascend/__init__.py +48 -0
  899. mindspore/ops/_op_impl/akg/ascend/abs.py +35 -0
  900. mindspore/ops/_op_impl/akg/ascend/add.py +42 -0
  901. mindspore/ops/_op_impl/akg/ascend/add_n.py +37 -0
  902. mindspore/ops/_op_impl/akg/ascend/batchmatmul.py +33 -0
  903. mindspore/ops/_op_impl/akg/ascend/cast.py +46 -0
  904. mindspore/ops/_op_impl/akg/ascend/equal.py +35 -0
  905. mindspore/ops/_op_impl/akg/ascend/exp.py +35 -0
  906. mindspore/ops/_op_impl/akg/ascend/expand_dims.py +33 -0
  907. mindspore/ops/_op_impl/akg/ascend/greater.py +34 -0
  908. mindspore/ops/_op_impl/akg/ascend/greater_equal.py +35 -0
  909. mindspore/ops/_op_impl/akg/ascend/less.py +31 -0
  910. mindspore/ops/_op_impl/akg/ascend/less_equal.py +35 -0
  911. mindspore/ops/_op_impl/akg/ascend/load_im2col.py +33 -0
  912. mindspore/ops/_op_impl/akg/ascend/log.py +34 -0
  913. mindspore/ops/_op_impl/akg/ascend/maximum.py +36 -0
  914. mindspore/ops/_op_impl/akg/ascend/minimum.py +39 -0
  915. mindspore/ops/_op_impl/akg/ascend/mul.py +41 -0
  916. mindspore/ops/_op_impl/akg/ascend/neg.py +37 -0
  917. mindspore/ops/_op_impl/akg/ascend/pow.py +35 -0
  918. mindspore/ops/_op_impl/akg/ascend/prod_force_se_a.py +33 -0
  919. mindspore/ops/_op_impl/akg/ascend/real_div.py +36 -0
  920. mindspore/ops/_op_impl/akg/ascend/reciprocal.py +32 -0
  921. mindspore/ops/_op_impl/akg/ascend/reduce_max.py +32 -0
  922. mindspore/ops/_op_impl/akg/ascend/reduce_min.py +32 -0
  923. mindspore/ops/_op_impl/akg/ascend/reduce_sum.py +37 -0
  924. mindspore/ops/_op_impl/akg/ascend/rsqrt.py +35 -0
  925. mindspore/ops/_op_impl/akg/ascend/select.py +37 -0
  926. mindspore/ops/_op_impl/akg/ascend/sqrt.py +35 -0
  927. mindspore/ops/_op_impl/akg/ascend/square.py +35 -0
  928. mindspore/ops/_op_impl/akg/ascend/sub.py +42 -0
  929. mindspore/ops/_op_impl/akg/cpu/__init__.py +23 -0
  930. mindspore/ops/_op_impl/akg/cpu/coo2csr.py +29 -0
  931. mindspore/ops/_op_impl/akg/cpu/csr2coo.py +29 -0
  932. mindspore/ops/_op_impl/akg/cpu/csr_gather.py +33 -0
  933. mindspore/ops/_op_impl/akg/cpu/csr_mm.py +34 -0
  934. mindspore/ops/_op_impl/akg/cpu/csr_mul.py +33 -0
  935. mindspore/ops/_op_impl/akg/cpu/csr_mv.py +33 -0
  936. mindspore/ops/_op_impl/akg/cpu/csr_reduce_sum.py +31 -0
  937. mindspore/ops/_op_impl/akg/gpu/__init__.py +24 -0
  938. mindspore/ops/_op_impl/akg/gpu/coo2csr.py +29 -0
  939. mindspore/ops/_op_impl/akg/gpu/csr2coo.py +29 -0
  940. mindspore/ops/_op_impl/akg/gpu/csr_div.py +36 -0
  941. mindspore/ops/_op_impl/akg/gpu/csr_gather.py +33 -0
  942. mindspore/ops/_op_impl/akg/gpu/csr_mm.py +37 -0
  943. mindspore/ops/_op_impl/akg/gpu/csr_mul.py +36 -0
  944. mindspore/ops/_op_impl/akg/gpu/csr_mv.py +36 -0
  945. mindspore/ops/_op_impl/akg/gpu/csr_reduce_sum.py +33 -0
  946. mindspore/ops/_op_impl/cpu/__init__.py +78 -0
  947. mindspore/ops/_op_impl/cpu/adam.py +49 -0
  948. mindspore/ops/_op_impl/cpu/adam_weight_decay.py +47 -0
  949. mindspore/ops/_op_impl/cpu/arg_max.py +30 -0
  950. mindspore/ops/_op_impl/cpu/arg_max_with_value.py +31 -0
  951. mindspore/ops/_op_impl/cpu/arg_min_with_value.py +31 -0
  952. mindspore/ops/_op_impl/cpu/buffer_append.py +28 -0
  953. mindspore/ops/_op_impl/cpu/buffer_get.py +28 -0
  954. mindspore/ops/_op_impl/cpu/buffer_sample.py +28 -0
  955. mindspore/ops/_op_impl/cpu/cast.py +171 -0
  956. mindspore/ops/_op_impl/cpu/concat_offset.py +38 -0
  957. mindspore/ops/_op_impl/cpu/conv2d.py +30 -0
  958. mindspore/ops/_op_impl/cpu/conv3d.py +30 -0
  959. mindspore/ops/_op_impl/cpu/div.py +32 -0
  960. mindspore/ops/_op_impl/cpu/dropout.py +31 -0
  961. mindspore/ops/_op_impl/cpu/dropout_grad.py +30 -0
  962. mindspore/ops/_op_impl/cpu/dynamic_shape.py +42 -0
  963. mindspore/ops/_op_impl/cpu/dynamic_stitch.py +41 -0
  964. mindspore/ops/_op_impl/cpu/equal_count.py +30 -0
  965. mindspore/ops/_op_impl/cpu/gather_d.py +49 -0
  966. mindspore/ops/_op_impl/cpu/gather_d_grad.py +38 -0
  967. mindspore/ops/_op_impl/cpu/gather_d_grad_v2.py +40 -0
  968. mindspore/ops/_op_impl/cpu/gather_v2.py +40 -0
  969. mindspore/ops/_op_impl/cpu/hsigmoid.py +33 -0
  970. mindspore/ops/_op_impl/cpu/hsigmoid_grad.py +34 -0
  971. mindspore/ops/_op_impl/cpu/hswish.py +32 -0
  972. mindspore/ops/_op_impl/cpu/hswish_grad.py +33 -0
  973. mindspore/ops/_op_impl/cpu/identity_n.py +40 -0
  974. mindspore/ops/_op_impl/cpu/is_finite.py +39 -0
  975. mindspore/ops/_op_impl/cpu/l2loss.py +30 -0
  976. mindspore/ops/_op_impl/cpu/layer_norm.py +36 -0
  977. mindspore/ops/_op_impl/cpu/layer_norm_grad.py +38 -0
  978. mindspore/ops/_op_impl/cpu/maximum.py +35 -0
  979. mindspore/ops/_op_impl/cpu/maximum_grad.py +47 -0
  980. mindspore/ops/_op_impl/cpu/minimum.py +40 -0
  981. mindspore/ops/_op_impl/cpu/minimum_grad.py +51 -0
  982. mindspore/ops/_op_impl/cpu/mirror_pad.py +36 -0
  983. mindspore/ops/_op_impl/cpu/mirror_pad_grad.py +36 -0
  984. mindspore/ops/_op_impl/cpu/mul.py +32 -0
  985. mindspore/ops/_op_impl/cpu/one_hot.py +31 -0
  986. mindspore/ops/_op_impl/cpu/pad.py +32 -0
  987. mindspore/ops/_op_impl/cpu/pow.py +32 -0
  988. mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +42 -0
  989. mindspore/ops/_op_impl/cpu/pyexecute.py +29 -0
  990. mindspore/ops/_op_impl/cpu/pyfunc.py +29 -0
  991. mindspore/ops/_op_impl/cpu/range.py +34 -0
  992. mindspore/ops/_op_impl/cpu/real_div.py +33 -0
  993. mindspore/ops/_op_impl/cpu/reduce_all.py +29 -0
  994. mindspore/ops/_op_impl/cpu/reduce_any.py +29 -0
  995. mindspore/ops/_op_impl/cpu/reduce_max.py +32 -0
  996. mindspore/ops/_op_impl/cpu/reduce_mean.py +40 -0
  997. mindspore/ops/_op_impl/cpu/reduce_min.py +32 -0
  998. mindspore/ops/_op_impl/cpu/reduce_prod.py +40 -0
  999. mindspore/ops/_op_impl/cpu/reduce_std.py +31 -0
  1000. mindspore/ops/_op_impl/cpu/reduce_sum.py +41 -0
  1001. mindspore/ops/_op_impl/cpu/space_to_batch_nd.py +38 -0
  1002. mindspore/ops/_op_impl/cpu/sparse_slice.py +62 -0
  1003. mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +60 -0
  1004. mindspore/ops/_op_impl/cpu/split.py +34 -0
  1005. mindspore/ops/_op_impl/cpu/sspaddmm.py +95 -0
  1006. mindspore/ops/_op_impl/cpu/stack.py +38 -0
  1007. mindspore/ops/_op_impl/cpu/sub.py +32 -0
  1008. mindspore/ops/_op_impl/cpu/tensor_copy_slices.py +41 -0
  1009. mindspore/ops/_op_impl/cpu/tile.py +37 -0
  1010. mindspore/ops/_op_impl/cpu/top_k.py +31 -0
  1011. mindspore/ops/_op_impl/cpu/transpose.py +39 -0
  1012. mindspore/ops/_primitive_cache.py +90 -0
  1013. mindspore/ops/_register_for_op.py +73 -0
  1014. mindspore/ops/_utils/__init__.py +20 -0
  1015. mindspore/ops/_utils/utils.py +147 -0
  1016. mindspore/ops/_vmap/__init__.py +25 -0
  1017. mindspore/ops/_vmap/vmap_array_ops.py +2149 -0
  1018. mindspore/ops/_vmap/vmap_base.py +533 -0
  1019. mindspore/ops/_vmap/vmap_convolution_ops.py +441 -0
  1020. mindspore/ops/_vmap/vmap_debug_ops.py +50 -0
  1021. mindspore/ops/_vmap/vmap_grad_math_ops.py +274 -0
  1022. mindspore/ops/_vmap/vmap_grad_nn_ops.py +806 -0
  1023. mindspore/ops/_vmap/vmap_image_ops.py +194 -0
  1024. mindspore/ops/_vmap/vmap_math_ops.py +993 -0
  1025. mindspore/ops/_vmap/vmap_nn_ops.py +2250 -0
  1026. mindspore/ops/_vmap/vmap_other_ops.py +105 -0
  1027. mindspore/ops/_vmap/vmap_random_ops.py +122 -0
  1028. mindspore/ops/_vmap/vmap_sparse_ops.py +89 -0
  1029. mindspore/ops/auto_generate/__init__.py +31 -0
  1030. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +309 -0
  1031. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +252 -0
  1032. mindspore/ops/auto_generate/gen_arg_handler.py +197 -0
  1033. mindspore/ops/auto_generate/gen_extend_func.py +1701 -0
  1034. mindspore/ops/auto_generate/gen_ops_def.py +8482 -0
  1035. mindspore/ops/auto_generate/gen_ops_prim.py +16704 -0
  1036. mindspore/ops/auto_generate/pyboost_inner_prim.py +549 -0
  1037. mindspore/ops/composite/__init__.py +71 -0
  1038. mindspore/ops/composite/base.py +1318 -0
  1039. mindspore/ops/composite/env_ops.py +41 -0
  1040. mindspore/ops/composite/math_ops.py +125 -0
  1041. mindspore/ops/composite/multitype_ops/__init__.py +77 -0
  1042. mindspore/ops/composite/multitype_ops/_compile_utils.py +1459 -0
  1043. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +897 -0
  1044. mindspore/ops/composite/multitype_ops/add_impl.py +606 -0
  1045. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +56 -0
  1046. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +56 -0
  1047. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +56 -0
  1048. mindspore/ops/composite/multitype_ops/div_impl.py +189 -0
  1049. mindspore/ops/composite/multitype_ops/equal_impl.py +335 -0
  1050. mindspore/ops/composite/multitype_ops/floordiv_impl.py +88 -0
  1051. mindspore/ops/composite/multitype_ops/getitem_impl.py +400 -0
  1052. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +109 -0
  1053. mindspore/ops/composite/multitype_ops/greater_impl.py +110 -0
  1054. mindspore/ops/composite/multitype_ops/in_impl.py +196 -0
  1055. mindspore/ops/composite/multitype_ops/left_shift_impl.py +37 -0
  1056. mindspore/ops/composite/multitype_ops/less_equal_impl.py +111 -0
  1057. mindspore/ops/composite/multitype_ops/less_impl.py +112 -0
  1058. mindspore/ops/composite/multitype_ops/logic_not_impl.py +113 -0
  1059. mindspore/ops/composite/multitype_ops/logical_and_impl.py +60 -0
  1060. mindspore/ops/composite/multitype_ops/logical_or_impl.py +61 -0
  1061. mindspore/ops/composite/multitype_ops/mod_impl.py +86 -0
  1062. mindspore/ops/composite/multitype_ops/mul_impl.py +294 -0
  1063. mindspore/ops/composite/multitype_ops/negative_impl.py +79 -0
  1064. mindspore/ops/composite/multitype_ops/not_equal_impl.py +290 -0
  1065. mindspore/ops/composite/multitype_ops/not_in_impl.py +196 -0
  1066. mindspore/ops/composite/multitype_ops/ones_like_impl.py +96 -0
  1067. mindspore/ops/composite/multitype_ops/pow_impl.py +87 -0
  1068. mindspore/ops/composite/multitype_ops/right_shift_impl.py +37 -0
  1069. mindspore/ops/composite/multitype_ops/setitem_impl.py +884 -0
  1070. mindspore/ops/composite/multitype_ops/sub_impl.py +116 -0
  1071. mindspore/ops/composite/multitype_ops/uadd_impl.py +29 -0
  1072. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +228 -0
  1073. mindspore/ops/deprecated.py +315 -0
  1074. mindspore/ops/function/__init__.py +782 -0
  1075. mindspore/ops/function/array_func.py +7226 -0
  1076. mindspore/ops/function/clip_func.py +384 -0
  1077. mindspore/ops/function/debug_func.py +181 -0
  1078. mindspore/ops/function/fft_func.py +44 -0
  1079. mindspore/ops/function/grad/__init__.py +34 -0
  1080. mindspore/ops/function/grad/grad_func.py +1425 -0
  1081. mindspore/ops/function/image_func.py +292 -0
  1082. mindspore/ops/function/linalg_func.py +416 -0
  1083. mindspore/ops/function/math_func.py +12228 -0
  1084. mindspore/ops/function/nn_func.py +8609 -0
  1085. mindspore/ops/function/other_func.py +115 -0
  1086. mindspore/ops/function/parameter_func.py +134 -0
  1087. mindspore/ops/function/random_func.py +1715 -0
  1088. mindspore/ops/function/reshard_func.py +104 -0
  1089. mindspore/ops/function/sparse_func.py +884 -0
  1090. mindspore/ops/function/sparse_unary_func.py +2422 -0
  1091. mindspore/ops/function/spectral_func.py +150 -0
  1092. mindspore/ops/function/vmap_func.py +117 -0
  1093. mindspore/ops/functional.py +464 -0
  1094. mindspore/ops/op_info_register.py +1572 -0
  1095. mindspore/ops/operations/__init__.py +722 -0
  1096. mindspore/ops/operations/_csr_ops.py +403 -0
  1097. mindspore/ops/operations/_custom_grad.py +181 -0
  1098. mindspore/ops/operations/_embedding_cache_ops.py +307 -0
  1099. mindspore/ops/operations/_grad_ops.py +2978 -0
  1100. mindspore/ops/operations/_infer_ops.py +19 -0
  1101. mindspore/ops/operations/_inner_ops.py +2544 -0
  1102. mindspore/ops/operations/_map_tensor_ops.py +112 -0
  1103. mindspore/ops/operations/_ms_kernel.py +601 -0
  1104. mindspore/ops/operations/_ocr_ops.py +379 -0
  1105. mindspore/ops/operations/_opaque_predicate_registry.py +41 -0
  1106. mindspore/ops/operations/_pyfunc_registry.py +58 -0
  1107. mindspore/ops/operations/_quant_ops.py +1844 -0
  1108. mindspore/ops/operations/_rl_inner_ops.py +1231 -0
  1109. mindspore/ops/operations/_scalar_ops.py +106 -0
  1110. mindspore/ops/operations/_sequence_ops.py +1155 -0
  1111. mindspore/ops/operations/_sparse_grad_ops.py +56 -0
  1112. mindspore/ops/operations/_tensor_array.py +359 -0
  1113. mindspore/ops/operations/_thor_ops.py +807 -0
  1114. mindspore/ops/operations/array_ops.py +6124 -0
  1115. mindspore/ops/operations/comm_ops.py +1985 -0
  1116. mindspore/ops/operations/control_ops.py +127 -0
  1117. mindspore/ops/operations/custom_ops.py +1129 -0
  1118. mindspore/ops/operations/debug_ops.py +678 -0
  1119. mindspore/ops/operations/image_ops.py +1041 -0
  1120. mindspore/ops/operations/inner_ops.py +697 -0
  1121. mindspore/ops/operations/linalg_ops.py +95 -0
  1122. mindspore/ops/operations/manually_defined/__init__.py +24 -0
  1123. mindspore/ops/operations/manually_defined/_inner.py +73 -0
  1124. mindspore/ops/operations/manually_defined/ops_def.py +2271 -0
  1125. mindspore/ops/operations/math_ops.py +5095 -0
  1126. mindspore/ops/operations/nn_ops.py +9575 -0
  1127. mindspore/ops/operations/other_ops.py +874 -0
  1128. mindspore/ops/operations/random_ops.py +1288 -0
  1129. mindspore/ops/operations/reshard_ops.py +53 -0
  1130. mindspore/ops/operations/rl_ops.py +288 -0
  1131. mindspore/ops/operations/sparse_ops.py +2753 -0
  1132. mindspore/ops/operations/spectral_ops.py +111 -0
  1133. mindspore/ops/primitive.py +1046 -0
  1134. mindspore/ops/signature.py +54 -0
  1135. mindspore/ops/vm_impl_registry.py +91 -0
  1136. mindspore/ops_generate/__init__.py +27 -0
  1137. mindspore/ops_generate/arg_dtype_cast.py +252 -0
  1138. mindspore/ops_generate/arg_handler.py +197 -0
  1139. mindspore/ops_generate/gen_aclnn_implement.py +263 -0
  1140. mindspore/ops_generate/gen_constants.py +36 -0
  1141. mindspore/ops_generate/gen_ops.py +1099 -0
  1142. mindspore/ops_generate/gen_ops_inner_prim.py +131 -0
  1143. mindspore/ops_generate/gen_pyboost_func.py +1052 -0
  1144. mindspore/ops_generate/gen_utils.py +209 -0
  1145. mindspore/ops_generate/op_proto.py +145 -0
  1146. mindspore/ops_generate/pyboost_utils.py +367 -0
  1147. mindspore/ops_generate/template.py +261 -0
  1148. mindspore/parallel/__init__.py +30 -0
  1149. mindspore/parallel/_auto_parallel_context.py +1486 -0
  1150. mindspore/parallel/_cell_wrapper.py +174 -0
  1151. mindspore/parallel/_cost_model_context.py +700 -0
  1152. mindspore/parallel/_dp_allreduce_fusion.py +159 -0
  1153. mindspore/parallel/_offload_context.py +275 -0
  1154. mindspore/parallel/_parallel_serialization.py +561 -0
  1155. mindspore/parallel/_ps_context.py +242 -0
  1156. mindspore/parallel/_recovery_context.py +110 -0
  1157. mindspore/parallel/_tensor.py +730 -0
  1158. mindspore/parallel/_transformer/__init__.py +35 -0
  1159. mindspore/parallel/_transformer/layers.py +765 -0
  1160. mindspore/parallel/_transformer/loss.py +251 -0
  1161. mindspore/parallel/_transformer/moe.py +693 -0
  1162. mindspore/parallel/_transformer/op_parallel_config.py +222 -0
  1163. mindspore/parallel/_transformer/transformer.py +3119 -0
  1164. mindspore/parallel/_utils.py +612 -0
  1165. mindspore/parallel/algo_parameter_config.py +400 -0
  1166. mindspore/parallel/checkpoint_transform.py +650 -0
  1167. mindspore/parallel/cluster/__init__.py +15 -0
  1168. mindspore/parallel/cluster/process_entity/__init__.py +18 -0
  1169. mindspore/parallel/cluster/process_entity/_api.py +352 -0
  1170. mindspore/parallel/cluster/process_entity/_utils.py +101 -0
  1171. mindspore/parallel/cluster/run.py +136 -0
  1172. mindspore/parallel/mpi/__init__.py +14 -0
  1173. mindspore/parallel/mpi/_mpi_config.py +116 -0
  1174. mindspore/parallel/parameter_broadcast.py +151 -0
  1175. mindspore/parallel/shard.py +481 -0
  1176. mindspore/parallel/transform_safetensors.py +993 -0
  1177. mindspore/profiler/__init__.py +28 -0
  1178. mindspore/profiler/common/__init__.py +14 -0
  1179. mindspore/profiler/common/constant.py +29 -0
  1180. mindspore/profiler/common/exceptions/__init__.py +14 -0
  1181. mindspore/profiler/common/exceptions/error_code.py +83 -0
  1182. mindspore/profiler/common/exceptions/exceptions.py +286 -0
  1183. mindspore/profiler/common/process_pool.py +41 -0
  1184. mindspore/profiler/common/registry.py +47 -0
  1185. mindspore/profiler/common/singleton.py +28 -0
  1186. mindspore/profiler/common/struct_type.py +118 -0
  1187. mindspore/profiler/common/util.py +472 -0
  1188. mindspore/profiler/common/validator/__init__.py +14 -0
  1189. mindspore/profiler/common/validator/validate_path.py +84 -0
  1190. mindspore/profiler/dynamic_profiler.py +694 -0
  1191. mindspore/profiler/envprofiling.py +254 -0
  1192. mindspore/profiler/parser/__init__.py +14 -0
  1193. mindspore/profiler/parser/aicpu_data_parser.py +272 -0
  1194. mindspore/profiler/parser/ascend_analysis/__init__.py +14 -0
  1195. mindspore/profiler/parser/ascend_analysis/constant.py +71 -0
  1196. mindspore/profiler/parser/ascend_analysis/file_manager.py +180 -0
  1197. mindspore/profiler/parser/ascend_analysis/function_event.py +185 -0
  1198. mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +136 -0
  1199. mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +131 -0
  1200. mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +104 -0
  1201. mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
  1202. mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +123 -0
  1203. mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +86 -0
  1204. mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +75 -0
  1205. mindspore/profiler/parser/ascend_cluster_generator.py +116 -0
  1206. mindspore/profiler/parser/ascend_communicate_generator.py +314 -0
  1207. mindspore/profiler/parser/ascend_flops_generator.py +116 -0
  1208. mindspore/profiler/parser/ascend_fpbp_generator.py +82 -0
  1209. mindspore/profiler/parser/ascend_hccl_generator.py +271 -0
  1210. mindspore/profiler/parser/ascend_integrate_generator.py +42 -0
  1211. mindspore/profiler/parser/ascend_memory_generator.py +185 -0
  1212. mindspore/profiler/parser/ascend_msprof_exporter.py +282 -0
  1213. mindspore/profiler/parser/ascend_msprof_generator.py +187 -0
  1214. mindspore/profiler/parser/ascend_op_generator.py +334 -0
  1215. mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
  1216. mindspore/profiler/parser/ascend_timeline_generator.py +545 -0
  1217. mindspore/profiler/parser/base_timeline_generator.py +483 -0
  1218. mindspore/profiler/parser/container.py +229 -0
  1219. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +697 -0
  1220. mindspore/profiler/parser/flops_parser.py +531 -0
  1221. mindspore/profiler/parser/framework_enum.py +111 -0
  1222. mindspore/profiler/parser/framework_parser.py +464 -0
  1223. mindspore/profiler/parser/framework_struct.py +61 -0
  1224. mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
  1225. mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
  1226. mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
  1227. mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
  1228. mindspore/profiler/parser/hccl_parser.py +573 -0
  1229. mindspore/profiler/parser/hwts_log_parser.py +122 -0
  1230. mindspore/profiler/parser/integrator.py +526 -0
  1231. mindspore/profiler/parser/memory_usage_parser.py +277 -0
  1232. mindspore/profiler/parser/minddata_analyzer.py +800 -0
  1233. mindspore/profiler/parser/minddata_parser.py +186 -0
  1234. mindspore/profiler/parser/minddata_pipeline_parser.py +299 -0
  1235. mindspore/profiler/parser/op_intermediate_parser.py +149 -0
  1236. mindspore/profiler/parser/optime_parser.py +250 -0
  1237. mindspore/profiler/parser/profiler_info.py +213 -0
  1238. mindspore/profiler/parser/step_trace_parser.py +666 -0
  1239. mindspore/profiler/profiler.py +153 -0
  1240. mindspore/profiler/profiling.py +1922 -0
  1241. mindspore/rewrite/__init__.py +28 -0
  1242. mindspore/rewrite/api/__init__.py +17 -0
  1243. mindspore/rewrite/api/node.py +519 -0
  1244. mindspore/rewrite/api/node_type.py +53 -0
  1245. mindspore/rewrite/api/pattern_engine.py +490 -0
  1246. mindspore/rewrite/api/scoped_value.py +181 -0
  1247. mindspore/rewrite/api/symbol_tree.py +497 -0
  1248. mindspore/rewrite/ast_helpers/__init__.py +25 -0
  1249. mindspore/rewrite/ast_helpers/ast_converter.py +143 -0
  1250. mindspore/rewrite/ast_helpers/ast_finder.py +404 -0
  1251. mindspore/rewrite/ast_helpers/ast_flattener.py +268 -0
  1252. mindspore/rewrite/ast_helpers/ast_modifier.py +605 -0
  1253. mindspore/rewrite/ast_helpers/ast_replacer.py +79 -0
  1254. mindspore/rewrite/common/__init__.py +19 -0
  1255. mindspore/rewrite/common/config.py +24 -0
  1256. mindspore/rewrite/common/error_log.py +39 -0
  1257. mindspore/rewrite/common/event.py +28 -0
  1258. mindspore/rewrite/common/namer.py +271 -0
  1259. mindspore/rewrite/common/namespace.py +118 -0
  1260. mindspore/rewrite/common/observable.py +44 -0
  1261. mindspore/rewrite/common/observer.py +54 -0
  1262. mindspore/rewrite/node/__init__.py +22 -0
  1263. mindspore/rewrite/node/call_function.py +95 -0
  1264. mindspore/rewrite/node/cell_container.py +139 -0
  1265. mindspore/rewrite/node/control_flow.py +113 -0
  1266. mindspore/rewrite/node/node.py +1428 -0
  1267. mindspore/rewrite/node/node_manager.py +283 -0
  1268. mindspore/rewrite/node/node_topological_manager.py +223 -0
  1269. mindspore/rewrite/parsers/__init__.py +29 -0
  1270. mindspore/rewrite/parsers/arguments_parser.py +63 -0
  1271. mindspore/rewrite/parsers/assign_parser.py +852 -0
  1272. mindspore/rewrite/parsers/attribute_parser.py +57 -0
  1273. mindspore/rewrite/parsers/class_def_parser.py +289 -0
  1274. mindspore/rewrite/parsers/constant_parser.py +104 -0
  1275. mindspore/rewrite/parsers/container_parser.py +88 -0
  1276. mindspore/rewrite/parsers/expr_parser.py +55 -0
  1277. mindspore/rewrite/parsers/for_parser.py +61 -0
  1278. mindspore/rewrite/parsers/function_def_parser.py +84 -0
  1279. mindspore/rewrite/parsers/if_parser.py +85 -0
  1280. mindspore/rewrite/parsers/module_parser.py +117 -0
  1281. mindspore/rewrite/parsers/parser.py +43 -0
  1282. mindspore/rewrite/parsers/parser_register.py +86 -0
  1283. mindspore/rewrite/parsers/return_parser.py +37 -0
  1284. mindspore/rewrite/parsers/while_parser.py +59 -0
  1285. mindspore/rewrite/sparsify/__init__.py +0 -0
  1286. mindspore/rewrite/sparsify/sparse_transformer.py +457 -0
  1287. mindspore/rewrite/sparsify/sparsify.py +112 -0
  1288. mindspore/rewrite/sparsify/utils.py +179 -0
  1289. mindspore/rewrite/symbol_tree/__init__.py +20 -0
  1290. mindspore/rewrite/symbol_tree/symbol_tree.py +1819 -0
  1291. mindspore/rewrite/symbol_tree/symbol_tree_builder.py +76 -0
  1292. mindspore/rewrite/symbol_tree/symbol_tree_dumper.py +142 -0
  1293. mindspore/run_check/__init__.py +20 -0
  1294. mindspore/run_check/_check_version.py +507 -0
  1295. mindspore/run_check/run_check.py +66 -0
  1296. mindspore/safeguard/__init__.py +18 -0
  1297. mindspore/safeguard/rewrite_obfuscation.py +875 -0
  1298. mindspore/scipy/__init__.py +18 -0
  1299. mindspore/scipy/fft.py +264 -0
  1300. mindspore/scipy/linalg.py +919 -0
  1301. mindspore/scipy/ops.py +165 -0
  1302. mindspore/scipy/ops_grad.py +115 -0
  1303. mindspore/scipy/ops_wrapper.py +74 -0
  1304. mindspore/scipy/optimize/__init__.py +20 -0
  1305. mindspore/scipy/optimize/_bfgs.py +230 -0
  1306. mindspore/scipy/optimize/_lagrange.py +201 -0
  1307. mindspore/scipy/optimize/_lbfgs.py +146 -0
  1308. mindspore/scipy/optimize/gradient_optimization_algorithm.py +168 -0
  1309. mindspore/scipy/optimize/line_search.py +370 -0
  1310. mindspore/scipy/optimize/linear_sum_assignment.py +78 -0
  1311. mindspore/scipy/optimize/minimize.py +200 -0
  1312. mindspore/scipy/utils.py +156 -0
  1313. mindspore/scipy/utils_const.py +246 -0
  1314. mindspore/train/__init__.py +48 -0
  1315. mindspore/train/_utils.py +465 -0
  1316. mindspore/train/amp.py +935 -0
  1317. mindspore/train/anf_ir_pb2.py +1517 -0
  1318. mindspore/train/callback/__init__.py +44 -0
  1319. mindspore/train/callback/_backup_and_restore.py +117 -0
  1320. mindspore/train/callback/_callback.py +613 -0
  1321. mindspore/train/callback/_checkpoint.py +814 -0
  1322. mindspore/train/callback/_cluster_monitor.py +201 -0
  1323. mindspore/train/callback/_dataset_graph.py +150 -0
  1324. mindspore/train/callback/_early_stop.py +239 -0
  1325. mindspore/train/callback/_flops_collector.py +239 -0
  1326. mindspore/train/callback/_history.py +92 -0
  1327. mindspore/train/callback/_lambda_callback.py +80 -0
  1328. mindspore/train/callback/_landscape.py +1049 -0
  1329. mindspore/train/callback/_loss_monitor.py +107 -0
  1330. mindspore/train/callback/_lr_scheduler_callback.py +76 -0
  1331. mindspore/train/callback/_on_request_exit.py +298 -0
  1332. mindspore/train/callback/_reduce_lr_on_plateau.py +226 -0
  1333. mindspore/train/callback/_summary_collector.py +1184 -0
  1334. mindspore/train/callback/_tft_register.py +352 -0
  1335. mindspore/train/callback/_time_monitor.py +141 -0
  1336. mindspore/train/checkpoint_pb2.py +233 -0
  1337. mindspore/train/data_sink.py +219 -0
  1338. mindspore/train/dataset_helper.py +692 -0
  1339. mindspore/train/lineage_pb2.py +1260 -0
  1340. mindspore/train/loss_scale_manager.py +213 -0
  1341. mindspore/train/memory_profiling_pb2.py +298 -0
  1342. mindspore/train/metrics/__init__.py +175 -0
  1343. mindspore/train/metrics/accuracy.py +133 -0
  1344. mindspore/train/metrics/auc.py +129 -0
  1345. mindspore/train/metrics/bleu_score.py +170 -0
  1346. mindspore/train/metrics/confusion_matrix.py +700 -0
  1347. mindspore/train/metrics/cosine_similarity.py +109 -0
  1348. mindspore/train/metrics/dice.py +116 -0
  1349. mindspore/train/metrics/error.py +175 -0
  1350. mindspore/train/metrics/fbeta.py +167 -0
  1351. mindspore/train/metrics/hausdorff_distance.py +333 -0
  1352. mindspore/train/metrics/loss.py +97 -0
  1353. mindspore/train/metrics/mean_surface_distance.py +189 -0
  1354. mindspore/train/metrics/metric.py +373 -0
  1355. mindspore/train/metrics/occlusion_sensitivity.py +225 -0
  1356. mindspore/train/metrics/perplexity.py +133 -0
  1357. mindspore/train/metrics/precision.py +160 -0
  1358. mindspore/train/metrics/recall.py +159 -0
  1359. mindspore/train/metrics/roc.py +223 -0
  1360. mindspore/train/metrics/root_mean_square_surface_distance.py +191 -0
  1361. mindspore/train/metrics/topk.py +167 -0
  1362. mindspore/train/mind_ir_pb2.py +1908 -0
  1363. mindspore/train/model.py +2252 -0
  1364. mindspore/train/node_strategy_pb2.py +653 -0
  1365. mindspore/train/print_pb2.py +184 -0
  1366. mindspore/train/profiling_parallel_pb2.py +151 -0
  1367. mindspore/train/serialization.py +3325 -0
  1368. mindspore/train/summary/__init__.py +23 -0
  1369. mindspore/train/summary/_lineage_adapter.py +41 -0
  1370. mindspore/train/summary/_summary_adapter.py +496 -0
  1371. mindspore/train/summary/_writer_pool.py +207 -0
  1372. mindspore/train/summary/enums.py +56 -0
  1373. mindspore/train/summary/summary_record.py +581 -0
  1374. mindspore/train/summary/writer.py +167 -0
  1375. mindspore/train/summary_pb2.py +1165 -0
  1376. mindspore/train/train_thor/__init__.py +20 -0
  1377. mindspore/train/train_thor/convert_utils.py +268 -0
  1378. mindspore/train/train_thor/dataset_helper.py +192 -0
  1379. mindspore/train/train_thor/model_thor.py +257 -0
  1380. mindspore/utils/__init__.py +21 -0
  1381. mindspore/utils/utils.py +60 -0
  1382. mindspore/version.py +1 -0
  1383. mindspore-2.4.0.dist-info/METADATA +352 -0
  1384. mindspore-2.4.0.dist-info/RECORD +1387 -0
  1385. mindspore-2.4.0.dist-info/WHEEL +5 -0
  1386. mindspore-2.4.0.dist-info/entry_points.txt +3 -0
  1387. mindspore-2.4.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1486 @@
1
+ # Copyright 2020-2023 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+ """Context of auto parallel"""
16
+ from __future__ import absolute_import
17
+ import os
18
+ import copy
19
+ import threading
20
+ from mindspore import context
21
+ import mindspore.log as logger
22
+ from mindspore.parallel._dp_allreduce_fusion import _set_fusion_strategy_by_idx, _set_fusion_strategy_by_size
23
+ from mindspore.parallel._ps_context import _is_role_pserver
24
+ from mindspore._c_expression import AutoParallelContext
25
+ from mindspore._checkparam import args_type_check
26
+ from mindspore import _checkparam as Validator
27
+
28
+ _MAX_GROUP_NAME_LEN = 127
29
+ _DEFAULT_HCCL_FUSION_GROUP_NAME = "hccl_world_groupsum1"
30
+ _DEFAULT_NCCL_FUSION_GROUP_NAME = "nccl_world_groupsum1"
31
+
32
+
33
+ class _ParallelFusionConfig:
34
+ """
35
+ The key of the Parallel fusion method configuration.
36
+ """
37
+ ALLREDUCE = "allreduce"
38
+ ALLGATHER = "allgather"
39
+ REDUCESCATTER = "reducescatter"
40
+ MODE = "mode"
41
+ FUSION_CONFIG = "config"
42
+ AUTO = "auto"
43
+ INDEX = "index"
44
+ SIZE = "size"
45
+ OPENSTATE = "openstate"
46
+ CONFIG = {"openstate": True,
47
+ "allreduce": {"mode": "auto", "config": None},
48
+ "allgather": {"mode": "auto", "config": None},
49
+ "reducescatter": {"mode": "auto", "config": None}}
50
+
51
+ @classmethod
52
+ def reset(cls):
53
+ cls.CONFIG = {"openstate": True,
54
+ "allreduce": {"mode": "auto", "config": None},
55
+ "allgather": {"mode": "auto", "config": None},
56
+ "reducescatter": {"mode": "auto", "config": None}}
57
+
58
+
59
+ class _ParallelOptimizerConfig:
60
+ """
61
+ The key of the Parallel Optimizer. There are three
62
+ """
63
+ GRADIENT_ACCUMULATION_SHARD = "gradient_accumulation_shard"
64
+ PARALLEL_OPTIMIZER_THRESHOLD = "parallel_optimizer_threshold"
65
+ OPTIMIZER_WEIGHT_SHARD_SIZE = "optimizer_weight_shard_size"
66
+
67
+
68
+ class _PipelineConfig:
69
+ """
70
+ The key of the Pipeline parallelism.
71
+ """
72
+ PIPELINE_INTERLEAVE = "pipeline_interleave"
73
+ PIPELINE_SCHEDULER = "pipeline_scheduler"
74
+
75
+
76
+ class _PipelineScheduler:
77
+ PIPELINE_1F1B = "1f1b"
78
+ PIPELINE_GPIPE = "gpipe"
79
+
80
+
81
+ class _AutoParallelContext:
82
+ """
83
+ _AutoParallelContext is the environment in which operations are executed
84
+
85
+ Note:
86
+ Create a context through instantiating Context object is not recommended.
87
+ Should use auto_parallel_context() to get the context since Context is singleton.
88
+ """
89
+ _instance = None
90
+ _instance_lock = threading.Lock()
91
+
92
+ def __new__(cls):
93
+ if cls._instance is None:
94
+ cls._instance_lock.acquire()
95
+ cls._instance = object.__new__(cls)
96
+ cls._instance_lock.release()
97
+ return cls._instance
98
+
99
+ def __init__(self):
100
+ self._context_handle = AutoParallelContext.get_instance()
101
+ self._dataset_strategy_using_str = True
102
+
103
+ def check_context_handle(self):
104
+ """
105
+ Check context handle.
106
+
107
+ Raises:
108
+ ValueError: If the context handle is none.
109
+ """
110
+ if self._context_handle is None:
111
+ raise ValueError("Context handle is none in context!!!")
112
+
113
+ def set_device_num(self, device_num):
114
+ """
115
+ Set device num for auto parallel.
116
+
117
+ Args:
118
+ device_num (int): The device number.
119
+
120
+ Raises:
121
+ ValueError: If the device num is not a positive integer.
122
+ """
123
+ self.check_context_handle()
124
+ if device_num < 1:
125
+ raise ValueError("The context configuration parameter 'device_num' must be a positive integer, "
126
+ "but got the value of device_num : {}.".format(device_num))
127
+ from mindspore.communication._comm_helper import _HCCL_TEST_AVAILABLE
128
+ self._context_handle.set_hccl_test_avaible(_HCCL_TEST_AVAILABLE)
129
+ self._context_handle.set_device_num(device_num)
130
+
131
+ def get_device_num(self):
132
+ """Get device num."""
133
+ self.check_context_handle()
134
+ return self._context_handle.get_device_num()
135
+
136
+ def set_comm_fusion(self, config):
137
+ """
138
+ Set fusion method for auto parallel.
139
+
140
+ Args:
141
+ config (dict): A dict contains the methods and values for setting the communication fusion. Currently it
142
+ supports: `allreduce`.
143
+
144
+ Raises:
145
+ KeyError: When key of comm_fusion is not 'allreduce'.
146
+ """
147
+ self.check_context_handle()
148
+ config = copy.deepcopy(config)
149
+ if _ParallelFusionConfig.OPENSTATE not in config.keys():
150
+ config[_ParallelFusionConfig.OPENSTATE] = True
151
+ for key in list(config.keys()):
152
+ if key == _ParallelFusionConfig.ALLREDUCE:
153
+ self._set_allreduce_comm_fusion(config[key])
154
+ elif key == _ParallelFusionConfig.ALLGATHER:
155
+ self._set_allgather_comm_fusion(config[key], key)
156
+ elif key == _ParallelFusionConfig.REDUCESCATTER:
157
+ self._set_allgather_comm_fusion(config[key], key)
158
+ elif key == _ParallelFusionConfig.OPENSTATE:
159
+ self._set_openstate_comm_fusion(config[key])
160
+ else:
161
+ raise KeyError("comm fusion type must be openstate,"
162
+ "allreduce, allgather or reducescatter, but got {}".format(key))
163
+ if key in _ParallelFusionConfig.CONFIG:
164
+ _ParallelFusionConfig.CONFIG[key] = config[key]
165
+
166
+ def get_comm_fusion(self):
167
+ """Get comm fusion config."""
168
+ self.check_context_handle()
169
+ return _ParallelFusionConfig.CONFIG
170
+
171
+ def set_dump_local_norm(self, dump_local_norm):
172
+ """
173
+ Set dump local norm for auto parallel.
174
+
175
+ Args:
176
+ dump_local_norm (bool): User need to specify if he want to dump local norm. Default: False
177
+
178
+ Raises:
179
+ KeyError: When key of comm_fusion is not 'allreduce'.
180
+ """
181
+ self.check_context_handle()
182
+ self._context_handle.set_dump_local_norm(dump_local_norm)
183
+
184
+ def get_dump_local_norm(self):
185
+ """Get dump local norm."""
186
+ self.check_context_handle()
187
+ return self._context_handle.get_dump_local_norm()
188
+
189
+ def set_fusion_threshold_mb(self, fusion_threshold=64, comm_type="allreduce"):
190
+ """
191
+ Set fusion threshold (MB) for auto parallel.
192
+
193
+ Args:
194
+ fusion_threshold (int): The fusion threshold (unit: MB). Default: 64.
195
+ comm_type (str): The name of the communication operator, `allreduce`, `allgather` or `reducescatter`.
196
+
197
+ Raises:
198
+ ValueError: If the fusion threshold is not in [0, +inf].
199
+ """
200
+ self.check_context_handle()
201
+ if fusion_threshold < 0:
202
+ raise ValueError("fusion threshold must be larger than 0, but got {}".format(fusion_threshold))
203
+
204
+ if comm_type == _ParallelFusionConfig.ALLREDUCE:
205
+ self._context_handle.set_fusion_threshold_mb(fusion_threshold)
206
+ if comm_type == _ParallelFusionConfig.ALLGATHER:
207
+ self._context_handle.set_allgather_fusion_threshold_mb(fusion_threshold)
208
+ if comm_type == _ParallelFusionConfig.REDUCESCATTER:
209
+ self._context_handle.set_reducescatter_fusion_threshold_mb(fusion_threshold)
210
+
211
+ def fusion_threshold_mb(self):
212
+ """Get all reduce threshold."""
213
+ self.check_context_handle()
214
+ return self._context_handle.fusion_threshold_mb()
215
+
216
+ def allgather_fusion_threshold_mb(self):
217
+ """Get allgather threshold."""
218
+ self.check_context_handle()
219
+ return self._context_handle.allgather_fusion_threshold_mb()
220
+
221
+ def reducescatter_fusion_threshold_mb(self):
222
+ """Get reducescatter threshold."""
223
+ self.check_context_handle()
224
+ return self._context_handle.reducescatter_fusion_threshold_mb()
225
+
226
+ def set_global_rank(self, global_rank):
227
+ """
228
+ Set global rank for auto parallel.
229
+
230
+ Args:
231
+ global_rank (int): The rank id of current rank.
232
+
233
+ Raises:
234
+ ValueError: If the global rank is not in [1, 4096].
235
+ """
236
+ self.check_context_handle()
237
+ if global_rank < 0 or global_rank > 4095:
238
+ raise ValueError("The context configuration parameter 'global_rank' must be in [0, 4095], "
239
+ "but got the value of global_rank : {}.".format(global_rank))
240
+ self._context_handle.set_global_rank(global_rank)
241
+
242
+ def get_global_rank(self):
243
+ """Get current rank id."""
244
+ self.check_context_handle()
245
+ return self._context_handle.get_global_rank()
246
+
247
+ def set_pipeline_stages(self, stages):
248
+ """Set the stages of the pipeline"""
249
+ if isinstance(stages, bool) or not isinstance(stages, int):
250
+ raise TypeError("For 'set_auto_parallel_context', the argument 'pipeline_stages' "
251
+ "must be int, but got the type : {}.".format(type(stages)))
252
+ if stages < 1:
253
+ raise ValueError("For 'set_auto_parallel_context', the argument 'pipeline_stages' "
254
+ "should be greater or equal 1, but got the value of stages : {}.".format(stages))
255
+ self.check_context_handle()
256
+ self._context_handle.set_pipeline_stage_split_num(stages)
257
+
258
+ def get_pipeline_stages(self):
259
+ """Get the stages of the pipeline"""
260
+ self.check_context_handle()
261
+ return self._context_handle.get_pipeline_stage_split_num()
262
+
263
+ def set_auto_pipeline(self, auto_pipeline):
264
+ """Set the pipeline stage number to automatic"""
265
+ if not isinstance(auto_pipeline, bool):
266
+ raise TypeError("For 'set_auto_parallel_context', the argument 'auto_pipeline' "
267
+ "must be bool, but got the type : {}.".format(type(auto_pipeline)))
268
+ self.check_context_handle()
269
+ self._context_handle.set_auto_pipeline(auto_pipeline)
270
+
271
+ def get_auto_pipeline(self):
272
+ """Get whether the pipeline stage number is automatic"""
273
+ self.check_context_handle()
274
+ return self._context_handle.get_auto_pipeline()
275
+
276
+ def set_pipeline_result_broadcast(self, pipeline_result_broadcast):
277
+ """
278
+ Set the value of enabling pipeline result broadcast. Default: ``False``.
279
+
280
+ Args:
281
+ pipeline_result_broadcast (bool): Enable/disable broadcast the last stage result to all other stages.
282
+ """
283
+ self.check_context_handle()
284
+ if not isinstance(pipeline_result_broadcast, bool):
285
+ raise TypeError("For 'set_auto_parallel_context().set_pipeline_result_broadcast', the argument "
286
+ "'pipeline_result_broadcast' must be bool, but got the type : {}."
287
+ .format(type(pipeline_result_broadcast)))
288
+ self._context_handle.set_pipeline_result_broadcast(pipeline_result_broadcast)
289
+
290
+ def get_pipeline_result_broadcast(self):
291
+ """Get the value of enabling pipeline result broadcast"""
292
+ self.check_context_handle()
293
+ return self._context_handle.get_pipeline_result_broadcast()
294
+
295
+ def get_pipeline_interleave(self):
296
+ """Get pipeline interleave flag"""
297
+ self.check_context_handle()
298
+ return self._context_handle.get_pipeline_interleave()
299
+
300
+ def get_pipeline_scheduler(self):
301
+ """Get pipeline scheduler"""
302
+ self.check_context_handle()
303
+ return self._context_handle.get_pipeline_scheduler()
304
+
305
+ def set_pipeline_segments(self, segments):
306
+ """Set the segments of the pipeline"""
307
+ if isinstance(segments, bool) or not isinstance(segments, int):
308
+ raise TypeError("For 'set_auto_parallel_context', the argument 'pipeline_segments' "
309
+ "must be int, but got the type : {}.".format(type(segments)))
310
+ if segments < 1:
311
+ raise ValueError("For 'set_auto_parallel_context', the argument 'pipeline_segments' "
312
+ "should be greater or equal 1, but got the value of segments : {}.".format(segments))
313
+ self.check_context_handle()
314
+ self._context_handle.set_pipeline_segment_split_num(segments)
315
+
316
+ def get_pipeline_segments(self):
317
+ """Get the stages of the pipeline"""
318
+ self.check_context_handle()
319
+ return self._context_handle.get_pipeline_segment_split_num()
320
+
321
+ def set_gradients_mean(self, gradients_mean):
322
+ """
323
+ Set gradients_mean flag.
324
+
325
+ Note:
326
+ If gradients_mean is true, it will insert a div operator after parameter gradients allreduce.
327
+
328
+ Args:
329
+ gradients_mean (bool): The gradients_mean flag.
330
+ """
331
+ self.check_context_handle()
332
+ self._context_handle.set_gradients_mean(gradients_mean)
333
+
334
+ def get_gradients_mean(self):
335
+ """Get gradients_mean flag."""
336
+ self.check_context_handle()
337
+ return self._context_handle.get_gradients_mean()
338
+
339
+ def set_gradient_fp32_sync(self, gradient_fp32_sync):
340
+ """
341
+ Set gradient_fp32_sync.
342
+
343
+ Note:
344
+ If gradient_fp32_sync is true,
345
+ it will convert tensor type from fp16 to fp32 before parameter gradients allreduce.
346
+
347
+ Args:
348
+ gradient_fp32_sync (bool): The gradient_fp32_sync flag.
349
+ """
350
+ self.check_context_handle()
351
+ self._context_handle.set_gradient_fp32_sync(gradient_fp32_sync)
352
+
353
+ def get_gradient_fp32_sync(self):
354
+ """Get gradient_fp32_sync flag."""
355
+ self.check_context_handle()
356
+ return self._context_handle.get_gradient_fp32_sync()
357
+
358
+ def set_loss_repeated_mean(self, loss_repeated_mean):
359
+ """
360
+ Set loss_repeated_mean flag.
361
+
362
+ Note:
363
+ If loss_repeated_mean is true,
364
+ Distributed automatic differentiation will perform a mean operator
365
+ in backward in the case of repeated calculations.
366
+
367
+ Args:
368
+ loss_repeated_mean (bool): The loss_repeated_mean flag.
369
+ """
370
+ if not isinstance(loss_repeated_mean, bool):
371
+ raise TypeError("For 'set_auto_parallel_context', the argument 'loss_repeated_mean' "
372
+ "must be bool, but got the type : {}.".format(type(loss_repeated_mean)))
373
+ self.check_context_handle()
374
+ self._context_handle.set_loss_repeated_mean(loss_repeated_mean)
375
+
376
+ def get_loss_repeated_mean(self):
377
+ """Get loss_repeated_mean flag."""
378
+ self.check_context_handle()
379
+ return self._context_handle.get_loss_repeated_mean()
380
+
381
+ def set_parallel_mode(self, parallel_mode):
382
+ """
383
+ Set parallel mode for auto parallel.
384
+
385
+ Args:
386
+ parallel_mode (str): The parallel mode of auto parallel.
387
+
388
+ Raises:
389
+ ValueError: If parallel mode is not supported.
390
+ """
391
+ self.check_context_handle()
392
+ run_mode = context.get_context("mode")
393
+ if run_mode == context.PYNATIVE_MODE and parallel_mode not in (
394
+ context.ParallelMode.DATA_PARALLEL, context.ParallelMode.STAND_ALONE,
395
+ context.ParallelMode.AUTO_PARALLEL):
396
+ raise ValueError(f"Pynative only supports STAND_ALONE, DATA_PARALLEL and AUTO_PARALLEL using"
397
+ f" sharding_propagation under shard function"
398
+ f" for ParallelMode, "
399
+ f"but got {parallel_mode.upper()}.")
400
+ ret = self._context_handle.set_parallel_mode(parallel_mode)
401
+ if ret is False:
402
+ raise ValueError("The context configuration parameter 'parallel_mode' only support 'stand_alone', "
403
+ "'data_parallel', 'hybrid_parallel', 'semi_auto_parallel' and 'auto_parallel', "
404
+ "but got the value : {}.".format(parallel_mode))
405
+
406
+ def get_parallel_mode(self):
407
+ """Get parallel mode."""
408
+ self.check_context_handle()
409
+ return self._context_handle.get_parallel_mode()
410
+
411
+ def set_strategy_search_mode(self, search_mode):
412
+ """
413
+ Set search mode of strategy.
414
+
415
+ Args:
416
+ search_mode (str): The search mode of strategy.
417
+ """
418
+ self.check_context_handle()
419
+ ret = self._context_handle.set_strategy_search_mode(search_mode)
420
+ if ret is False:
421
+ raise ValueError("The context configuration parameter 'auto_parallel_search_mode' only support "
422
+ "'recursive_programming', 'dynamic_programming' and 'sharding_propagation', "
423
+ "but got the value: {}."
424
+ .format(search_mode))
425
+
426
+ def get_strategy_search_mode(self):
427
+ """Get search mode of strategy."""
428
+ self.check_context_handle()
429
+ return self._context_handle.get_strategy_search_mode()
430
+
431
+ def set_auto_parallel_search_mode(self, search_mode):
432
+ """
433
+ Set search mode of strategy searching. This is the old version of 'search_mode', and will be deleted in a future
434
+ MindSpore version.
435
+
436
+ Args:
437
+ search_mode (str): The search mode of strategy.
438
+ """
439
+ logger.warning("The attribute 'auto_parallel_search_mode' is currently replaced by 'search_mode'. "
440
+ "The attribute 'auto_parallel_search_mode' will be deleted in a future MindSpore version.")
441
+ self.check_context_handle()
442
+ ret = self._context_handle.set_strategy_search_mode(search_mode)
443
+ if ret is False:
444
+ raise ValueError("The context configuration parameter 'search_mode' only support "
445
+ "'recursive_programming', 'dynamic_programming' and 'sharding_propagation', "
446
+ "but got the value: {}."
447
+ .format(search_mode))
448
+
449
+ def get_auto_parallel_search_mode(self):
450
+ """Get search mode of strategy. This is the old version of 'search_mode', and will be deleted in a future
451
+ MindSpore version.
452
+ """
453
+ logger.warning("The attribute 'auto_parallel_search_mode' is currently replaced by 'search_mode'. "
454
+ "The attribute 'auto_parallel_search_mode' will be deleted in a future MindSpore version.")
455
+ self.check_context_handle()
456
+ return self._context_handle.get_strategy_search_mode()
457
+
458
+ def set_sharding_propagation(self, sharding_propagation):
459
+ """
460
+ Set the value of sharding strategy propagation in AUTO_PARALLEL mode. If True, the strategy-configured operators
461
+ will propagate the strategies to other operators with minimum redistribution cost; otherwise, the algorithm
462
+ will search the desired strategies. Default: ``False``.
463
+ This attribute is replaced by context.set_auto_parallel_context(search_mode="sharding_propagation").
464
+
465
+ Args:
466
+ sharding_propagation (bool): Enable/disable strategy propagation.
467
+ """
468
+ logger.warning("This attribute is replaced by "
469
+ "context.set_auto_parallel_context(search_mode='sharding_propagation'), and this attribute will"
470
+ " be deleted in a future MindSpore version.")
471
+ self.check_context_handle()
472
+ if not isinstance(sharding_propagation, bool):
473
+ raise TypeError("For 'set_auto_parallel_context().set_sharding_propagation', "
474
+ "the argument 'sharding_propagation' must be bool, but got the type : {}."
475
+ .format(type(sharding_propagation)))
476
+ self._context_handle.set_sharding_propagation(sharding_propagation)
477
+
478
+ def get_sharding_propagation(self):
479
+ """Get the value of sharding strategy propagation."""
480
+ self.check_context_handle()
481
+ return self._context_handle.get_sharding_propagation()
482
+
483
+ def set_parameter_broadcast(self, parameter_broadcast):
484
+ """
485
+ Set parameter broadcast.
486
+
487
+ Args:
488
+ parameter_broadcast (bool): Parameter broadcast or not.
489
+ """
490
+ self.check_context_handle()
491
+ self._context_handle.set_parameter_broadcast(parameter_broadcast)
492
+
493
+ def get_parameter_broadcast(self):
494
+ """Get parameter broadcast flag."""
495
+ self.check_context_handle()
496
+ return self._context_handle.get_parameter_broadcast()
497
+
498
+ def set_strategy_ckpt_load_file(self, strategy_ckpt_load_file):
499
+ """
500
+ Set strategy checkpoint load path.
501
+
502
+ Args:
503
+ strategy_ckpt_load_file (str): Path to load parallel strategy checkpoint.
504
+ """
505
+ self.check_context_handle()
506
+ self._context_handle.set_strategy_ckpt_load_file(strategy_ckpt_load_file)
507
+
508
+ def get_strategy_ckpt_load_file(self):
509
+ """Get strategy checkpoint load path."""
510
+ self.check_context_handle()
511
+ return self._context_handle.get_strategy_ckpt_load_file()
512
+
513
+ def set_full_batch(self, full_batch):
514
+ """
515
+ Set whether load full batch on each device.
516
+
517
+ Args:
518
+ full_batch (bool): True if load full batch on each device.
519
+ """
520
+ self.check_context_handle()
521
+ self._context_handle.set_full_batch(full_batch)
522
+
523
+ def get_full_batch(self):
524
+ """Get whether load full batch on each device."""
525
+ self.check_context_handle()
526
+ if _is_role_pserver():
527
+ return False
528
+ return self._context_handle.get_full_batch()
529
+
530
+ def set_dataset_strategy(self, dataset_strategy):
531
+ """
532
+ Set dataset sharding strategy.
533
+
534
+ Args:
535
+ dataset_strategy (str or tuple(tuple)): The dataset sharding strategy.
536
+ """
537
+ self.check_context_handle()
538
+ if isinstance(dataset_strategy, str):
539
+ if dataset_strategy not in ("full_batch", "data_parallel"):
540
+ raise ValueError("For 'set_auto_parallel_context', the argument "
541
+ "'dataset_strategy' must be 'full_batch' or 'data_parallel', but got the value : {}."
542
+ .format(dataset_strategy))
543
+ self._context_handle.set_full_batch(dataset_strategy == "full_batch")
544
+ self._dataset_strategy_using_str = True
545
+ return
546
+ if not isinstance(dataset_strategy, tuple):
547
+ raise TypeError("For 'set_auto_parallel_context', the argument 'dataset_strategy' "
548
+ "must be str or tuple type, but got the type : {}.".format(type(dataset_strategy)))
549
+ for ele in dataset_strategy:
550
+ if not isinstance(ele, tuple):
551
+ raise TypeError("For 'set_auto_parallel_context', the element of argument "
552
+ "'dataset_strategy' must be tuple, but got the type : {} .".format(type(ele)))
553
+ for dim in ele:
554
+ if not isinstance(dim, int):
555
+ raise TypeError("For 'set_auto_parallel_context', the element of argument "
556
+ "'dataset_strategy' must be int type, but got the type : {} .".format(type(dim)))
557
+ if context.get_context('mode') == context.PYNATIVE_MODE:
558
+ raise ValueError("In PyNative mode, the setting value of 'dataset_strategy' must be either 'full_batch' "
559
+ f"or 'data_parallel', but got {dataset_strategy}.")
560
+ self._dataset_strategy_using_str = False
561
+ self._context_handle.set_dataset_strategy(dataset_strategy)
562
+
563
+ def get_dataset_strategy(self):
564
+ """Get dataset sharding strategy."""
565
+ self.check_context_handle()
566
+ if self._dataset_strategy_using_str:
567
+ if self._context_handle.get_full_batch():
568
+ return "full_batch"
569
+ return "data_parallel"
570
+ dataset_strategy = self._context_handle.get_dataset_strategy()
571
+ if context.get_context('mode') == context.PYNATIVE_MODE:
572
+ raise ValueError("In PyNative mode, the value of 'dataset_strategy' must be either 'full_batch' "
573
+ f"or 'data_parallel', but got the setting value is {dataset_strategy}.")
574
+ return dataset_strategy
575
+
576
+ def set_grad_accumulation_step(self, grad_accumulation_step):
577
+ """
578
+ Set grad accumulation step.
579
+
580
+ Args:
581
+ grad_accumulation_step (int): The grad accumulation step.
582
+ """
583
+ if grad_accumulation_step > 1:
584
+ raise ValueError("The interface is deprecated. To use gradient accumulation, "
585
+ "please use GradAccumulationCell in mindspore.nn.wrap.cell_wrapper.")
586
+ self.check_context_handle()
587
+ Validator.check_positive_int(grad_accumulation_step)
588
+ self._context_handle.set_grad_accumulation_step(grad_accumulation_step)
589
+
590
+ def get_grad_accumulation_step(self):
591
+ """Get grad accumulation step."""
592
+ self.check_context_handle()
593
+ return self._context_handle.get_grad_accumulation_step()
594
+
595
+ def set_strategy_ckpt_save_file(self, strategy_ckpt_save_file):
596
+ """
597
+ Set strategy checkpoint save path.
598
+
599
+ Args:
600
+ strategy_ckpt_save_file (bool): Path to save parallel strategy checkpoint.
601
+ """
602
+ self.check_context_handle()
603
+ dir_path = os.path.dirname(strategy_ckpt_save_file)
604
+ if dir_path and not os.path.exists(dir_path):
605
+ os.makedirs(dir_path, mode=0o700, exist_ok=True)
606
+ self._context_handle.set_strategy_ckpt_save_file(strategy_ckpt_save_file)
607
+
608
+ def get_strategy_ckpt_save_file(self):
609
+ """Get strategy checkpoint save path."""
610
+ self.check_context_handle()
611
+ return self._context_handle.get_strategy_ckpt_save_file()
612
+
613
+ def set_strategy_ckpt_config(self, strategy_ckpt_config):
614
+ """
615
+ Set strategy checkpoint config.
616
+
617
+ Args:
618
+ strategy_ckpt_config (dict): The strategy checkpoint config.
619
+ """
620
+ self.check_context_handle()
621
+ if not isinstance(strategy_ckpt_config, dict):
622
+ raise TypeError("For 'set_auto_parallel_context', the argument 'strategy_ckpt_config' "
623
+ "must be dict, but got the type : {}.".format(type(strategy_ckpt_config)))
624
+ for config_name in strategy_ckpt_config:
625
+ unknown_config = []
626
+ if config_name not in ["load_file", "save_file", "only_trainable_params"]:
627
+ unknown_config.append(config_name)
628
+
629
+ if unknown_config:
630
+ raise ValueError("Unknown config: {}".format(unknown_config))
631
+ if "load_file" in strategy_ckpt_config:
632
+ load_file = strategy_ckpt_config.get("load_file")
633
+ if not isinstance(load_file, str):
634
+ raise TypeError("For 'set_auto_parallel_context().set_strategy_ckpt_config', "
635
+ "the argument 'load_file' must be str, but got the type : {} .".format(type(load_file)))
636
+ self._context_handle.set_strategy_ckpt_load_file(load_file)
637
+ if "save_file" in strategy_ckpt_config:
638
+ save_file = strategy_ckpt_config.get("save_file")
639
+ if not isinstance(save_file, str):
640
+ raise TypeError("For 'set_auto_parallel_context().set_strategy_ckpt_config', "
641
+ "the argument 'save_file' must be str, but got the type : {} .".format(type(save_file)))
642
+ self._context_handle.set_strategy_ckpt_save_file(save_file)
643
+ if "only_trainable_params" in strategy_ckpt_config:
644
+ only_trainable_params = strategy_ckpt_config.get("only_trainable_params")
645
+ if not isinstance(only_trainable_params, bool):
646
+ raise TypeError("For 'set_auto_parallel_context().set_strategy_ckpt_config', "
647
+ "the argument 'only_trainable_params' must be bool,"
648
+ " but got the type : {} .".format(type(only_trainable_params)))
649
+ self._context_handle.set_stra_file_only_trainable_params(only_trainable_params)
650
+
651
+ def get_strategy_ckpt_config(self):
652
+ """Get strategy checkpoint config."""
653
+ self.check_context_handle()
654
+ load_file = self._context_handle.get_strategy_ckpt_load_file()
655
+ save_file = self._context_handle.get_strategy_ckpt_save_file()
656
+ only_trainable_param = self._context_handle.get_stra_file_only_trainable_params()
657
+ return {"load_file": load_file, "save_file": save_file, "only_trainable_params": only_trainable_param}
658
+
659
+ def set_group_ckpt_save_file(self, group_ckpt_save_file):
660
+ """Set group checkpoint save path."""
661
+ self.check_context_handle()
662
+ dir_path = os.path.dirname(group_ckpt_save_file)
663
+ if dir_path and not os.path.exists(dir_path):
664
+ os.makedirs(dir_path, mode=0o700, exist_ok=True)
665
+ self._context_handle.set_group_ckpt_save_file(group_ckpt_save_file)
666
+
667
+ def get_parameter_broadcast_is_set(self):
668
+ """Get parameter broadcast is set or not."""
669
+ self.check_context_handle()
670
+ return self._context_handle.get_parameter_broadcast_is_set()
671
+
672
+ def set_all_reduce_fusion_split_indices(self, indices, group=""):
673
+ """
674
+ Set allreduce fusion strategy by parameters indices.
675
+
676
+ Args:
677
+ indices (list): Indices list.
678
+ group (str): The communication group of hccl/nccl.
679
+
680
+ Raises:
681
+ TypeError: If type of indices item is not int.
682
+ TypeError: If group is not a python str.
683
+ """
684
+ self.check_context_handle()
685
+ if not indices:
686
+ raise ValueError("For 'set_auto_parallel_context().set_all_reduce_fusion_split_indices', "
687
+ "the argument 'indices' can not be empty")
688
+
689
+ if isinstance(indices, (list)):
690
+ for index in indices:
691
+ if not isinstance(index, int) or isinstance(index, bool):
692
+ raise TypeError("For 'set_auto_parallel_context().set_all_reduce_fusion_split_indices', "
693
+ "the argument 'index' must be int, but got the type : {} .".format(type(index)))
694
+ else:
695
+ raise TypeError("For 'set_auto_parallel_context().set_all_reduce_fusion_split_indices', "
696
+ "the argument 'indices' must be list, but got the type : {} .".format(type(indices)))
697
+
698
+ if len(set(indices)) != len(indices):
699
+ raise ValueError("The indices has duplicate elements")
700
+
701
+ if sorted(indices) != indices:
702
+ raise ValueError("For 'set_auto_parallel_context().set_all_reduce_fusion_split_indices', "
703
+ "the elements in argument 'indices' must be sorted in ascending order")
704
+
705
+ new_group = self._check_and_default_group(group)
706
+
707
+ self._context_handle.set_all_reduce_fusion_split_indices(indices, new_group)
708
+ if context.get_context("device_target") == "Ascend" and context.get_context("enable_ge"):
709
+ _set_fusion_strategy_by_idx(indices)
710
+
711
+ def get_all_reduce_fusion_split_indices(self, group=""):
712
+ """
713
+ Get allreduce fusion split indices.
714
+
715
+ Args:
716
+ group (str): The communication group of hccl/nccl.
717
+
718
+ Returns:
719
+ Return split sizes list according to the group.
720
+
721
+ Raises:
722
+ TypeError: If group is not a python str.
723
+ """
724
+ self.check_context_handle()
725
+ new_group = self._check_and_default_group(group)
726
+ return self._context_handle.get_all_reduce_fusion_split_indices(new_group)
727
+
728
+ def set_all_reduce_fusion_split_sizes(self, sizes, group=""):
729
+ """
730
+ Set allreduce fusion strategy by parameters data sizes.
731
+
732
+ Args:
733
+ sizes (list): Sizes list.
734
+ group (str): The communication group of hccl/nccl.
735
+
736
+ Raises:
737
+ TypeError: If type of sizes item is not int.
738
+ TypeError: If group is not a python str.
739
+ """
740
+ self.check_context_handle()
741
+ if isinstance(sizes, (list)):
742
+ for size in sizes:
743
+ if not isinstance(size, int) or isinstance(size, bool):
744
+ raise TypeError("For 'set_auto_parallel_context().set_all_reduce_fusion_split_sizes', "
745
+ "the argument 'sizes' must be int, but got the type : {}.".format(type(size)))
746
+ else:
747
+ raise TypeError("For 'set_auto_parallel_context().set_all_reduce_fusion_split_sizes', "
748
+ "the argument 'sizes' must be list, but got the type : {}.".format(type(sizes)))
749
+
750
+ new_group = self._check_and_default_group(group)
751
+ self._context_handle.set_all_reduce_fusion_split_sizes(sizes, new_group)
752
+ if context.get_context("device_target") == "Ascend":
753
+ _set_fusion_strategy_by_size(sizes)
754
+
755
+ def get_all_reduce_fusion_split_sizes(self, group=""):
756
+ """
757
+ Get allreduce fusion split sizes.
758
+
759
+ Args:
760
+ group (str): The communication group of hccl/nccl.
761
+
762
+ Returns:
763
+ Return split sizes list according to the group.
764
+
765
+ Raises:
766
+ TypeError: If group is not a python str.
767
+ """
768
+ self.check_context_handle()
769
+ new_group = self._check_and_default_group(group)
770
+ return self._context_handle.get_all_reduce_fusion_split_sizes(new_group)
771
+
772
+ def set_enable_all_reduce_fusion(self, enable_all_reduce_fusion):
773
+ """
774
+ Set enable/disable all reduce fusion.
775
+
776
+ Args:
777
+ enable_all_reduce_fusion (bool): Enable/disable all reduce fusion.
778
+ """
779
+ self.check_context_handle()
780
+ if not isinstance(enable_all_reduce_fusion, bool):
781
+ raise TypeError("For 'set_auto_parallel_context().set_enable_all_reduce_fusion', "
782
+ "the argument 'enable_fusion' must be bool, but got the type : {}."
783
+ .format(type(enable_all_reduce_fusion)))
784
+ self._context_handle.set_enable_all_reduce_fusion(enable_all_reduce_fusion)
785
+
786
+ def set_enable_all_gather_fusion(self, enable_all_gather_fusion):
787
+ """
788
+ Set enable/disable all gather fusion.
789
+
790
+ Args:
791
+ enable_all_gather_fusion (bool): Enable/disable all gather fusion.
792
+ """
793
+ self.check_context_handle()
794
+ if not isinstance(enable_all_gather_fusion, bool):
795
+ raise TypeError("For 'set_auto_parallel_context().set_enable_all_gather_fusion', "
796
+ "the argument 'enable_fusion' must be bool, but got the type : {}."
797
+ .format(type(enable_all_gather_fusion)))
798
+ self._context_handle.set_enable_all_gather_fusion(enable_all_gather_fusion)
799
+
800
+ def set_enable_reduce_scatter_fusion(self, enable_reduce_scatter_fusion):
801
+ """
802
+ Set enable/disable reduce scatter fusion.
803
+
804
+ Args:
805
+ enable_reduce_scatter_fusion (bool): Enable/disable reduce scatter fusion.
806
+ """
807
+ self.check_context_handle()
808
+ if not isinstance(enable_reduce_scatter_fusion, bool):
809
+ raise TypeError("For 'set_auto_parallel_context().set_enable_reduce_scatter_fusion', "
810
+ "the argument 'enable_fusion' must be bool, but got the type : {}."
811
+ .format(type(enable_reduce_scatter_fusion)))
812
+ self._context_handle.set_enable_reduce_scatter_fusion(enable_reduce_scatter_fusion)
813
+
814
+ def get_enable_all_reduce_fusion(self):
815
+ """Get all reduce fusion flag."""
816
+ self.check_context_handle()
817
+ return self._context_handle.get_enable_all_reduce_fusion()
818
+
819
+ def get_enable_all_gather_fusion(self):
820
+ """Get all gather fusion flag."""
821
+ self.check_context_handle()
822
+ return self._context_handle.get_enable_all_gather_fusion()
823
+
824
+ def get_enable_reduce_scatter_fusion(self):
825
+ """Get reduce scatter flag."""
826
+ self.check_context_handle()
827
+ return self._context_handle.get_enable_reduce_scatter_fusion()
828
+
829
+ def get_device_num_is_set(self):
830
+ """Get device number is set or not."""
831
+ self.check_context_handle()
832
+ return self._context_handle.get_device_num_is_set()
833
+
834
+ def get_global_rank_is_set(self):
835
+ """Get global rank is set or not."""
836
+ self.check_context_handle()
837
+ return self._context_handle.get_global_rank_is_set()
838
+
839
+ def set_enable_parallel_optimizer(self, enable_parallel_optimizer):
840
+ """
841
+ Set enable/disable parallel optimizer.
842
+
843
+ Args:
844
+ set_enable_parallel_optimizer (bool): Enable/disable parallel optimizer.
845
+ """
846
+ self.check_context_handle()
847
+ if not isinstance(enable_parallel_optimizer, bool):
848
+ raise TypeError("For 'set_auto_parallel_context', "
849
+ "the argument 'enable_parallel_optimizer' must be bool, but got the type : {}."
850
+ .format(type(enable_parallel_optimizer)))
851
+ self._context_handle.set_enable_parallel_optimizer(enable_parallel_optimizer)
852
+
853
+ def set_force_fp32_communication(self, force_fp32_communication):
854
+ """
855
+ Set enable/disable force fp32 communication.
856
+
857
+ Args:
858
+ set_force_fp32_communication (bool): Enable/disable force fp32 communication.
859
+ """
860
+ self.check_context_handle()
861
+ if not isinstance(force_fp32_communication, bool):
862
+ raise TypeError("For 'set_auto_parallel_context', "
863
+ "the argument 'force_fp32_communication' must be bool, but got the type : {}."
864
+ .format(type(force_fp32_communication)))
865
+ self._context_handle.set_force_fp32_communication(force_fp32_communication)
866
+
867
+ def get_enable_fold_pipeline(self):
868
+ """Get parallel optimizer flag."""
869
+ self.check_context_handle()
870
+ return self._context_handle.get_enable_fold_pipeline()
871
+
872
+ def set_pipeline_config(self, pipeline_config):
873
+ r"""
874
+ Set the configuration for pipeline parallelism. The configuration provides more detailed behavior control about
875
+ parallel training when pipeline parallelism is enabled.
876
+
877
+ Args:
878
+ pipeline_config (dict): The configuration for pipeline parallelism. It supports following keys:
879
+
880
+ - pipeline_interleave(bool): Setting true enable interleave scheduler for pipeline parallelism. This
881
+ scheduler requires more memory but less bubble.
882
+ - pipeline_scheduler(string): There are two choices, "1f1b" and "gpipe". default is "1f1b"
883
+
884
+ - 1f1b: It requires less memory and bubble ratio, for it run backward pass when corresponding forward pass
885
+ finished.
886
+ - gpipe: It requires more memory and bubble ratio, for it run backward pass after all forward pass
887
+ finished.
888
+
889
+ Raises:
890
+ TypeError: If the type of `pipeline_config` is not `dict`.
891
+ ValueError: If the key in `pipeline_config` not in ["pipeline_interleave", "pipeline_scheduler"].
892
+ ValueError: If pipeline interleave is False, pipeline scheduler is not `1f1b`.
893
+ """
894
+ self.check_context_handle()
895
+
896
+ if not isinstance(pipeline_config, dict):
897
+ raise TypeError("For 'set_pipeline_config', the argument 'pipeine_config' "
898
+ "must be dict, but got the type : {}.".format(type(pipeline_config)))
899
+
900
+ pp_interleave = _PipelineConfig.PIPELINE_INTERLEAVE
901
+ pp_scheduler = _PipelineConfig.PIPELINE_SCHEDULER
902
+
903
+ for config_name in pipeline_config:
904
+ unknown_config = []
905
+ if config_name not in [pp_interleave, pp_scheduler]:
906
+ unknown_config.append(config_name)
907
+
908
+ if unknown_config:
909
+ raise ValueError("Unknown config: {}".format(unknown_config))
910
+
911
+ Validator.check_bool(
912
+ pipeline_config[pp_interleave], pp_interleave, pp_interleave)
913
+ self._context_handle.set_pipeline_interleave(
914
+ pipeline_config[pp_interleave])
915
+
916
+ Validator.check_string(pipeline_config[pp_scheduler], [_PipelineScheduler.PIPELINE_1F1B,
917
+ _PipelineScheduler.PIPELINE_GPIPE])
918
+ if not pipeline_config[pp_interleave] and pipeline_config[pp_scheduler] != _PipelineScheduler.PIPELINE_1F1B:
919
+ raise ValueError(f"When pipeline_interleave is False, {pp_scheduler} is not supported")
920
+
921
+ self._context_handle.set_pipeline_scheduler(pipeline_config[pp_scheduler])
922
+
923
+ def get_enable_parallel_optimizer(self):
924
+ """Get parallel optimizer flag."""
925
+ self.check_context_handle()
926
+ return self._context_handle.get_enable_parallel_optimizer()
927
+
928
+ def get_force_fp32_communication(self):
929
+ """Get force fp32 communication flag."""
930
+ self.check_context_handle()
931
+ return self._context_handle.get_force_fp32_communication()
932
+
933
+
934
+ def set_parallel_optimizer_config(self, parallel_optimizer_config):
935
+ r"""
936
+ Set the configure for parallel optimizer. The configure provides more detailed behavior control about parallel
937
+ training when parallel optimizer is enabled.
938
+
939
+ Args:
940
+ parallel_optimizer_config(dict): A dict contains the keys and values for setting the parallel optimizer
941
+ configure. It supports the following keys:
942
+
943
+ - gradient_accumulation_shard(bool): If true, the accumulation gradient parameters will be sharded
944
+ across the data parallel devices. This will introduce additional
945
+ communication cost(ReduceScatter) at each step when accumulate the
946
+ gradients, but saves a lot of device memories,
947
+ thus can make model be trained with larger batch size.
948
+ This configuration is effective only when the model runs on pipeline
949
+ training or gradient accumulation with data parallel.
950
+
951
+ - parallel_optimizer_threshold(int): Set the threshold of parallel optimizer. When parallel optimizer is
952
+ enabled, parameters with size smaller than this threshold will not be
953
+ sharded across the devices. Parameter size = shape[0] \* ... \*
954
+ shape[n] \* size(dtype). Non-negative. Unit: KB. Default: 64.
955
+ - optimizer_weight_shard_size(int): Set the optimizer weight shard group size if you want to specific the
956
+ maximum group size across devices when the parallel optimizer is
957
+ enabled. The numerical range can be (0, device_num]. Default value
958
+ is -1, which means the optimizer weight shard group size will
959
+ the data parallel group of each parameter. Default -1.
960
+
961
+ """
962
+ self.check_context_handle()
963
+ grad_shard_name = _ParallelOptimizerConfig.GRADIENT_ACCUMULATION_SHARD
964
+ threshold_name = _ParallelOptimizerConfig.PARALLEL_OPTIMIZER_THRESHOLD
965
+ optimizer_weight_shard_size_name = _ParallelOptimizerConfig.OPTIMIZER_WEIGHT_SHARD_SIZE
966
+
967
+ for config_name in parallel_optimizer_config:
968
+ unknown_config = []
969
+ if config_name not in [grad_shard_name, threshold_name, optimizer_weight_shard_size_name]:
970
+ unknown_config.append(config_name)
971
+
972
+ if unknown_config:
973
+ raise ValueError("Unknown config: {}".format(unknown_config))
974
+
975
+ if grad_shard_name in parallel_optimizer_config:
976
+ Validator.check_bool(
977
+ parallel_optimizer_config[grad_shard_name], grad_shard_name, grad_shard_name)
978
+ self._context_handle.set_grad_accumulation_shard(
979
+ parallel_optimizer_config[grad_shard_name])
980
+
981
+ if threshold_name in parallel_optimizer_config:
982
+ Validator.check_non_negative_int(
983
+ parallel_optimizer_config[threshold_name])
984
+ self._context_handle.set_parallel_optimizer_threshold(
985
+ parallel_optimizer_config[threshold_name])
986
+
987
+ if optimizer_weight_shard_size_name in parallel_optimizer_config:
988
+ value = parallel_optimizer_config[optimizer_weight_shard_size_name]
989
+ Validator.check_positive_int(value)
990
+ self.set_optimizer_weight_shard_size(value)
991
+
992
+ def get_grad_accumulation_shard(self):
993
+ """Get grad accumulation shard."""
994
+ self.check_context_handle()
995
+ return self._context_handle.get_grad_accumulation_shard()
996
+
997
+ def get_parallel_optimizer_threshold(self):
998
+ """Get parallel optimizer threshold."""
999
+ self.check_context_handle()
1000
+ return self._context_handle.get_parallel_optimizer_threshold()
1001
+
1002
+ def set_enable_alltoall(self, enable_a2a):
1003
+ """
1004
+ Set the value of enabling AllToAll. If False, AllGather and Split are used to circumvent AllToAll.
1005
+ Default: ``False``.
1006
+
1007
+ Args:
1008
+ enable_a2a (bool): Enable/disable AllToAll.
1009
+ """
1010
+ self.check_context_handle()
1011
+ if not isinstance(enable_a2a, bool):
1012
+ raise TypeError("For 'set_auto_parallel_context().set_enable_alltoall', the argument 'enable_a2a' "
1013
+ "must be bool, but got the type : {}.".format(type(enable_a2a)))
1014
+ self._context_handle.set_enable_alltoall(enable_a2a)
1015
+
1016
+ def get_enable_alltoall(self):
1017
+ """Get the value of enabling AllToAll."""
1018
+ self.check_context_handle()
1019
+ return self._context_handle.get_enable_alltoall()
1020
+
1021
+ def set_communi_parallel_mode(self, communi_parallel_mode):
1022
+ """
1023
+ Set communication parallel mode.
1024
+
1025
+ Args:
1026
+ communi_parallel_mode (str): The communication parallel mode.
1027
+
1028
+ Raises:
1029
+ ValueError: If parallel mode is not supported.
1030
+ """
1031
+ if not isinstance(communi_parallel_mode, str):
1032
+ raise TypeError("For 'set_auto_parallel_context().set_communi_parallel_mode', "
1033
+ "the argument 'communi_parallel_mode' must be str, but got the type : {}."
1034
+ .format(type(communi_parallel_mode)))
1035
+ self.check_context_handle()
1036
+ ret = self._context_handle.set_communi_parallel_mode(communi_parallel_mode)
1037
+ if ret is False:
1038
+ raise ValueError("For 'set_auto_parallel_context().set_communi_parallel_mode', "
1039
+ "the argument 'communi_parallel_mode' only support 'ALL_GROUP_PARALLEL', "
1040
+ "'SAME_SEVER_GROUP_PARALLEL' and 'NO_GROUP_PARALLEL', "
1041
+ "but got the value : {}.".format(communi_parallel_mode))
1042
+
1043
+ def get_communi_parallel_mode(self):
1044
+ """Get communication parallel mode."""
1045
+ self.check_context_handle()
1046
+ return self._context_handle.get_communi_parallel_mode()
1047
+
1048
+ def set_optimizer_weight_shard_size(self, optimizer_weight_shard_size):
1049
+ """
1050
+ Set optimizer_weight_shard_size.
1051
+
1052
+ Args:
1053
+ optimizer_weight_shard_size (int): Opt shard group size when not globally use parallel
1054
+ optimizer across devices.
1055
+ """
1056
+ self.check_context_handle()
1057
+ if not isinstance(optimizer_weight_shard_size, int) or isinstance(optimizer_weight_shard_size, bool):
1058
+ raise TypeError(f"The type of optimizer_weight_shard_size must be int, \
1059
+ but got {type(optimizer_weight_shard_size)}.")
1060
+ if optimizer_weight_shard_size <= 1:
1061
+ logger.warning("The setting 'optimizer_weight_shard_size' is invalid. "
1062
+ "Please use the integer larger than 1.")
1063
+ return
1064
+ self._context_handle.set_optimizer_weight_shard_size(optimizer_weight_shard_size)
1065
+
1066
+ def get_optimizer_weight_shard_size(self):
1067
+ """Get optimizer_weight_shard_size."""
1068
+ self.check_context_handle()
1069
+ return self._context_handle.get_optimizer_weight_shard_size()
1070
+
1071
+ def set_ops_strategy_json_config(self, type, path, mode):
1072
+ """
1073
+ Set configuration of saving ops strategy in file .json.
1074
+ """
1075
+ self.check_context_handle()
1076
+ self._context_handle.set_ops_strategy_json_config(type, path, mode)
1077
+
1078
+ def set_optimizer_weight_shard_aggregated_save(self, optimizer_weight_shard_aggregated_save):
1079
+ """
1080
+ Set optimizer_weight_shard_aggregated_save.
1081
+
1082
+ Args:
1083
+ optimizer_weight_shard_aggregated_save (bool): Whether to integrated save weight shard when
1084
+ enable parallel optimizer.
1085
+ """
1086
+ self.check_context_handle()
1087
+ if not isinstance(optimizer_weight_shard_aggregated_save, bool):
1088
+ raise TypeError('optimizer_weight_shard_aggregated_save is invalid type')
1089
+ self._context_handle.set_optimizer_weight_shard_aggregated_save(optimizer_weight_shard_aggregated_save)
1090
+
1091
+ def get_optimizer_weight_shard_aggregated_save(self):
1092
+ """Get optimizer_weight_shard_size."""
1093
+ self.check_context_handle()
1094
+ return self._context_handle.get_optimizer_weight_shard_aggregated_save()
1095
+
1096
+ def get_full_batch_is_set(self):
1097
+ """Get full batch attr"""
1098
+ self.check_context_handle()
1099
+ return self._context_handle.get_full_batch_is_set()
1100
+
1101
+ def reset(self):
1102
+ """Reset all settings."""
1103
+ self.check_context_handle()
1104
+ self._context_handle.reset()
1105
+ _ParallelFusionConfig.reset()
1106
+
1107
+ def _check_and_default_group(self, group):
1108
+ """Validate the given group, if group is empty, returns a default fusion group"""
1109
+ if isinstance(group, (str)):
1110
+ group_len = len(group)
1111
+ if group_len > _MAX_GROUP_NAME_LEN:
1112
+ raise ValueError('Group name len is out of range {_MAX_GROUP_NAME_LEN}')
1113
+ else:
1114
+ raise TypeError('Group must be a python str')
1115
+
1116
+ if group == "":
1117
+ if context.get_context("device_target") == "Ascend":
1118
+ group = _DEFAULT_HCCL_FUSION_GROUP_NAME
1119
+ else:
1120
+ group = _DEFAULT_NCCL_FUSION_GROUP_NAME
1121
+ return group
1122
+
1123
+ def _set_allgather_comm_fusion(self, comm_fusion, comm_type="allgather"):
1124
+ """
1125
+ Set allgather and reducescatter fusion method for auto parallel.
1126
+
1127
+ Args:
1128
+ comm_fusion (dict): A dict contains the methods and values for setting the fusion method. Currently it
1129
+ supports four fusion methods: `auto` and `size`.
1130
+ comm_type (str): The name of the communication operator, `allgather` or `reducescatter`.
1131
+
1132
+ Raises:
1133
+ KeyError: When key of comm_fusion is not 'mode' or 'config'.
1134
+ KeyError: When `mode` is not 'auto', 'size'.
1135
+ """
1136
+ self.check_context_handle()
1137
+ if comm_type == "allgather" and not self.get_enable_all_gather_fusion():
1138
+ self.set_enable_all_gather_fusion(True)
1139
+ if comm_type == "reducescatter" and not self.get_enable_reduce_scatter_fusion():
1140
+ self.set_enable_reduce_scatter_fusion(True)
1141
+ if not isinstance(comm_fusion, dict):
1142
+ raise TypeError("For 'comm_fusion', {} config must be dict, but got the type : {}.".format(
1143
+ comm_type, type(comm_fusion)))
1144
+ if _ParallelFusionConfig.MODE not in comm_fusion:
1145
+ raise KeyError("For 'comm_fusion', the key 'mode' should be contained.")
1146
+ if _ParallelFusionConfig.FUSION_CONFIG not in comm_fusion:
1147
+ raise KeyError("For 'comm_fusion', the key 'config' should be contained.")
1148
+ check_mode = [_ParallelFusionConfig.AUTO, _ParallelFusionConfig.SIZE]
1149
+ if comm_fusion[_ParallelFusionConfig.MODE] in check_mode:
1150
+ self._context_handle.set_fusion_mode(comm_fusion[_ParallelFusionConfig.MODE])
1151
+ else:
1152
+ raise KeyError("fusion method mode must be auto or size, but got {}".format(
1153
+ comm_fusion[_ParallelFusionConfig.MODE]))
1154
+
1155
+ fusion_threshold = 64
1156
+ if comm_fusion[_ParallelFusionConfig.MODE] != _ParallelFusionConfig.AUTO:
1157
+ fusion_threshold = comm_fusion[_ParallelFusionConfig.FUSION_CONFIG]
1158
+ self.set_fusion_threshold_mb(fusion_threshold, comm_type)
1159
+
1160
+ def _set_allreduce_comm_fusion(self, comm_fusion):
1161
+ """
1162
+ Set fusion method for auto parallel.
1163
+
1164
+ Args:
1165
+ comm_fusion (dict): A dict contains the methods and values for setting the fusion method. Currently it
1166
+ supports four fusion methods: `auto`, `size` and `index`.
1167
+
1168
+ Raises:
1169
+ KeyError: When key of comm_fusion is not 'mode' or 'config'.
1170
+ KeyError: When `mode` is not 'auto', 'size' or 'index'.
1171
+ """
1172
+ self.check_context_handle()
1173
+ if not self.get_enable_all_reduce_fusion():
1174
+ self.set_enable_all_reduce_fusion(True)
1175
+ if not isinstance(comm_fusion, dict):
1176
+ raise TypeError("For 'comm_fusion', the 'allreduce' config must be dict, but got the type : {}.".format(
1177
+ type(comm_fusion)))
1178
+ if _ParallelFusionConfig.MODE not in comm_fusion:
1179
+ raise KeyError("For 'comm_fusion', the key 'mode' should be contained.")
1180
+ if _ParallelFusionConfig.FUSION_CONFIG not in comm_fusion:
1181
+ raise KeyError("For 'comm_fusion', the key 'config' should be contained.")
1182
+ check_mode = [_ParallelFusionConfig.AUTO, _ParallelFusionConfig.INDEX, _ParallelFusionConfig.SIZE]
1183
+ if comm_fusion[_ParallelFusionConfig.MODE] in check_mode:
1184
+ self._context_handle.set_fusion_mode(comm_fusion[_ParallelFusionConfig.MODE])
1185
+ else:
1186
+ raise KeyError("fusion method mode must be auto, index or size, but got {}".format(
1187
+ comm_fusion[_ParallelFusionConfig.MODE]))
1188
+ if comm_fusion[_ParallelFusionConfig.MODE] == _ParallelFusionConfig.AUTO:
1189
+ self.set_fusion_threshold_mb(fusion_threshold=64)
1190
+ if comm_fusion[_ParallelFusionConfig.MODE] == _ParallelFusionConfig.SIZE:
1191
+ self.set_fusion_threshold_mb(comm_fusion[_ParallelFusionConfig.FUSION_CONFIG])
1192
+ if comm_fusion[_ParallelFusionConfig.MODE] == _ParallelFusionConfig.INDEX:
1193
+ self.set_all_reduce_fusion_split_indices(comm_fusion[_ParallelFusionConfig.FUSION_CONFIG])
1194
+
1195
+ def _set_openstate_comm_fusion(self, openstate):
1196
+ """
1197
+ Set open state for comm fusion.
1198
+
1199
+ Args:
1200
+ openstate (bool): The open state value to set the fusion method whether or not. Currently it
1201
+ supports two states: `True`, or `Flase`.
1202
+
1203
+ Raises:
1204
+ TypeError: When the value is not bool.
1205
+ """
1206
+ self.check_context_handle()
1207
+ if not isinstance(openstate, bool):
1208
+ raise TypeError("For 'comm_fusion', the 'openstate' must be bool, but got the type : {}.".format(
1209
+ type(openstate)))
1210
+ if not openstate:
1211
+ self.set_enable_all_reduce_fusion(openstate)
1212
+ self.set_enable_all_gather_fusion(openstate)
1213
+ self.set_enable_reduce_scatter_fusion(openstate)
1214
+
1215
+
1216
+ def _set_ops_strategy_json_config(type="SAVE", path="", mode="all"):
1217
+ """
1218
+ Set strategy json configuration.
1219
+
1220
+ Args:
1221
+ type (str): The parameter for choosing save or load .json file.
1222
+ path (str): Path to save or load parallel strategy json.
1223
+ mode (str): The parameter for choosing save all or important operators.
1224
+
1225
+ Raises:
1226
+ KeyError: When type is not 'SAVE' or 'LOAD'.
1227
+ KeyError: When mode is not 'all' or 'principal'.
1228
+ """
1229
+ dir_path = os.path.dirname(path)
1230
+ if dir_path and not os.path.exists(dir_path):
1231
+ os.makedirs(dir_path, mode=0o700, exist_ok=True)
1232
+ check_type = ["SAVE", "LOAD"]
1233
+ check_mode = ["all", "principal"]
1234
+ if type in check_type and mode in check_mode:
1235
+ auto_parallel_context().set_ops_strategy_json_config(type, path, mode)
1236
+ else:
1237
+ raise KeyError("Type must be 'SAVE' or 'LOAD' and mode must be 'all' or 'principal'")
1238
+
1239
+
1240
+ _AUTO_PARALLEL_CONTEXT = None
1241
+
1242
+
1243
+ def auto_parallel_context():
1244
+ """
1245
+ Get the global _AUTO_PARALLEL_CONTEXT, if it is not created, create a new one.
1246
+
1247
+ Returns:
1248
+ _AutoParallelContext, the global auto parallel context.
1249
+ """
1250
+ global _AUTO_PARALLEL_CONTEXT
1251
+ if _AUTO_PARALLEL_CONTEXT is None:
1252
+ _AUTO_PARALLEL_CONTEXT = _AutoParallelContext()
1253
+ return _AUTO_PARALLEL_CONTEXT
1254
+
1255
+
1256
+ _set_auto_parallel_context_func_map = {
1257
+ "device_num": auto_parallel_context().set_device_num,
1258
+ "global_rank": auto_parallel_context().set_global_rank,
1259
+ "gradients_mean": auto_parallel_context().set_gradients_mean,
1260
+ "gradient_fp32_sync": auto_parallel_context().set_gradient_fp32_sync,
1261
+ "loss_repeated_mean": auto_parallel_context().set_loss_repeated_mean,
1262
+ "pipeline_stages": auto_parallel_context().set_pipeline_stages,
1263
+ "auto_pipeline": auto_parallel_context().set_auto_pipeline,
1264
+ "pipeline_result_broadcast": auto_parallel_context().set_pipeline_result_broadcast,
1265
+ "pipeline_segments": auto_parallel_context().set_pipeline_segments,
1266
+ "parallel_mode": auto_parallel_context().set_parallel_mode,
1267
+ "search_mode": auto_parallel_context().set_strategy_search_mode,
1268
+ "auto_parallel_search_mode": auto_parallel_context().set_auto_parallel_search_mode,
1269
+ "parameter_broadcast": auto_parallel_context().set_parameter_broadcast,
1270
+ "strategy_ckpt_load_file": auto_parallel_context().set_strategy_ckpt_load_file,
1271
+ "strategy_ckpt_save_file": auto_parallel_context().set_strategy_ckpt_save_file,
1272
+ "group_ckpt_save_file": auto_parallel_context().set_group_ckpt_save_file,
1273
+ "full_batch": auto_parallel_context().set_full_batch,
1274
+ "dataset_strategy": auto_parallel_context().set_dataset_strategy,
1275
+ "enable_parallel_optimizer": auto_parallel_context().set_enable_parallel_optimizer,
1276
+ "force_fp32_communication": auto_parallel_context().set_force_fp32_communication,
1277
+ "parallel_optimizer_config": auto_parallel_context().set_parallel_optimizer_config,
1278
+ "pipeline_config": auto_parallel_context().set_pipeline_config,
1279
+ "grad_accumulation_step": auto_parallel_context().set_grad_accumulation_step,
1280
+ "all_reduce_fusion_config": auto_parallel_context().set_all_reduce_fusion_split_indices,
1281
+ "communi_parallel_mode": auto_parallel_context().set_communi_parallel_mode,
1282
+ "optimizer_weight_shard_size": auto_parallel_context().set_optimizer_weight_shard_size,
1283
+ "optimizer_weight_shard_aggregated_save": auto_parallel_context().set_optimizer_weight_shard_aggregated_save,
1284
+ "sharding_propagation": auto_parallel_context().set_sharding_propagation,
1285
+ "enable_alltoall": auto_parallel_context().set_enable_alltoall,
1286
+ "strategy_ckpt_config": auto_parallel_context().set_strategy_ckpt_config,
1287
+ "comm_fusion": auto_parallel_context().set_comm_fusion,
1288
+ "dump_local_norm": auto_parallel_context().set_dump_local_norm}
1289
+
1290
+ _get_auto_parallel_context_func_map = {
1291
+ "device_num": auto_parallel_context().get_device_num,
1292
+ "global_rank": auto_parallel_context().get_global_rank,
1293
+ "gradients_mean": auto_parallel_context().get_gradients_mean,
1294
+ "gradient_fp32_sync": auto_parallel_context().get_gradient_fp32_sync,
1295
+ "loss_repeated_mean": auto_parallel_context().get_loss_repeated_mean,
1296
+ "pipeline_stages": auto_parallel_context().get_pipeline_stages,
1297
+ "auto_pipeline": auto_parallel_context().get_auto_pipeline,
1298
+ "pipeline_result_broadcast": auto_parallel_context().get_pipeline_result_broadcast,
1299
+ "pipeline_interleave": auto_parallel_context().get_pipeline_interleave,
1300
+ "pipeline_scheduler": auto_parallel_context().get_pipeline_scheduler,
1301
+ "parallel_mode": auto_parallel_context().get_parallel_mode,
1302
+ "search_mode": auto_parallel_context().get_strategy_search_mode,
1303
+ "auto_parallel_search_mode": auto_parallel_context().get_auto_parallel_search_mode,
1304
+ "parameter_broadcast": auto_parallel_context().get_parameter_broadcast,
1305
+ "strategy_ckpt_load_file": auto_parallel_context().get_strategy_ckpt_load_file,
1306
+ "strategy_ckpt_save_file": auto_parallel_context().get_strategy_ckpt_save_file,
1307
+ "full_batch": auto_parallel_context().get_full_batch,
1308
+ "dataset_strategy": auto_parallel_context().get_dataset_strategy,
1309
+ "enable_parallel_optimizer": auto_parallel_context().get_enable_parallel_optimizer,
1310
+ "force_fp32_communication": auto_parallel_context().get_force_fp32_communication,
1311
+ "grad_accumulation_step": auto_parallel_context().get_grad_accumulation_step,
1312
+ "all_reduce_fusion_config": auto_parallel_context().get_all_reduce_fusion_split_indices,
1313
+ "communi_parallel_mode": auto_parallel_context().get_communi_parallel_mode,
1314
+ "optimizer_weight_shard_size": auto_parallel_context().get_optimizer_weight_shard_size,
1315
+ "optimizer_weight_shard_aggregated_save": auto_parallel_context().get_optimizer_weight_shard_aggregated_save,
1316
+ "sharding_propagation": auto_parallel_context().get_sharding_propagation,
1317
+ "enable_alltoall": auto_parallel_context().get_enable_alltoall,
1318
+ "comm_fusion": auto_parallel_context().get_comm_fusion,
1319
+ "strategy_ckpt_config": auto_parallel_context().get_strategy_ckpt_config,
1320
+ "full_batch_is_set": auto_parallel_context().get_full_batch_is_set,
1321
+ "dump_local_norm": auto_parallel_context().get_dump_local_norm}
1322
+
1323
+
1324
+ @args_type_check(device_num=int, global_rank=int, gradients_mean=bool, gradient_fp32_sync=bool,
1325
+ loss_repeated_mean=bool, parallel_mode=str, search_mode=str, auto_parallel_search_mode=str,
1326
+ parameter_broadcast=bool, strategy_ckpt_load_file=str,
1327
+ strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool,
1328
+ grad_accumulation_step=int, all_reduce_fusion_config=list, group_ckpt_save_file=str,
1329
+ communi_parallel_mode=str, optimizer_weight_shard_size=int, sharding_propagation=bool,
1330
+ optimizer_weight_shard_aggregated_save=bool, enable_alltoall=bool, comm_fusion=dict,
1331
+ strategy_ckpt_config=dict, force_fp32_communication=bool)
1332
+ def _set_auto_parallel_context(**kwargs):
1333
+ """
1334
+ Set auto parallel context.
1335
+
1336
+ Note:
1337
+ Attribute name is required for setting attributes.
1338
+
1339
+ Args:
1340
+ device_num (int): Available device number, the value must be in [1, 4096]. Default: 1.
1341
+ global_rank (int): Global rank id, the value must be in [0, 4095]. Default: 0.
1342
+ gradients_mean (bool): Whether to perform mean operator after all-reduce of mirror. Default: ``False``.
1343
+ loss_repeated_mean (bool): Whether to perform mean operator in backward in the case of repeated
1344
+ calculations. Default: ``True``.
1345
+ gradient_fp32_sync (bool): Gradients allreduce by fp32 even though gradients is fp16 if this flag is True.
1346
+ Default: ``True``.
1347
+ parallel_mode (str): There are five kinds of parallel modes, "stand_alone", "data_parallel",
1348
+ "hybrid_parallel", "semi_auto_parallel" and "auto_parallel". Default: "stand_alone".
1349
+
1350
+ - stand_alone: Only one processor working.
1351
+
1352
+ - data_parallel: Distributing the data across different processors.
1353
+
1354
+ - hybrid_parallel: Achieving data parallelism and model parallelism manually.
1355
+
1356
+ - semi_auto_parallel: Achieving data parallelism and model parallelism by
1357
+ setting parallel strategies.
1358
+
1359
+ - auto_parallel: Achieving parallelism automatically.
1360
+ search_mode (str): There are two kinds of search modes: "recursive_programming", "dynamic_programming"
1361
+ and "sharding_propagation". Default: "dynamic_programming".
1362
+
1363
+ - recursive_programming: Recursive programming search mode.
1364
+
1365
+ - dynamic_programming: Dynamic programming search mode.
1366
+
1367
+ - sharding_propagation: Propagate shardings from configured ops to non-configured ops.
1368
+ auto_parallel_search_mode (str): This is the old version of 'search_mode'. Here, remaining this attribute is
1369
+ for forward compatibility, and this attribute will be deleted in a future MindSpore version.
1370
+ parameter_broadcast (bool): Indicating whether to broadcast parameters before training.
1371
+ "stand_alone", "semi_auto_parallel" and "auto_parallel" do not support parameter
1372
+ broadcast. Default: ``False``.
1373
+ strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: ''
1374
+ strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: ''
1375
+ group_ckpt_save_file (str): The path to save parallel group checkpoint. Default: ''
1376
+ full_batch (bool): Whether to load the whole batch on each device. Default: ``False``.
1377
+ dataset_strategy Union[str, tuple]: Dataset sharding strategy. Default: "data_parallel".
1378
+ enable_parallel_optimizer (bool): Enable using optimizer segmentation or not. Default: ``False``.
1379
+ force_fp32_communication (bool): A switch that determines whether reduce operators (AllReduce, ReduceScatter)
1380
+ are forced to use the fp32 data type for communication during communication. True is the enable
1381
+ switch. Default: ``False`` .
1382
+ all_reduce_fusion_config (list): Set allreduce fusion strategy by parameters indices.
1383
+ pipeline_stages (int): Set the stage information for pipeline parallel. This indicates how
1384
+ the devices are distributed alone the pipeline. The total devices will be divided into
1385
+ 'pipeline_stags' stages. This currently could only be used when
1386
+ parallel mode semi_auto_parallel is enabled. Default: 0
1387
+ auto_pipeline (bool): Set the pipeline stage number to automatic. Its value will be selected between 1 and the
1388
+ parameter `pipeline_stages`. This option requires the `parallel_mode` to be ``auto_parallel``
1389
+ and the `search_mode` to be ``recursive_programming``. Default: ``False`` .
1390
+ pipeline_result_broadcast (bool): A switch that broadcast the last stage result to all other stage in pipeline
1391
+ parallel inference. Default: ``False`` .
1392
+ communi_parallel_mode (str): There are tree kinds of communication parallel modes, "all_group_parallel",
1393
+ "same_server_group_parallel" and "no_group_parallel". Default: "all_group_parallel".
1394
+
1395
+ - all_group_parallel: All communication groups are in parallel.
1396
+
1397
+ - same_server_group_parallel: Only the communication groups within the same server are parallel.
1398
+
1399
+ - no_group_parallel: All communication groups are not parallel.
1400
+ optimizer_weight_shard_size (int): Set optimizer shard group size when not fully use parallel optimizer.
1401
+ It should be larger than one and less than or equal with the data parallel size.
1402
+ Default: -1, which means fully use parallel optimizer in data parallel dimension.
1403
+ optimizer_weight_shard_aggregated_save (bool): Whether to integrated save weight shard when enable parallel
1404
+ optimizer. Default: ``False``.
1405
+ sharding_propagation (bool): Set the value of sharding strategy propagation in AUTO_PARALLEL mode. If True,
1406
+ the strategy-configured operators will propagate the strategies to other
1407
+ operators with minimum redistribution cost; otherwise, the algorithm will
1408
+ search the desired strategies. Default: ``False``.
1409
+ enable_alltoall (bool): Set the value of enabling AllToAll. If False, AllGather and Split are used to
1410
+ circumvent AllToAll. Default: ``False``.
1411
+ comm_fusion (dict): A dict contains the types and configurations for setting the communication fusion. each
1412
+ communication fusion config has two keys: "mode" and "config".
1413
+ It supports following communication fusion types and configurations:
1414
+
1415
+ - openstate: Whether turn on the communication fusion or not. If `openstate` is `True`, turn on
1416
+ the communication fusion, otherwise, turn off the communication fusion. Default: `True`.
1417
+
1418
+ - allreduce: if communication fusion type is `allreduce`. The `mode` contains: `auto`, `size`
1419
+ and `index`. In `auto` mode, allreduce fusion is configured by gradients size, and the default
1420
+ fusion threshold is `64` MB. In 'size' mode, allreduce fusion is configured by gradients size
1421
+ manually, and the fusion threshold must be larger than `0` MB. In `index` mode, it is same as
1422
+ `all_reduce_fusion_config`.
1423
+
1424
+ - allgather: If communication fusion type is `allgather`. The `mode` contains: `auto`, `size`.
1425
+ In `auto` mode, AllGather fusion is configured by gradients size, and the default fusion
1426
+ threshold is `64` MB. In 'size' mode, AllGather fusion is configured by gradients size
1427
+ manually, and the fusion threshold must be larger than `0` MB.
1428
+
1429
+ - reducescatter: If communication fusion type is `reducescatter`. The `mode` contains: `auto`
1430
+ and `size`. Config is same as `allgather`.
1431
+
1432
+
1433
+
1434
+ Raises:
1435
+ ValueError: If input key is not attribute in auto parallel context.
1436
+ """
1437
+ for key, value in kwargs.items():
1438
+ if key not in _set_auto_parallel_context_func_map:
1439
+ raise ValueError("Set context keyword %s is not recognized!" % key)
1440
+ set_func = _set_auto_parallel_context_func_map[key]
1441
+ set_func(value)
1442
+
1443
+
1444
+ def _get_auto_parallel_context(attr_key):
1445
+ """
1446
+ Get auto parallel context attribute value according to the key.
1447
+
1448
+ Args:
1449
+ attr_key (str): The key of the attribute.
1450
+
1451
+ Returns:
1452
+ Return attribute value according to the key.
1453
+
1454
+ Raises:
1455
+ ValueError: If input key is not attribute in auto parallel context.
1456
+ """
1457
+ if attr_key not in _get_auto_parallel_context_func_map:
1458
+ raise ValueError("Get context keyword %s is not recognized!" % attr_key)
1459
+ get_func = _get_auto_parallel_context_func_map[attr_key]
1460
+ return get_func()
1461
+
1462
+
1463
+ def _reset_auto_parallel_context():
1464
+ """
1465
+ Reset auto parallel context attributes to the default values:
1466
+
1467
+ - device_num: 1.
1468
+ - global_rank: 0.
1469
+ - gradients_mean: False.
1470
+ - gradient_fp32_sync: True.
1471
+ - parallel_mode: "stand_alone".
1472
+ - parameter_broadcast: False.
1473
+ - strategy_ckpt_load_file: ""
1474
+ - strategy_ckpt_save_file: ""
1475
+ - enable_parallel_optimizer: False
1476
+ - force_fp32_communication: False
1477
+ - search_mode: 'recursive_programming
1478
+ - auto_parallel_search_mode: 'recursive_programming
1479
+ - sharding_propagation: False
1480
+ - pipeline_stages: 0
1481
+ - auto_pipeline: False
1482
+ - pipeline_result_broadcast: False
1483
+ - gradient_accumulation_shard: True
1484
+ - fusion_threshold: 64
1485
+ """
1486
+ auto_parallel_context().reset()