mindspore 2.6.0__cp311-cp311-win_amd64.whl → 2.7.0__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (455) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  3. mindspore/Newtonsoft.Json.dll +0 -0
  4. mindspore/__init__.py +2 -2
  5. mindspore/_c_dataengine.cp311-win_amd64.pyd +0 -0
  6. mindspore/_c_expression.cp311-win_amd64.pyd +0 -0
  7. mindspore/_c_mindrecord.cp311-win_amd64.pyd +0 -0
  8. mindspore/_checkparam.py +42 -11
  9. mindspore/_extends/builtin_operations.py +3 -3
  10. mindspore/{_deprecated → _extends/optimize}/__init__.py +9 -3
  11. mindspore/_extends/optimize/cell_utils.py +96 -0
  12. mindspore/_extends/parallel_compile/akg_compiler/custom.py +1109 -0
  13. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
  14. mindspore/_extends/parse/__init__.py +3 -3
  15. mindspore/_extends/parse/compile_config.py +44 -22
  16. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +1 -2
  17. mindspore/_extends/parse/parser.py +64 -83
  18. mindspore/_extends/parse/resources.py +39 -0
  19. mindspore/_extends/parse/standard_method.py +47 -14
  20. mindspore/_extends/parse/trope.py +8 -1
  21. mindspore/_extends/pijit/__init__.py +1 -2
  22. mindspore/_extends/pijit/pijit_func_white_list.py +2 -5
  23. mindspore/amp.py +4 -22
  24. mindspore/atlprov.dll +0 -0
  25. mindspore/avcodec-59.dll +0 -0
  26. mindspore/avdevice-59.dll +0 -0
  27. mindspore/avfilter-8.dll +0 -0
  28. mindspore/avformat-59.dll +0 -0
  29. mindspore/avutil-57.dll +0 -0
  30. mindspore/boost/adasum.py +1 -1
  31. mindspore/boost/boost_cell_wrapper.py +4 -4
  32. mindspore/c1.dll +0 -0
  33. mindspore/c1xx.dll +0 -0
  34. mindspore/c2.dll +0 -0
  35. mindspore/common/__init__.py +43 -12
  36. mindspore/common/_grad_function.py +2 -1
  37. mindspore/common/_pijit_context.py +28 -7
  38. mindspore/common/_stub_tensor.py +1 -209
  39. mindspore/common/_tensor_cpp_method.py +1 -1
  40. mindspore/common/_tensor_docs.py +177 -52
  41. mindspore/common/_utils.py +9 -1
  42. mindspore/common/api.py +338 -208
  43. mindspore/common/dtype.py +108 -57
  44. mindspore/common/dump.py +11 -16
  45. mindspore/common/dynamic_shape/__init__.py +0 -0
  46. mindspore/common/{auto_dynamic_shape.py → dynamic_shape/auto_dynamic_shape.py} +17 -23
  47. mindspore/common/dynamic_shape/enable_dynamic.py +197 -0
  48. mindspore/common/file_system.py +59 -9
  49. mindspore/common/generator.py +2 -3
  50. mindspore/common/hook_handle.py +33 -5
  51. mindspore/common/jit_config.py +1 -1
  52. mindspore/common/jit_trace.py +84 -105
  53. mindspore/common/np_dtype.py +3 -3
  54. mindspore/common/parameter.py +27 -29
  55. mindspore/common/recompute.py +5 -7
  56. mindspore/common/sparse_tensor.py +0 -3
  57. mindspore/common/symbol.py +0 -1
  58. mindspore/common/tensor.py +84 -133
  59. mindspore/communication/_comm_helper.py +46 -4
  60. mindspore/communication/management.py +79 -7
  61. mindspore/context.py +47 -38
  62. mindspore/dataset/__init__.py +1 -1
  63. mindspore/dataset/audio/transforms.py +1 -1
  64. mindspore/dataset/core/config.py +38 -4
  65. mindspore/dataset/engine/datasets.py +350 -322
  66. mindspore/dataset/engine/datasets_user_defined.py +69 -23
  67. mindspore/dataset/engine/iterators.py +2 -2
  68. mindspore/dataset/engine/obs/config_loader.py +2 -2
  69. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +8 -0
  70. mindspore/dataset/transforms/c_transforms.py +2 -2
  71. mindspore/dataset/transforms/py_transforms.py +7 -3
  72. mindspore/dataset/transforms/transforms.py +10 -6
  73. mindspore/dataset/vision/__init__.py +1 -1
  74. mindspore/dataset/vision/py_transforms.py +8 -8
  75. mindspore/dataset/vision/transforms.py +17 -5
  76. mindspore/dataset/vision/utils.py +632 -21
  77. mindspore/dataset/vision/validators.py +1 -0
  78. mindspore/device_context/ascend/device.py +1 -1
  79. mindspore/device_context/ascend/op_tuning.py +35 -1
  80. mindspore/device_context/gpu/__init__.py +2 -2
  81. mindspore/device_context/gpu/device.py +1 -1
  82. mindspore/device_context/gpu/op_precision.py +4 -2
  83. mindspore/device_context/gpu/op_tuning.py +6 -3
  84. mindspore/device_manager.py +16 -9
  85. mindspore/dnnl.dll +0 -0
  86. mindspore/dpcmi.dll +0 -0
  87. mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +5 -4
  88. mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
  89. mindspore/experimental/optim/adadelta.py +13 -20
  90. mindspore/experimental/optim/adagrad.py +15 -22
  91. mindspore/experimental/optim/adam.py +17 -24
  92. mindspore/experimental/optim/adamax.py +14 -22
  93. mindspore/experimental/optim/adamw.py +28 -34
  94. mindspore/experimental/optim/asgd.py +15 -25
  95. mindspore/experimental/optim/lr_scheduler.py +27 -45
  96. mindspore/experimental/optim/nadam.py +14 -24
  97. mindspore/experimental/optim/optimizer.py +13 -23
  98. mindspore/experimental/optim/radam.py +18 -24
  99. mindspore/experimental/optim/rmsprop.py +14 -25
  100. mindspore/experimental/optim/rprop.py +15 -26
  101. mindspore/experimental/optim/sgd.py +9 -19
  102. mindspore/hal/__init__.py +4 -4
  103. mindspore/hal/contiguous_tensors_handle.py +2 -2
  104. mindspore/hal/memory.py +1 -0
  105. mindspore/include/api/cell.h +65 -5
  106. mindspore/include/api/cfg.h +24 -7
  107. mindspore/include/api/context.h +1 -0
  108. mindspore/include/api/delegate.h +10 -2
  109. mindspore/include/api/dual_abi_helper.h +100 -19
  110. mindspore/include/api/graph.h +14 -1
  111. mindspore/include/api/kernel.h +16 -3
  112. mindspore/include/api/kernel_api.h +9 -1
  113. mindspore/include/api/metrics/accuracy.h +9 -0
  114. mindspore/include/api/model.h +8 -1
  115. mindspore/include/api/model_group.h +4 -0
  116. mindspore/include/api/model_parallel_runner.h +2 -0
  117. mindspore/include/api/status.h +48 -10
  118. mindspore/include/api/types.h +8 -3
  119. mindspore/include/c_api/model_c.h +0 -58
  120. mindspore/include/c_api/tensor_c.h +0 -26
  121. mindspore/include/dataset/constants.h +9 -0
  122. mindspore/include/dataset/vision_ascend.h +1 -1
  123. mindspore/jpeg62.dll +0 -0
  124. mindspore/mindrecord/tools/cifar10.py +61 -11
  125. mindspore/mindrecord/tools/cifar10_to_mr.py +5 -0
  126. mindspore/mindspore_backend_common.dll +0 -0
  127. mindspore/mindspore_backend_manager.dll +0 -0
  128. mindspore/mindspore_common.dll +0 -0
  129. mindspore/mindspore_core.dll +0 -0
  130. mindspore/mindspore_cpu_res_manager.dll +0 -0
  131. mindspore/mindspore_dump.dll +0 -0
  132. mindspore/mindspore_frontend.dll +0 -0
  133. mindspore/mindspore_glog.dll +0 -0
  134. mindspore/mindspore_memory_pool.dll +0 -0
  135. mindspore/mindspore_ms_backend.dll +0 -0
  136. mindspore/mindspore_ops.dll +0 -0
  137. mindspore/mindspore_ops_host.dll +0 -0
  138. mindspore/mindspore_ops_kernel_common.dll +0 -0
  139. mindspore/mindspore_profiler.dll +0 -0
  140. mindspore/mindspore_pyboost.dll +0 -0
  141. mindspore/mindspore_pynative.dll +0 -0
  142. mindspore/mindspore_res_manager.dll +0 -0
  143. mindspore/mindspore_runtime_pipeline.dll +0 -0
  144. mindspore/mint/__init__.py +4 -44
  145. mindspore/mint/distributed/__init__.py +5 -0
  146. mindspore/mint/distributed/distributed.py +425 -19
  147. mindspore/mint/nn/__init__.py +1 -1
  148. mindspore/mint/nn/functional.py +53 -6
  149. mindspore/mint/nn/layer/_functions.py +163 -294
  150. mindspore/mint/nn/layer/activation.py +8 -6
  151. mindspore/mint/nn/layer/conv.py +125 -101
  152. mindspore/mint/nn/layer/normalization.py +11 -25
  153. mindspore/mint/optim/adam.py +19 -18
  154. mindspore/mint/optim/adamw.py +14 -8
  155. mindspore/mint/optim/sgd.py +5 -5
  156. mindspore/msobj140.dll +0 -0
  157. mindspore/mspdb140.dll +0 -0
  158. mindspore/mspdbcore.dll +0 -0
  159. mindspore/mspdbst.dll +0 -0
  160. mindspore/mspft140.dll +0 -0
  161. mindspore/msvcdis140.dll +0 -0
  162. mindspore/msvcp140_1.dll +0 -0
  163. mindspore/msvcp140_2.dll +0 -0
  164. mindspore/msvcp140_atomic_wait.dll +0 -0
  165. mindspore/msvcp140_codecvt_ids.dll +0 -0
  166. mindspore/nn/cell.py +488 -620
  167. mindspore/nn/grad/cell_grad.py +11 -12
  168. mindspore/nn/layer/activation.py +36 -36
  169. mindspore/nn/layer/basic.py +74 -77
  170. mindspore/nn/layer/channel_shuffle.py +4 -4
  171. mindspore/nn/layer/combined.py +4 -2
  172. mindspore/nn/layer/conv.py +86 -85
  173. mindspore/nn/layer/dense.py +9 -7
  174. mindspore/nn/layer/embedding.py +50 -52
  175. mindspore/nn/layer/image.py +38 -40
  176. mindspore/nn/layer/math.py +111 -112
  177. mindspore/nn/layer/normalization.py +56 -44
  178. mindspore/nn/layer/pooling.py +58 -63
  179. mindspore/nn/layer/rnn_cells.py +33 -33
  180. mindspore/nn/layer/rnns.py +56 -56
  181. mindspore/nn/layer/thor_layer.py +74 -73
  182. mindspore/nn/layer/transformer.py +11 -1
  183. mindspore/nn/learning_rate_schedule.py +20 -20
  184. mindspore/nn/loss/loss.py +79 -81
  185. mindspore/nn/optim/adam.py +2 -4
  186. mindspore/nn/optim/adasum.py +2 -2
  187. mindspore/nn/optim/lamb.py +1 -3
  188. mindspore/nn/optim/optimizer.py +1 -1
  189. mindspore/nn/optim/tft_wrapper.py +2 -3
  190. mindspore/nn/optim/thor.py +2 -2
  191. mindspore/nn/probability/distribution/_utils/utils.py +2 -2
  192. mindspore/nn/probability/distribution/exponential.py +2 -1
  193. mindspore/nn/probability/distribution/poisson.py +2 -1
  194. mindspore/nn/sparse/sparse.py +3 -3
  195. mindspore/nn/wrap/cell_wrapper.py +73 -42
  196. mindspore/nn/wrap/grad_reducer.py +37 -52
  197. mindspore/nn/wrap/loss_scale.py +72 -74
  198. mindspore/numpy/array_creations.py +7 -7
  199. mindspore/numpy/fft.py +1 -1
  200. mindspore/numpy/math_ops.py +1 -1
  201. mindspore/numpy/utils_const.py +1 -1
  202. mindspore/opencv_core452.dll +0 -0
  203. mindspore/opencv_imgcodecs452.dll +0 -0
  204. mindspore/opencv_imgproc452.dll +0 -0
  205. mindspore/ops/_grad_experimental/grad_comm_ops.py +51 -13
  206. mindspore/ops/_grad_experimental/grad_debug_ops.py +14 -0
  207. mindspore/ops/_grad_experimental/grad_inner_ops.py +0 -9
  208. mindspore/ops/_op_impl/cpu/__init__.py +1 -0
  209. mindspore/{experimental/es/__init__.py → ops/_op_impl/cpu/joinedstr_op.py} +12 -6
  210. mindspore/ops/_vmap/vmap_array_ops.py +6 -13
  211. mindspore/ops/_vmap/vmap_nn_ops.py +8 -16
  212. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +29 -10
  213. mindspore/ops/auto_generate/gen_extend_func.py +5 -55
  214. mindspore/ops/auto_generate/gen_ops_def.py +753 -273
  215. mindspore/ops/auto_generate/gen_ops_prim.py +1687 -958
  216. mindspore/ops/auto_generate/pyboost_inner_prim.py +31 -1
  217. mindspore/ops/composite/__init__.py +10 -0
  218. mindspore/ops/composite/base.py +9 -5
  219. mindspore/ops/composite/multitype_ops/__init__.py +12 -1
  220. mindspore/ops/composite/multitype_ops/_compile_utils.py +132 -108
  221. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -1
  222. mindspore/ops/composite/multitype_ops/add_impl.py +70 -2
  223. mindspore/ops/composite/multitype_ops/div_impl.py +49 -0
  224. mindspore/ops/composite/multitype_ops/floordiv_impl.py +29 -0
  225. mindspore/ops/composite/multitype_ops/getitem_impl.py +11 -0
  226. mindspore/ops/composite/multitype_ops/mod_impl.py +5 -3
  227. mindspore/ops/composite/multitype_ops/mul_impl.py +49 -0
  228. mindspore/ops/composite/multitype_ops/setitem_impl.py +57 -0
  229. mindspore/ops/composite/multitype_ops/sub_impl.py +34 -0
  230. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +14 -0
  231. mindspore/ops/function/__init__.py +4 -1
  232. mindspore/ops/function/_add_attr_func.py +11 -6
  233. mindspore/ops/function/array_func.py +17 -100
  234. mindspore/ops/function/debug_func.py +8 -5
  235. mindspore/ops/function/grad/grad_func.py +5 -13
  236. mindspore/ops/function/math_func.py +65 -399
  237. mindspore/ops/function/nn_func.py +44 -61
  238. mindspore/ops/function/other_func.py +4 -1
  239. mindspore/ops/function/random_func.py +31 -4
  240. mindspore/ops/functional.py +2 -3
  241. mindspore/ops/functional_overload.py +486 -18
  242. mindspore/ops/op_info_register.py +21 -0
  243. mindspore/ops/operations/__init__.py +5 -2
  244. mindspore/ops/operations/_custom_ops_utils.py +675 -8
  245. mindspore/ops/operations/_inner_ops.py +14 -18
  246. mindspore/ops/operations/_sequence_ops.py +1 -1
  247. mindspore/ops/operations/array_ops.py +4 -50
  248. mindspore/ops/operations/comm_ops.py +186 -41
  249. mindspore/ops/operations/custom_ops.py +244 -175
  250. mindspore/ops/operations/debug_ops.py +55 -4
  251. mindspore/ops/operations/image_ops.py +13 -13
  252. mindspore/ops/operations/manually_defined/ops_def.py +27 -28
  253. mindspore/ops/operations/math_ops.py +8 -9
  254. mindspore/ops/operations/nn_ops.py +6 -7
  255. mindspore/ops/primitive.py +9 -20
  256. mindspore/ops/tensor_method.py +52 -11
  257. mindspore/ops_generate/api/cpp_create_prim_instance_helper_generator.py +1 -1
  258. mindspore/ops_generate/api/functional_map_cpp_generator.py +10 -9
  259. mindspore/ops_generate/api/functions_cc_generator.py +58 -10
  260. mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +1 -1
  261. mindspore/ops_generate/common/base_generator.py +14 -0
  262. mindspore/ops_generate/common/gen_constants.py +7 -2
  263. mindspore/ops_generate/common/gen_utils.py +0 -19
  264. mindspore/ops_generate/common/op_proto.py +11 -4
  265. mindspore/ops_generate/common/template.py +88 -11
  266. mindspore/ops_generate/gen_ops.py +1 -1
  267. mindspore/ops_generate/op_def/lite_ops_cpp_generator.py +4 -4
  268. mindspore/ops_generate/op_def/ops_name_h_generator.py +0 -3
  269. mindspore/ops_generate/op_def/ops_primitive_h_generator.py +0 -4
  270. mindspore/ops_generate/op_def_py/op_prim_py_generator.py +5 -2
  271. mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +49 -8
  272. mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +2 -2
  273. mindspore/ops_generate/pyboost/gen_pyboost_func.py +31 -16
  274. mindspore/ops_generate/pyboost/op_template_parser.py +98 -72
  275. mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +70 -273
  276. mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +14 -6
  277. mindspore/ops_generate/pyboost/pyboost_functions_impl_cpp_generator.py +316 -0
  278. mindspore/ops_generate/pyboost/pyboost_functions_py_generator.py +1 -1
  279. mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +5 -3
  280. mindspore/ops_generate/pyboost/pyboost_inner_prim_generator.py +1 -1
  281. mindspore/ops_generate/pyboost/pyboost_internal_functions_cpp_generator.py +76 -0
  282. mindspore/ops_generate/pyboost/pyboost_internal_functions_h_generator.py +76 -0
  283. mindspore/ops_generate/pyboost/pyboost_internal_kernel_info_adapter_generator.py +125 -0
  284. mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +4 -3
  285. mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +348 -61
  286. mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +1 -1
  287. mindspore/ops_generate/pyboost/pyboost_utils.py +118 -9
  288. mindspore/ops_generate/tensor_py_cc_generator.py +1 -24
  289. mindspore/parallel/_auto_parallel_context.py +9 -17
  290. mindspore/parallel/_cell_wrapper.py +106 -40
  291. mindspore/parallel/_parallel_serialization.py +4 -3
  292. mindspore/parallel/_ps_context.py +4 -6
  293. mindspore/parallel/_tensor.py +167 -12
  294. mindspore/parallel/_transformer/moe.py +1 -1
  295. mindspore/parallel/_transformer/transformer.py +17 -12
  296. mindspore/parallel/_utils.py +5 -11
  297. mindspore/parallel/auto_parallel.py +33 -12
  298. mindspore/parallel/checkpoint_convert.py +3 -3
  299. mindspore/parallel/checkpoint_transform.py +5 -1
  300. mindspore/parallel/cluster/process_entity/_api.py +88 -49
  301. mindspore/parallel/cluster/process_entity/_utils.py +95 -7
  302. mindspore/parallel/cluster/run.py +48 -7
  303. mindspore/parallel/function/__init__.py +8 -1
  304. mindspore/parallel/function/reshard_func.py +7 -6
  305. mindspore/parallel/nn/__init__.py +15 -2
  306. mindspore/parallel/nn/parallel_cell_wrapper.py +50 -14
  307. mindspore/parallel/nn/parallel_grad_reducer.py +7 -14
  308. mindspore/parallel/shard.py +9 -23
  309. mindspore/parallel/transform_safetensors.py +468 -174
  310. mindspore/pgodb140.dll +0 -0
  311. mindspore/pgort140.dll +0 -0
  312. mindspore/profiler/__init__.py +2 -1
  313. mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +7 -7
  314. mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +3 -0
  315. mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +3 -0
  316. mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +3 -3
  317. mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +3 -3
  318. mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +4 -4
  319. mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +3 -3
  320. mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +4 -1
  321. mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +2 -1
  322. mindspore/profiler/analysis/task_manager.py +1 -1
  323. mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +5 -1
  324. mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +2 -1
  325. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +10 -9
  326. mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +43 -23
  327. mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +3 -2
  328. mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +9 -5
  329. mindspore/profiler/analysis/viewer/ms_operator_details_viewer.py +132 -0
  330. mindspore/profiler/common/constant.py +16 -0
  331. mindspore/profiler/common/msprof_cmd_tool.py +2 -2
  332. mindspore/profiler/common/path_manager.py +9 -0
  333. mindspore/profiler/common/profiler_context.py +50 -29
  334. mindspore/profiler/common/profiler_info.py +0 -16
  335. mindspore/profiler/common/profiler_meta_data.py +1 -0
  336. mindspore/profiler/common/profiler_op_analyse.py +239 -0
  337. mindspore/profiler/common/profiler_output_path.py +23 -8
  338. mindspore/profiler/common/profiler_parameters.py +128 -35
  339. mindspore/profiler/dynamic_profile/__init__.py +0 -0
  340. mindspore/profiler/dynamic_profile/dynamic_monitor_proxy.py +39 -0
  341. mindspore/profiler/dynamic_profile/dynamic_profiler_config_context.py +666 -0
  342. mindspore/profiler/dynamic_profile/dynamic_profiler_utils.py +62 -0
  343. mindspore/profiler/dynamic_profiler.py +374 -338
  344. mindspore/profiler/envprofiler.py +42 -12
  345. mindspore/profiler/experimental_config.py +112 -7
  346. mindspore/profiler/mstx.py +33 -12
  347. mindspore/profiler/platform/__init__.py +2 -3
  348. mindspore/profiler/platform/cpu_profiler.py +10 -4
  349. mindspore/profiler/platform/npu_profiler.py +30 -20
  350. mindspore/profiler/profiler.py +218 -154
  351. mindspore/profiler/profiler_action_controller.py +65 -77
  352. mindspore/profiler/profiler_interface.py +2 -2
  353. mindspore/profiler/schedule.py +10 -4
  354. mindspore/rewrite/common/config.py +1 -0
  355. mindspore/rewrite/common/namer.py +1 -0
  356. mindspore/rewrite/common/namespace.py +1 -0
  357. mindspore/rewrite/node/node.py +31 -11
  358. mindspore/rewrite/parsers/assign_parser.py +1 -1
  359. mindspore/rewrite/symbol_tree/symbol_tree.py +2 -2
  360. mindspore/run_check/_check_version.py +7 -10
  361. mindspore/runtime/__init__.py +8 -6
  362. mindspore/runtime/event.py +10 -4
  363. mindspore/runtime/executor.py +87 -45
  364. mindspore/runtime/memory.py +22 -30
  365. mindspore/runtime/thread_bind_core.py +299 -165
  366. mindspore/safeguard/rewrite_obfuscation.py +12 -13
  367. mindspore/swresample-4.dll +0 -0
  368. mindspore/swscale-6.dll +0 -0
  369. mindspore/tbbmalloc.dll +0 -0
  370. mindspore/tinyxml2.dll +0 -0
  371. mindspore/train/_utils.py +9 -5
  372. mindspore/train/amp.py +43 -23
  373. mindspore/train/callback/__init__.py +5 -5
  374. mindspore/train/callback/_callback.py +2 -1
  375. mindspore/train/callback/_checkpoint.py +4 -14
  376. mindspore/train/callback/_flops_collector.py +11 -7
  377. mindspore/train/callback/_landscape.py +0 -1
  378. mindspore/train/callback/_train_fault_tolerance.py +72 -18
  379. mindspore/train/data_sink.py +15 -6
  380. mindspore/train/dataset_helper.py +14 -5
  381. mindspore/train/model.py +49 -47
  382. mindspore/train/serialization.py +168 -126
  383. mindspore/train/summary/summary_record.py +13 -2
  384. mindspore/train/train_thor/model_thor.py +2 -2
  385. mindspore/turbojpeg.dll +0 -0
  386. mindspore/utils/__init__.py +3 -2
  387. mindspore/utils/dryrun.py +0 -6
  388. mindspore/utils/runtime_execution_order_check.py +162 -78
  389. mindspore/utils/sdc_detect.py +68 -0
  390. mindspore/utils/utils.py +14 -17
  391. mindspore/vcmeta.dll +0 -0
  392. mindspore/vcruntime140.dll +0 -0
  393. mindspore/vcruntime140_1.dll +0 -0
  394. mindspore/version.py +1 -1
  395. {mindspore-2.6.0.dist-info → mindspore-2.7.0.dist-info}/METADATA +5 -4
  396. {mindspore-2.6.0.dist-info → mindspore-2.7.0.dist-info}/RECORD +400 -439
  397. mindspore/_deprecated/jit.py +0 -198
  398. mindspore/_extends/remote/kernel_build_server_ascend.py +0 -75
  399. mindspore/communication/_hccl_management.py +0 -297
  400. mindspore/experimental/es/embedding_service.py +0 -891
  401. mindspore/experimental/es/embedding_service_layer.py +0 -581
  402. mindspore/profiler/common/validator/__init__.py +0 -14
  403. mindspore/profiler/common/validator/validate_path.py +0 -84
  404. mindspore/profiler/parser/__init__.py +0 -14
  405. mindspore/profiler/parser/aicpu_data_parser.py +0 -272
  406. mindspore/profiler/parser/ascend_analysis/__init__.py +0 -14
  407. mindspore/profiler/parser/ascend_analysis/constant.py +0 -71
  408. mindspore/profiler/parser/ascend_analysis/file_manager.py +0 -180
  409. mindspore/profiler/parser/ascend_analysis/function_event.py +0 -185
  410. mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +0 -136
  411. mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +0 -131
  412. mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +0 -104
  413. mindspore/profiler/parser/ascend_analysis/path_manager.py +0 -313
  414. mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +0 -123
  415. mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +0 -86
  416. mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +0 -75
  417. mindspore/profiler/parser/ascend_cluster_generator.py +0 -116
  418. mindspore/profiler/parser/ascend_communicate_generator.py +0 -314
  419. mindspore/profiler/parser/ascend_flops_generator.py +0 -116
  420. mindspore/profiler/parser/ascend_fpbp_generator.py +0 -82
  421. mindspore/profiler/parser/ascend_hccl_generator.py +0 -271
  422. mindspore/profiler/parser/ascend_integrate_generator.py +0 -42
  423. mindspore/profiler/parser/ascend_memory_generator.py +0 -185
  424. mindspore/profiler/parser/ascend_msprof_exporter.py +0 -282
  425. mindspore/profiler/parser/ascend_msprof_generator.py +0 -187
  426. mindspore/profiler/parser/ascend_op_generator.py +0 -334
  427. mindspore/profiler/parser/ascend_steptrace_generator.py +0 -94
  428. mindspore/profiler/parser/ascend_timeline_generator.py +0 -545
  429. mindspore/profiler/parser/base_timeline_generator.py +0 -483
  430. mindspore/profiler/parser/container.py +0 -229
  431. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +0 -697
  432. mindspore/profiler/parser/flops_parser.py +0 -531
  433. mindspore/profiler/parser/framework_enum.py +0 -111
  434. mindspore/profiler/parser/framework_parser.py +0 -464
  435. mindspore/profiler/parser/framework_struct.py +0 -61
  436. mindspore/profiler/parser/gpu_analysis/__init__.py +0 -14
  437. mindspore/profiler/parser/gpu_analysis/function_event.py +0 -44
  438. mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +0 -89
  439. mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +0 -72
  440. mindspore/profiler/parser/hccl_parser.py +0 -573
  441. mindspore/profiler/parser/hwts_log_parser.py +0 -122
  442. mindspore/profiler/parser/integrator.py +0 -526
  443. mindspore/profiler/parser/memory_usage_parser.py +0 -277
  444. mindspore/profiler/parser/minddata_analyzer.py +0 -800
  445. mindspore/profiler/parser/minddata_parser.py +0 -186
  446. mindspore/profiler/parser/minddata_pipeline_parser.py +0 -299
  447. mindspore/profiler/parser/op_intermediate_parser.py +0 -149
  448. mindspore/profiler/parser/optime_parser.py +0 -250
  449. mindspore/profiler/parser/profiler_info.py +0 -213
  450. mindspore/profiler/parser/step_trace_parser.py +0 -666
  451. mindspore/utils/hooks.py +0 -81
  452. /mindspore/common/{_auto_dynamic.py → dynamic_shape/_auto_dynamic.py} +0 -0
  453. {mindspore-2.6.0.dist-info → mindspore-2.7.0.dist-info}/WHEEL +0 -0
  454. {mindspore-2.6.0.dist-info → mindspore-2.7.0.dist-info}/entry_points.txt +0 -0
  455. {mindspore-2.6.0.dist-info → mindspore-2.7.0.dist-info}/top_level.txt +0 -0
