mindspore 2.4.10__cp310-cp310-win_amd64.whl → 2.6.0rc1__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (602) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  3. mindspore/Newtonsoft.Json.dll +0 -0
  4. mindspore/__init__.py +13 -6
  5. mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
  6. mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
  7. mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
  8. mindspore/_check_jit_forbidden_api.py +3 -0
  9. mindspore/_checkparam.py +3 -38
  10. mindspore/_deprecated/__init__.py +17 -0
  11. mindspore/_deprecated/jit.py +198 -0
  12. mindspore/_extends/builtin_operations.py +1 -1
  13. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
  14. mindspore/_extends/parse/__init__.py +6 -7
  15. mindspore/_extends/parse/compile_config.py +83 -0
  16. mindspore/_extends/parse/deprecated/__init__.py +0 -0
  17. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +394 -0
  18. mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
  19. mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
  20. mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
  21. mindspore/_extends/parse/parser.py +46 -197
  22. mindspore/_extends/parse/resources.py +1 -5
  23. mindspore/_extends/parse/standard_method.py +217 -98
  24. mindspore/_extends/pijit/__init__.py +2 -2
  25. mindspore/_extends/pijit/pijit_func_white_list.py +17 -12
  26. mindspore/_extends/pijit/tensor_func_list.py +27 -0
  27. mindspore/_extends/utils.py +1 -1
  28. mindspore/amp.py +11 -5
  29. mindspore/atlprov.dll +0 -0
  30. mindspore/avcodec-59.dll +0 -0
  31. mindspore/avdevice-59.dll +0 -0
  32. mindspore/avfilter-8.dll +0 -0
  33. mindspore/avformat-59.dll +0 -0
  34. mindspore/avutil-57.dll +0 -0
  35. mindspore/boost/__init__.py +2 -2
  36. mindspore/boost/base.py +3 -7
  37. mindspore/boost/boost_cell_wrapper.py +138 -43
  38. mindspore/c1.dll +0 -0
  39. mindspore/c1xx.dll +0 -0
  40. mindspore/c2.dll +0 -0
  41. mindspore/common/__init__.py +6 -3
  42. mindspore/common/_grad_function.py +56 -0
  43. mindspore/common/_pijit_context.py +14 -5
  44. mindspore/common/_register_for_tensor.py +1 -2
  45. mindspore/common/_stub_tensor.py +30 -14
  46. mindspore/common/_tensor_cpp_method.py +17 -0
  47. mindspore/common/_tensor_docs.py +4760 -0
  48. mindspore/common/api.py +435 -371
  49. mindspore/common/auto_dynamic_shape.py +41 -44
  50. mindspore/common/dtype.py +39 -36
  51. mindspore/common/dump.py +9 -6
  52. mindspore/common/file_system.py +9 -1
  53. mindspore/common/generator.py +2 -0
  54. mindspore/common/hook_handle.py +6 -2
  55. mindspore/common/initializer.py +13 -10
  56. mindspore/common/jit_begin_end.py +94 -0
  57. mindspore/common/jit_config.py +6 -1
  58. mindspore/common/jit_context.py +76 -0
  59. mindspore/common/jit_trace.py +378 -0
  60. mindspore/common/lazy_inline.py +9 -3
  61. mindspore/common/mindir_util.py +10 -2
  62. mindspore/common/mutable.py +5 -4
  63. mindspore/common/parameter.py +135 -52
  64. mindspore/common/seed.py +2 -2
  65. mindspore/common/sparse_tensor.py +23 -17
  66. mindspore/common/tensor.py +951 -1992
  67. mindspore/communication/__init__.py +7 -5
  68. mindspore/communication/_comm_helper.py +52 -2
  69. mindspore/communication/comm_func.py +240 -181
  70. mindspore/communication/management.py +95 -26
  71. mindspore/context.py +314 -566
  72. mindspore/dataset/__init__.py +65 -37
  73. mindspore/dataset/audio/__init__.py +2 -8
  74. mindspore/dataset/audio/transforms.py +3 -17
  75. mindspore/dataset/callback/ds_callback.py +2 -1
  76. mindspore/dataset/core/config.py +87 -6
  77. mindspore/dataset/engine/cache_admin.py +3 -3
  78. mindspore/dataset/engine/cache_client.py +6 -5
  79. mindspore/dataset/engine/datasets.py +292 -267
  80. mindspore/dataset/engine/datasets_audio.py +22 -8
  81. mindspore/dataset/engine/datasets_standard_format.py +46 -27
  82. mindspore/dataset/engine/datasets_text.py +78 -48
  83. mindspore/dataset/engine/datasets_user_defined.py +182 -116
  84. mindspore/dataset/engine/datasets_vision.py +120 -44
  85. mindspore/dataset/engine/iterators.py +283 -63
  86. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
  87. mindspore/dataset/engine/obs/util.py +8 -0
  88. mindspore/dataset/engine/queue.py +40 -0
  89. mindspore/dataset/engine/samplers.py +289 -43
  90. mindspore/dataset/engine/serializer_deserializer.py +3 -2
  91. mindspore/dataset/engine/validators.py +53 -11
  92. mindspore/dataset/text/__init__.py +7 -6
  93. mindspore/dataset/text/transforms.py +6 -5
  94. mindspore/dataset/text/utils.py +3 -3
  95. mindspore/dataset/transforms/__init__.py +0 -9
  96. mindspore/dataset/transforms/py_transforms_util.py +17 -0
  97. mindspore/dataset/transforms/transforms.py +31 -14
  98. mindspore/dataset/utils/browse_dataset.py +1 -1
  99. mindspore/dataset/vision/__init__.py +2 -9
  100. mindspore/dataset/vision/transforms.py +202 -158
  101. mindspore/dataset/vision/utils.py +7 -5
  102. mindspore/dataset/vision/validators.py +1 -2
  103. mindspore/device_context/__init__.py +21 -0
  104. mindspore/device_context/ascend/__init__.py +25 -0
  105. mindspore/device_context/ascend/device.py +72 -0
  106. mindspore/device_context/ascend/op_debug.py +153 -0
  107. mindspore/device_context/ascend/op_precision.py +193 -0
  108. mindspore/device_context/ascend/op_tuning.py +123 -0
  109. mindspore/{ops_generate/gen_constants.py → device_context/cpu/__init__.py} +6 -17
  110. mindspore/device_context/cpu/device.py +62 -0
  111. mindspore/device_context/cpu/op_tuning.py +43 -0
  112. mindspore/device_context/gpu/__init__.py +21 -0
  113. mindspore/device_context/gpu/device.py +70 -0
  114. mindspore/device_context/gpu/op_precision.py +67 -0
  115. mindspore/device_context/gpu/op_tuning.py +175 -0
  116. mindspore/device_manager.py +170 -0
  117. mindspore/dnnl.dll +0 -0
  118. mindspore/dpcmi.dll +0 -0
  119. mindspore/experimental/es/embedding_service.py +35 -27
  120. mindspore/experimental/llm_boost/__init__.py +1 -0
  121. mindspore/experimental/llm_boost/ascend_native/__init__.py +22 -0
  122. mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +211 -0
  123. mindspore/experimental/llm_boost/ascend_native/llm_boost.py +52 -0
  124. mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
  125. mindspore/experimental/llm_boost/atb/llama_boost.py +6 -1
  126. mindspore/experimental/llm_boost/register.py +1 -0
  127. mindspore/experimental/map_parameter.py +4 -4
  128. mindspore/experimental/optim/adadelta.py +6 -6
  129. mindspore/experimental/optim/adagrad.py +4 -4
  130. mindspore/experimental/optim/adam.py +7 -0
  131. mindspore/experimental/optim/adamax.py +4 -4
  132. mindspore/experimental/optim/adamw.py +4 -0
  133. mindspore/experimental/optim/asgd.py +1 -1
  134. mindspore/experimental/optim/lr_scheduler.py +73 -46
  135. mindspore/experimental/optim/radam.py +34 -31
  136. mindspore/experimental/optim/rprop.py +1 -1
  137. mindspore/experimental/optim/sgd.py +1 -1
  138. mindspore/hal/contiguous_tensors_handle.py +6 -10
  139. mindspore/hal/device.py +55 -53
  140. mindspore/hal/event.py +52 -52
  141. mindspore/hal/memory.py +157 -117
  142. mindspore/hal/stream.py +150 -109
  143. mindspore/include/api/context.h +0 -1
  144. mindspore/include/dataset/constants.h +7 -4
  145. mindspore/include/dataset/execute.h +2 -2
  146. mindspore/jpeg62.dll +0 -0
  147. mindspore/log.py +50 -0
  148. mindspore/mindrecord/__init__.py +21 -8
  149. mindspore/mindrecord/config.py +17 -316
  150. mindspore/mindrecord/filereader.py +1 -9
  151. mindspore/mindrecord/filewriter.py +5 -15
  152. mindspore/mindrecord/mindpage.py +1 -9
  153. mindspore/mindspore_backend_common.dll +0 -0
  154. mindspore/mindspore_backend_manager.dll +0 -0
  155. mindspore/mindspore_common.dll +0 -0
  156. mindspore/mindspore_core.dll +0 -0
  157. mindspore/mindspore_dump.dll +0 -0
  158. mindspore/mindspore_frontend.dll +0 -0
  159. mindspore/mindspore_glog.dll +0 -0
  160. mindspore/mindspore_memory_pool.dll +0 -0
  161. mindspore/mindspore_ms_backend.dll +0 -0
  162. mindspore/mindspore_ops.dll +0 -0
  163. mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
  164. mindspore/mindspore_ops_kernel_common.dll +0 -0
  165. mindspore/mindspore_profiler.dll +0 -0
  166. mindspore/mindspore_pyboost.dll +0 -0
  167. mindspore/mindspore_pynative.dll +0 -0
  168. mindspore/mindspore_res_manager.dll +0 -0
  169. mindspore/mindspore_runtime_pipeline.dll +0 -0
  170. mindspore/mint/__init__.py +796 -759
  171. mindspore/mint/distributed/__init__.py +70 -4
  172. mindspore/mint/distributed/distributed.py +2679 -44
  173. mindspore/mint/linalg/__init__.py +8 -0
  174. mindspore/mint/nn/__init__.py +743 -22
  175. mindspore/mint/nn/functional.py +716 -23
  176. mindspore/mint/nn/layer/__init__.py +21 -4
  177. mindspore/mint/nn/layer/_functions.py +334 -0
  178. mindspore/mint/nn/layer/activation.py +276 -1
  179. mindspore/mint/nn/layer/basic.py +123 -0
  180. mindspore/mint/nn/layer/conv.py +921 -0
  181. mindspore/mint/nn/layer/normalization.py +223 -28
  182. mindspore/mint/nn/layer/padding.py +797 -0
  183. mindspore/mint/nn/layer/pooling.py +235 -0
  184. mindspore/mint/optim/__init__.py +3 -1
  185. mindspore/mint/optim/adam.py +223 -0
  186. mindspore/mint/optim/adamw.py +26 -19
  187. mindspore/mint/optim/sgd.py +171 -0
  188. mindspore/mint/special/__init__.py +2 -1
  189. mindspore/msobj140.dll +0 -0
  190. mindspore/mspdb140.dll +0 -0
  191. mindspore/mspdbcore.dll +0 -0
  192. mindspore/mspdbst.dll +0 -0
  193. mindspore/mspft140.dll +0 -0
  194. mindspore/msvcdis140.dll +0 -0
  195. mindspore/msvcp140_1.dll +0 -0
  196. mindspore/msvcp140_2.dll +0 -0
  197. mindspore/msvcp140_atomic_wait.dll +0 -0
  198. mindspore/msvcp140_codecvt_ids.dll +0 -0
  199. mindspore/multiprocessing/__init__.py +5 -0
  200. mindspore/nn/__init__.py +4 -1
  201. mindspore/nn/cell.py +1370 -189
  202. mindspore/nn/dynamic_lr.py +2 -1
  203. mindspore/nn/layer/activation.py +29 -27
  204. mindspore/nn/layer/basic.py +51 -35
  205. mindspore/nn/layer/channel_shuffle.py +3 -3
  206. mindspore/nn/layer/container.py +1 -1
  207. mindspore/nn/layer/conv.py +22 -17
  208. mindspore/nn/layer/embedding.py +12 -11
  209. mindspore/nn/layer/normalization.py +56 -49
  210. mindspore/nn/layer/padding.py +4 -3
  211. mindspore/nn/layer/pooling.py +120 -42
  212. mindspore/nn/layer/rnn_cells.py +1 -1
  213. mindspore/nn/layer/rnns.py +2 -1
  214. mindspore/nn/layer/timedistributed.py +5 -5
  215. mindspore/nn/layer/transformer.py +59 -36
  216. mindspore/nn/learning_rate_schedule.py +8 -4
  217. mindspore/nn/loss/loss.py +58 -55
  218. mindspore/nn/optim/ada_grad.py +7 -5
  219. mindspore/nn/optim/adadelta.py +11 -9
  220. mindspore/nn/optim/adafactor.py +1 -1
  221. mindspore/nn/optim/adam.py +17 -13
  222. mindspore/nn/optim/adamax.py +8 -7
  223. mindspore/nn/optim/adasum.py +5 -5
  224. mindspore/nn/optim/asgd.py +1 -1
  225. mindspore/nn/optim/ftrl.py +11 -9
  226. mindspore/nn/optim/lamb.py +1 -1
  227. mindspore/nn/optim/lars.py +1 -4
  228. mindspore/nn/optim/lazyadam.py +12 -10
  229. mindspore/nn/optim/momentum.py +7 -6
  230. mindspore/nn/optim/optimizer.py +3 -3
  231. mindspore/nn/optim/proximal_ada_grad.py +12 -10
  232. mindspore/nn/optim/rmsprop.py +13 -12
  233. mindspore/nn/optim/rprop.py +11 -9
  234. mindspore/nn/optim/sgd.py +9 -6
  235. mindspore/nn/optim/tft_wrapper.py +5 -2
  236. mindspore/nn/optim/thor.py +2 -1
  237. mindspore/nn/probability/bijector/bijector.py +17 -11
  238. mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
  239. mindspore/nn/probability/bijector/invert.py +2 -2
  240. mindspore/nn/probability/bijector/scalar_affine.py +3 -3
  241. mindspore/nn/probability/bijector/softplus.py +3 -2
  242. mindspore/nn/probability/distribution/beta.py +3 -3
  243. mindspore/nn/probability/distribution/categorical.py +1 -1
  244. mindspore/nn/probability/distribution/cauchy.py +4 -2
  245. mindspore/nn/probability/distribution/exponential.py +6 -7
  246. mindspore/nn/probability/distribution/gamma.py +2 -2
  247. mindspore/nn/probability/distribution/gumbel.py +2 -2
  248. mindspore/nn/probability/distribution/half_normal.py +5 -3
  249. mindspore/nn/probability/distribution/logistic.py +5 -3
  250. mindspore/nn/probability/distribution/poisson.py +1 -1
  251. mindspore/nn/probability/distribution/uniform.py +5 -3
  252. mindspore/nn/reinforcement/_tensors_queue.py +1 -1
  253. mindspore/nn/reinforcement/tensor_array.py +1 -1
  254. mindspore/nn/utils/init.py +13 -11
  255. mindspore/nn/wrap/__init__.py +6 -6
  256. mindspore/nn/wrap/cell_wrapper.py +181 -122
  257. mindspore/nn/wrap/grad_reducer.py +45 -36
  258. mindspore/nn/wrap/loss_scale.py +6 -7
  259. mindspore/numpy/array_creations.py +63 -65
  260. mindspore/numpy/array_ops.py +149 -144
  261. mindspore/numpy/logic_ops.py +41 -42
  262. mindspore/numpy/math_ops.py +365 -363
  263. mindspore/numpy/utils.py +17 -18
  264. mindspore/numpy/utils_const.py +5 -6
  265. mindspore/opencv_core452.dll +0 -0
  266. mindspore/opencv_imgcodecs452.dll +0 -0
  267. mindspore/opencv_imgproc452.dll +0 -0
  268. mindspore/ops/__init__.py +5 -3
  269. mindspore/ops/_grad_experimental/grad_comm_ops.py +112 -16
  270. mindspore/ops/_grad_experimental/grad_debug_ops.py +14 -2
  271. mindspore/ops/_grad_experimental/grad_inner_ops.py +9 -0
  272. mindspore/ops/_grad_experimental/grad_math_ops.py +2 -1
  273. mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
  274. mindspore/ops/_op_impl/cpu/__init__.py +1 -0
  275. mindspore/ops/_op_impl/cpu/raise_op.py +28 -0
  276. mindspore/ops/_register_for_op.py +0 -11
  277. mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
  278. mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -65
  279. mindspore/ops/_vmap/vmap_array_ops.py +27 -25
  280. mindspore/ops/_vmap/vmap_base.py +0 -2
  281. mindspore/ops/_vmap/vmap_grad_nn_ops.py +21 -14
  282. mindspore/ops/_vmap/vmap_math_ops.py +15 -16
  283. mindspore/ops/_vmap/vmap_nn_ops.py +29 -42
  284. mindspore/ops/auto_generate/__init__.py +4 -3
  285. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +236 -46
  286. mindspore/ops/auto_generate/gen_extend_func.py +764 -124
  287. mindspore/ops/auto_generate/gen_ops_def.py +4018 -2264
  288. mindspore/ops/auto_generate/gen_ops_prim.py +15463 -5037
  289. mindspore/ops/auto_generate/pyboost_inner_prim.py +221 -87
  290. mindspore/ops/composite/__init__.py +2 -1
  291. mindspore/ops/composite/base.py +20 -25
  292. mindspore/ops/composite/math_ops.py +6 -16
  293. mindspore/ops/composite/multitype_ops/__init__.py +5 -2
  294. mindspore/ops/composite/multitype_ops/_compile_utils.py +228 -30
  295. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
  296. mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
  297. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
  298. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
  299. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
  300. mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
  301. mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
  302. mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
  303. mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
  304. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
  305. mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
  306. mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
  307. mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
  308. mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
  309. mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
  310. mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
  311. mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
  312. mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
  313. mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
  314. mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
  315. mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
  316. mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
  317. mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
  318. mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
  319. mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
  320. mindspore/ops/composite/multitype_ops/pow_impl.py +2 -30
  321. mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
  322. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  323. mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
  324. mindspore/ops/function/__init__.py +40 -2
  325. mindspore/ops/function/_add_attr_func.py +58 -0
  326. mindspore/ops/function/array_func.py +2089 -2403
  327. mindspore/ops/function/clip_func.py +80 -23
  328. mindspore/ops/function/debug_func.py +57 -57
  329. mindspore/ops/function/grad/__init__.py +1 -0
  330. mindspore/ops/function/grad/grad_func.py +104 -71
  331. mindspore/ops/function/image_func.py +2 -2
  332. mindspore/ops/function/linalg_func.py +47 -78
  333. mindspore/ops/function/math_func.py +4501 -3802
  334. mindspore/ops/function/nn_func.py +1726 -620
  335. mindspore/ops/function/other_func.py +159 -1
  336. mindspore/ops/function/parameter_func.py +18 -84
  337. mindspore/ops/function/random_func.py +440 -387
  338. mindspore/ops/function/reshard_func.py +4 -70
  339. mindspore/ops/function/sparse_func.py +3 -3
  340. mindspore/ops/function/sparse_unary_func.py +6 -6
  341. mindspore/ops/function/spectral_func.py +25 -58
  342. mindspore/ops/function/vmap_func.py +24 -17
  343. mindspore/ops/functional.py +22 -7
  344. mindspore/ops/functional_overload.py +1440 -0
  345. mindspore/ops/op_info_register.py +32 -244
  346. mindspore/ops/operations/__init__.py +13 -7
  347. mindspore/ops/operations/_custom_ops_utils.py +247 -0
  348. mindspore/ops/operations/_embedding_cache_ops.py +4 -4
  349. mindspore/ops/operations/_grad_ops.py +2 -43
  350. mindspore/ops/operations/_infer_ops.py +2 -1
  351. mindspore/ops/operations/_inner_ops.py +43 -84
  352. mindspore/ops/operations/_ms_kernel.py +4 -10
  353. mindspore/ops/operations/_rl_inner_ops.py +1 -1
  354. mindspore/ops/operations/_scalar_ops.py +3 -2
  355. mindspore/ops/operations/_sequence_ops.py +1 -1
  356. mindspore/ops/operations/_tensor_array.py +1 -1
  357. mindspore/ops/operations/array_ops.py +81 -324
  358. mindspore/ops/operations/comm_ops.py +154 -108
  359. mindspore/ops/operations/custom_ops.py +232 -78
  360. mindspore/ops/operations/debug_ops.py +153 -59
  361. mindspore/ops/operations/inner_ops.py +7 -5
  362. mindspore/ops/operations/linalg_ops.py +1 -57
  363. mindspore/ops/operations/manually_defined/_inner.py +1 -1
  364. mindspore/ops/operations/manually_defined/ops_def.py +928 -180
  365. mindspore/ops/operations/math_ops.py +32 -234
  366. mindspore/ops/operations/nn_ops.py +210 -498
  367. mindspore/ops/operations/other_ops.py +62 -9
  368. mindspore/ops/operations/random_ops.py +13 -7
  369. mindspore/ops/operations/reshard_ops.py +1 -1
  370. mindspore/ops/operations/sparse_ops.py +2 -2
  371. mindspore/ops/primitive.py +66 -53
  372. mindspore/ops/tensor_method.py +1888 -0
  373. mindspore/ops_generate/__init__.py +0 -5
  374. mindspore/ops_generate/aclnn/__init__.py +0 -0
  375. mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +135 -0
  376. mindspore/ops_generate/aclnn/gen_aclnn_implement.py +257 -0
  377. mindspore/ops_generate/api/__init__.py +0 -0
  378. mindspore/ops_generate/api/add_tensor_docs_generator.py +56 -0
  379. mindspore/ops_generate/api/cpp_create_prim_instance_helper_generator.py +105 -0
  380. mindspore/ops_generate/api/functional_map_cpp_generator.py +504 -0
  381. mindspore/ops_generate/api/functional_overload_py_generator.py +112 -0
  382. mindspore/ops_generate/api/functions_cc_generator.py +237 -0
  383. mindspore/ops_generate/api/gen_api.py +103 -0
  384. mindspore/ops_generate/api/op_api_proto.py +235 -0
  385. mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +461 -0
  386. mindspore/ops_generate/common/__init__.py +0 -0
  387. mindspore/ops_generate/common/base_generator.py +11 -0
  388. mindspore/ops_generate/common/gen_constants.py +91 -0
  389. mindspore/ops_generate/common/gen_utils.py +348 -0
  390. mindspore/ops_generate/common/op_proto.py +473 -0
  391. mindspore/ops_generate/common/template.py +523 -0
  392. mindspore/ops_generate/gen_ops.py +22 -1069
  393. mindspore/ops_generate/op_def/__init__.py +0 -0
  394. mindspore/ops_generate/op_def/gen_op_def.py +90 -0
  395. mindspore/ops_generate/op_def/lite_ops_cpp_generator.py +191 -0
  396. mindspore/ops_generate/op_def/ops_def_cc_generator.py +299 -0
  397. mindspore/ops_generate/op_def/ops_def_h_generator.py +74 -0
  398. mindspore/ops_generate/op_def/ops_name_h_generator.py +83 -0
  399. mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
  400. mindspore/ops_generate/op_def_py/__init__.py +0 -0
  401. mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
  402. mindspore/ops_generate/op_def_py/op_def_py_generator.py +132 -0
  403. mindspore/ops_generate/op_def_py/op_prim_py_generator.py +489 -0
  404. mindspore/ops_generate/pyboost/__init__.py +0 -0
  405. mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +139 -0
  406. mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +93 -0
  407. mindspore/ops_generate/pyboost/gen_pyboost_func.py +175 -0
  408. mindspore/ops_generate/pyboost/op_template_parser.py +517 -0
  409. mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +407 -0
  410. mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +100 -0
  411. mindspore/ops_generate/pyboost/pyboost_functions_py_generator.py +148 -0
  412. mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +155 -0
  413. mindspore/ops_generate/pyboost/pyboost_inner_prim_generator.py +132 -0
  414. mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +272 -0
  415. mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +938 -0
  416. mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +357 -0
  417. mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +179 -36
  418. mindspore/ops_generate/resources/__init__.py +0 -0
  419. mindspore/ops_generate/resources/resource_list.py +30 -0
  420. mindspore/ops_generate/resources/resource_loader.py +36 -0
  421. mindspore/ops_generate/resources/resource_manager.py +64 -0
  422. mindspore/ops_generate/resources/yaml_loader.py +88 -0
  423. mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
  424. mindspore/parallel/__init__.py +7 -3
  425. mindspore/parallel/_auto_parallel_context.py +152 -34
  426. mindspore/parallel/_cell_wrapper.py +130 -15
  427. mindspore/parallel/_parallel_serialization.py +107 -5
  428. mindspore/parallel/_ps_context.py +1 -1
  429. mindspore/parallel/_recovery_context.py +7 -2
  430. mindspore/parallel/_tensor.py +142 -18
  431. mindspore/parallel/_utils.py +199 -23
  432. mindspore/parallel/algo_parameter_config.py +4 -4
  433. mindspore/parallel/auto_parallel.py +732 -0
  434. mindspore/parallel/checkpoint_convert.py +159 -0
  435. mindspore/parallel/checkpoint_transform.py +698 -35
  436. mindspore/parallel/cluster/process_entity/_api.py +276 -50
  437. mindspore/parallel/cluster/process_entity/_utils.py +41 -6
  438. mindspore/parallel/cluster/run.py +21 -4
  439. mindspore/parallel/function/__init__.py +24 -0
  440. mindspore/parallel/function/reshard_func.py +259 -0
  441. mindspore/parallel/nn/__init__.py +25 -0
  442. mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
  443. mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
  444. mindspore/parallel/parameter_broadcast.py +25 -14
  445. mindspore/parallel/shard.py +137 -58
  446. mindspore/parallel/transform_safetensors.py +363 -305
  447. mindspore/pgodb140.dll +0 -0
  448. mindspore/pgort140.dll +0 -0
  449. mindspore/profiler/__init__.py +22 -5
  450. mindspore/profiler/analysis/__init__.py +0 -0
  451. mindspore/profiler/analysis/parser/__init__.py +0 -0
  452. mindspore/profiler/analysis/parser/ascend_cann_parser.py +170 -0
  453. mindspore/profiler/analysis/parser/base_parser.py +158 -0
  454. mindspore/profiler/analysis/parser/framework_cann_relation_parser.py +45 -0
  455. mindspore/profiler/analysis/parser/ms_framework_parser.py +142 -0
  456. mindspore/profiler/analysis/parser/ms_minddata_parser.py +145 -0
  457. mindspore/profiler/analysis/parser/timeline_assembly_factory/__init__.py +0 -0
  458. mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +264 -0
  459. mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +40 -0
  460. mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +106 -0
  461. mindspore/profiler/analysis/parser/timeline_creator/__init__.py +0 -0
  462. mindspore/profiler/analysis/parser/timeline_creator/base_timeline_creator.py +44 -0
  463. mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +90 -0
  464. mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +76 -0
  465. mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +103 -0
  466. mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +134 -0
  467. mindspore/profiler/analysis/parser/timeline_event/__init__.py +0 -0
  468. mindspore/profiler/analysis/parser/timeline_event/base_event.py +233 -0
  469. mindspore/profiler/analysis/parser/timeline_event/cpu_op_event.py +47 -0
  470. mindspore/profiler/analysis/parser/timeline_event/flow_event.py +36 -0
  471. mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +415 -0
  472. mindspore/profiler/analysis/parser/timeline_event/msprof_event.py +73 -0
  473. mindspore/profiler/analysis/parser/timeline_event/scope_layer_event.py +53 -0
  474. mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +146 -0
  475. mindspore/profiler/analysis/task_manager.py +131 -0
  476. mindspore/profiler/analysis/time_converter.py +84 -0
  477. mindspore/profiler/analysis/viewer/__init__.py +0 -0
  478. mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +372 -0
  479. mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +87 -0
  480. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +250 -0
  481. mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +320 -0
  482. mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +327 -0
  483. mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +376 -0
  484. mindspore/profiler/analysis/viewer/ascend_timeline_viewer.py +58 -0
  485. mindspore/profiler/analysis/viewer/base_viewer.py +26 -0
  486. mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +96 -0
  487. mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +581 -0
  488. mindspore/profiler/analysis/work_flow.py +73 -0
  489. mindspore/profiler/common/ascend_msprof_exporter.py +139 -0
  490. mindspore/profiler/common/command_executor.py +90 -0
  491. mindspore/profiler/common/constant.py +186 -3
  492. mindspore/profiler/common/file_manager.py +208 -0
  493. mindspore/profiler/common/log.py +130 -0
  494. mindspore/profiler/common/msprof_cmd_tool.py +221 -0
  495. mindspore/profiler/common/path_manager.py +395 -0
  496. mindspore/profiler/common/process_bar.py +168 -0
  497. mindspore/profiler/common/process_pool.py +9 -3
  498. mindspore/profiler/common/profiler_context.py +500 -0
  499. mindspore/profiler/common/profiler_info.py +304 -0
  500. mindspore/profiler/common/profiler_meta_data.py +74 -0
  501. mindspore/profiler/common/profiler_output_path.py +284 -0
  502. mindspore/profiler/common/profiler_parameters.py +251 -0
  503. mindspore/profiler/common/profiler_path_manager.py +179 -0
  504. mindspore/profiler/common/record_function.py +76 -0
  505. mindspore/profiler/common/tlv_decoder.py +76 -0
  506. mindspore/profiler/common/util.py +75 -2
  507. mindspore/profiler/dynamic_profiler.py +341 -75
  508. mindspore/profiler/envprofiler.py +163 -0
  509. mindspore/profiler/experimental_config.py +197 -0
  510. mindspore/profiler/mstx.py +242 -0
  511. mindspore/profiler/platform/__init__.py +21 -0
  512. mindspore/profiler/platform/base_profiler.py +40 -0
  513. mindspore/profiler/platform/cpu_profiler.py +124 -0
  514. mindspore/profiler/platform/gpu_profiler.py +74 -0
  515. mindspore/profiler/platform/npu_profiler.py +335 -0
  516. mindspore/profiler/profiler.py +1073 -90
  517. mindspore/profiler/profiler_action_controller.py +187 -0
  518. mindspore/profiler/profiler_interface.py +118 -0
  519. mindspore/profiler/schedule.py +243 -0
  520. mindspore/rewrite/api/node.py +15 -13
  521. mindspore/rewrite/api/symbol_tree.py +2 -3
  522. mindspore/run_check/_check_version.py +27 -20
  523. mindspore/run_check/run_check.py +1 -1
  524. mindspore/runtime/__init__.py +37 -0
  525. mindspore/runtime/device.py +27 -0
  526. mindspore/runtime/event.py +209 -0
  527. mindspore/runtime/executor.py +177 -0
  528. mindspore/runtime/memory.py +409 -0
  529. mindspore/runtime/stream.py +460 -0
  530. mindspore/runtime/thread_bind_core.py +401 -0
  531. mindspore/safeguard/rewrite_obfuscation.py +12 -9
  532. mindspore/swresample-4.dll +0 -0
  533. mindspore/swscale-6.dll +0 -0
  534. mindspore/tbbmalloc.dll +0 -0
  535. mindspore/tinyxml2.dll +0 -0
  536. mindspore/train/__init__.py +8 -8
  537. mindspore/train/_utils.py +88 -25
  538. mindspore/train/amp.py +9 -5
  539. mindspore/train/callback/__init__.py +2 -2
  540. mindspore/train/callback/_callback.py +2 -16
  541. mindspore/train/callback/_checkpoint.py +53 -55
  542. mindspore/train/callback/_cluster_monitor.py +14 -18
  543. mindspore/train/callback/_early_stop.py +1 -1
  544. mindspore/train/callback/_flops_collector.py +103 -68
  545. mindspore/train/callback/_history.py +8 -5
  546. mindspore/train/callback/_lambda_callback.py +2 -2
  547. mindspore/train/callback/_landscape.py +0 -3
  548. mindspore/train/callback/_loss_monitor.py +2 -1
  549. mindspore/train/callback/_on_request_exit.py +6 -5
  550. mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
  551. mindspore/train/callback/_summary_collector.py +52 -19
  552. mindspore/train/callback/_time_monitor.py +2 -1
  553. mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +204 -107
  554. mindspore/train/data_sink.py +25 -2
  555. mindspore/train/dataset_helper.py +15 -16
  556. mindspore/train/loss_scale_manager.py +8 -7
  557. mindspore/train/metrics/accuracy.py +3 -3
  558. mindspore/train/metrics/confusion_matrix.py +9 -9
  559. mindspore/train/metrics/error.py +3 -3
  560. mindspore/train/metrics/hausdorff_distance.py +4 -4
  561. mindspore/train/metrics/mean_surface_distance.py +3 -3
  562. mindspore/train/metrics/metric.py +0 -12
  563. mindspore/train/metrics/occlusion_sensitivity.py +4 -2
  564. mindspore/train/metrics/precision.py +11 -10
  565. mindspore/train/metrics/recall.py +9 -9
  566. mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
  567. mindspore/train/mind_ir_pb2.py +174 -46
  568. mindspore/train/model.py +184 -113
  569. mindspore/train/serialization.py +622 -978
  570. mindspore/train/summary/_summary_adapter.py +2 -2
  571. mindspore/train/summary/summary_record.py +2 -3
  572. mindspore/train/train_thor/model_thor.py +1 -1
  573. mindspore/turbojpeg.dll +0 -0
  574. mindspore/utils/__init__.py +6 -3
  575. mindspore/utils/dryrun.py +140 -0
  576. mindspore/utils/hooks.py +81 -0
  577. mindspore/utils/runtime_execution_order_check.py +550 -0
  578. mindspore/utils/utils.py +138 -4
  579. mindspore/vcmeta.dll +0 -0
  580. mindspore/vcruntime140.dll +0 -0
  581. mindspore/vcruntime140_1.dll +0 -0
  582. mindspore/version.py +1 -1
  583. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/METADATA +3 -3
  584. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/RECORD +587 -418
  585. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/entry_points.txt +1 -1
  586. mindspore/_install_custom.py +0 -43
  587. mindspore/common/_register_for_adapter.py +0 -74
  588. mindspore/common/_tensor_overload.py +0 -139
  589. mindspore/mindspore_np_dtype.dll +0 -0
  590. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
  591. mindspore/ops/auto_generate/gen_arg_handler.py +0 -197
  592. mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
  593. mindspore/ops_generate/gen_aclnn_implement.py +0 -263
  594. mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
  595. mindspore/ops_generate/gen_pyboost_func.py +0 -1052
  596. mindspore/ops_generate/gen_utils.py +0 -209
  597. mindspore/ops_generate/op_proto.py +0 -145
  598. mindspore/ops_generate/template.py +0 -261
  599. mindspore/profiler/envprofiling.py +0 -254
  600. mindspore/profiler/profiling.py +0 -1926
  601. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/WHEEL +0 -0
  602. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/top_level.txt +0 -0
