mindspore 2.4.10__cp310-cp310-win_amd64.whl → 2.6.0__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (602) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  3. mindspore/Newtonsoft.Json.dll +0 -0
  4. mindspore/__init__.py +13 -6
  5. mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
  6. mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
  7. mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
  8. mindspore/_check_jit_forbidden_api.py +3 -0
  9. mindspore/_checkparam.py +3 -38
  10. mindspore/_deprecated/__init__.py +17 -0
  11. mindspore/_deprecated/jit.py +198 -0
  12. mindspore/_extends/builtin_operations.py +1 -1
  13. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
  14. mindspore/_extends/parse/__init__.py +6 -7
  15. mindspore/_extends/parse/compile_config.py +83 -0
  16. mindspore/_extends/parse/deprecated/__init__.py +0 -0
  17. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +394 -0
  18. mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
  19. mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
  20. mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
  21. mindspore/_extends/parse/parser.py +47 -198
  22. mindspore/_extends/parse/resources.py +1 -5
  23. mindspore/_extends/parse/standard_method.py +229 -99
  24. mindspore/_extends/pijit/__init__.py +2 -2
  25. mindspore/_extends/pijit/pijit_func_white_list.py +17 -12
  26. mindspore/_extends/pijit/tensor_func_list.py +27 -0
  27. mindspore/_extends/utils.py +1 -1
  28. mindspore/amp.py +11 -5
  29. mindspore/atlprov.dll +0 -0
  30. mindspore/avcodec-59.dll +0 -0
  31. mindspore/avdevice-59.dll +0 -0
  32. mindspore/avfilter-8.dll +0 -0
  33. mindspore/avformat-59.dll +0 -0
  34. mindspore/avutil-57.dll +0 -0
  35. mindspore/boost/__init__.py +2 -2
  36. mindspore/boost/base.py +3 -7
  37. mindspore/boost/boost_cell_wrapper.py +138 -43
  38. mindspore/c1.dll +0 -0
  39. mindspore/c1xx.dll +0 -0
  40. mindspore/c2.dll +0 -0
  41. mindspore/common/__init__.py +6 -3
  42. mindspore/common/_grad_function.py +56 -0
  43. mindspore/common/_pijit_context.py +14 -5
  44. mindspore/common/_register_for_tensor.py +1 -2
  45. mindspore/common/_stub_tensor.py +30 -14
  46. mindspore/common/_tensor_cpp_method.py +17 -0
  47. mindspore/common/_tensor_docs.py +4760 -0
  48. mindspore/common/api.py +480 -372
  49. mindspore/common/auto_dynamic_shape.py +41 -44
  50. mindspore/common/dtype.py +39 -36
  51. mindspore/common/dump.py +9 -6
  52. mindspore/common/file_system.py +9 -1
  53. mindspore/common/generator.py +5 -0
  54. mindspore/common/hook_handle.py +6 -2
  55. mindspore/common/initializer.py +13 -10
  56. mindspore/common/jit_begin_end.py +94 -0
  57. mindspore/common/jit_config.py +6 -1
  58. mindspore/common/jit_context.py +76 -0
  59. mindspore/common/jit_trace.py +378 -0
  60. mindspore/common/lazy_inline.py +9 -3
  61. mindspore/common/mindir_util.py +10 -2
  62. mindspore/common/mutable.py +5 -4
  63. mindspore/common/parameter.py +135 -52
  64. mindspore/common/seed.py +2 -2
  65. mindspore/common/sparse_tensor.py +23 -17
  66. mindspore/common/tensor.py +975 -1981
  67. mindspore/communication/__init__.py +7 -5
  68. mindspore/communication/_comm_helper.py +52 -2
  69. mindspore/communication/comm_func.py +240 -181
  70. mindspore/communication/management.py +95 -26
  71. mindspore/context.py +324 -573
  72. mindspore/dataset/__init__.py +65 -37
  73. mindspore/dataset/audio/__init__.py +2 -8
  74. mindspore/dataset/audio/transforms.py +3 -17
  75. mindspore/dataset/callback/ds_callback.py +2 -1
  76. mindspore/dataset/core/config.py +87 -6
  77. mindspore/dataset/engine/cache_admin.py +3 -3
  78. mindspore/dataset/engine/cache_client.py +6 -5
  79. mindspore/dataset/engine/datasets.py +292 -267
  80. mindspore/dataset/engine/datasets_audio.py +22 -8
  81. mindspore/dataset/engine/datasets_standard_format.py +46 -27
  82. mindspore/dataset/engine/datasets_text.py +78 -48
  83. mindspore/dataset/engine/datasets_user_defined.py +183 -117
  84. mindspore/dataset/engine/datasets_vision.py +120 -44
  85. mindspore/dataset/engine/iterators.py +283 -63
  86. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
  87. mindspore/dataset/engine/obs/util.py +8 -0
  88. mindspore/dataset/engine/queue.py +40 -0
  89. mindspore/dataset/engine/samplers.py +289 -43
  90. mindspore/dataset/engine/serializer_deserializer.py +3 -2
  91. mindspore/dataset/engine/validators.py +53 -11
  92. mindspore/dataset/text/__init__.py +7 -6
  93. mindspore/dataset/text/transforms.py +6 -5
  94. mindspore/dataset/text/utils.py +3 -3
  95. mindspore/dataset/transforms/__init__.py +0 -9
  96. mindspore/dataset/transforms/py_transforms_util.py +17 -0
  97. mindspore/dataset/transforms/transforms.py +31 -14
  98. mindspore/dataset/utils/browse_dataset.py +1 -1
  99. mindspore/dataset/vision/__init__.py +2 -9
  100. mindspore/dataset/vision/transforms.py +202 -158
  101. mindspore/dataset/vision/utils.py +7 -5
  102. mindspore/dataset/vision/validators.py +1 -2
  103. mindspore/device_context/__init__.py +21 -0
  104. mindspore/device_context/ascend/__init__.py +25 -0
  105. mindspore/device_context/ascend/device.py +72 -0
  106. mindspore/device_context/ascend/op_debug.py +153 -0
  107. mindspore/device_context/ascend/op_precision.py +193 -0
  108. mindspore/device_context/ascend/op_tuning.py +123 -0
  109. mindspore/{ops_generate/gen_constants.py → device_context/cpu/__init__.py} +6 -17
  110. mindspore/device_context/cpu/device.py +62 -0
  111. mindspore/device_context/cpu/op_tuning.py +43 -0
  112. mindspore/device_context/gpu/__init__.py +21 -0
  113. mindspore/device_context/gpu/device.py +70 -0
  114. mindspore/device_context/gpu/op_precision.py +67 -0
  115. mindspore/device_context/gpu/op_tuning.py +175 -0
  116. mindspore/device_manager.py +170 -0
  117. mindspore/dnnl.dll +0 -0
  118. mindspore/dpcmi.dll +0 -0
  119. mindspore/experimental/es/embedding_service.py +35 -27
  120. mindspore/experimental/llm_boost/__init__.py +1 -0
  121. mindspore/experimental/llm_boost/ascend_native/__init__.py +22 -0
  122. mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +209 -0
  123. mindspore/experimental/llm_boost/ascend_native/llm_boost.py +52 -0
  124. mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
  125. mindspore/experimental/llm_boost/atb/llama_boost.py +6 -1
  126. mindspore/experimental/llm_boost/register.py +1 -0
  127. mindspore/experimental/map_parameter.py +4 -4
  128. mindspore/experimental/optim/adadelta.py +6 -6
  129. mindspore/experimental/optim/adagrad.py +4 -4
  130. mindspore/experimental/optim/adam.py +7 -0
  131. mindspore/experimental/optim/adamax.py +4 -4
  132. mindspore/experimental/optim/adamw.py +4 -0
  133. mindspore/experimental/optim/asgd.py +1 -1
  134. mindspore/experimental/optim/lr_scheduler.py +73 -46
  135. mindspore/experimental/optim/radam.py +34 -31
  136. mindspore/experimental/optim/rprop.py +1 -1
  137. mindspore/experimental/optim/sgd.py +1 -1
  138. mindspore/hal/contiguous_tensors_handle.py +6 -10
  139. mindspore/hal/device.py +55 -53
  140. mindspore/hal/event.py +52 -52
  141. mindspore/hal/memory.py +179 -120
  142. mindspore/hal/stream.py +150 -109
  143. mindspore/include/api/context.h +0 -1
  144. mindspore/include/dataset/constants.h +7 -4
  145. mindspore/include/dataset/execute.h +2 -2
  146. mindspore/jpeg62.dll +0 -0
  147. mindspore/log.py +50 -0
  148. mindspore/mindrecord/__init__.py +21 -8
  149. mindspore/mindrecord/config.py +17 -316
  150. mindspore/mindrecord/filereader.py +1 -9
  151. mindspore/mindrecord/filewriter.py +5 -15
  152. mindspore/mindrecord/mindpage.py +1 -9
  153. mindspore/mindspore_backend_common.dll +0 -0
  154. mindspore/mindspore_backend_manager.dll +0 -0
  155. mindspore/mindspore_common.dll +0 -0
  156. mindspore/mindspore_core.dll +0 -0
  157. mindspore/mindspore_dump.dll +0 -0
  158. mindspore/mindspore_frontend.dll +0 -0
  159. mindspore/mindspore_glog.dll +0 -0
  160. mindspore/mindspore_memory_pool.dll +0 -0
  161. mindspore/mindspore_ms_backend.dll +0 -0
  162. mindspore/mindspore_ops.dll +0 -0
  163. mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
  164. mindspore/mindspore_ops_kernel_common.dll +0 -0
  165. mindspore/mindspore_profiler.dll +0 -0
  166. mindspore/mindspore_pyboost.dll +0 -0
  167. mindspore/mindspore_pynative.dll +0 -0
  168. mindspore/mindspore_res_manager.dll +0 -0
  169. mindspore/mindspore_runtime_pipeline.dll +0 -0
  170. mindspore/mint/__init__.py +798 -761
  171. mindspore/mint/distributed/__init__.py +70 -4
  172. mindspore/mint/distributed/distributed.py +2679 -44
  173. mindspore/mint/linalg/__init__.py +8 -0
  174. mindspore/mint/nn/__init__.py +743 -22
  175. mindspore/mint/nn/functional.py +716 -23
  176. mindspore/mint/nn/layer/__init__.py +21 -4
  177. mindspore/mint/nn/layer/_functions.py +334 -0
  178. mindspore/mint/nn/layer/activation.py +276 -1
  179. mindspore/mint/nn/layer/basic.py +123 -0
  180. mindspore/mint/nn/layer/conv.py +933 -0
  181. mindspore/mint/nn/layer/normalization.py +223 -28
  182. mindspore/mint/nn/layer/padding.py +797 -0
  183. mindspore/mint/nn/layer/pooling.py +235 -0
  184. mindspore/mint/optim/__init__.py +3 -1
  185. mindspore/mint/optim/adam.py +223 -0
  186. mindspore/mint/optim/adamw.py +26 -19
  187. mindspore/mint/optim/sgd.py +171 -0
  188. mindspore/mint/special/__init__.py +2 -1
  189. mindspore/msobj140.dll +0 -0
  190. mindspore/mspdb140.dll +0 -0
  191. mindspore/mspdbcore.dll +0 -0
  192. mindspore/mspdbst.dll +0 -0
  193. mindspore/mspft140.dll +0 -0
  194. mindspore/msvcdis140.dll +0 -0
  195. mindspore/msvcp140_1.dll +0 -0
  196. mindspore/msvcp140_2.dll +0 -0
  197. mindspore/msvcp140_atomic_wait.dll +0 -0
  198. mindspore/msvcp140_codecvt_ids.dll +0 -0
  199. mindspore/multiprocessing/__init__.py +5 -0
  200. mindspore/nn/__init__.py +4 -1
  201. mindspore/nn/cell.py +1373 -192
  202. mindspore/nn/dynamic_lr.py +2 -1
  203. mindspore/nn/layer/activation.py +29 -27
  204. mindspore/nn/layer/basic.py +51 -35
  205. mindspore/nn/layer/channel_shuffle.py +3 -3
  206. mindspore/nn/layer/container.py +1 -1
  207. mindspore/nn/layer/conv.py +53 -42
  208. mindspore/nn/layer/embedding.py +12 -11
  209. mindspore/nn/layer/normalization.py +56 -49
  210. mindspore/nn/layer/padding.py +4 -3
  211. mindspore/nn/layer/pooling.py +120 -42
  212. mindspore/nn/layer/rnn_cells.py +1 -1
  213. mindspore/nn/layer/rnns.py +2 -1
  214. mindspore/nn/layer/timedistributed.py +5 -5
  215. mindspore/nn/layer/transformer.py +59 -36
  216. mindspore/nn/learning_rate_schedule.py +8 -4
  217. mindspore/nn/loss/loss.py +58 -55
  218. mindspore/nn/optim/ada_grad.py +7 -5
  219. mindspore/nn/optim/adadelta.py +11 -9
  220. mindspore/nn/optim/adafactor.py +1 -1
  221. mindspore/nn/optim/adam.py +19 -15
  222. mindspore/nn/optim/adamax.py +8 -7
  223. mindspore/nn/optim/adasum.py +5 -5
  224. mindspore/nn/optim/asgd.py +3 -1
  225. mindspore/nn/optim/ftrl.py +11 -9
  226. mindspore/nn/optim/lamb.py +1 -1
  227. mindspore/nn/optim/lars.py +1 -4
  228. mindspore/nn/optim/lazyadam.py +12 -10
  229. mindspore/nn/optim/momentum.py +7 -6
  230. mindspore/nn/optim/optimizer.py +3 -3
  231. mindspore/nn/optim/proximal_ada_grad.py +12 -10
  232. mindspore/nn/optim/rmsprop.py +13 -12
  233. mindspore/nn/optim/rprop.py +11 -9
  234. mindspore/nn/optim/sgd.py +9 -6
  235. mindspore/nn/optim/tft_wrapper.py +5 -2
  236. mindspore/nn/optim/thor.py +2 -1
  237. mindspore/nn/probability/bijector/bijector.py +17 -11
  238. mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
  239. mindspore/nn/probability/bijector/invert.py +2 -2
  240. mindspore/nn/probability/bijector/scalar_affine.py +3 -3
  241. mindspore/nn/probability/bijector/softplus.py +3 -2
  242. mindspore/nn/probability/distribution/beta.py +3 -3
  243. mindspore/nn/probability/distribution/categorical.py +1 -1
  244. mindspore/nn/probability/distribution/cauchy.py +4 -2
  245. mindspore/nn/probability/distribution/exponential.py +6 -7
  246. mindspore/nn/probability/distribution/gamma.py +2 -2
  247. mindspore/nn/probability/distribution/gumbel.py +2 -2
  248. mindspore/nn/probability/distribution/half_normal.py +5 -3
  249. mindspore/nn/probability/distribution/logistic.py +5 -3
  250. mindspore/nn/probability/distribution/poisson.py +1 -1
  251. mindspore/nn/probability/distribution/uniform.py +5 -3
  252. mindspore/nn/reinforcement/_tensors_queue.py +1 -1
  253. mindspore/nn/reinforcement/tensor_array.py +1 -1
  254. mindspore/nn/utils/init.py +13 -11
  255. mindspore/nn/wrap/__init__.py +6 -6
  256. mindspore/nn/wrap/cell_wrapper.py +181 -122
  257. mindspore/nn/wrap/grad_reducer.py +45 -36
  258. mindspore/nn/wrap/loss_scale.py +6 -7
  259. mindspore/numpy/array_creations.py +63 -65
  260. mindspore/numpy/array_ops.py +149 -144
  261. mindspore/numpy/logic_ops.py +41 -42
  262. mindspore/numpy/math_ops.py +361 -359
  263. mindspore/numpy/utils.py +17 -18
  264. mindspore/numpy/utils_const.py +5 -6
  265. mindspore/opencv_core452.dll +0 -0
  266. mindspore/opencv_imgcodecs452.dll +0 -0
  267. mindspore/opencv_imgproc452.dll +0 -0
  268. mindspore/ops/__init__.py +5 -3
  269. mindspore/ops/_grad_experimental/grad_comm_ops.py +112 -16
  270. mindspore/ops/_grad_experimental/grad_debug_ops.py +14 -2
  271. mindspore/ops/_grad_experimental/grad_inner_ops.py +9 -0
  272. mindspore/ops/_grad_experimental/grad_math_ops.py +2 -1
  273. mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
  274. mindspore/ops/_op_impl/cpu/__init__.py +1 -0
  275. mindspore/ops/_op_impl/cpu/raise_op.py +28 -0
  276. mindspore/ops/_register_for_op.py +0 -11
  277. mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
  278. mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -65
  279. mindspore/ops/_vmap/vmap_array_ops.py +52 -25
  280. mindspore/ops/_vmap/vmap_base.py +0 -2
  281. mindspore/ops/_vmap/vmap_grad_nn_ops.py +21 -14
  282. mindspore/ops/_vmap/vmap_math_ops.py +15 -16
  283. mindspore/ops/_vmap/vmap_nn_ops.py +29 -42
  284. mindspore/ops/auto_generate/__init__.py +4 -3
  285. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +258 -46
  286. mindspore/ops/auto_generate/gen_extend_func.py +757 -185
  287. mindspore/ops/auto_generate/gen_ops_def.py +4197 -2243
  288. mindspore/ops/auto_generate/gen_ops_prim.py +16976 -6055
  289. mindspore/ops/auto_generate/pyboost_inner_prim.py +221 -87
  290. mindspore/ops/composite/__init__.py +2 -1
  291. mindspore/ops/composite/base.py +20 -25
  292. mindspore/ops/composite/math_ops.py +6 -16
  293. mindspore/ops/composite/multitype_ops/__init__.py +5 -2
  294. mindspore/ops/composite/multitype_ops/_compile_utils.py +228 -30
  295. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
  296. mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
  297. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
  298. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
  299. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
  300. mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
  301. mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
  302. mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
  303. mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
  304. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
  305. mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
  306. mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
  307. mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
  308. mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
  309. mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
  310. mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
  311. mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
  312. mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
  313. mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
  314. mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
  315. mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
  316. mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
  317. mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
  318. mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
  319. mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
  320. mindspore/ops/composite/multitype_ops/pow_impl.py +2 -30
  321. mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
  322. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  323. mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
  324. mindspore/ops/function/__init__.py +40 -2
  325. mindspore/ops/function/_add_attr_func.py +58 -0
  326. mindspore/ops/function/array_func.py +2089 -2403
  327. mindspore/ops/function/clip_func.py +80 -23
  328. mindspore/ops/function/debug_func.py +57 -57
  329. mindspore/ops/function/grad/__init__.py +1 -0
  330. mindspore/ops/function/grad/grad_func.py +104 -71
  331. mindspore/ops/function/image_func.py +2 -2
  332. mindspore/ops/function/linalg_func.py +47 -78
  333. mindspore/ops/function/math_func.py +4351 -3813
  334. mindspore/ops/function/nn_func.py +1712 -637
  335. mindspore/ops/function/other_func.py +159 -1
  336. mindspore/ops/function/parameter_func.py +18 -84
  337. mindspore/ops/function/random_func.py +452 -387
  338. mindspore/ops/function/reshard_func.py +4 -70
  339. mindspore/ops/function/sparse_func.py +3 -3
  340. mindspore/ops/function/sparse_unary_func.py +6 -6
  341. mindspore/ops/function/spectral_func.py +25 -58
  342. mindspore/ops/function/vmap_func.py +26 -18
  343. mindspore/ops/functional.py +23 -7
  344. mindspore/ops/functional_overload.py +1548 -0
  345. mindspore/ops/op_info_register.py +32 -244
  346. mindspore/ops/operations/__init__.py +23 -15
  347. mindspore/ops/operations/_custom_ops_utils.py +235 -0
  348. mindspore/ops/operations/_embedding_cache_ops.py +4 -4
  349. mindspore/ops/operations/_grad_ops.py +2 -43
  350. mindspore/ops/operations/_infer_ops.py +2 -1
  351. mindspore/ops/operations/_inner_ops.py +43 -84
  352. mindspore/ops/operations/_ms_kernel.py +4 -10
  353. mindspore/ops/operations/_rl_inner_ops.py +1 -1
  354. mindspore/ops/operations/_scalar_ops.py +3 -2
  355. mindspore/ops/operations/_sequence_ops.py +1 -1
  356. mindspore/ops/operations/_tensor_array.py +1 -1
  357. mindspore/ops/operations/array_ops.py +81 -324
  358. mindspore/ops/operations/comm_ops.py +154 -108
  359. mindspore/ops/operations/custom_ops.py +298 -87
  360. mindspore/ops/operations/debug_ops.py +157 -59
  361. mindspore/ops/operations/inner_ops.py +7 -5
  362. mindspore/ops/operations/linalg_ops.py +1 -57
  363. mindspore/ops/operations/manually_defined/_inner.py +1 -1
  364. mindspore/ops/operations/manually_defined/ops_def.py +928 -180
  365. mindspore/ops/operations/math_ops.py +32 -234
  366. mindspore/ops/operations/nn_ops.py +212 -531
  367. mindspore/ops/operations/other_ops.py +62 -9
  368. mindspore/ops/operations/random_ops.py +13 -7
  369. mindspore/ops/operations/reshard_ops.py +1 -1
  370. mindspore/ops/operations/sparse_ops.py +2 -2
  371. mindspore/ops/primitive.py +66 -53
  372. mindspore/ops/tensor_method.py +1895 -0
  373. mindspore/ops_generate/__init__.py +0 -5
  374. mindspore/ops_generate/aclnn/__init__.py +0 -0
  375. mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +135 -0
  376. mindspore/ops_generate/aclnn/gen_aclnn_implement.py +257 -0
  377. mindspore/ops_generate/api/__init__.py +0 -0
  378. mindspore/ops_generate/api/add_tensor_docs_generator.py +56 -0
  379. mindspore/ops_generate/api/cpp_create_prim_instance_helper_generator.py +105 -0
  380. mindspore/ops_generate/api/functional_map_cpp_generator.py +504 -0
  381. mindspore/ops_generate/api/functional_overload_py_generator.py +112 -0
  382. mindspore/ops_generate/api/functions_cc_generator.py +237 -0
  383. mindspore/ops_generate/api/gen_api.py +103 -0
  384. mindspore/ops_generate/api/op_api_proto.py +235 -0
  385. mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +461 -0
  386. mindspore/ops_generate/common/__init__.py +0 -0
  387. mindspore/ops_generate/common/base_generator.py +11 -0
  388. mindspore/ops_generate/common/gen_constants.py +91 -0
  389. mindspore/ops_generate/common/gen_utils.py +348 -0
  390. mindspore/ops_generate/common/op_proto.py +473 -0
  391. mindspore/ops_generate/common/template.py +523 -0
  392. mindspore/ops_generate/gen_ops.py +22 -1069
  393. mindspore/ops_generate/op_def/__init__.py +0 -0
  394. mindspore/ops_generate/op_def/gen_op_def.py +90 -0
  395. mindspore/ops_generate/op_def/lite_ops_cpp_generator.py +191 -0
  396. mindspore/ops_generate/op_def/ops_def_cc_generator.py +296 -0
  397. mindspore/ops_generate/op_def/ops_def_h_generator.py +74 -0
  398. mindspore/ops_generate/op_def/ops_name_h_generator.py +83 -0
  399. mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
  400. mindspore/ops_generate/op_def_py/__init__.py +0 -0
  401. mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
  402. mindspore/ops_generate/op_def_py/op_def_py_generator.py +132 -0
  403. mindspore/ops_generate/op_def_py/op_prim_py_generator.py +489 -0
  404. mindspore/ops_generate/pyboost/__init__.py +0 -0
  405. mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +139 -0
  406. mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +93 -0
  407. mindspore/ops_generate/pyboost/gen_pyboost_func.py +175 -0
  408. mindspore/ops_generate/pyboost/op_template_parser.py +517 -0
  409. mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +407 -0
  410. mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +100 -0
  411. mindspore/ops_generate/pyboost/pyboost_functions_py_generator.py +148 -0
  412. mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +155 -0
  413. mindspore/ops_generate/pyboost/pyboost_inner_prim_generator.py +132 -0
  414. mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +272 -0
  415. mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +938 -0
  416. mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +357 -0
  417. mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +179 -36
  418. mindspore/ops_generate/resources/__init__.py +0 -0
  419. mindspore/ops_generate/resources/resource_list.py +30 -0
  420. mindspore/ops_generate/resources/resource_loader.py +36 -0
  421. mindspore/ops_generate/resources/resource_manager.py +64 -0
  422. mindspore/ops_generate/resources/yaml_loader.py +88 -0
  423. mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
  424. mindspore/parallel/__init__.py +7 -3
  425. mindspore/parallel/_auto_parallel_context.py +159 -40
  426. mindspore/parallel/_cell_wrapper.py +132 -15
  427. mindspore/parallel/_parallel_serialization.py +107 -5
  428. mindspore/parallel/_ps_context.py +1 -1
  429. mindspore/parallel/_recovery_context.py +7 -2
  430. mindspore/parallel/_tensor.py +142 -18
  431. mindspore/parallel/_utils.py +199 -23
  432. mindspore/parallel/algo_parameter_config.py +4 -4
  433. mindspore/parallel/auto_parallel.py +732 -0
  434. mindspore/parallel/checkpoint_convert.py +159 -0
  435. mindspore/parallel/checkpoint_transform.py +700 -35
  436. mindspore/parallel/cluster/process_entity/_api.py +276 -50
  437. mindspore/parallel/cluster/process_entity/_utils.py +41 -6
  438. mindspore/parallel/cluster/run.py +21 -4
  439. mindspore/parallel/function/__init__.py +24 -0
  440. mindspore/parallel/function/reshard_func.py +258 -0
  441. mindspore/parallel/nn/__init__.py +25 -0
  442. mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
  443. mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
  444. mindspore/parallel/parameter_broadcast.py +25 -14
  445. mindspore/parallel/shard.py +137 -59
  446. mindspore/parallel/transform_safetensors.py +364 -305
  447. mindspore/pgodb140.dll +0 -0
  448. mindspore/pgort140.dll +0 -0
  449. mindspore/profiler/__init__.py +22 -5
  450. mindspore/profiler/analysis/__init__.py +0 -0
  451. mindspore/profiler/analysis/parser/__init__.py +0 -0
  452. mindspore/profiler/analysis/parser/ascend_cann_parser.py +170 -0
  453. mindspore/profiler/analysis/parser/base_parser.py +158 -0
  454. mindspore/profiler/analysis/parser/framework_cann_relation_parser.py +45 -0
  455. mindspore/profiler/analysis/parser/ms_framework_parser.py +142 -0
  456. mindspore/profiler/analysis/parser/ms_minddata_parser.py +145 -0
  457. mindspore/profiler/analysis/parser/timeline_assembly_factory/__init__.py +0 -0
  458. mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +264 -0
  459. mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +40 -0
  460. mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +109 -0
  461. mindspore/profiler/analysis/parser/timeline_creator/__init__.py +0 -0
  462. mindspore/profiler/analysis/parser/timeline_creator/base_timeline_creator.py +44 -0
  463. mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +90 -0
  464. mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +76 -0
  465. mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +103 -0
  466. mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +134 -0
  467. mindspore/profiler/analysis/parser/timeline_event/__init__.py +0 -0
  468. mindspore/profiler/analysis/parser/timeline_event/base_event.py +233 -0
  469. mindspore/profiler/analysis/parser/timeline_event/cpu_op_event.py +47 -0
  470. mindspore/profiler/analysis/parser/timeline_event/flow_event.py +36 -0
  471. mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +415 -0
  472. mindspore/profiler/analysis/parser/timeline_event/msprof_event.py +73 -0
  473. mindspore/profiler/analysis/parser/timeline_event/scope_layer_event.py +53 -0
  474. mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +146 -0
  475. mindspore/profiler/analysis/task_manager.py +131 -0
  476. mindspore/profiler/analysis/time_converter.py +84 -0
  477. mindspore/profiler/analysis/viewer/__init__.py +0 -0
  478. mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +372 -0
  479. mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +87 -0
  480. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +250 -0
  481. mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +320 -0
  482. mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +327 -0
  483. mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +376 -0
  484. mindspore/profiler/analysis/viewer/ascend_timeline_viewer.py +58 -0
  485. mindspore/profiler/analysis/viewer/base_viewer.py +26 -0
  486. mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +96 -0
  487. mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +581 -0
  488. mindspore/profiler/analysis/work_flow.py +73 -0
  489. mindspore/profiler/common/ascend_msprof_exporter.py +139 -0
  490. mindspore/profiler/common/command_executor.py +90 -0
  491. mindspore/profiler/common/constant.py +186 -3
  492. mindspore/profiler/common/file_manager.py +208 -0
  493. mindspore/profiler/common/log.py +130 -0
  494. mindspore/profiler/common/msprof_cmd_tool.py +221 -0
  495. mindspore/profiler/common/path_manager.py +395 -0
  496. mindspore/profiler/common/process_bar.py +168 -0
  497. mindspore/profiler/common/process_pool.py +9 -3
  498. mindspore/profiler/common/profiler_context.py +500 -0
  499. mindspore/profiler/common/profiler_info.py +304 -0
  500. mindspore/profiler/common/profiler_meta_data.py +74 -0
  501. mindspore/profiler/common/profiler_output_path.py +284 -0
  502. mindspore/profiler/common/profiler_parameters.py +251 -0
  503. mindspore/profiler/common/profiler_path_manager.py +179 -0
  504. mindspore/profiler/common/record_function.py +76 -0
  505. mindspore/profiler/common/tlv_decoder.py +76 -0
  506. mindspore/profiler/common/util.py +75 -2
  507. mindspore/profiler/dynamic_profiler.py +341 -75
  508. mindspore/profiler/envprofiler.py +163 -0
  509. mindspore/profiler/experimental_config.py +197 -0
  510. mindspore/profiler/mstx.py +242 -0
  511. mindspore/profiler/platform/__init__.py +21 -0
  512. mindspore/profiler/platform/base_profiler.py +40 -0
  513. mindspore/profiler/platform/cpu_profiler.py +124 -0
  514. mindspore/profiler/platform/gpu_profiler.py +74 -0
  515. mindspore/profiler/platform/npu_profiler.py +335 -0
  516. mindspore/profiler/profiler.py +1073 -90
  517. mindspore/profiler/profiler_action_controller.py +187 -0
  518. mindspore/profiler/profiler_interface.py +118 -0
  519. mindspore/profiler/schedule.py +243 -0
  520. mindspore/rewrite/api/node.py +15 -13
  521. mindspore/rewrite/api/symbol_tree.py +2 -3
  522. mindspore/run_check/_check_version.py +27 -20
  523. mindspore/run_check/run_check.py +1 -1
  524. mindspore/runtime/__init__.py +37 -0
  525. mindspore/runtime/device.py +27 -0
  526. mindspore/runtime/event.py +209 -0
  527. mindspore/runtime/executor.py +177 -0
  528. mindspore/runtime/memory.py +416 -0
  529. mindspore/runtime/stream.py +460 -0
  530. mindspore/runtime/thread_bind_core.py +401 -0
  531. mindspore/safeguard/rewrite_obfuscation.py +12 -9
  532. mindspore/swresample-4.dll +0 -0
  533. mindspore/swscale-6.dll +0 -0
  534. mindspore/tbbmalloc.dll +0 -0
  535. mindspore/tinyxml2.dll +0 -0
  536. mindspore/train/__init__.py +8 -8
  537. mindspore/train/_utils.py +96 -27
  538. mindspore/train/amp.py +9 -5
  539. mindspore/train/callback/__init__.py +2 -2
  540. mindspore/train/callback/_callback.py +2 -16
  541. mindspore/train/callback/_checkpoint.py +53 -55
  542. mindspore/train/callback/_cluster_monitor.py +14 -18
  543. mindspore/train/callback/_early_stop.py +1 -1
  544. mindspore/train/callback/_flops_collector.py +103 -68
  545. mindspore/train/callback/_history.py +8 -5
  546. mindspore/train/callback/_lambda_callback.py +2 -2
  547. mindspore/train/callback/_landscape.py +0 -3
  548. mindspore/train/callback/_loss_monitor.py +2 -1
  549. mindspore/train/callback/_on_request_exit.py +6 -5
  550. mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
  551. mindspore/train/callback/_summary_collector.py +52 -19
  552. mindspore/train/callback/_time_monitor.py +2 -1
  553. mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +228 -108
  554. mindspore/train/data_sink.py +25 -2
  555. mindspore/train/dataset_helper.py +15 -16
  556. mindspore/train/loss_scale_manager.py +8 -7
  557. mindspore/train/metrics/accuracy.py +3 -3
  558. mindspore/train/metrics/confusion_matrix.py +9 -9
  559. mindspore/train/metrics/error.py +3 -3
  560. mindspore/train/metrics/hausdorff_distance.py +4 -4
  561. mindspore/train/metrics/mean_surface_distance.py +3 -3
  562. mindspore/train/metrics/metric.py +0 -12
  563. mindspore/train/metrics/occlusion_sensitivity.py +4 -2
  564. mindspore/train/metrics/precision.py +11 -10
  565. mindspore/train/metrics/recall.py +9 -9
  566. mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
  567. mindspore/train/mind_ir_pb2.py +174 -46
  568. mindspore/train/model.py +269 -136
  569. mindspore/train/serialization.py +622 -978
  570. mindspore/train/summary/_summary_adapter.py +2 -2
  571. mindspore/train/summary/summary_record.py +2 -3
  572. mindspore/train/train_thor/model_thor.py +1 -1
  573. mindspore/turbojpeg.dll +0 -0
  574. mindspore/utils/__init__.py +6 -3
  575. mindspore/utils/dryrun.py +140 -0
  576. mindspore/utils/hooks.py +81 -0
  577. mindspore/utils/runtime_execution_order_check.py +552 -0
  578. mindspore/utils/utils.py +138 -4
  579. mindspore/vcmeta.dll +0 -0
  580. mindspore/vcruntime140.dll +0 -0
  581. mindspore/vcruntime140_1.dll +0 -0
  582. mindspore/version.py +1 -1
  583. {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/METADATA +3 -3
  584. {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/RECORD +587 -418
  585. {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/entry_points.txt +1 -1
  586. mindspore/_install_custom.py +0 -43
  587. mindspore/common/_register_for_adapter.py +0 -74
  588. mindspore/common/_tensor_overload.py +0 -139
  589. mindspore/mindspore_np_dtype.dll +0 -0
  590. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
  591. mindspore/ops/auto_generate/gen_arg_handler.py +0 -197
  592. mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
  593. mindspore/ops_generate/gen_aclnn_implement.py +0 -263
  594. mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
  595. mindspore/ops_generate/gen_pyboost_func.py +0 -1052
  596. mindspore/ops_generate/gen_utils.py +0 -209
  597. mindspore/ops_generate/op_proto.py +0 -145
  598. mindspore/ops_generate/template.py +0 -261
  599. mindspore/profiler/envprofiling.py +0 -254
  600. mindspore/profiler/profiling.py +0 -1926
  601. {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/WHEEL +0 -0
  602. {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/top_level.txt +0 -0
@@ -29,7 +29,7 @@ from mindspore.common.api import jit
29
29
  from mindspore.common.tensor import Tensor
30
30
  from mindspore.common._register_for_tensor import Registry
31
31
  from mindspore._c_expression import MetaFuncGraph_, function_id
32
- from mindspore._c_expression import Tensor as Tensor_
32
+ from mindspore._c_expression import TensorPy as Tensor_
33
33
  from mindspore._extends.parse.resources import convert_object_map
34
34
  from mindspore import _checkparam as validator
35
35
  from mindspore import Parameter, ParameterTuple
@@ -49,11 +49,12 @@ from mindspore.train.data_sink import _init_sink_dataset
49
49
  from mindspore.train.summary import SummaryRecord
50
50
  from mindspore.train._utils import _exec_datagraph
51
51
  from mindspore.train.summary.writer import BaseWriter
52
- from mindspore.train.serialization import _exec_save, load, export_split_mindir, obfuscate_model, _parse_ckpt_proto, \
52
+ from mindspore.train.serialization import _exec_save, load, export_split_mindir, _parse_ckpt_proto, \
53
53
  _generate_front_info_for_param_data_file, _get_data_file, _encrypt_data, _split_save, _save_mindir_together, \
54
54
  _load_into_param_dict
55
55
  from mindspore.parallel import _cost_model_context
56
56
  from mindspore.parallel._offload_context import offload_context
57
+ from mindspore.parallel._utils import _is_in_data_parallel_mode
57
58
  from mindspore.run_check._check_version import check_version_and_env_config
58
59
  from mindspore.dataset.callback.ds_callback import DSCallback, WaitedDSCallback
59
60
  from mindspore.dataset.transforms.c_transforms import TensorOperation as CTensorOperation, OneHot as COneHot, \
@@ -127,7 +128,7 @@ from mindspore.dataset.vision.transforms import AdjustBrightness, AdjustContrast
127
128
  RandomVerticalFlipWithBBox as VRandomVerticalFlipWithBBox, Rescale as VRescale, Resize as VResize, ResizedCrop, \
128
129
  ResizeWithBBox as VResizeWithBBox, Rotate as VRotate, SlicePatches as VSlicePatches, Solarize, ToTensor,\
129
130
  TrivialAugmentWide, UniformAugment as VUniformAugment, VerticalFlip as VVerticalFlip
130
- from mindspore.profiler.profiling import Profiler
131
+ from mindspore.profiler.profiler import Profiler
131
132
  from mindspore.communication._hccl_management import get_rank_size, get_rank_id
132
133
  from mindspore.communication._comm_helper import _create_group_helper, _destroy_group_helper
133
134
  from mindspore.communication.management import _set_rank_from_mpi, init as cinit, release as crelease
@@ -360,6 +361,7 @@ FUNC_KEY_DICT_ITEMS = 22 # dict.items
360
361
  FUNC_KEY_PRIMITIVE_ASSIGN = 23 # mindspore.ops.assign, Primitive("Assign")
361
362
  FUNC_KEY_TENSOR_SETITEM = 24 # Tensor.__setitem__
362
363
  FUNC_KEY_TENSOR_ASSIGN_VALUE = 25 # Tensor.assign_value
364
+ FUNC_KEY_TENSOR_IS_CONTIGUOUS = 26 # Tensor.is_contiguous
363
365
 
364
366
  # Initialized only once. This map will initialize by c++ when start pijit.
365
367
  # key is customer if fuzzy match. (Primitive, constexpr, primexpr, MetaFuncGraph)
@@ -376,19 +378,19 @@ _func_map = {
376
378
  constexpr_key: FUNC_KEY_CONSTEXPR,
377
379
  primexpr_key: FUNC_KEY_PRIMEXPR,
378
380
  meta_func_graph_key: FUNC_KEY_META_FUNCG_RAPH,
379
- id(GraphCell.__call__): FUNC_KEY_GRAPH_CELL,
381
+ function_id(GraphCell.__call__): FUNC_KEY_GRAPH_CELL,
380
382
  id(psjit_code): FUNC_KEY_PSJIT_CODE,
381
- id(_get_cache_prim): FUNC_KEY_GET_CACHE_PRIM,
382
- id(Registry.get): FUNC_KEY_REGISTRY_GET,
383
+ function_id(_get_cache_prim): FUNC_KEY_GET_CACHE_PRIM,
384
+ function_id(Registry.get): FUNC_KEY_REGISTRY_GET,
383
385
 
384
386
  # tensor side-effect
385
387
  primitive_assign_key: FUNC_KEY_PRIMITIVE_ASSIGN,
386
- id(F.assign): FUNC_KEY_PRIMITIVE_ASSIGN,
387
- id(Tensor.assign_value): FUNC_KEY_TENSOR_ASSIGN_VALUE,
388
- id(Tensor.__setitem__): FUNC_KEY_TENSOR_SETITEM,
388
+ function_id(F.assign): FUNC_KEY_PRIMITIVE_ASSIGN,
389
+ function_id(Tensor.assign_value): FUNC_KEY_TENSOR_ASSIGN_VALUE,
390
+ function_id(Tensor.__setitem__): FUNC_KEY_TENSOR_SETITEM,
389
391
 
390
392
  # Tensor method
391
- id(Tensor.astype): FUNC_KEY_TENSOR_ASTYPE,
393
+ function_id(Tensor.astype): FUNC_KEY_TENSOR_ASTYPE,
392
394
 
393
395
  # types.BuiltinFunctionType
394
396
  function_id(isinstance): FUNC_KEY_BUILTIN_FUNC,
@@ -448,6 +450,7 @@ _func_map = {
448
450
  function_id(str.isalnum): FUNC_KEY_BUILTIN_FUNC,
449
451
  function_id(str.isidentifier): FUNC_KEY_BUILTIN_FUNC,
450
452
  function_id(str.isprintable): FUNC_KEY_BUILTIN_FUNC,
453
+ function_id(str.replace): FUNC_KEY_BUILTIN_FUNC,
451
454
  function_id(str.format): FUNC_KEY_BUILTIN_FUNC,
452
455
  function_id(str.format_map): FUNC_KEY_BUILTIN_FUNC,
453
456
  function_id(str.__format__): FUNC_KEY_BUILTIN_FUNC,
@@ -472,7 +475,7 @@ _func_map = {
472
475
  function_id(Tensor_.getitem_index_info): FUNC_KEY_BUILTIN_FUNC,
473
476
  function_id(Tensor_.get_bytes): FUNC_KEY_BUILTIN_FUNC,
474
477
  function_id(Tensor_.is_init): FUNC_KEY_BUILTIN_FUNC,
475
- function_id(Tensor_.is_contiguous): FUNC_KEY_BUILTIN_FUNC,
478
+ function_id(Tensor_.is_contiguous): FUNC_KEY_TENSOR_IS_CONTIGUOUS,
476
479
  function_id(Tensor_.stride): FUNC_KEY_BUILTIN_FUNC,
477
480
  # Tensor_.asnumpy need real tensor value
478
481
 
@@ -488,6 +491,7 @@ _func_map = {
488
491
  function_id(validator.check_number_range): FUNC_KEY_PIJIT_CONSTEXPR,
489
492
  function_id(validator.check_is_int): FUNC_KEY_PIJIT_CONSTEXPR,
490
493
  function_id(validator.check_is_number): FUNC_KEY_PIJIT_CONSTEXPR,
494
+ function_id(validator.check_positive_int_sequence): FUNC_KEY_PIJIT_CONSTEXPR,
491
495
  function_id(np_version_valid): FUNC_KEY_PIJIT_CONSTEXPR,
492
496
  function_id(_is_initialized): FUNC_KEY_PIJIT_CONSTEXPR,
493
497
  function_id(_set_elegant_exit_handle): FUNC_KEY_PIJIT_CONSTEXPR,
@@ -496,7 +500,9 @@ _func_map = {
496
500
  function_id(get_rank_size): FUNC_KEY_PIJIT_CONSTEXPR,
497
501
  function_id(get_rank_id): FUNC_KEY_PIJIT_CONSTEXPR,
498
502
  function_id(offload_context): FUNC_KEY_PIJIT_CONSTEXPR,
503
+ function_id(_is_in_data_parallel_mode): FUNC_KEY_PIJIT_CONSTEXPR,
499
504
  function_id(check_version_and_env_config): FUNC_KEY_PIJIT_CONSTEXPR,
505
+ function_id(Tensor.tolist): FUNC_KEY_PIJIT_CONSTEXPR,
500
506
 
501
507
  # inner function
502
508
  function_id(type_size_in_bytes): FUNC_KEY_BUILTIN_FUNC,
@@ -530,7 +536,6 @@ _func_map = {
530
536
  function_id(_exec_save): FUNC_KEY_PIJIT_FORBIDDEN,
531
537
  function_id(load): FUNC_KEY_PIJIT_FORBIDDEN,
532
538
  function_id(export_split_mindir): FUNC_KEY_PIJIT_FORBIDDEN,
533
- function_id(obfuscate_model): FUNC_KEY_PIJIT_FORBIDDEN,
534
539
  function_id(_parse_ckpt_proto): FUNC_KEY_PIJIT_FORBIDDEN,
535
540
  function_id(_generate_front_info_for_param_data_file): FUNC_KEY_PIJIT_FORBIDDEN,
536
541
  function_id(_get_data_file): FUNC_KEY_PIJIT_FORBIDDEN,
@@ -0,0 +1,27 @@
1
+ # Copyright 2025 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+ """Store and get tensor method"""
16
+ from mindspore import Tensor
17
+ from mindspore._c_expression import function_id
18
+
19
+ tensor_method_id_to_name = {}
20
+ for method_name in dir(Tensor):
21
+ method_id = function_id(getattr(Tensor, method_name))
22
+ tensor_method_id_to_name[method_id] = method_name
23
+
24
+
25
+ def get_tensor_method_name(id):
26
+ """Get method name by function id"""
27
+ return tensor_method_id_to_name.get(id, None)
@@ -28,7 +28,7 @@ def cell_attr_register(fn=None, attrs=None):
28
28
 
29
29
  Args:
30
30
  fn (function): __init__ function of cell.
31
- attrs (list(string) | string): attr list.
31
+ attrs (list(str) | str): attr list.
32
32
 
33
33
  Returns:
34
34
  function, original function.
mindspore/amp.py CHANGED
@@ -69,6 +69,12 @@ def _enable_all_finite():
69
69
  if not checker.check_custom_version():
70
70
  logger.debug("Disable AllFinite due to version check failure.")
71
71
  return False
72
+ else:
73
+ return False
74
+
75
+ if "RANK_TABLE_FILE" in os.environ:
76
+ return False
77
+
72
78
  runtime_conf = os.environ.get('MS_DEV_RUNTIME_CONF')
73
79
  global_jit_config = context.get_jit_config()
74
80
  if runtime_conf is not None and ("all_finite:True" in runtime_conf or "all_finite:true" in runtime_conf):
@@ -82,7 +88,7 @@ def _enable_all_finite():
82
88
  if global_jit_config:
83
89
  logger.debug("Current global jit config is: {}".format(global_jit_config["jit_level"]))
84
90
  return global_jit_config["jit_level"] == "O0" or global_jit_config["jit_level"] == "O1"
85
- return False
91
+ return True
86
92
 
87
93
 
88
94
  def _grad_unscale(scale, grad):
@@ -93,12 +99,12 @@ def _grad_scale(scale, grad):
93
99
  return grad * scale.astype(grad.dtype)
94
100
 
95
101
 
96
- @jit
102
+ @jit(backend="ms_backend")
97
103
  def _grad_scale_map(scale_value, inputs):
98
104
  return _hypermap(_partial(_grad_scale, scale_value), inputs)
99
105
 
100
106
 
101
- @jit
107
+ @jit(backend="ms_backend")
102
108
  def _grad_unscale_map(scale_value, inputs):
103
109
  return _hypermap(_partial(_grad_unscale, scale_value), inputs)
104
110
 
@@ -110,7 +116,7 @@ def _overflow(inputs):
110
116
  return 1 - status.all()
111
117
 
112
118
 
113
- @jit
119
+ @jit(backend="ms_backend")
114
120
  def _all_finite(inputs, check_overflow_mode, enable_allfinite):
115
121
  """all finite check"""
116
122
  if _ascend_target():
@@ -319,7 +325,7 @@ class StaticLossScaler(LossScaler):
319
325
 
320
326
  class DynamicLossScaler(LossScaler):
321
327
  r"""
322
- Dynamic Loss scale class.
328
+ Manager for dynamically adjusting the loss scaling factor.
323
329
 
324
330
  Dynamic loss scaling tries to determine the largest loss scale value that
325
331
  will keep gradients finite. It does this by increasing the loss scale every
mindspore/atlprov.dll CHANGED
Binary file
mindspore/avcodec-59.dll CHANGED
Binary file
mindspore/avdevice-59.dll CHANGED
Binary file
mindspore/avfilter-8.dll CHANGED
Binary file
mindspore/avformat-59.dll CHANGED
Binary file
mindspore/avutil-57.dll CHANGED
Binary file
@@ -13,8 +13,8 @@
13
13
  # limitations under the License.
14
14
  # ============================================================================
15
15
  """
16
- Boost provide auto accelerating for network, such as Less BN, Gradient Freeze, Gradient
17
- accumulation and so on.
16
+ Boost is able to automatically optimize network performance, e.g., by reducing BN, gradient freezing,
17
+ and accumulating gradients to achieve network acceleration.
18
18
 
19
19
  Note:
20
20
  This feature is a beta feature, and we are still improving its functionality.
mindspore/boost/base.py CHANGED
@@ -21,15 +21,12 @@ import math
21
21
  import copy
22
22
  import numpy as np
23
23
  from scipy import linalg as la
24
- from mindspore.context import ParallelMode
25
24
  import mindspore.nn as nn
26
25
  from mindspore.nn.optim import LARS
27
26
  from mindspore import log as logger
28
27
  from mindspore.common import Parameter
29
- from mindspore.communication.management import get_group_size
28
+ from mindspore.communication.management import get_rank, get_group_size
30
29
  from mindspore.train.serialization import load_checkpoint
31
- from mindspore.parallel._utils import _get_global_rank
32
- from mindspore.parallel._auto_parallel_context import auto_parallel_context
33
30
  from mindspore.boost.less_batch_normalization import CommonHeadLastFN
34
31
 
35
32
 
@@ -329,7 +326,7 @@ def _get_local_pca_mat_path(weight_load_dir, pca_mat_path, n_component, device_n
329
326
  if os.path.exists(save_pca_end_path):
330
327
  os.remove(save_pca_end_path)
331
328
 
332
- rank = _get_global_rank()
329
+ rank = get_rank()
333
330
  local_pca_mat_path = full_pca_mat_path[:-4] + "_rank_" + str(rank) + ".npy"
334
331
  if os.path.exists(local_pca_mat_path):
335
332
  os.remove(local_pca_mat_path)
@@ -498,8 +495,7 @@ def _save_local_pca_mat(pca_mat, full_pca_mat_path, n_component):
498
495
  full_pca_mat_path (str): the path of full pca mat.
499
496
  n_component (int): pca component.
500
497
  """
501
- parallel_mode = auto_parallel_context().get_parallel_mode()
502
- rank_size = 1 if parallel_mode == ParallelMode.STAND_ALONE else get_group_size()
498
+ rank_size = get_group_size()
503
499
  local_dim = math.ceil(n_component // rank_size)
504
500
  for rank_id in range(rank_size):
505
501
  start_index = rank_id * local_dim
@@ -1,4 +1,4 @@
1
- # Copyright 2021-2022 Huawei Technologies Co., Ltd
1
+ # Copyright 2021-2025 Huawei Technologies Co., Ltd
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -15,12 +15,13 @@
15
15
  """Boost Mode Cell Wrapper."""
16
16
  from __future__ import absolute_import
17
17
 
18
+ import os
18
19
  import numpy as np
19
20
  from mindspore.nn.wrap import TrainOneStepCell
20
21
  import mindspore.context as context
21
22
  from mindspore.context import ParallelMode
22
23
  from mindspore.parallel._utils import _get_global_rank, _get_device_num, _get_gradients_mean
23
- from mindspore.communication.management import get_group_size, create_group
24
+ from mindspore.communication.management import get_rank, get_group_size, create_group
24
25
  from mindspore.nn.cell import Cell
25
26
  from mindspore.nn import SequentialCell
26
27
  from mindspore.common import Tensor
@@ -38,6 +39,10 @@ from mindspore.boost.adasum import AdaSum
38
39
  from mindspore.boost.dim_reduce import DimReduce
39
40
  from mindspore.boost.grad_accumulation import gradient_accumulation_op, gradient_clear_op
40
41
  from mindspore.boost.base import _load_local_pca_mat
42
+ from mindspore.ops.operations.nn_ops import AllFinite
43
+ from mindspore._c_expression import MSContext
44
+ from mindspore.run_check._check_version import AscendEnvChecker
45
+ from mindspore import log as logger
41
46
 
42
47
  __all__ = ["BoostTrainOneStepCell", "BoostTrainOneStepWithLossScaleCell"]
43
48
 
@@ -90,6 +95,27 @@ def _tensor_grad_overflow(grad):
90
95
  def _tensor_grad_overflow_row_tensor(grad):
91
96
  return grad_overflow(grad.values)
92
97
 
98
+ _ascend_grad_overflow = C.MultitypeFuncGraph("_ascend_grad_overflow")
99
+ ascend_grad_overflow = P.IsFinite()
100
+
101
+
102
+ @_ascend_grad_overflow.register("Tensor")
103
+ def _tensor_ascend_grad_overflow(grad):
104
+ status = ascend_grad_overflow(grad)
105
+ base = Tensor(1.0, dtype=mstype.float32)
106
+ output = base - status.all()
107
+ output = P.Reshape()(output, ((-1,)))
108
+ return output
109
+
110
+
111
+ @_ascend_grad_overflow.register("RowTensor")
112
+ def _tensor_ascend_grad_overflow_row_tensor(grad):
113
+ status = ascend_grad_overflow(grad.values)
114
+ base = Tensor(1.0, dtype=mstype.float32)
115
+ output = base - status.all()
116
+ output = P.Reshape()(output, ((1,)))
117
+ return output
118
+
93
119
 
94
120
  class _OutputToFloat16(Cell):
95
121
  "Wrap cell for amp. Cast network output back to float16"
@@ -362,7 +388,7 @@ class BoostTrainOneStepCell(TrainOneStepCell):
362
388
  gamma = self.auto_boost.gamma
363
389
  alpha = self.auto_boost.alpha
364
390
  sigma = self.auto_boost.sigma
365
- _rank = _get_global_rank()
391
+ _rank = get_rank()
366
392
  _rank_size = 1 if self.parallel_mode == ParallelMode.STAND_ALONE else get_group_size()
367
393
  n_components = self.auto_boost.n_components
368
394
  timeout = self.auto_boost.timeout
@@ -483,7 +509,11 @@ class BoostTrainOneStepWithLossScaleCell(BoostTrainOneStepCell):
483
509
  self.allreduce = P.AllReduce()
484
510
  self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
485
511
  self.gpu_target = (context.get_context("device_target") == "GPU")
512
+ self.ascend_910a_target = (MSContext.get_instance().get_ascend_soc_version() == 'ascend910')
513
+ self.ascend_910b_target = (MSContext.get_instance().get_ascend_soc_version() in ['ascend910b', 'ascend910_93'])
486
514
  self.loss_scaling_manager = None
515
+ self._ascend_check_overflow_mode = os.environ.get('MS_ASCEND_CHECK_OVERFLOW_MODE')
516
+
487
517
  self.base0 = Tensor(0, mstype.int32)
488
518
  self.reduce_all = P.ReduceAll(keep_dims=False)
489
519
  self.logic_not = P.LogicalNot()
@@ -512,6 +542,26 @@ class BoostTrainOneStepWithLossScaleCell(BoostTrainOneStepCell):
512
542
  else:
513
543
  raise TypeError("The scale_sense must be Cell or Tensor, but got {}".format(type(scale_sense)))
514
544
 
545
+ self.enable_allfinite = True
546
+ runtime_conf = os.environ.get('MS_DEV_RUNTIME_CONF')
547
+ global_jit_config = context.get_jit_config()
548
+ if runtime_conf is not None and ("all_finite:True" in runtime_conf or "all_finite:true" in runtime_conf):
549
+ logger.debug("Enable AllFinite through the environment variable MS_DEV_RUNTIME_CONF.")
550
+ self.enable_allfinite = True
551
+ elif runtime_conf is not None and ("all_finite:False" in runtime_conf or "all_finite:false" in runtime_conf):
552
+ logger.debug("Disable AllFinite through the environment variable MS_DEV_RUNTIME_CONF.")
553
+ self.enable_allfinite = False
554
+ elif global_jit_config:
555
+ logger.debug("Current global jit config is: {}".format(global_jit_config["jit_level"]))
556
+ self.enable_allfinite = global_jit_config["jit_level"] == "O0" or global_jit_config["jit_level"] == "O1"
557
+ if "RANK_TABLE_FILE" in os.environ:
558
+ self.enable_allfinite = False
559
+ if self.ascend_910b_target:
560
+ checker = AscendEnvChecker(None)
561
+ if not checker.check_custom_version():
562
+ logger.debug("Disable AllFinite due to version check failure.")
563
+ self.enable_allfinite = False
564
+
515
565
  def construct(self, *inputs):
516
566
  weights = self.weights
517
567
  loss = self.network(*inputs)
@@ -523,7 +573,7 @@ class BoostTrainOneStepWithLossScaleCell(BoostTrainOneStepCell):
523
573
  cond, scaling_sens = self._enhanced_amp_process_overflow_status(grads)
524
574
  else:
525
575
  scaling_sens = self.scale_sense
526
- status, scaling_sens = self._start_overflow_check(loss, scaling_sens)
576
+ status = Tensor([0] * 8, mstype.int32)
527
577
  scaling_sens_filled = C.ones_like(loss) * F.cast(scaling_sens, F.dtype(loss))
528
578
 
529
579
  grads = self.grad(self.network, weights)(*inputs, scaling_sens_filled)
@@ -646,54 +696,99 @@ class BoostTrainOneStepWithLossScaleCell(BoostTrainOneStepCell):
646
696
  compute_input = F.depend(compute_input, clear_status)
647
697
  return status, compute_input
648
698
 
699
+ def _check_overflow_status_on_infnan_mode(self, grad_overflow_check_func, compute_output):
700
+ """check overflow status on infnan mode."""
701
+ flag_sum = self.hyper_map(F.partial(grad_overflow_check_func), compute_output)
702
+ flag_sum = P.AddN()(flag_sum)
703
+ # convert flag_sum to scalar
704
+ flag_sum = P.Reshape()(flag_sum, (()))
705
+ return flag_sum
706
+
707
+ def _get_distributed_overflow_status_on_infnan_mode(self, grad_overflow_check_func, compute_output):
708
+ """converge the distributed overflow status on infnan mode."""
709
+ flag_sum = self._check_overflow_status_on_infnan_mode(grad_overflow_check_func, compute_output)
710
+
711
+ if self.is_distributed:
712
+ # sum overflow flag over devices
713
+ flag_reduce = self.allreduce(flag_sum)
714
+ overflow = self.less_equal(self.base, flag_reduce)
715
+ else:
716
+ overflow = self.less_equal(self.base, flag_sum)
717
+ return overflow
718
+
719
+ def _get_distributed_overflow_status_on_infnan_enable_allfinite(self, compute_output):
720
+ """check overflow status on infnan kernel mode."""
721
+ overflow = AllFinite()(compute_output)
722
+
723
+ if self.is_distributed:
724
+ overflow = P.Cast()(overflow, mstype.int8)
725
+ overflow = P.Cast()(self.allreduce(overflow), mstype.bool_)
726
+ return overflow
727
+
728
+ def _get_gpu_overflow_status(self, compute_output):
729
+ """get overflow status of gpu."""
730
+ overflow = self._get_distributed_overflow_status_on_infnan_mode(_grad_overflow, compute_output)
731
+ return overflow
732
+
733
+ def _get_ascend_overflow_status_on_infnan_mode(self, compute_output):
734
+ """get overflow status of ascend on infnan mode."""
735
+ overflow = False
736
+ if self.enable_allfinite:
737
+ overflow = self._get_distributed_overflow_status_on_infnan_enable_allfinite(compute_output)
738
+ else:
739
+ overflow = self._get_distributed_overflow_status_on_infnan_mode(_ascend_grad_overflow, compute_output)
740
+ return overflow
741
+
742
+ def _get_ascend_overflow_status_on_saturation_mode(self, status, compute_output):
743
+ """get overflow status of ascend on saturation mode"""
744
+ status = F.depend(status, compute_output)
745
+ get_status = NPUGetFloatStatusV2()(status)
746
+
747
+ if self.is_distributed:
748
+ # sum overflow flag over devices
749
+ flag_reduce = self.allreduce(get_status)
750
+ # get_status not equal to [0]*8 means overflow
751
+ flag = self.equal(self.base0, flag_reduce)
752
+ status = F.depend(status, flag)
753
+ # distributed needs to skip allreduce to avoid its overflow affecting the next step
754
+ clear_status = NPUClearFloatStatusV2()(status)
755
+ flag = F.depend(flag, clear_status)
756
+ overall_finite = self.reduce_all(flag)
757
+ else:
758
+ status = F.depend(status, get_status)
759
+ clear_status = NPUClearFloatStatusV2()(status)
760
+ get_status = F.depend(get_status, clear_status)
761
+ flag = self.equal(self.base0, get_status)
762
+ overall_finite = self.reduce_all(flag)
763
+ overflow = self.logic_not(overall_finite)
764
+ return overflow
765
+
766
+
649
767
  def _get_overflow_status(self, status, compute_output):
650
768
  """
651
769
  Get floating-point overflow status.
652
770
 
653
- Get overflow results after executing the target process for overflow detection.
771
+ Get overflow results after executing the target process for overflow detection. User-defined training network
772
+ based on this class can also call this interface to process the overflow.
654
773
 
655
- Inputs:
656
- - **status** (object) - A status instance used to detect the overflow.
657
- - **compute_output** - Overflow detection should be performed on a certain computation. Set `compute_output`
658
- as the output of the computation, to ensure overflow status is acquired before executing the
659
- computation.
774
+ Args:
775
+ status (object): To control the execution sequence with start_overflow_check, it should be set as the first
776
+ output of start_overflow_check.
777
+ compute_output: Overflow detection should be performed in a certain computation process. Set
778
+ `compute_output` as the output of the computation process.
660
779
 
661
- Outputs:
780
+ Returns:
662
781
  bool, whether the overflow occurs or not.
663
782
  """
664
- if not self.gpu_target:
665
- status = F.depend(status, compute_output)
666
- get_status = NPUGetFloatStatusV2()(status)
667
-
668
- if self.is_distributed:
669
- # sum overflow flag over devices
670
- flag_reduce = self.allreduce(get_status)
671
- # get_status not equal to [0]*8 means overflow
672
- flag = self.equal(self.base0, flag_reduce)
673
- status = F.depend(status, flag)
674
- # distributed needs to skip allreduce to avoid its overflow affecting the next step
675
- clear_status = NPUClearFloatStatusV2()(status)
676
- flag = F.depend(flag, clear_status)
677
- overall_finite = self.reduce_all(flag)
678
- else:
679
- status = F.depend(status, get_status)
680
- clear_status = NPUClearFloatStatusV2()(status)
681
- get_status = F.depend(get_status, clear_status)
682
- flag = self.equal(self.base0, get_status)
683
- overall_finite = self.reduce_all(flag)
684
- overflow = self.logic_not(overall_finite)
685
- else:
686
- flag_sum = self.hyper_map(F.partial(_grad_overflow), compute_output)
687
- flag_sum = P.AddN()(flag_sum)
688
- # convert flag_sum to scalar
689
- flag_sum = P.Reshape()(flag_sum, (()))
690
-
691
- if self.is_distributed:
692
- # sum overflow flag over devices
693
- flag_reduce = self.allreduce(flag_sum)
694
- overflow = self.less_equal(self.base, flag_reduce)
783
+ if self.gpu_target:
784
+ overflow = self._get_gpu_overflow_status(compute_output)
785
+ elif self.ascend_910b_target:
786
+ if self._ascend_check_overflow_mode == "SATURATION_MODE":
787
+ overflow = self._get_ascend_overflow_status_on_saturation_mode(status, compute_output)
695
788
  else:
696
- overflow = self.less_equal(self.base, flag_sum)
789
+ overflow = self._get_ascend_overflow_status_on_infnan_mode(compute_output)
790
+ else: # ascend_910a_target
791
+ overflow = self._get_ascend_overflow_status_on_saturation_mode(status, compute_output)
697
792
  return overflow
698
793
 
699
794
  def _process_loss_scale(self, overflow):
mindspore/c1.dll CHANGED
Binary file
mindspore/c1xx.dll CHANGED
Binary file
mindspore/c2.dll CHANGED
Binary file
@@ -15,7 +15,8 @@
15
15
  """Top-level reference to dtype of common module."""
16
16
  from __future__ import absolute_import
17
17
  from mindspore.common import dtype
18
- from mindspore.common.api import ms_function, ms_memory_recycle, ms_class, jit, jit_class, _no_grad, flops_collection
18
+ from mindspore.common.api import ms_memory_recycle, jit, jit_class, _no_grad, \
19
+ flops_collection, set_recursion_limit
19
20
  from mindspore.common.dtype import Type, int8, byte, int16, short, int32, intc, int64, intp, \
20
21
  uint8, ubyte, uint16, ushort, uint32, uintc, uint64, uintp, float16, half, \
21
22
  float32, single, float64, bfloat16, double, bool_, float_, list_, tuple_, int_, \
@@ -38,6 +39,7 @@ from mindspore.common import generator
38
39
  from mindspore.common.generator import (
39
40
  Generator, default_generator, seed, manual_seed, initial_seed, get_rng_state, set_rng_state)
40
41
  from mindspore.ops.function.array_func import is_tensor, from_numpy
42
+ from mindspore.common._grad_function import _Function
41
43
 
42
44
  # symbols from dtype
43
45
  __all__ = [
@@ -69,18 +71,19 @@ __all__ = [
69
71
 
70
72
  __all__.extend([
71
73
  "tensor", "Tensor", "RowTensor", "SparseTensor", "COOTensor", "CSRTensor", # tensor
72
- "ms_function", "ms_class", 'jit', 'jit_class', '_no_grad', # api
74
+ 'jit', 'jit_class', '_no_grad', # api
73
75
  "Parameter", "ParameterTuple", # parameter
74
76
  "dtype",
75
77
  "set_seed", "get_seed", "manual_seed", # random seed
76
78
  "set_dump",
77
79
  "ms_memory_recycle",
80
+ "set_recursion_limit",
78
81
  "mutable", "JitConfig",
79
82
  "flops_collection",
80
83
  "lazy_inline", "load_mindir", "save_mindir",
81
84
  "no_inline",
82
85
  "Symbol",
83
86
  "recompute",
84
- "is_tensor", "from_numpy",
87
+ "is_tensor", "from_numpy", "_Function"
85
88
  ])
86
89
  __all__.extend(generator.__all__)
@@ -0,0 +1,56 @@
1
+ # Copyright 2025 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+
16
+ """Defines custom autograd function with functional form."""
17
+
18
+ from typing import Any
19
+ from mindspore._c_expression import FunctionBase as FunctionBase_
20
+ from mindspore.common.tensor import Tensor
21
+
22
+ __all__ = ['_Function']
23
+
24
+ class _Function(FunctionBase_):
25
+ """
26
+ A Class provides the ability to custom autograd function.
27
+
28
+ Note:
29
+ It is only supported in pynative mode.
30
+
31
+ Supported Platforms:
32
+ ``Ascend`` ``GPU`` ``CPU``
33
+ """
34
+ @staticmethod
35
+ def forward(ctx: Any, *args: Any, **kwars: Any) -> Any:
36
+ raise NotImplementedError("forward function should be customized.")
37
+
38
+ @staticmethod
39
+ def backward(ctx: Any, *grad_outputs: Any) -> Any:
40
+ raise NotImplementedError("backward function should be customized.")
41
+
42
+ @classmethod
43
+ def apply(cls, *args, **kwargs):
44
+ return super().apply(cls, *args, **kwargs)
45
+
46
+ def save_for_backward(self, *tensors: Tensor):
47
+ self.saved_tensors = tensors
48
+
49
+ def mark_dirty(self, *args: Tensor):
50
+ self.dirty_tensors = args
51
+
52
+ def mark_non_differentiable(self, *args: Tensor):
53
+ self.non_differentiable = args
54
+
55
+ def set_materialize_grads(self, value: bool):
56
+ self.materialize_grads = value