@@ -1189,7 +1189,7 @@ class PixelShuffle(Cell):
1189
1189
  >>> pixel_shuffle = mint.nn.PixelShuffle(3)
1190
1190
  >>> input = mint.randn(1, 9, 4, 4)
1191
1191
  >>> output = pixel_shuffle(input)
1192
- >>> print(output.shape())
1192
+ >>> print(output.shape)
1193
1193
  [1, 1, 12, 12]
1194
1194
  """
1195
1195
 
@@ -67,7 +67,7 @@ from mindspore.ops.auto_generate import prelu
67
67
  # 20
68
68
 
69
69
  # 21
70
- from mindspore.ops.function.nn_func import conv3d_ext as conv3d
70
+ from mindspore.ops.functional_overload import conv3d
71
71
  # 22
72
72
 
73
73
  # 23
@@ -118,7 +118,8 @@ from mindspore.ops.auto_generate import soft_margin_loss
118
118
  # 45
119
119
 
120
120
  # 46
121
- from mindspore.ops.functional import silu
121
+ from mindspore.ops.auto_generate import silu as silu_func
122
+ from mindspore.ops.auto_generate import inplace_silu
122
123
  # 47
123
124
 
124
125
  # 48
@@ -283,6 +284,52 @@ from mindspore.ops.functional import adaptive_avg_pool2d_ext as adaptive_avg_poo
283
284
  from mindspore.ops.function.nn_func import cross_entropy_ext as cross_entropy
284
285
  from mindspore.ops.function.nn_func import nll_loss_ext as nll_loss
285
286
 
287
+ def silu(input, inplace=False):
288
+ r"""
289
+ Computes Sigmoid Linear Unit of input element-wise. The SiLU function is defined as:
290
+
291
+ .. math::
292
+
293
+ \text{SiLU}(x) = x * \sigma(x),
294
+
295
+ where :math:`x` is an element of the input, :math:`\sigma(x)` is Sigmoid function.
296
+
297
+ .. math::
298
+
299
+ \text{sigma}(x_i) = \frac{1}{1 + \exp(-x_i)},
300
+
301
+ SiLU Function Graph:
302
+
303
+ .. image:: ../images/SiLU.png
304
+ :align: center
305
+
306
+ Args:
307
+ input (Tensor): `input` is :math:`x` in the preceding formula. Input with the data type
308
+ float16 or float32.
309
+ inplace (bool, optional): If it is ``True``, enable the in place update function. Default value: ``False``.
310
+
311
+ Returns:
312
+ Tensor, with the same type and shape as the `input`.
313
+
314
+ Raises:
315
+ TypeError: If dtype of `input` is neither float16 nor float32.
316
+
317
+ Supported Platforms:
318
+ ``Ascend`` ``GPU`` ``CPU``
319
+
320
+ Examples:
321
+ >>> import mindspore
322
+ >>> from mindspore import Tensor, mint
323
+ >>> import numpy as np
324
+ >>> input = Tensor(np.array([-1, 2, -3, 2, -1]), mindspore.float16)
325
+ >>> output = mint.nn.functional.silu(input, inplace=False)
326
+ >>> print(output)
327
+ [-0.269 1.762 -0.1423 1.762 -0.269]
328
+ """
329
+ if inplace:
330
+ return inplace_silu(input)
331
+ return silu_func(input)
332
+
286
333
 
287
334
  def elu(input, alpha=1.0, inplace=False):
288
335
  r"""