@@ -27,13 +27,30 @@ from mindspore.mint.nn.layer.normalization import BatchNorm1d
27
27
  from mindspore.mint.nn.layer.normalization import BatchNorm2d
28
28
  from mindspore.mint.nn.layer.normalization import BatchNorm3d
29
29
  from mindspore.mint.nn.layer.normalization import LayerNorm
30
+ from mindspore.mint.nn.layer.normalization import SyncBatchNorm
30
31
  from mindspore.mint.nn.layer.activation import LogSigmoid
31
32
  from mindspore.mint.nn.layer.activation import SiLU
33
+ from mindspore.mint.nn.layer.activation import Threshold
34
+ from mindspore.mint.nn.layer.basic import Dropout2d
35
+ from mindspore.mint.nn.layer.pooling import AdaptiveMaxPool1d
32
36
  from mindspore.mint.nn.layer.pooling import AdaptiveAvgPool1d
33
37
  from mindspore.mint.nn.layer.pooling import AdaptiveAvgPool2d
38
+ from mindspore.mint.nn.layer.pooling import AdaptiveAvgPool3d
34
39
 
35
40
 
36
- __all__ = []
37
- __all__.extend(normalization.__all__)
38
- __all__.extend(activation.__all__)
39
- __all__.extend(pooling.__all__)
41
+ __all__ = [
42
+ 'GroupNorm',
43
+ 'BatchNorm1d',
44
+ 'BatchNorm2d',
45
+ 'BatchNorm3d',
46
+ 'LayerNorm',
47
+ 'LogSigmoid',
48
+ 'SiLU',
49
+ 'Dropout2d',
50
+ 'AdaptiveMaxPool1d',
51
+ 'AdaptiveAvgPool1d',
52
+ 'AdaptiveAvgPool2d',
53
+ 'AdaptiveAvgPool3d',
54
+ 'SyncBatchNorm',
55
+ 'Threshold',
56
+ ]
@@ -0,0 +1,334 @@
1
+ import mindspore
2
+ from mindspore import Tensor
3
+ from mindspore import context
4
+ import mindspore.communication
5
+ import mindspore.communication.comm_func
6
+ from mindspore.nn.cell import Cell
7
+ from mindspore.ops.auto_generate.gen_ops_prim import BatchNormReduceGrad
8
+ from mindspore.ops.auto_generate.gen_ops_prim import BatchNormElemtGrad
9
+ from mindspore.communication import GlobalComm
10
+ from mindspore.ops import ReduceOp
11
+ from mindspore._c_expression import TensorPy as Tensor_
12
+ from mindspore.communication._comm_helper import _get_size_helper, HCCL_WORLD_COMM_GROUP
13
+ from mindspore.ops._primitive_cache import _get_cache_prim
14
+ from mindspore.communication.comm_func import all_gather_into_tensor as all_gather_into_tensor_dy
15
+ from mindspore.ops import operations as P
16
+ from mindspore import ops, mint
17
+
18
+
19
+ DEFAULT_WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP
20
+
21
+ batch_norm_reduce_grad = BatchNormReduceGrad()
22
+ batch_norm_elemt_grad = BatchNormElemtGrad()
23
+ shape = P.Shape()
24
+
25
+
26
+ def _deal_comm_outputs(output, async_op):
27
+ if isinstance(output, tuple):
28
+ if not async_op:
29
+ output[1].wait()
30
+ return output[0]
31
+ return output
32
+
33
+ if not async_op:
34
+ return output
35
+ return output
36
+
37
+
38
+ def get_group_size(group=GlobalComm.WORLD_COMM_GROUP):
39
+ if not isinstance(group, str):
40
+ raise TypeError("For 'get_group_size', the argument 'group' must be type of string, "
41
+ "but got 'group' type : {}.".format(type(group)))
42
+ return _get_size_helper(group=_get_group(group))
43
+
44
+
45
+ def _contiguous(tensor):
46
+ if not tensor.is_contiguous() or tensor.storage_offset() != 0:
47
+ tensor = tensor.contiguous()
48
+ return tensor
49
+
50
+
51
+ def _get_group(group):
52
+ """Return the world communication group if the `group` is `DEFAULT_WORLD_COMM_GROUP`."""
53
+ if group == DEFAULT_WORLD_COMM_GROUP:
54
+ return GlobalComm.WORLD_COMM_GROUP
55
+ return group
56
+
57
+
58
+ def all_gather_into_tensor(tensor, group=GlobalComm.WORLD_COMM_GROUP, async_op=False):
59
+ if not isinstance(tensor, (Tensor, Tensor_)):
60
+ raise TypeError(
61
+ "For all_gather_into_tensor, the input tensor must be tensor")
62
+ group = _get_group(group)
63
+ tensor = _contiguous(tensor)
64
+ all_gather_op = _get_cache_prim(P.AllGather)(group=group)
65
+ output = all_gather_op(tensor)
66
+ return _deal_comm_outputs(output, async_op)
67
+
68
+
69
+ def all_reduce(tensor, op=ReduceOp.SUM, group=GlobalComm.WORLD_COMM_GROUP, async_op=False):
70
+ if not isinstance(tensor, (Tensor, Tensor_)):
71
+ raise TypeError("For all_reduce, the input tensor must be tensor")
72
+ if not isinstance(op, str):
73
+ raise TypeError("For all_reduce, the input op type must be str")
74
+ if op not in ('sum', 'prod', 'min', 'max'):
75
+ raise TypeError(
76
+ "For all_reduce, the input op value must be one of sum, prod, min, max")
77
+ group = _get_group(group)
78
+ tensor = _contiguous(tensor)
79
+ all_reduce_op = _get_cache_prim(P.AllReduce)(op=op, group=group)
80
+ output = all_reduce_op(tensor)
81
+ return _deal_comm_outputs(output, async_op)
82
+
83
+
84
+ def bprop_pynative(input_x, weight, bias, running_mean, running_var, eps, momentum,
85
+ process_group, world_size, output, doutput):
86
+ _, mean_param, invstd_param, count_all_param = output
87
+ dout, _, _, _ = doutput
88
+
89
+ # 不支持 KBK模式
90
+ if not dout.is_contiguous():
91
+ dout = dout.contiguous()
92
+
93
+ grad_input = grad_weight = grad_bias = None
94
+
95
+ inputG = True
96
+ weightG = True
97
+ biasG = True
98
+
99
+ # calculate local stats as well as grad_weight / grad_bias
100
+ sum_dy, sum_dy_xmu, grad_weight, grad_bias = batch_norm_reduce_grad(
101
+ dout,
102
+ input_x,
103
+ mean_param,
104
+ invstd_param,
105
+ weight,
106
+ inputG,
107
+ weightG,
108
+ biasG
109
+ )
110
+
111
+ if inputG:
112
+ # synchronizing stats used to calculate input gradient.
113
+ sum_dy_shape = shape(sum_dy)
114
+ num_channels = sum_dy_shape[0]
115
+ combined = mint.cat([sum_dy, sum_dy_xmu], dim=0)
116
+
117
+ new_combined, _ = mindspore.communication.comm_func.all_reduce(
118
+ combined, group=process_group)
119
+
120
+ sum_dy, sum_dy_xmu = mint.split(new_combined, num_channels)
121
+
122
+ # backward pass for gradient calculation
123
+ grad_input = batch_norm_elemt_grad(
124
+ dout,
125
+ input_x,
126
+ mean_param,
127
+ invstd_param,
128
+ weight,
129
+ sum_dy,
130
+ sum_dy_xmu,
131
+ count_all_param
132
+ )
133
+
134
+ # synchronizing of grad_weight / grad_bias is not needed as distributed
135
+ # training would handle all reduce.
136
+ if weight is None or not weightG:
137
+ grad_weight = None
138
+
139
+ if weight is None or not biasG:
140
+ grad_bias = None
141
+
142
+ return grad_input, grad_weight, grad_bias, None, None, None, None, None, None
143
+
144
+
145
+ def bprop_kbk(input_x, weight, bias, running_mean, running_var, eps, momentum,
146
+ process_group, world_size, output, doutput):
147
+ _, mean_param, invstd_param, count_all_param = output
148
+ dout, _, _, _ = doutput
149
+
150
+ dout = dout.contiguous()
151
+
152
+ grad_input = grad_weight = grad_bias = None
153
+
154
+ inputG = True
155
+ weightG = True
156
+ biasG = True
157
+
158
+ # calculate local stats as well as grad_weight / grad_bias
159
+ sum_dy, sum_dy_xmu, grad_weight, grad_bias = batch_norm_reduce_grad(
160
+ dout,
161
+ input_x,
162
+ mean_param,
163
+ invstd_param,
164
+ weight,
165
+ inputG,
166
+ weightG,
167
+ biasG
168
+ )
169
+
170
+ if inputG:
171
+ # synchronizing stats used to calculate input gradient.
172
+ sum_dy_shape = shape(sum_dy)
173
+ num_channels = sum_dy_shape[0]
174
+ combined = mint.cat([sum_dy, sum_dy_xmu], dim=0)
175
+
176
+ new_combined = all_reduce(combined, group=process_group)
177
+
178
+ sum_dy, sum_dy_xmu = mint.split(new_combined, num_channels)
179
+
180
+ # backward pass for gradient calculation
181
+ grad_input = batch_norm_elemt_grad(
182
+ dout,
183
+ input_x,
184
+ mean_param,
185
+ invstd_param,
186
+ weight,
187
+ sum_dy,
188
+ sum_dy_xmu,
189
+ count_all_param
190
+ )
191
+
192
+ # synchronizing of grad_weight / grad_bias is not needed as distributed
193
+ # training would handle all reduce.
194
+ if weight is None or not weightG:
195
+ grad_weight = None
196
+
197
+ if weight is None or not biasG:
198
+ grad_bias = None
199
+
200
+ return grad_input, grad_weight, grad_bias, None, None, None, None, None, None
201
+
202
+
203
+ def construct_pynative(input, weight, bias, running_mean, running_var, eps, momentum, process_group,
204
+ world_size, self_num_features, self_world_size):
205
+ if self_world_size != world_size:
206
+ raise ValueError('World Size Error')
207
+ if not input.is_contiguous():
208
+ input = input.contiguous()
209
+ if weight is not None:
210
+ weight = weight.contiguous()
211
+
212
+ input_shape = shape(input)
213
+ input_numel = ops.numel(input)
214
+ size = int(input_numel // input_shape[1])
215
+ if size == 1 and world_size < 2:
216
+ raise ValueError(
217
+ 'Expected more than 1 value per channel when training, got input size {}'.format(size))
218
+
219
+ # calculate mean/invstd for input.
220
+ mean, invstd = mint.batch_norm_stats(input, eps)
221
+ count = mint.full((1,), input_numel //
222
+ input_shape[1], dtype=mean.dtype)
223
+
224
+ num_channels = input_shape[1]
225
+ if self_num_features != num_channels:
226
+ raise ValueError('Features Error')
227
+ # C, C, 1 -> (2C + 1)
228
+ combined = mint.cat([mean, invstd, count], dim=0)
229
+ # Use allgather instead of allreduce because count could be different across
230
+ # ranks, simple all reduce op can not give correct results.
231
+ # batch_norm_gather_stats_with_counts calculates global mean & invstd based on
232
+ # all gathered mean, invstd and count.
233
+ # world_size * (2C + 1)
234
+ combined, _ = all_gather_into_tensor_dy(combined, process_group)
235
+ combined = ops.reshape(combined, [world_size, -1])
236
+ # world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1
237
+ mean_val_all, invstd_val_all, count_val_all = mint.split(
238
+ combined, num_channels, dim=1)
239
+ # calculate global mean & invstd
240
+ mean, invstd = mint.batch_norm_gather_stats_with_counts(input, mean_val_all, invstd_val_all, running_mean,
241
+ running_var, momentum, eps, count_val_all.view(-1))
242
+
243
+ # apply element-wise normalization
244
+ out = mint.batch_norm_elemt(input, weight, bias, mean, invstd, eps)
245
+ return (out, mean, invstd, count_val_all.view(-1))
246
+
247
+
248
+ def construct_kbk(input, weight, bias, running_mean, running_var, eps, momentum, process_group,
249
+ world_size, self_num_features, self_world_size):
250
+ if self_world_size != world_size:
251
+ raise ValueError('World Size Error')
252
+ input = input.contiguous()
253
+ if weight is not None:
254
+ weight = weight.contiguous()
255
+
256
+ input_shape = shape(input)
257
+ input_numel = ops.numel(input)
258
+ size = int(input_numel // input_shape[1])
259
+ if size == 1 and world_size < 2:
260
+ raise ValueError(
261
+ 'Expected more than 1 value per channel when training, got input size {}'.format(size))
262
+
263
+ # calculate mean/invstd for input.
264
+ mean, invstd = mint.batch_norm_stats(input, eps)
265
+ count = mint.full((1,), input_numel //
266
+ input_shape[1], dtype=mean.dtype)
267
+
268
+ num_channels = input_shape[1]
269
+ if self_num_features != num_channels:
270
+ raise ValueError('Features Error')
271
+ # C, C, 1 -> (2C + 1)
272
+ combined = mint.cat([mean, invstd, count], dim=0)
273
+ # Use allgather instead of allreduce because count could be different across
274
+ # ranks, simple all reduce op can not give correct results.
275
+ # batch_norm_gather_stats_with_counts calculates global mean & invstd based on
276
+ # all gathered mean, invstd and count.
277
+ # world_size * (2C + 1)
278
+ combined = all_gather_into_tensor(combined, process_group)
279
+ combined = ops.reshape(combined, [world_size, -1])
280
+ # world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1
281
+ mean_all, invstd_all, count_all = mint.split(
282
+ combined, num_channels, dim=1)
283
+ # calculate global mean & invstd
284
+ mean, invstd = mint.batch_norm_gather_stats_with_counts(
285
+ input,
286
+ mean_all,
287
+ invstd_all,
288
+ running_mean,
289
+ running_var,
290
+ momentum,
291
+ eps,
292
+ count_all.view(-1)
293
+ )
294
+
295
+ # apply element-wise normalization
296
+ out = mint.batch_norm_elemt(input, weight, bias, mean, invstd, eps)
297
+ return (out, mean, invstd, count_all.view(-1))
298
+
299
+
300
+ class SyncBatchNormInner(Cell):
301
+ def __init__(self, self_num_features, self_world_size):
302
+ super(SyncBatchNormInner, self).__init__()
303
+ self.num_features = self_num_features
304
+ self.world_size = self_world_size
305
+ self.mode = context.get_context("mode")
306
+ if self.mode == 1:
307
+ self.fn_bprop = bprop_pynative
308
+ self.fn_construct = construct_pynative
309
+ else:
310
+ self.fn_bprop = bprop_kbk
311
+ self.fn_construct = construct_kbk
312
+
313
+ def construct(self, input, weight, bias, running_mean, running_var, eps, momentum, process_group, world_size):
314
+ return self.fn_construct(input, weight, bias, running_mean, running_var, eps, momentum, process_group,
315
+ world_size, self.num_features, self.world_size)
316
+
317
+ def bprop(self, input_x, weight, bias, running_mean, running_var, eps, momentum,
318
+ process_group, world_size, output, doutput):
319
+ return self.fn_bprop(input_x, weight, bias, running_mean, running_var, eps, momentum,
320
+ process_group, world_size, output, doutput)
321
+
322
+
323
+ class _SyncBatchNorm(Cell):
324
+ def __init__(self, num_features, world_size, dtype=mindspore.float32):
325
+ super(_SyncBatchNorm, self).__init__()
326
+ self.num_features = num_features
327
+ self.world_size = world_size
328
+ self.inner = SyncBatchNormInner(self.num_features, self.world_size)
329
+
330
+ def construct(self, input, weight, bias, running_mean, running_var, eps, momentum, process_group, world_size):
331
+ res = self.inner(input, weight, bias, running_mean,
332
+ running_var, eps, momentum, process_group, world_size)
333
+ output, _, _, _ = res
334
+ return output
@@ -77,6 +77,55 @@ class SiLU(Cell):
77
77
  return mint.nn.functional.silu(x)
78
78
 
79
79
 
80
+ class Sigmoid(Cell):
81
+ r"""
82
+ Applies sigmoid activation function element-wise.
83
+
84
+ Sigmoid function is defined as:
85
+
86
+ .. math::
87
+
88
+ \text{sigmoid}(x_i) = \frac{1}{1 + \exp(-x_i)},
89
+
90
+ where :math:`x_i` is the element of `x`.
91
+
92
+ Sigmoid Activation Function Graph:
93
+
94
+ .. image:: ../images/Sigmoid.png
95
+ :align: center
96
+
97
+ Inputs:
98
+ - **input** (Tensor) - `input` is :math:`x` in the preceding formula. Tensor of any dimension,
99
+ the data type is float16, float32, float64, complex64 or complex128.
100
+
101
+ Outputs:
102
+ Tensor, with the same type and shape as the `input`.
103
+
104
+ Raises:
105
+ TypeError: If dtype of `input` is not float16, float32, float64, complex64 or complex128.
106
+ TypeError: If `input` is not a Tensor.
107
+
108
+ Supported Platforms:
109
+ ``Ascend`` ``GPU`` ``CPU``
110
+
111
+ Examples:
112
+ >>> import mindspore
113
+ >>> from mindspore import Tensor, nn
114
+ >>> import numpy as np
115
+ >>> input = Tensor(np.array([-1, -2, 0, 2, 1]), mindspore.float16)
116
+ >>> sigmoid = mint.nn.Sigmoid()
117
+ >>> output = sigmoid(input)
118
+ >>> print(output)
119
+ [0.2688 0.11914 0.5 0.881 0.7305 ]
120
+ """
121
+ def __init__(self):
122
+ """Initialize LogSigmoid."""
123
+ super(Sigmoid, self).__init__()
124
+
125
+ def construct(self, input):
126
+ return mint.nn.functional.sigmoid(input)
127
+
128
+
80
129
  class LogSigmoid(Cell):
81
130
  r"""
82
131
  Applies logsigmoid activation element-wise. The input is a Tensor with any valid shape.
@@ -84,7 +133,7 @@ class LogSigmoid(Cell):
84
133
  Logsigmoid is defined as:
85
134
 
86
135
  .. math::
87
- \text{logsigmoid}(x_{i}) = \log(\frac{1}{1 + \exp(-x_i)}),
136
+ \text{LogSigmoid}(x_{i}) = \log(\frac{1}{1 + \exp(-x_i)}),
88
137
 
89
138
  where :math:`x_{i}` is the element of the input.
90
139
 
@@ -127,7 +176,233 @@ class LogSigmoid(Cell):
127
176
  return mint.nn.functional.logsigmoid(input)
128
177
 
129
178
 
179
+ class ELU(Cell):
180
+ r"""
181
+ Exponential Linear Unit activation function
182
+
183
+ Applies the exponential linear unit function element-wise.The activation function is defined as:
184
+
185
+ .. math::
186
+ ELU_{i} =
187
+ \begin{cases}
188
+ x_i, &\text{if } x_i \geq 0; \cr
189
+ \alpha * (\exp(x_i) - 1), &\text{otherwise.}
190
+ \end{cases}
191
+
192
+ where :math:`x_i` represents the element of the input and :math:`\alpha` represents the `alpha` parameter, and
193
+ `alpha` represents the smoothness of the ELU.
194
+
195
+ ELU Activation Function Graph:
196
+
197
+ .. image:: ../images/ELU.png
198
+ :align: center
199
+
200
+ .. warning::
201
+ This is an experimental API that is subject to change or deletion.
202
+
203
+ Args:
204
+ alpha (float, optional): The alpha value of ELU, the data type is float. Default: ``1.0``.
205
+ inplace (bool, optional): Whether to use inplace mode, the data type is bool. Default: ``False``.
206
+
207
+ Inputs:
208
+ - **input** (Tensor) - The input of ELU is a Tensor of any dimension.
209
+
210
+ Outputs:
211
+ Tensor, with the same shape and type as the `input`.
212
+
213
+ Raises:
214
+ RuntimeError: If the dtype of `input` is not float16, float32 or bfloat16.
215
+ TypeError: If the dtype of `alpha` is not float.
216
+
217
+ Supported Platforms:
218
+ ``Ascend``
219
+
220
+ Examples:
221
+ >>> import mindspore
222
+ >>> from mindspore import Tensor, mint
223
+ >>> import numpy as np
224
+ >>> input = Tensor(np.array([-1, -2, 0, 2, 1]), mindspore.float32)
225
+ >>> elu = mint.nn.ELU()
226
+ >>> result = elu(input)
227
+ >>> print(result)
228
+ [-0.63212055 -0.86466473 0. 2. 1.]
229
+ """
230
+
231
+ def __init__(self, alpha=1.0, inplace=False):
232
+ """Initialize ELU."""
233
+ super(ELU, self).__init__()
234
+ self.alpha = alpha
235
+ self.inplace = inplace
236
+
237
+ def construct(self, input):
238
+ return mint.nn.functional.elu(input, self.alpha, self.inplace)
239
+
240
+
241
+ class GLU(Cell):
242
+ r"""
243
+ Computes GLU (Gated Linear Unit activation function) of the input tensor.
244
+
245
+ .. math::
246
+ {GLU}(a, b)= a \otimes \sigma(b)
247
+
248
+ where :math:`a` is the first half of the `input` Tensor after `input` is split and :math:`b` is the second half.
249
+
250
+ Here :math:`\sigma` is the sigmoid function, and :math:`\otimes` is the Hadamard product.
251
+ See `Language Modeling with Gated Convluational Networks <https://arxiv.org/abs/1612.08083>`_ .
252
+
253
+ Args:
254
+ dim (int, optional): The dimension to split the input `input`. The value range is `[-r, r)` where `r`
255
+ is the number of dimensions of `input`. Default: ``-1`` , the last dimension in `input`.
256
+
257
+ Inputs:
258
+ - **input** (Tensor) - Tensor to be calculated. Dtype is floating point and the shape
259
+ is :math:`(\ast_1, N, \ast_2)` where `*` means, any number of additional dimensions. :math:`N`
260
+ is required to be an even number, where :math:`N` is the size of `input` on the dimension
261
+ selected by `dim`.
262
+
263
+ Outputs:
264
+ Tensor, the same dtype as the `input`, with the shape :math:`(\ast_1, M, \ast_2)` where :math:`M=N/2`.
265
+
266
+ Raises:
267
+ TypeError: If `input` is not a Tensor or `dim` is not an int.
268
+ IndexError: If the value of `dim` is out of the range of `[-r, r)`, where `r` is the number
269
+ of dimensions of `input`.
270
+ RuntimeError: If dtype of `input` is not supported.
271
+ RuntimeError: If the length of `input` in the dimension selected by `dim` is not even.
272
+
273
+ Supported Platforms:
274
+ ``Ascend`` ``CPU``
275
+
276
+ Examples:
277
+ >>> from mindspore import mint, Tensor
278
+ >>> glu = mint.nn.GLU()
279
+ >>> input = Tensor([[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]])
280
+ >>> output = glu(input)
281
+ >>> print(output)
282
+ [[0.05744425 0.11973753]
283
+ [0.33409387 0.41398472]]
284
+ """
285
+
286
+ def __init__(self, dim=-1):
287
+ """Initialize GLU."""
288
+ super().__init__("GLU")
289
+ self.dim = dim
290
+
291
+ def construct(self, input):
292
+ return mint.nn.functional.glu(input, self.dim)
293
+
294
+
295
+ class Tanh(Cell):
296
+ r"""
297
+ Applies the Tanh function element-wise, returns a new tensor with the hyperbolic tangent of the elements of input.
298
+
299
+ Tanh function is defined as:
300
+
301
+ .. math::
302
+ tanh(x_i) = \frac{\exp(x_i) - \exp(-x_i)}{\exp(x_i) + \exp(-x_i)} = \frac{\exp(2x_i) - 1}{\exp(2x_i) + 1},
303
+
304
+ where :math:`x_i` is an element of the input Tensor.
305
+
306
+ Tanh Activation Function Graph:
307
+
308
+ .. image:: ../images/Tanh.png
309
+ :align: center
310
+
311
+ .. warning::
312
+ This is an experimental API that is subject to change or deletion.
313
+
314
+ Inputs:
315
+ - **input** (Tensor) - Tensor of any dimension, input with data type of float16 or float32.
316
+
317
+ Outputs:
318
+ Tensor, with the same type and shape as the `input`.
319
+
320
+ Raises:
321
+ TypeError: If dtype of `input` is neither float16 nor float32.
322
+
323
+ Supported Platforms:
324
+ ``Ascend``
325
+
326
+ Examples:
327
+ >>> import mindspore
328
+ >>> from mindspore import Tensor, mint
329
+ >>> import numpy as np
330
+ >>> input = Tensor(np.array([1, 2, 3, 2, 1]), mindspore.float16)
331
+ >>> tanh = mint.nn.Tanh()
332
+ >>> output = tanh(input)
333
+ >>> print(output)
334
+ [0.7617 0.964 0.995 0.964 0.7617]
335
+ """
336
+
337
+ def __init__(self):
338
+ """Initialize Tanh."""
339
+ super(Tanh, self).__init__()
340
+
341
+ def construct(self, input):
342
+ return mint.nn.functional.tanh(input)
343
+
344
+
345
+ class Threshold(Cell):
346
+ r"""
347
+ Compute the Threshold activation function element-wise.
348
+
349
+ The Threshold is defined as:
350
+
351
+ .. math::
352
+ y =
353
+ \begin{cases}
354
+ x, &\text{ if } x > \text{threshold} \\
355
+ \text{value}, &\text{ otherwise }
356
+ \end{cases}
357
+
358
+ .. warning::
359
+ This is an experimental API that is subject to change or deletion.
360
+
361
+ Args:
362
+ threshold (Union[int, float]): The value of the threshold.
363
+ value (Union[int, float]): The value to replace with when element is less than threshold.
364
+ inplace (bool, optional): Whether to apply erasing inplace. Default: ``False``.
365
+
366
+ Inputs:
367
+ - **input** (Tensor) - The input Tensor.
368
+
369
+ Outputs:
370
+ Tensor, the same shape and data type as the input.
371
+
372
+ Raises:
373
+ TypeError: If `input` is not a Tensor.
374
+ TypeError: If `threshold` is not a float or an int.
375
+ TypeError: If `value` is not a float or an int.
376
+
377
+ Supported Platforms:
378
+ ``Ascend``
379
+
380
+ Examples:
381
+ >>> import mindspore
382
+ >>> from mindspore import Tensor, mint
383
+ >>> inputs = mindspore.Tensor([0.0, 2, 3], mindspore.float32)
384
+ >>> net = mint.nn.Threshold(1, 100)
385
+ >>> outputs = net(inputs)
386
+ >>> print(outputs)
387
+ [100. 2. 3.]
388
+ """
389
+
390
+ def __init__(self, threshold, value, inplace=False):
391
+ """Initialize Tanh."""
392
+ super(Threshold, self).__init__()
393
+ self.threshold = threshold
394
+ self.value = value
395
+ self.inplace = inplace
396
+
397
+ def construct(self, input):
398
+ return mint.nn.functional.threshold(input, self.threshold, self.value,
399
+ self.inplace)
400
+
130
401
  __all__ = [
131
402
  'LogSigmoid',
132
403
  'SiLU',
404
+ 'ELU',
405
+ 'GLU',
406
+ 'Tanh',
407
+ 'Threshold',
133
408
  ]