@@ -511,7 +558,10 @@ def binary_cross_entropy(input, target, weight=None, reduction='mean'):
511
558
  \end{cases}
512
559
 
513
560
  .. warning::
514
- - The value of `input` must range from `0` to `l`.
561
+ The value of `input` must range from `0` to `l`.
562
+
563
+ .. note::
564
+ Currently, when the platform is Ascend, all gradient calculations are performed on NPU.
515
565
 
516
566
  Args:
517
567
  input (Tensor): The predictive value whose data type must be float16 or float32.
@@ -955,9 +1005,6 @@ def threshold(input, threshold, value, inplace=False): # pylint: disable=W0621
955
1005
  \text{value}, &\text{ otherwise }
956
1006
  \end{cases}
957
1007
 
958
- .. warning::
959
- This is an experimental API that is subject to change or deletion.
960
-
961
1008
  Args:
962
1009
  input (Tensor): The input Tensor.
963
1010
  threshold (Union[int, float]): The value of the threshold.
@@ -1,323 +1,192 @@
1
+ # Copyright 2025 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+ """functions for mint"""
1
16
  import mindspore
2
- from mindspore import Tensor
3
- from mindspore import context
4
- import mindspore.communication
5
- import mindspore.communication.comm_func
17
+ from mindspore import ops, mint
18
+ from mindspore import _checkparam as validator
6
19
  from mindspore.nn.cell import Cell
20
+ from mindspore.communication.comm_func import all_gather_into_tensor
21
+ from mindspore.communication.comm_func import all_reduce
22
+ from mindspore.communication.management import get_rank, get_group_size, GlobalComm, _get_group
7
23
  from mindspore.ops.auto_generate.gen_ops_prim import BatchNormReduceGrad
8
24
  from mindspore.ops.auto_generate.gen_ops_prim import BatchNormElemtGrad
9
- from mindspore.communication import GlobalComm
10
- from mindspore.ops import ReduceOp
11
- from mindspore._c_expression import TensorPy as Tensor_
12
- from mindspore.communication._comm_helper import _get_size_helper, HCCL_WORLD_COMM_GROUP
13
- from mindspore.ops._primitive_cache import _get_cache_prim
14
- from mindspore.communication.comm_func import all_gather_into_tensor as all_gather_into_tensor_dy
15
- from mindspore.ops import operations as P
16
- from mindspore import ops, mint
17
-
18
-
19
- DEFAULT_WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP
25
+ from mindspore.ops.primitive import Primitive, prim_arg_register, PrimitiveWithInfer, prim_attr_register
26
+ from mindspore.ops.operations.comm_ops import ReduceOp, check_collective_target_dtype
20
27
 
21
28
  batch_norm_reduce_grad = BatchNormReduceGrad()
22
29
  batch_norm_elemt_grad = BatchNormElemtGrad()
23
- shape = P.Shape()
24
-
25
-
26
- def _deal_comm_outputs(output, async_op):
27
- if isinstance(output, tuple):
28
- if not async_op:
29
- output[1].wait()
30
- return output[0]
30
+ shape = ops.Shape()
31
+
32
+
33
+ class AllGather(PrimitiveWithInfer):
34
+ @prim_arg_register
35
+ def __init__(self, group=GlobalComm.WORLD_COMM_GROUP):
36
+ super(AllGather, self).__init__(self.__class__.__name__)
37
+ self.group = _get_group(group)
38
+ validator.check_value_type('group', self.group, (str,), self.name)
39
+ self.rank = get_rank(self.group)
40
+ self.rank_size = get_group_size(self.group)
41
+ validator.check('rank', self.rank, 'rank_size', self.rank_size, validator.LT, self.name)
42
+ self.add_prim_attr('rank_size', self.rank_size)
43
+ self.add_prim_attr('group', self.group)
44
+ self.add_prim_attr('fusion', 0)
45
+ self.add_prim_attr('mean_flag', False)
46
+ self.add_prim_attr('no_eliminate', True)
47
+
48
+ def __call__(self, combined):
49
+ output, _ = all_gather_into_tensor(combined, group=self.group)
31
50
  return output
32
51
 
33
- if not async_op:
52
+ def infer_shape(self, x_shape):
53
+ validator.check_positive_int(len(x_shape), "x shape", self.name)
54
+ if x_shape[0] > 0:
55
+ x_shape[0] = x_shape[0] * self.rank_size
56
+ return x_shape
57
+
58
+ def infer_dtype(self, x_dtype):
59
+ check_collective_target_dtype('x', x_dtype, self.name)
60
+ return x_dtype
61
+
62
+
63
+ class AllReduce(Primitive):
64
+ @prim_attr_register
65
+ def __init__(self, op=ReduceOp.SUM, group=GlobalComm.WORLD_COMM_GROUP):
66
+ """Initialize AllReduce."""
67
+ super().__init__(name="AllReduce")
68
+ self.group = _get_group(group)
69
+ if not isinstance(op, type(ReduceOp.SUM)):
70
+ raise TypeError(f"For '{self.name}', the 'op' must be str, but got {type(op).__name__}.")
71
+ if not isinstance(self.group, str):
72
+ raise TypeError(f"For '{self.name}', the 'group' must be str, "
73
+ f"but got {type(self.group).__name__}.")
74
+ self.op = op
75
+ self.add_prim_attr('group', self.group)
76
+ self.add_prim_attr('fusion', 0)
77
+ self.add_prim_attr('index', 0)
78
+ self.add_prim_attr('no_eliminate', True)
79
+
80
+ def __call__(self, combined):
81
+ output, _ = all_reduce(combined, group=self.group)
34
82
  return output
35
- return output
36
-
37
-
38
- def get_group_size(group=GlobalComm.WORLD_COMM_GROUP):
39
- if not isinstance(group, str):
40
- raise TypeError("For 'get_group_size', the argument 'group' must be type of string, "
41
- "but got 'group' type : {}.".format(type(group)))
42
- return _get_size_helper(group=_get_group(group))
43
-
44
-
45
- def _contiguous(tensor):
46
- if not tensor.is_contiguous() or tensor.storage_offset() != 0:
47
- tensor = tensor.contiguous()
48
- return tensor
49
83
 
50
84
 
51
- def _get_group(group):
52
- """Return the world communication group if the `group` is `DEFAULT_WORLD_COMM_GROUP`."""
53
- if group == DEFAULT_WORLD_COMM_GROUP:
54
- return GlobalComm.WORLD_COMM_GROUP
55
- return group
56
-
57
-
58
- def all_gather_into_tensor(tensor, group=GlobalComm.WORLD_COMM_GROUP, async_op=False):
59
- if not isinstance(tensor, (Tensor, Tensor_)):
60
- raise TypeError(
61
- "For all_gather_into_tensor, the input tensor must be tensor")
62
- group = _get_group(group)
63
- tensor = _contiguous(tensor)
64
- all_gather_op = _get_cache_prim(P.AllGather)(group=group)
65
- output = all_gather_op(tensor)
66
- return _deal_comm_outputs(output, async_op)
67
-
68
-
69
- def all_reduce(tensor, op=ReduceOp.SUM, group=GlobalComm.WORLD_COMM_GROUP, async_op=False):
70
- if not isinstance(tensor, (Tensor, Tensor_)):
71
- raise TypeError("For all_reduce, the input tensor must be tensor")
72
- if not isinstance(op, str):
73
- raise TypeError("For all_reduce, the input op type must be str")
74
- if op not in ('sum', 'prod', 'min', 'max'):
75
- raise TypeError(
76
- "For all_reduce, the input op value must be one of sum, prod, min, max")
77
- group = _get_group(group)
78
- tensor = _contiguous(tensor)
79
- all_reduce_op = _get_cache_prim(P.AllReduce)(op=op, group=group)
80
- output = all_reduce_op(tensor)
81
- return _deal_comm_outputs(output, async_op)
82
-
83
-
84
- def bprop_pynative(input_x, weight, bias, running_mean, running_var, eps, momentum,
85
- process_group, world_size, output, doutput):
86
- _, mean_param, invstd_param, count_all_param = output
87
- dout, _, _, _ = doutput
88
-
89
- # 不支持 KBK模式
90
- if not dout.is_contiguous():
91
- dout = dout.contiguous()
92
-
93
- grad_input = grad_weight = grad_bias = None
94
-
95
- inputG = True
96
- weightG = True
97
- biasG = True
98
-
99
- # calculate local stats as well as grad_weight / grad_bias
100
- sum_dy, sum_dy_xmu, grad_weight, grad_bias = batch_norm_reduce_grad(
101
- dout,
102
- input_x,
103
- mean_param,
104
- invstd_param,
105
- weight,
106
- inputG,
107
- weightG,
108
- biasG
109
- )
110
-
111
- if inputG:
112
- # synchronizing stats used to calculate input gradient.
113
- sum_dy_shape = shape(sum_dy)
114
- num_channels = sum_dy_shape[0]
115
- combined = mint.cat([sum_dy, sum_dy_xmu], dim=0)
116
-
117
- new_combined, _ = mindspore.communication.comm_func.all_reduce(
118
- combined, group=process_group)
119
-
120
- sum_dy, sum_dy_xmu = mint.split(new_combined, num_channels)
121
-
122
- # backward pass for gradient calculation
123
- grad_input = batch_norm_elemt_grad(
124
- dout,
125
- input_x,
126
- mean_param,
127
- invstd_param,
128
- weight,
129
- sum_dy,
130
- sum_dy_xmu,
131
- count_all_param
132
- )
133
-
134
- # synchronizing of grad_weight / grad_bias is not needed as distributed
135
- # training would handle all reduce.
136
- if weight is None or not weightG:
137
- grad_weight = None
138
-
139
- if weight is None or not biasG:
140
- grad_bias = None
141
-
142
- return grad_input, grad_weight, grad_bias, None, None, None, None, None, None
85
+ class SyncBatchNormInner(Cell):
86
+ def __init__(self, self_num_features, self_world_size):
87
+ super(SyncBatchNormInner, self).__init__()
88
+ self.num_features = self_num_features
89
+ self.world_size = self_world_size
143
90
 
91
+ def construct(self, input, weight, bias, running_mean, running_var, eps, momentum, process_group, world_size):
92
+ if self.world_size != world_size:
93
+ raise ValueError('World Size Error')
94
+ input = input.contiguous()
95
+ if weight is not None:
96
+ weight = weight.contiguous()
97
+
98
+ input_shape = shape(input)
99
+ input_numel = ops.numel(input)
100
+ size = int(input_numel // input_shape[1])
101
+ if size == 1 and world_size < 2:
102
+ raise ValueError(
103
+ 'Expected more than 1 value per channel when training, got input size {}'.format(size))
104
+
105
+ # calculate mean/invstd for input.
106
+ mean, invstd = mint.batch_norm_stats(input, eps)
107
+ count = mint.full((1,), input_numel // input_shape[1], dtype=mean.dtype)
108
+
109
+ num_channels = input_shape[1]
110
+ if self.num_features != num_channels:
111
+ raise ValueError('Features Error')
112
+ # C, C, 1 -> (2C + 1)
113
+ combined = mint.cat([mean, invstd, count], dim=0)
114
+ # Use allgather instead of allreduce because count could be different across
115
+ # ranks, simple all reduce op can not give correct results.
116
+ # batch_norm_gather_stats_with_counts calculates global mean & invstd based on
117
+ # all gathered mean, invstd and count.
118
+ # world_size * (2C + 1)
119
+ all_gather_op = AllGather(process_group)
120
+ combined = all_gather_op(combined)
121
+ combined = ops.reshape(combined, [world_size, -1])
122
+ # world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1
123
+ mean_val_all, invstd_val_all, count_val_all = mint.split(
124
+ combined, num_channels, dim=1)
125
+ # calculate global mean & invstd
126
+ mean, invstd = mint.batch_norm_gather_stats_with_counts(input, mean_val_all, invstd_val_all, running_mean,
127
+ running_var, momentum, eps, count_val_all.view(-1))
128
+
129
+ # apply element-wise normalization
130
+ out = mint.batch_norm_elemt(input, weight, bias, mean, invstd, eps)
131
+ return (out, mean, invstd, count_val_all.view(-1))
144
132
 
145
- def bprop_kbk(input_x, weight, bias, running_mean, running_var, eps, momentum,
133
+ def bprop(self, input_x, weight, bias, running_mean, running_var, eps, momentum,
146
134
  process_group, world_size, output, doutput):
147
- _, mean_param, invstd_param, count_all_param = output
148
- dout, _, _, _ = doutput
135
+ _, mean_param, invstd_param, count_all_param = output
136
+ dout, _, _, _ = doutput
149
137
 
150
- dout = dout.contiguous()
151
-
152
- grad_input = grad_weight = grad_bias = None
153
-
154
- inputG = True
155
- weightG = True
156
- biasG = True
157
-
158
- # calculate local stats as well as grad_weight / grad_bias
159
- sum_dy, sum_dy_xmu, grad_weight, grad_bias = batch_norm_reduce_grad(
160
- dout,
161
- input_x,
162
- mean_param,
163
- invstd_param,
164
- weight,
165
- inputG,
166
- weightG,
167
- biasG
168
- )
169
-
170
- if inputG:
171
- # synchronizing stats used to calculate input gradient.
172
- sum_dy_shape = shape(sum_dy)
173
- num_channels = sum_dy_shape[0]
174
- combined = mint.cat([sum_dy, sum_dy_xmu], dim=0)
138
+ # 不支持 KBK模式
139
+ dout = dout.contiguous()
175
140
 
176
- new_combined = all_reduce(combined, group=process_group)
141
+ grad_input = grad_weight = grad_bias = None
177
142
 
178
- sum_dy, sum_dy_xmu = mint.split(new_combined, num_channels)
143
+ inputG = True
144
+ weightG = True
145
+ biasG = True
179
146
 
180
- # backward pass for gradient calculation
181
- grad_input = batch_norm_elemt_grad(
147
+ # calculate local stats as well as grad_weight / grad_bias
148
+ sum_dy, sum_dy_xmu, grad_weight, grad_bias = batch_norm_reduce_grad(
182
149
  dout,
183
150
  input_x,
184
151
  mean_param,
185
152
  invstd_param,
186
153
  weight,
187
- sum_dy,
188
- sum_dy_xmu,
189
- count_all_param
154
+ inputG,
155
+ weightG,
156
+ biasG
190
157
  )
191
158
 
192
- # synchronizing of grad_weight / grad_bias is not needed as distributed
193
- # training would handle all reduce.
194
- if weight is None or not weightG:
195
- grad_weight = None
196
-
197
- if weight is None or not biasG:
198
- grad_bias = None
199
-
200
- return grad_input, grad_weight, grad_bias, None, None, None, None, None, None
201
-
202
-
203
- def construct_pynative(input, weight, bias, running_mean, running_var, eps, momentum, process_group,
204
- world_size, self_num_features, self_world_size):
205
- if self_world_size != world_size:
206
- raise ValueError('World Size Error')
207
- if not input.is_contiguous():
208
- input = input.contiguous()
209
- if weight is not None:
210
- weight = weight.contiguous()
211
-
212
- input_shape = shape(input)
213
- input_numel = ops.numel(input)
214
- size = int(input_numel // input_shape[1])
215
- if size == 1 and world_size < 2:
216
- raise ValueError(
217
- 'Expected more than 1 value per channel when training, got input size {}'.format(size))
218
-
219
- # calculate mean/invstd for input.
220
- mean, invstd = mint.batch_norm_stats(input, eps)
221
- count = mint.full((1,), input_numel //
222
- input_shape[1], dtype=mean.dtype)
223
-
224
- num_channels = input_shape[1]
225
- if self_num_features != num_channels:
226
- raise ValueError('Features Error')
227
- # C, C, 1 -> (2C + 1)
228
- combined = mint.cat([mean, invstd, count], dim=0)
229
- # Use allgather instead of allreduce because count could be different across
230
- # ranks, simple all reduce op can not give correct results.
231
- # batch_norm_gather_stats_with_counts calculates global mean & invstd based on
232
- # all gathered mean, invstd and count.
233
- # world_size * (2C + 1)
234
- combined, _ = all_gather_into_tensor_dy(combined, process_group)
235
- combined = ops.reshape(combined, [world_size, -1])
236
- # world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1
237
- mean_val_all, invstd_val_all, count_val_all = mint.split(
238
- combined, num_channels, dim=1)
239
- # calculate global mean & invstd
240
- mean, invstd = mint.batch_norm_gather_stats_with_counts(input, mean_val_all, invstd_val_all, running_mean,
241
- running_var, momentum, eps, count_val_all.view(-1))
242
-
243
- # apply element-wise normalization
244
- out = mint.batch_norm_elemt(input, weight, bias, mean, invstd, eps)
245
- return (out, mean, invstd, count_val_all.view(-1))
246
-
247
-
248
- def construct_kbk(input, weight, bias, running_mean, running_var, eps, momentum, process_group,
249
- world_size, self_num_features, self_world_size):
250
- if self_world_size != world_size:
251
- raise ValueError('World Size Error')
252
- input = input.contiguous()
253
- if weight is not None:
254
- weight = weight.contiguous()
255
-
256
- input_shape = shape(input)
257
- input_numel = ops.numel(input)
258
- size = int(input_numel // input_shape[1])
259
- if size == 1 and world_size < 2:
260
- raise ValueError(
261
- 'Expected more than 1 value per channel when training, got input size {}'.format(size))
262
-
263
- # calculate mean/invstd for input.
264
- mean, invstd = mint.batch_norm_stats(input, eps)
265
- count = mint.full((1,), input_numel //
266
- input_shape[1], dtype=mean.dtype)
267
-
268
- num_channels = input_shape[1]
269
- if self_num_features != num_channels:
270
- raise ValueError('Features Error')
271
- # C, C, 1 -> (2C + 1)
272
- combined = mint.cat([mean, invstd, count], dim=0)
273
- # Use allgather instead of allreduce because count could be different across
274
- # ranks, simple all reduce op can not give correct results.
275
- # batch_norm_gather_stats_with_counts calculates global mean & invstd based on
276
- # all gathered mean, invstd and count.
277
- # world_size * (2C + 1)
278
- combined = all_gather_into_tensor(combined, process_group)
279
- combined = ops.reshape(combined, [world_size, -1])
280
- # world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1
281
- mean_all, invstd_all, count_all = mint.split(
282
- combined, num_channels, dim=1)
283
- # calculate global mean & invstd
284
- mean, invstd = mint.batch_norm_gather_stats_with_counts(
285
- input,
286
- mean_all,
287
- invstd_all,
288
- running_mean,
289
- running_var,
290
- momentum,
291
- eps,
292
- count_all.view(-1)
293
- )
294
-
295
- # apply element-wise normalization
296
- out = mint.batch_norm_elemt(input, weight, bias, mean, invstd, eps)
297
- return (out, mean, invstd, count_all.view(-1))
298
-
299
-
300
- class SyncBatchNormInner(Cell):
301
- def __init__(self, self_num_features, self_world_size):
302
- super(SyncBatchNormInner, self).__init__()
303
- self.num_features = self_num_features
304
- self.world_size = self_world_size
305
- self.mode = context.get_context("mode")
306
- if self.mode == 1:
307
- self.fn_bprop = bprop_pynative
308
- self.fn_construct = construct_pynative
309
- else:
310
- self.fn_bprop = bprop_kbk
311
- self.fn_construct = construct_kbk
312
-
313
- def construct(self, input, weight, bias, running_mean, running_var, eps, momentum, process_group, world_size):
314
- return self.fn_construct(input, weight, bias, running_mean, running_var, eps, momentum, process_group,
315
- world_size, self.num_features, self.world_size)
316
-
317
- def bprop(self, input_x, weight, bias, running_mean, running_var, eps, momentum,
318
- process_group, world_size, output, doutput):
319
- return self.fn_bprop(input_x, weight, bias, running_mean, running_var, eps, momentum,
320
- process_group, world_size, output, doutput)
159
+ if inputG:
160
+ # synchronizing stats used to calculate input gradient.
161
+ sum_dy_shape = shape(sum_dy)
162
+ num_channels = sum_dy_shape[0]
163
+ combined = mint.cat([sum_dy, sum_dy_xmu], dim=0)
164
+ all_reduce_op = AllReduce(group=process_group)
165
+ new_combined = all_reduce_op(combined)
166
+
167
+ sum_dy, sum_dy_xmu = mint.split(new_combined, num_channels)
168
+
169
+ # backward pass for gradient calculation
170
+ grad_input = batch_norm_elemt_grad(
171
+ dout,
172
+ input_x,
173
+ mean_param,
174
+ invstd_param,
175
+ weight,
176
+ sum_dy,
177
+ sum_dy_xmu,
178
+ count_all_param
179
+ )
180
+
181
+ # synchronizing of grad_weight / grad_bias is not needed as distributed
182
+ # training would handle all reduce.
183
+ if weight is None or not weightG:
184
+ grad_weight = None
185
+
186
+ if weight is None or not biasG:
187
+ grad_bias = None
188
+
189
+ return grad_input, grad_weight, grad_bias, None, None, None, None, None, None
321
190
 
322
191
 
323
192
  class _SyncBatchNorm(Cell):
@@ -45,6 +45,10 @@ class SiLU(Cell):
45
45
  .. warning::
46
46
  This is an experimental API that is subject to change or deletion.
47
47
 
48
+ Args:
49
+ inplace (bool, optional): If it is ``True``, enable the in-place update function.
50
+ Default value: ``False``.
51
+
48
52
  Inputs:
49
53
  - **input** (Tensor) - `input` is :math:`x` in the preceding formula.
50
54
  Input with the data type float16 or float32. Tensor of any dimension.
@@ -63,18 +67,19 @@ class SiLU(Cell):
63
67
  >>> from mindspore import Tensor, mint
64
68
  >>> import numpy as np
65
69
  >>> input = Tensor(np.array([-1, 2, -3, 2, -1]), mindspore.float16)
66
- >>> silu = mint.nn.SiLU()
70
+ >>> silu = mint.nn.SiLU(inplace=False)
67
71
  >>> output = silu(input)
68
72
  >>> print(output)
69
73
  [-0.269 1.762 -0.1423 1.762 -0.269]
70
74
  """
71
75
 
72
- def __init__(self):
76
+ def __init__(self, inplace=False):
73
77
  """Initialize SiLU."""
74
78
  super(SiLU, self).__init__()
79
+ self.inplace = inplace
75
80
 
76
81
  def construct(self, x):
77
- return mint.nn.functional.silu(x)
82
+ return mint.nn.functional.silu(x, self.inplace)
78
83
 
79
84
 
80
85
  class Sigmoid(Cell):
@@ -355,9 +360,6 @@ class Threshold(Cell):
355
360
  \text{value}, &\text{ otherwise }
356
361
  \end{cases}
357
362
 
358
- .. warning::
359
- This is an experimental API that is subject to change or deletion.
360
-
361
363
  Args:
362
364
  threshold (Union[int, float]): The value of the threshold.
363
365
  value (Union[int, float]): The value to replace with when element is less than threshold.