mindspore 2.4.10__cp310-cp310-win_amd64.whl → 2.6.0rc1__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (602) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  3. mindspore/Newtonsoft.Json.dll +0 -0
  4. mindspore/__init__.py +13 -6
  5. mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
  6. mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
  7. mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
  8. mindspore/_check_jit_forbidden_api.py +3 -0
  9. mindspore/_checkparam.py +3 -38
  10. mindspore/_deprecated/__init__.py +17 -0
  11. mindspore/_deprecated/jit.py +198 -0
  12. mindspore/_extends/builtin_operations.py +1 -1
  13. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
  14. mindspore/_extends/parse/__init__.py +6 -7
  15. mindspore/_extends/parse/compile_config.py +83 -0
  16. mindspore/_extends/parse/deprecated/__init__.py +0 -0
  17. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +394 -0
  18. mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
  19. mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
  20. mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
  21. mindspore/_extends/parse/parser.py +46 -197
  22. mindspore/_extends/parse/resources.py +1 -5
  23. mindspore/_extends/parse/standard_method.py +217 -98
  24. mindspore/_extends/pijit/__init__.py +2 -2
  25. mindspore/_extends/pijit/pijit_func_white_list.py +17 -12
  26. mindspore/_extends/pijit/tensor_func_list.py +27 -0
  27. mindspore/_extends/utils.py +1 -1
  28. mindspore/amp.py +11 -5
  29. mindspore/atlprov.dll +0 -0
  30. mindspore/avcodec-59.dll +0 -0
  31. mindspore/avdevice-59.dll +0 -0
  32. mindspore/avfilter-8.dll +0 -0
  33. mindspore/avformat-59.dll +0 -0
  34. mindspore/avutil-57.dll +0 -0
  35. mindspore/boost/__init__.py +2 -2
  36. mindspore/boost/base.py +3 -7
  37. mindspore/boost/boost_cell_wrapper.py +138 -43
  38. mindspore/c1.dll +0 -0
  39. mindspore/c1xx.dll +0 -0
  40. mindspore/c2.dll +0 -0
  41. mindspore/common/__init__.py +6 -3
  42. mindspore/common/_grad_function.py +56 -0
  43. mindspore/common/_pijit_context.py +14 -5
  44. mindspore/common/_register_for_tensor.py +1 -2
  45. mindspore/common/_stub_tensor.py +30 -14
  46. mindspore/common/_tensor_cpp_method.py +17 -0
  47. mindspore/common/_tensor_docs.py +4760 -0
  48. mindspore/common/api.py +435 -371
  49. mindspore/common/auto_dynamic_shape.py +41 -44
  50. mindspore/common/dtype.py +39 -36
  51. mindspore/common/dump.py +9 -6
  52. mindspore/common/file_system.py +9 -1
  53. mindspore/common/generator.py +2 -0
  54. mindspore/common/hook_handle.py +6 -2
  55. mindspore/common/initializer.py +13 -10
  56. mindspore/common/jit_begin_end.py +94 -0
  57. mindspore/common/jit_config.py +6 -1
  58. mindspore/common/jit_context.py +76 -0
  59. mindspore/common/jit_trace.py +378 -0
  60. mindspore/common/lazy_inline.py +9 -3
  61. mindspore/common/mindir_util.py +10 -2
  62. mindspore/common/mutable.py +5 -4
  63. mindspore/common/parameter.py +135 -52
  64. mindspore/common/seed.py +2 -2
  65. mindspore/common/sparse_tensor.py +23 -17
  66. mindspore/common/tensor.py +951 -1992
  67. mindspore/communication/__init__.py +7 -5
  68. mindspore/communication/_comm_helper.py +52 -2
  69. mindspore/communication/comm_func.py +240 -181
  70. mindspore/communication/management.py +95 -26
  71. mindspore/context.py +314 -566
  72. mindspore/dataset/__init__.py +65 -37
  73. mindspore/dataset/audio/__init__.py +2 -8
  74. mindspore/dataset/audio/transforms.py +3 -17
  75. mindspore/dataset/callback/ds_callback.py +2 -1
  76. mindspore/dataset/core/config.py +87 -6
  77. mindspore/dataset/engine/cache_admin.py +3 -3
  78. mindspore/dataset/engine/cache_client.py +6 -5
  79. mindspore/dataset/engine/datasets.py +292 -267
  80. mindspore/dataset/engine/datasets_audio.py +22 -8
  81. mindspore/dataset/engine/datasets_standard_format.py +46 -27
  82. mindspore/dataset/engine/datasets_text.py +78 -48
  83. mindspore/dataset/engine/datasets_user_defined.py +182 -116
  84. mindspore/dataset/engine/datasets_vision.py +120 -44
  85. mindspore/dataset/engine/iterators.py +283 -63
  86. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
  87. mindspore/dataset/engine/obs/util.py +8 -0
  88. mindspore/dataset/engine/queue.py +40 -0
  89. mindspore/dataset/engine/samplers.py +289 -43
  90. mindspore/dataset/engine/serializer_deserializer.py +3 -2
  91. mindspore/dataset/engine/validators.py +53 -11
  92. mindspore/dataset/text/__init__.py +7 -6
  93. mindspore/dataset/text/transforms.py +6 -5
  94. mindspore/dataset/text/utils.py +3 -3
  95. mindspore/dataset/transforms/__init__.py +0 -9
  96. mindspore/dataset/transforms/py_transforms_util.py +17 -0
  97. mindspore/dataset/transforms/transforms.py +31 -14
  98. mindspore/dataset/utils/browse_dataset.py +1 -1
  99. mindspore/dataset/vision/__init__.py +2 -9
  100. mindspore/dataset/vision/transforms.py +202 -158
  101. mindspore/dataset/vision/utils.py +7 -5
  102. mindspore/dataset/vision/validators.py +1 -2
  103. mindspore/device_context/__init__.py +21 -0
  104. mindspore/device_context/ascend/__init__.py +25 -0
  105. mindspore/device_context/ascend/device.py +72 -0
  106. mindspore/device_context/ascend/op_debug.py +153 -0
  107. mindspore/device_context/ascend/op_precision.py +193 -0
  108. mindspore/device_context/ascend/op_tuning.py +123 -0
  109. mindspore/{ops_generate/gen_constants.py → device_context/cpu/__init__.py} +6 -17
  110. mindspore/device_context/cpu/device.py +62 -0
  111. mindspore/device_context/cpu/op_tuning.py +43 -0
  112. mindspore/device_context/gpu/__init__.py +21 -0
  113. mindspore/device_context/gpu/device.py +70 -0
  114. mindspore/device_context/gpu/op_precision.py +67 -0
  115. mindspore/device_context/gpu/op_tuning.py +175 -0
  116. mindspore/device_manager.py +170 -0
  117. mindspore/dnnl.dll +0 -0
  118. mindspore/dpcmi.dll +0 -0
  119. mindspore/experimental/es/embedding_service.py +35 -27
  120. mindspore/experimental/llm_boost/__init__.py +1 -0
  121. mindspore/experimental/llm_boost/ascend_native/__init__.py +22 -0
  122. mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +211 -0
  123. mindspore/experimental/llm_boost/ascend_native/llm_boost.py +52 -0
  124. mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
  125. mindspore/experimental/llm_boost/atb/llama_boost.py +6 -1
  126. mindspore/experimental/llm_boost/register.py +1 -0
  127. mindspore/experimental/map_parameter.py +4 -4
  128. mindspore/experimental/optim/adadelta.py +6 -6
  129. mindspore/experimental/optim/adagrad.py +4 -4
  130. mindspore/experimental/optim/adam.py +7 -0
  131. mindspore/experimental/optim/adamax.py +4 -4
  132. mindspore/experimental/optim/adamw.py +4 -0
  133. mindspore/experimental/optim/asgd.py +1 -1
  134. mindspore/experimental/optim/lr_scheduler.py +73 -46
  135. mindspore/experimental/optim/radam.py +34 -31
  136. mindspore/experimental/optim/rprop.py +1 -1
  137. mindspore/experimental/optim/sgd.py +1 -1
  138. mindspore/hal/contiguous_tensors_handle.py +6 -10
  139. mindspore/hal/device.py +55 -53
  140. mindspore/hal/event.py +52 -52
  141. mindspore/hal/memory.py +157 -117
  142. mindspore/hal/stream.py +150 -109
  143. mindspore/include/api/context.h +0 -1
  144. mindspore/include/dataset/constants.h +7 -4
  145. mindspore/include/dataset/execute.h +2 -2
  146. mindspore/jpeg62.dll +0 -0
  147. mindspore/log.py +50 -0
  148. mindspore/mindrecord/__init__.py +21 -8
  149. mindspore/mindrecord/config.py +17 -316
  150. mindspore/mindrecord/filereader.py +1 -9
  151. mindspore/mindrecord/filewriter.py +5 -15
  152. mindspore/mindrecord/mindpage.py +1 -9
  153. mindspore/mindspore_backend_common.dll +0 -0
  154. mindspore/mindspore_backend_manager.dll +0 -0
  155. mindspore/mindspore_common.dll +0 -0
  156. mindspore/mindspore_core.dll +0 -0
  157. mindspore/mindspore_dump.dll +0 -0
  158. mindspore/mindspore_frontend.dll +0 -0
  159. mindspore/mindspore_glog.dll +0 -0
  160. mindspore/mindspore_memory_pool.dll +0 -0
  161. mindspore/mindspore_ms_backend.dll +0 -0
  162. mindspore/mindspore_ops.dll +0 -0
  163. mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
  164. mindspore/mindspore_ops_kernel_common.dll +0 -0
  165. mindspore/mindspore_profiler.dll +0 -0
  166. mindspore/mindspore_pyboost.dll +0 -0
  167. mindspore/mindspore_pynative.dll +0 -0
  168. mindspore/mindspore_res_manager.dll +0 -0
  169. mindspore/mindspore_runtime_pipeline.dll +0 -0
  170. mindspore/mint/__init__.py +796 -759
  171. mindspore/mint/distributed/__init__.py +70 -4
  172. mindspore/mint/distributed/distributed.py +2679 -44
  173. mindspore/mint/linalg/__init__.py +8 -0
  174. mindspore/mint/nn/__init__.py +743 -22
  175. mindspore/mint/nn/functional.py +716 -23
  176. mindspore/mint/nn/layer/__init__.py +21 -4
  177. mindspore/mint/nn/layer/_functions.py +334 -0
  178. mindspore/mint/nn/layer/activation.py +276 -1
  179. mindspore/mint/nn/layer/basic.py +123 -0
  180. mindspore/mint/nn/layer/conv.py +921 -0
  181. mindspore/mint/nn/layer/normalization.py +223 -28
  182. mindspore/mint/nn/layer/padding.py +797 -0
  183. mindspore/mint/nn/layer/pooling.py +235 -0
  184. mindspore/mint/optim/__init__.py +3 -1
  185. mindspore/mint/optim/adam.py +223 -0
  186. mindspore/mint/optim/adamw.py +26 -19
  187. mindspore/mint/optim/sgd.py +171 -0
  188. mindspore/mint/special/__init__.py +2 -1
  189. mindspore/msobj140.dll +0 -0
  190. mindspore/mspdb140.dll +0 -0
  191. mindspore/mspdbcore.dll +0 -0
  192. mindspore/mspdbst.dll +0 -0
  193. mindspore/mspft140.dll +0 -0
  194. mindspore/msvcdis140.dll +0 -0
  195. mindspore/msvcp140_1.dll +0 -0
  196. mindspore/msvcp140_2.dll +0 -0
  197. mindspore/msvcp140_atomic_wait.dll +0 -0
  198. mindspore/msvcp140_codecvt_ids.dll +0 -0
  199. mindspore/multiprocessing/__init__.py +5 -0
  200. mindspore/nn/__init__.py +4 -1
  201. mindspore/nn/cell.py +1370 -189
  202. mindspore/nn/dynamic_lr.py +2 -1
  203. mindspore/nn/layer/activation.py +29 -27
  204. mindspore/nn/layer/basic.py +51 -35
  205. mindspore/nn/layer/channel_shuffle.py +3 -3
  206. mindspore/nn/layer/container.py +1 -1
  207. mindspore/nn/layer/conv.py +22 -17
  208. mindspore/nn/layer/embedding.py +12 -11
  209. mindspore/nn/layer/normalization.py +56 -49
  210. mindspore/nn/layer/padding.py +4 -3
  211. mindspore/nn/layer/pooling.py +120 -42
  212. mindspore/nn/layer/rnn_cells.py +1 -1
  213. mindspore/nn/layer/rnns.py +2 -1
  214. mindspore/nn/layer/timedistributed.py +5 -5
  215. mindspore/nn/layer/transformer.py +59 -36
  216. mindspore/nn/learning_rate_schedule.py +8 -4
  217. mindspore/nn/loss/loss.py +58 -55
  218. mindspore/nn/optim/ada_grad.py +7 -5
  219. mindspore/nn/optim/adadelta.py +11 -9
  220. mindspore/nn/optim/adafactor.py +1 -1
  221. mindspore/nn/optim/adam.py +17 -13
  222. mindspore/nn/optim/adamax.py +8 -7
  223. mindspore/nn/optim/adasum.py +5 -5
  224. mindspore/nn/optim/asgd.py +1 -1
  225. mindspore/nn/optim/ftrl.py +11 -9
  226. mindspore/nn/optim/lamb.py +1 -1
  227. mindspore/nn/optim/lars.py +1 -4
  228. mindspore/nn/optim/lazyadam.py +12 -10
  229. mindspore/nn/optim/momentum.py +7 -6
  230. mindspore/nn/optim/optimizer.py +3 -3
  231. mindspore/nn/optim/proximal_ada_grad.py +12 -10
  232. mindspore/nn/optim/rmsprop.py +13 -12
  233. mindspore/nn/optim/rprop.py +11 -9
  234. mindspore/nn/optim/sgd.py +9 -6
  235. mindspore/nn/optim/tft_wrapper.py +5 -2
  236. mindspore/nn/optim/thor.py +2 -1
  237. mindspore/nn/probability/bijector/bijector.py +17 -11
  238. mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
  239. mindspore/nn/probability/bijector/invert.py +2 -2
  240. mindspore/nn/probability/bijector/scalar_affine.py +3 -3
  241. mindspore/nn/probability/bijector/softplus.py +3 -2
  242. mindspore/nn/probability/distribution/beta.py +3 -3
  243. mindspore/nn/probability/distribution/categorical.py +1 -1
  244. mindspore/nn/probability/distribution/cauchy.py +4 -2
  245. mindspore/nn/probability/distribution/exponential.py +6 -7
  246. mindspore/nn/probability/distribution/gamma.py +2 -2
  247. mindspore/nn/probability/distribution/gumbel.py +2 -2
  248. mindspore/nn/probability/distribution/half_normal.py +5 -3
  249. mindspore/nn/probability/distribution/logistic.py +5 -3
  250. mindspore/nn/probability/distribution/poisson.py +1 -1
  251. mindspore/nn/probability/distribution/uniform.py +5 -3
  252. mindspore/nn/reinforcement/_tensors_queue.py +1 -1
  253. mindspore/nn/reinforcement/tensor_array.py +1 -1
  254. mindspore/nn/utils/init.py +13 -11
  255. mindspore/nn/wrap/__init__.py +6 -6
  256. mindspore/nn/wrap/cell_wrapper.py +181 -122
  257. mindspore/nn/wrap/grad_reducer.py +45 -36
  258. mindspore/nn/wrap/loss_scale.py +6 -7
  259. mindspore/numpy/array_creations.py +63 -65
  260. mindspore/numpy/array_ops.py +149 -144
  261. mindspore/numpy/logic_ops.py +41 -42
  262. mindspore/numpy/math_ops.py +365 -363
  263. mindspore/numpy/utils.py +17 -18
  264. mindspore/numpy/utils_const.py +5 -6
  265. mindspore/opencv_core452.dll +0 -0
  266. mindspore/opencv_imgcodecs452.dll +0 -0
  267. mindspore/opencv_imgproc452.dll +0 -0
  268. mindspore/ops/__init__.py +5 -3
  269. mindspore/ops/_grad_experimental/grad_comm_ops.py +112 -16
  270. mindspore/ops/_grad_experimental/grad_debug_ops.py +14 -2
  271. mindspore/ops/_grad_experimental/grad_inner_ops.py +9 -0
  272. mindspore/ops/_grad_experimental/grad_math_ops.py +2 -1
  273. mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
  274. mindspore/ops/_op_impl/cpu/__init__.py +1 -0
  275. mindspore/ops/_op_impl/cpu/raise_op.py +28 -0
  276. mindspore/ops/_register_for_op.py +0 -11
  277. mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
  278. mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -65
  279. mindspore/ops/_vmap/vmap_array_ops.py +27 -25
  280. mindspore/ops/_vmap/vmap_base.py +0 -2
  281. mindspore/ops/_vmap/vmap_grad_nn_ops.py +21 -14
  282. mindspore/ops/_vmap/vmap_math_ops.py +15 -16
  283. mindspore/ops/_vmap/vmap_nn_ops.py +29 -42
  284. mindspore/ops/auto_generate/__init__.py +4 -3
  285. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +236 -46
  286. mindspore/ops/auto_generate/gen_extend_func.py +764 -124
  287. mindspore/ops/auto_generate/gen_ops_def.py +4018 -2264
  288. mindspore/ops/auto_generate/gen_ops_prim.py +15463 -5037
  289. mindspore/ops/auto_generate/pyboost_inner_prim.py +221 -87
  290. mindspore/ops/composite/__init__.py +2 -1
  291. mindspore/ops/composite/base.py +20 -25
  292. mindspore/ops/composite/math_ops.py +6 -16
  293. mindspore/ops/composite/multitype_ops/__init__.py +5 -2
  294. mindspore/ops/composite/multitype_ops/_compile_utils.py +228 -30
  295. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
  296. mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
  297. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
  298. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
  299. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
  300. mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
  301. mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
  302. mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
  303. mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
  304. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
  305. mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
  306. mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
  307. mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
  308. mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
  309. mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
  310. mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
  311. mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
  312. mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
  313. mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
  314. mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
  315. mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
  316. mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
  317. mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
  318. mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
  319. mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
  320. mindspore/ops/composite/multitype_ops/pow_impl.py +2 -30
  321. mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
  322. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  323. mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
  324. mindspore/ops/function/__init__.py +40 -2
  325. mindspore/ops/function/_add_attr_func.py +58 -0
  326. mindspore/ops/function/array_func.py +2089 -2403
  327. mindspore/ops/function/clip_func.py +80 -23
  328. mindspore/ops/function/debug_func.py +57 -57
  329. mindspore/ops/function/grad/__init__.py +1 -0
  330. mindspore/ops/function/grad/grad_func.py +104 -71
  331. mindspore/ops/function/image_func.py +2 -2
  332. mindspore/ops/function/linalg_func.py +47 -78
  333. mindspore/ops/function/math_func.py +4501 -3802
  334. mindspore/ops/function/nn_func.py +1726 -620
  335. mindspore/ops/function/other_func.py +159 -1
  336. mindspore/ops/function/parameter_func.py +18 -84
  337. mindspore/ops/function/random_func.py +440 -387
  338. mindspore/ops/function/reshard_func.py +4 -70
  339. mindspore/ops/function/sparse_func.py +3 -3
  340. mindspore/ops/function/sparse_unary_func.py +6 -6
  341. mindspore/ops/function/spectral_func.py +25 -58
  342. mindspore/ops/function/vmap_func.py +24 -17
  343. mindspore/ops/functional.py +22 -7
  344. mindspore/ops/functional_overload.py +1440 -0
  345. mindspore/ops/op_info_register.py +32 -244
  346. mindspore/ops/operations/__init__.py +13 -7
  347. mindspore/ops/operations/_custom_ops_utils.py +247 -0
  348. mindspore/ops/operations/_embedding_cache_ops.py +4 -4
  349. mindspore/ops/operations/_grad_ops.py +2 -43
  350. mindspore/ops/operations/_infer_ops.py +2 -1
  351. mindspore/ops/operations/_inner_ops.py +43 -84
  352. mindspore/ops/operations/_ms_kernel.py +4 -10
  353. mindspore/ops/operations/_rl_inner_ops.py +1 -1
  354. mindspore/ops/operations/_scalar_ops.py +3 -2
  355. mindspore/ops/operations/_sequence_ops.py +1 -1
  356. mindspore/ops/operations/_tensor_array.py +1 -1
  357. mindspore/ops/operations/array_ops.py +81 -324
  358. mindspore/ops/operations/comm_ops.py +154 -108
  359. mindspore/ops/operations/custom_ops.py +232 -78
  360. mindspore/ops/operations/debug_ops.py +153 -59
  361. mindspore/ops/operations/inner_ops.py +7 -5
  362. mindspore/ops/operations/linalg_ops.py +1 -57
  363. mindspore/ops/operations/manually_defined/_inner.py +1 -1
  364. mindspore/ops/operations/manually_defined/ops_def.py +928 -180
  365. mindspore/ops/operations/math_ops.py +32 -234
  366. mindspore/ops/operations/nn_ops.py +210 -498
  367. mindspore/ops/operations/other_ops.py +62 -9
  368. mindspore/ops/operations/random_ops.py +13 -7
  369. mindspore/ops/operations/reshard_ops.py +1 -1
  370. mindspore/ops/operations/sparse_ops.py +2 -2
  371. mindspore/ops/primitive.py +66 -53
  372. mindspore/ops/tensor_method.py +1888 -0
  373. mindspore/ops_generate/__init__.py +0 -5
  374. mindspore/ops_generate/aclnn/__init__.py +0 -0
  375. mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +135 -0
  376. mindspore/ops_generate/aclnn/gen_aclnn_implement.py +257 -0
  377. mindspore/ops_generate/api/__init__.py +0 -0
  378. mindspore/ops_generate/api/add_tensor_docs_generator.py +56 -0
  379. mindspore/ops_generate/api/cpp_create_prim_instance_helper_generator.py +105 -0
  380. mindspore/ops_generate/api/functional_map_cpp_generator.py +504 -0
  381. mindspore/ops_generate/api/functional_overload_py_generator.py +112 -0
  382. mindspore/ops_generate/api/functions_cc_generator.py +237 -0
  383. mindspore/ops_generate/api/gen_api.py +103 -0
  384. mindspore/ops_generate/api/op_api_proto.py +235 -0
  385. mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +461 -0
  386. mindspore/ops_generate/common/__init__.py +0 -0
  387. mindspore/ops_generate/common/base_generator.py +11 -0
  388. mindspore/ops_generate/common/gen_constants.py +91 -0
  389. mindspore/ops_generate/common/gen_utils.py +348 -0
  390. mindspore/ops_generate/common/op_proto.py +473 -0
  391. mindspore/ops_generate/common/template.py +523 -0
  392. mindspore/ops_generate/gen_ops.py +22 -1069
  393. mindspore/ops_generate/op_def/__init__.py +0 -0
  394. mindspore/ops_generate/op_def/gen_op_def.py +90 -0
  395. mindspore/ops_generate/op_def/lite_ops_cpp_generator.py +191 -0
  396. mindspore/ops_generate/op_def/ops_def_cc_generator.py +299 -0
  397. mindspore/ops_generate/op_def/ops_def_h_generator.py +74 -0
  398. mindspore/ops_generate/op_def/ops_name_h_generator.py +83 -0
  399. mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
  400. mindspore/ops_generate/op_def_py/__init__.py +0 -0
  401. mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
  402. mindspore/ops_generate/op_def_py/op_def_py_generator.py +132 -0
  403. mindspore/ops_generate/op_def_py/op_prim_py_generator.py +489 -0
  404. mindspore/ops_generate/pyboost/__init__.py +0 -0
  405. mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +139 -0
  406. mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +93 -0
  407. mindspore/ops_generate/pyboost/gen_pyboost_func.py +175 -0
  408. mindspore/ops_generate/pyboost/op_template_parser.py +517 -0
  409. mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +407 -0
  410. mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +100 -0
  411. mindspore/ops_generate/pyboost/pyboost_functions_py_generator.py +148 -0
  412. mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +155 -0
  413. mindspore/ops_generate/pyboost/pyboost_inner_prim_generator.py +132 -0
  414. mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +272 -0
  415. mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +938 -0
  416. mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +357 -0
  417. mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +179 -36
  418. mindspore/ops_generate/resources/__init__.py +0 -0
  419. mindspore/ops_generate/resources/resource_list.py +30 -0
  420. mindspore/ops_generate/resources/resource_loader.py +36 -0
  421. mindspore/ops_generate/resources/resource_manager.py +64 -0
  422. mindspore/ops_generate/resources/yaml_loader.py +88 -0
  423. mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
  424. mindspore/parallel/__init__.py +7 -3
  425. mindspore/parallel/_auto_parallel_context.py +152 -34
  426. mindspore/parallel/_cell_wrapper.py +130 -15
  427. mindspore/parallel/_parallel_serialization.py +107 -5
  428. mindspore/parallel/_ps_context.py +1 -1
  429. mindspore/parallel/_recovery_context.py +7 -2
  430. mindspore/parallel/_tensor.py +142 -18
  431. mindspore/parallel/_utils.py +199 -23
  432. mindspore/parallel/algo_parameter_config.py +4 -4
  433. mindspore/parallel/auto_parallel.py +732 -0
  434. mindspore/parallel/checkpoint_convert.py +159 -0
  435. mindspore/parallel/checkpoint_transform.py +698 -35
  436. mindspore/parallel/cluster/process_entity/_api.py +276 -50
  437. mindspore/parallel/cluster/process_entity/_utils.py +41 -6
  438. mindspore/parallel/cluster/run.py +21 -4
  439. mindspore/parallel/function/__init__.py +24 -0
  440. mindspore/parallel/function/reshard_func.py +259 -0
  441. mindspore/parallel/nn/__init__.py +25 -0
  442. mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
  443. mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
  444. mindspore/parallel/parameter_broadcast.py +25 -14
  445. mindspore/parallel/shard.py +137 -58
  446. mindspore/parallel/transform_safetensors.py +363 -305
  447. mindspore/pgodb140.dll +0 -0
  448. mindspore/pgort140.dll +0 -0
  449. mindspore/profiler/__init__.py +22 -5
  450. mindspore/profiler/analysis/__init__.py +0 -0
  451. mindspore/profiler/analysis/parser/__init__.py +0 -0
  452. mindspore/profiler/analysis/parser/ascend_cann_parser.py +170 -0
  453. mindspore/profiler/analysis/parser/base_parser.py +158 -0
  454. mindspore/profiler/analysis/parser/framework_cann_relation_parser.py +45 -0
  455. mindspore/profiler/analysis/parser/ms_framework_parser.py +142 -0
  456. mindspore/profiler/analysis/parser/ms_minddata_parser.py +145 -0
  457. mindspore/profiler/analysis/parser/timeline_assembly_factory/__init__.py +0 -0
  458. mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +264 -0
  459. mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +40 -0
  460. mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +106 -0
  461. mindspore/profiler/analysis/parser/timeline_creator/__init__.py +0 -0
  462. mindspore/profiler/analysis/parser/timeline_creator/base_timeline_creator.py +44 -0
  463. mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +90 -0
  464. mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +76 -0
  465. mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +103 -0
  466. mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +134 -0
  467. mindspore/profiler/analysis/parser/timeline_event/__init__.py +0 -0
  468. mindspore/profiler/analysis/parser/timeline_event/base_event.py +233 -0
  469. mindspore/profiler/analysis/parser/timeline_event/cpu_op_event.py +47 -0
  470. mindspore/profiler/analysis/parser/timeline_event/flow_event.py +36 -0
  471. mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +415 -0
  472. mindspore/profiler/analysis/parser/timeline_event/msprof_event.py +73 -0
  473. mindspore/profiler/analysis/parser/timeline_event/scope_layer_event.py +53 -0
  474. mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +146 -0
  475. mindspore/profiler/analysis/task_manager.py +131 -0
  476. mindspore/profiler/analysis/time_converter.py +84 -0
  477. mindspore/profiler/analysis/viewer/__init__.py +0 -0
  478. mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +372 -0
  479. mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +87 -0
  480. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +250 -0
  481. mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +320 -0
  482. mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +327 -0
  483. mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +376 -0
  484. mindspore/profiler/analysis/viewer/ascend_timeline_viewer.py +58 -0
  485. mindspore/profiler/analysis/viewer/base_viewer.py +26 -0
  486. mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +96 -0
  487. mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +581 -0
  488. mindspore/profiler/analysis/work_flow.py +73 -0
  489. mindspore/profiler/common/ascend_msprof_exporter.py +139 -0
  490. mindspore/profiler/common/command_executor.py +90 -0
  491. mindspore/profiler/common/constant.py +186 -3
  492. mindspore/profiler/common/file_manager.py +208 -0
  493. mindspore/profiler/common/log.py +130 -0
  494. mindspore/profiler/common/msprof_cmd_tool.py +221 -0
  495. mindspore/profiler/common/path_manager.py +395 -0
  496. mindspore/profiler/common/process_bar.py +168 -0
  497. mindspore/profiler/common/process_pool.py +9 -3
  498. mindspore/profiler/common/profiler_context.py +500 -0
  499. mindspore/profiler/common/profiler_info.py +304 -0
  500. mindspore/profiler/common/profiler_meta_data.py +74 -0
  501. mindspore/profiler/common/profiler_output_path.py +284 -0
  502. mindspore/profiler/common/profiler_parameters.py +251 -0
  503. mindspore/profiler/common/profiler_path_manager.py +179 -0
  504. mindspore/profiler/common/record_function.py +76 -0
  505. mindspore/profiler/common/tlv_decoder.py +76 -0
  506. mindspore/profiler/common/util.py +75 -2
  507. mindspore/profiler/dynamic_profiler.py +341 -75
  508. mindspore/profiler/envprofiler.py +163 -0
  509. mindspore/profiler/experimental_config.py +197 -0
  510. mindspore/profiler/mstx.py +242 -0
  511. mindspore/profiler/platform/__init__.py +21 -0
  512. mindspore/profiler/platform/base_profiler.py +40 -0
  513. mindspore/profiler/platform/cpu_profiler.py +124 -0
  514. mindspore/profiler/platform/gpu_profiler.py +74 -0
  515. mindspore/profiler/platform/npu_profiler.py +335 -0
  516. mindspore/profiler/profiler.py +1073 -90
  517. mindspore/profiler/profiler_action_controller.py +187 -0
  518. mindspore/profiler/profiler_interface.py +118 -0
  519. mindspore/profiler/schedule.py +243 -0
  520. mindspore/rewrite/api/node.py +15 -13
  521. mindspore/rewrite/api/symbol_tree.py +2 -3
  522. mindspore/run_check/_check_version.py +27 -20
  523. mindspore/run_check/run_check.py +1 -1
  524. mindspore/runtime/__init__.py +37 -0
  525. mindspore/runtime/device.py +27 -0
  526. mindspore/runtime/event.py +209 -0
  527. mindspore/runtime/executor.py +177 -0
  528. mindspore/runtime/memory.py +409 -0
  529. mindspore/runtime/stream.py +460 -0
  530. mindspore/runtime/thread_bind_core.py +401 -0
  531. mindspore/safeguard/rewrite_obfuscation.py +12 -9
  532. mindspore/swresample-4.dll +0 -0
  533. mindspore/swscale-6.dll +0 -0
  534. mindspore/tbbmalloc.dll +0 -0
  535. mindspore/tinyxml2.dll +0 -0
  536. mindspore/train/__init__.py +8 -8
  537. mindspore/train/_utils.py +88 -25
  538. mindspore/train/amp.py +9 -5
  539. mindspore/train/callback/__init__.py +2 -2
  540. mindspore/train/callback/_callback.py +2 -16
  541. mindspore/train/callback/_checkpoint.py +53 -55
  542. mindspore/train/callback/_cluster_monitor.py +14 -18
  543. mindspore/train/callback/_early_stop.py +1 -1
  544. mindspore/train/callback/_flops_collector.py +103 -68
  545. mindspore/train/callback/_history.py +8 -5
  546. mindspore/train/callback/_lambda_callback.py +2 -2
  547. mindspore/train/callback/_landscape.py +0 -3
  548. mindspore/train/callback/_loss_monitor.py +2 -1
  549. mindspore/train/callback/_on_request_exit.py +6 -5
  550. mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
  551. mindspore/train/callback/_summary_collector.py +52 -19
  552. mindspore/train/callback/_time_monitor.py +2 -1
  553. mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +204 -107
  554. mindspore/train/data_sink.py +25 -2
  555. mindspore/train/dataset_helper.py +15 -16
  556. mindspore/train/loss_scale_manager.py +8 -7
  557. mindspore/train/metrics/accuracy.py +3 -3
  558. mindspore/train/metrics/confusion_matrix.py +9 -9
  559. mindspore/train/metrics/error.py +3 -3
  560. mindspore/train/metrics/hausdorff_distance.py +4 -4
  561. mindspore/train/metrics/mean_surface_distance.py +3 -3
  562. mindspore/train/metrics/metric.py +0 -12
  563. mindspore/train/metrics/occlusion_sensitivity.py +4 -2
  564. mindspore/train/metrics/precision.py +11 -10
  565. mindspore/train/metrics/recall.py +9 -9
  566. mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
  567. mindspore/train/mind_ir_pb2.py +174 -46
  568. mindspore/train/model.py +184 -113
  569. mindspore/train/serialization.py +622 -978
  570. mindspore/train/summary/_summary_adapter.py +2 -2
  571. mindspore/train/summary/summary_record.py +2 -3
  572. mindspore/train/train_thor/model_thor.py +1 -1
  573. mindspore/turbojpeg.dll +0 -0
  574. mindspore/utils/__init__.py +6 -3
  575. mindspore/utils/dryrun.py +140 -0
  576. mindspore/utils/hooks.py +81 -0
  577. mindspore/utils/runtime_execution_order_check.py +550 -0
  578. mindspore/utils/utils.py +138 -4
  579. mindspore/vcmeta.dll +0 -0
  580. mindspore/vcruntime140.dll +0 -0
  581. mindspore/vcruntime140_1.dll +0 -0
  582. mindspore/version.py +1 -1
  583. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/METADATA +3 -3
  584. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/RECORD +587 -418
  585. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/entry_points.txt +1 -1
  586. mindspore/_install_custom.py +0 -43
  587. mindspore/common/_register_for_adapter.py +0 -74
  588. mindspore/common/_tensor_overload.py +0 -139
  589. mindspore/mindspore_np_dtype.dll +0 -0
  590. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
  591. mindspore/ops/auto_generate/gen_arg_handler.py +0 -197
  592. mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
  593. mindspore/ops_generate/gen_aclnn_implement.py +0 -263
  594. mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
  595. mindspore/ops_generate/gen_pyboost_func.py +0 -1052
  596. mindspore/ops_generate/gen_utils.py +0 -209
  597. mindspore/ops_generate/op_proto.py +0 -145
  598. mindspore/ops_generate/template.py +0 -261
  599. mindspore/profiler/envprofiling.py +0 -254
  600. mindspore/profiler/profiling.py +0 -1926
  601. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/WHEEL +0 -0
  602. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/top_level.txt +0 -0
@@ -22,12 +22,13 @@ from mindspore.common import Tensor
22
22
  from mindspore.common import dtype as mstype
23
23
  from mindspore.nn.cell import Cell
24
24
  from mindspore.nn.grad.cell_grad import _LinearizeInner
25
- from mindspore.ops.operations.other_ops import stop_gradient_
25
+ from mindspore.ops.operations.other_ops import stop_gradient_op
26
26
  from mindspore.ops.primitive import constexpr, _primexpr
27
27
  from mindspore.ops.function.array_func import ones, expand_dims, size, reshape, broadcast_to, transpose, zeros
28
28
  from mindspore.ops.composite import _Vmap, _Grad, _TaylorOperation, GradOperation
29
29
  from mindspore.ops import operations as P
30
30
  from mindspore.ops.operations import _inner_ops as inner
31
+ from mindspore.ops.auto_generate.gen_ops_prim import inplace_stop_gradient_op
31
32
 
32
33
  cast = P.Cast()
33
34
  dtype = P.DType()
@@ -103,11 +104,13 @@ def grad(fn, grad_position=0, weights=None, has_aux=False, return_ids=False):
103
104
 
104
105
  Args:
105
106
  fn (Union[Cell, Function]): Function to do GradOperation.
106
- grad_position (Union[NoneType, int, tuple[int]]): Index to specify which inputs to be differentiated.
107
- If int, get the gradient with respect to single input.
108
- If tuple, get the gradients with respect to selected inputs. `grad_position` begins with 0.
109
- If None, none derivative of any input will be figured out, and in this case, `weights` is required.
110
- Default: ``0`` .
107
+ grad_position (Union[NoneType, int, tuple[int]]):
108
+ Index to specify which inputs to be differentiated. Default: ``0`` .
109
+
110
+ - If int, get the gradient with respect to single input.
111
+ - If tuple, get the gradients with respect to selected inputs. `grad_position` begins with 0.
112
+ - If None, none derivative of any input will be figured out, and in this case, `weights` is required.
113
+
111
114
  weights (Union[ParameterTuple, Parameter, list[Parameter]]): The parameters of the training network that need to
112
115
  calculate the gradient. `weights` can be got through `weights = net.trainable_params()` .
113
116
  Default: ``None`` .
@@ -237,36 +240,43 @@ def value_and_grad(fn, grad_position=0, weights=None, has_aux=False, return_ids=
237
240
 
238
241
  As for gradient, three typical cases are included:
239
242
 
240
- 1. gradient with respect to inputs. In this case, `grad_position` is not None while `weights` is None.
241
- 2. gradient with respect to weights. In this case, `grad_position` is None while `weights` is not None.
242
- 3. gradient with respect to inputs and weights. In this case, `grad_position` and `weights` are not None.
243
+ 1. gradient with respect to inputs. In this case, `grad_position` is not None while `weights` is ``None``.
244
+ 2. gradient with respect to weights. In this case, `grad_position` is None while `weights` is not ``None``.
245
+ 3. gradient with respect to inputs and weights. In this case, `grad_position` and `weights` are not ``None``.
243
246
 
244
247
  Args:
245
248
  fn (Union[Cell, Function]): Function to do GradOperation.
246
- grad_position (Union[NoneType, int, tuple[int]]): Index to specify which inputs to be differentiated.
247
- If int, get the gradient with respect to single input.
248
- If tuple, get the gradients with respect to selected inputs. `grad_position` begins with 0.
249
- If None, none derivative of any input will be solved, and in this case, `weights` is required.
250
- Default: ``0`` .
251
- weights (Union[ParameterTuple, Parameter, list[Parameter]]): The parameters of the training network that need to
249
+ grad_position (Union[NoneType, int, tuple[int]], optional): Index to specify which inputs
250
+ to be differentiated. Default: ``0`` .
251
+
252
+ - If int, get the gradient with respect to single input.
253
+ - If tuple, get the gradients with respect to selected inputs. `grad_position` begins with 0.
254
+ - If None, none derivative of any input will be solved, and in this case, `weights` is required.
255
+
256
+ weights (Union[ParameterTuple, Parameter, list[Parameter]], optional):
257
+ The parameters of the training network that need to
252
258
  calculate the gradient. `weights` can be got through `weights = net.trainable_params()` .
253
259
  Default: ``None`` .
254
- has_aux (bool): If ``True`` , only the first output of `fn` contributes the gradient of `fn`, while the other
260
+ has_aux (bool, optional): If ``True`` , only the first output of `fn` contributes the gradient of `fn`,
261
+ while the other
255
262
  outputs will be returned straightly. It means the `fn` must return more than one outputs in this case.
256
263
  Default: ``False`` .
257
- return_ids(bool): Whether return the tuple made by gradients and the index to specify which inputs
258
- to be differentiated or the name of parameters of the training network that need to calculate the gradient.
259
- If ``True`` , the output gradients will be replaced by the tuples made by gradients and the index to specify
260
- which inputs to be differentiated or the name of parameters of the training network.
264
+ return_ids(bool, optional): Whether the returned derivation function contains
265
+ `grad_position` or `weights` information. If ``True``,
266
+ all gradient values in the returned derivation function will be replaced
267
+ with: [gradient, grad_position] or [gradient, weights].
261
268
  Default: ``False`` .
262
269
 
263
270
  Returns:
264
- Function, returns the gradient function to calculate forward output and gradient for the input function or cell.
271
+ Function, the derivative function used to compute the gradient of a given function.
265
272
  For example, as for `out1, out2 = fn(*args)` , gradient function will return outputs like
266
- `((out1, out2), gradient)` . When `has_aux` is set to ``True``, only `out1` contributes to the differentiation.
273
+ `((out1, out2), gradient)` . When `has_aux` is set to ``True``,
274
+ only `out1` contributes to the differentiation. If `return_ids` is ``True``,
275
+ all gradient values in the returned derivation function will be replaced
276
+ with: [gradient, grad_position] or [gradient, weights].
267
277
 
268
278
  Raises:
269
- ValueError: If both `grad_position` and `weights` are None.
279
+ ValueError: If both `grad_position` and `weights` are ``None``.
270
280
  TypeError: If type of Args does not belong to required ones.
271
281
 
272
282
  Supported Platforms:
@@ -378,10 +388,10 @@ def get_grad(gradients, identifier):
378
388
  :func:`mindspore.grad`.
379
389
 
380
390
  Returns:
381
- The gradient of the tensor on the position or in the parameter that specified by the `identifier`.
391
+ The Tensor gradient value corresponding to the `identifier`.
382
392
 
383
393
  Raises:
384
- RuntimeError: If gradient is not found.
394
+ RuntimeError: If gradient value corresponding to the `identifier` is not found.
385
395
  TypeError: If type of Args does not belong to required ones.
386
396
 
387
397
  Supported Platforms:
@@ -457,49 +467,42 @@ def jet(fn, primals, series):
457
467
  while the other to 0, which is like the derivative of origin input with respect to itself.
458
468
 
459
469
  Note:
460
- If `primals` is Tensor of int type, it will be converted to Tensor of float type.
470
+ If `primals` is tensor of int type, it will be converted to Tensor of float type.
461
471
 
462
472
  Args:
463
473
  fn (Union[Cell, function]): Function to do TaylorOperation.
464
474
  primals (Union[Tensor, tuple[Tensor]]): The inputs to `fn`.
465
- series (Union[Tensor, tuple[Tensor]]): If tuple, the length and type of series should be the same as inputs.
466
- For each Tensor, the length of first dimension `i` represents the `1` to `i+1`-th order of derivative of
467
- output with respect to the inputs will be figured out.
475
+ series (Union[Tensor, tuple[Tensor]]): The original 1st to nth order derivatives of the input.
476
+ The index `i` of the zeroth dimension of the tensor corresponds to the `i+1` -th order derivative of the
477
+ output with respect to the input.
468
478
 
469
479
  Returns:
470
- Tuple, tuple of out_primals and out_series.
480
+ Tuple(out_primals, out_series)
471
481
 
472
482
  - **out_primals** (Union[Tensor, list[Tensor]]) - The output of `fn(primals)`.
473
- - **out_series** (Union[Tensor, list[Tensor]]) - The `1` to `i+1`-th order of derivative of output with respect
483
+ - **out_series** (Union[Tensor, list[Tensor]]) - The `1` to `i+1` -th order of derivative of output with respect
474
484
  to the inputs.
475
485
 
476
- Raises:
477
- TypeError: If `primals` is not a tensor or tuple of tensors.
478
- TypeError: If type of `primals` is not the same as type of `series`.
479
-
480
486
  Supported Platforms:
481
487
  ``Ascend`` ``GPU`` ``CPU``
482
488
 
483
489
  Examples:
484
- >>> import numpy as np
485
- >>> import mindspore.nn as nn
486
- >>> import mindspore as ms
487
- >>> from mindspore import ops
488
- >>> from mindspore import Tensor
489
- >>> ms.set_context(mode=ms.GRAPH_MODE)
490
+ >>> import mindspore
491
+ >>> from mindspore import nn
492
+ >>> mindspore.set_context(mode=mindspore.GRAPH_MODE)
490
493
  >>> class Net(nn.Cell):
491
494
  ... def __init__(self):
492
495
  ... super().__init__()
493
- ... self.sin = ops.Sin()
494
- ... self.exp = ops.Exp()
496
+ ... self.sin = mindspore.ops.Sin()
497
+ ... self.exp = mindspore.ops.Exp()
495
498
  ... def construct(self, x):
496
499
  ... out1 = self.sin(x)
497
500
  ... out2 = self.exp(out1)
498
501
  ... return out2
499
- >>> primals = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
500
- >>> series = Tensor(np.array([[[1, 1], [1, 1]], [[0, 0], [0, 0]], [[0, 0], [0, 0]]]).astype(np.float32))
502
+ >>> primals = mindspore.tensor([[1, 2], [3, 4]], mindspore.float32)
503
+ >>> series = mindspore.tensor([[[1, 1], [1, 1]], [[0, 0], [0, 0]], [[0, 0], [0, 0]]], mindspore.float32)
501
504
  >>> net = Net()
502
- >>> out_primals, out_series = ops.jet(net, primals, series)
505
+ >>> out_primals, out_series = mindspore.ops.jet(net, primals, series)
503
506
  >>> print(out_primals, out_series)
504
507
  [[2.319777 2.4825778]
505
508
  [1.1515628 0.4691642]] [[[ 1.2533808 -1.0331168 ]
@@ -567,49 +570,41 @@ def derivative(fn, primals, order):
567
570
  input first order derivative is set to 1, while the other to 0.
568
571
 
569
572
  Note:
570
- If `primals` is Tensor of int type, it will be converted to Tensor of float type.
573
+ If `primals` is tensor of int type, it will be converted to tensor of float type.
571
574
 
572
575
  Args:
573
576
  fn (Union[Cell, function]): Function to do TaylorOperation.
574
577
  primals (Union[Tensor, tuple[Tensor]]): The inputs to `fn`.
575
- order (int): For each Tensor, the `order`-th order of derivative of output with respect to the inputs will be
576
- figured out.
578
+ order (int): The order of differentiation.
577
579
 
578
580
  Returns:
579
- Tuple, tuple of out_primals and out_series.
581
+ Tuple(out_primals, out_series)
580
582
 
581
583
  - **out_primals** (Union[Tensor, list[Tensor]]) - The output of `fn(primals)`.
582
584
  - **out_series** (Union[Tensor, list[Tensor]]) - The `order`-th order of derivative of output with respect
583
585
  to the inputs.
584
586
 
585
- Raises:
586
- TypeError: If `primals` is not a tensor or tuple of tensors.
587
- TypeError: If `order` is not int.
588
- ValueError: If `order` is less than 1.
589
-
590
587
  Supported Platforms:
591
588
  ``Ascend`` ``GPU`` ``CPU``
592
589
 
593
590
  Examples:
594
- >>> import numpy as np
595
- >>> import mindspore as ms
596
- >>> import mindspore.nn as nn
597
- >>> from mindspore import ops
598
- >>> from mindspore import Tensor
599
- >>> ms.set_context(mode=ms.GRAPH_MODE)
591
+ >>> import mindspore
592
+ >>> from mindspore import nn
593
+ >>> mindspore.set_context(mode=mindspore.GRAPH_MODE)
600
594
  >>> class Net(nn.Cell):
601
595
  ... def __init__(self):
602
596
  ... super().__init__()
603
- ... self.sin = ops.Sin()
604
- ... self.exp = ops.Exp()
597
+ ... self.sin = mindspore.ops.Sin()
598
+ ... self.exp = mindspore.ops.Exp()
605
599
  ... def construct(self, x):
606
600
  ... out1 = self.sin(x)
607
601
  ... out2 = self.exp(out1)
608
602
  ... return out2
609
- >>> primals = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
603
+ >>>
604
+ >>> primals = mindspore.tensor([[1, 2], [3, 4]], mindspore.float32)
610
605
  >>> order = 3
611
606
  >>> net = Net()
612
- >>> out_primals, out_series = ops.derivative(net, primals, order)
607
+ >>> out_primals, out_series = mindspore.ops.derivative(net, primals, order)
613
608
  >>> print(out_primals, out_series)
614
609
  [[2.319777 2.4825778]
615
610
  [1.1515628 0.4691642]] [[-4.0515366 3.6724353 ]
@@ -677,7 +672,7 @@ def jvp(fn, inputs, v, has_aux=False):
677
672
  - **net_output** (Union[Tensor, tuple[Tensor]]) - The output of `fn(inputs)` . Specially, when `has_aux` is set
678
673
  ``True`` , `netout` is the first output of `fn(inputs)` .
679
674
  - **jvp** (Union[Tensor, tuple[Tensor]]) - The result of jacobian-vector-product.
680
- - **aux_value** (Union[Tensor, tuple[Tensor]], optional) - When `has_aux` is ``True`` , `aux_value` will be
675
+ - **aux_value** (Union[Tensor, tuple[Tensor]], optional) - Only when `has_aux` is ``True`` , `aux_value` will be
681
676
  returned. It means the second to last outputs of `fn(inputs)` . Specially, `aux_value` does not contribute to
682
677
  gradient.
683
678
 
@@ -841,7 +836,7 @@ def linearize(fn, inputs):
841
836
  """
842
837
  linearize_inner = _LinearizeInner()
843
838
 
844
- @jit(hash_args=fn)
839
+ @jit
845
840
  def _wrap_container(*arg):
846
841
  args = arg[1:-1]
847
842
  vectors = arg[-1]
@@ -882,10 +877,12 @@ def vjp(fn, *inputs, weights=None, has_aux=False):
882
877
  fn (Union[Function, Cell]): The function or net that takes Tensor inputs and returns single Tensor or tuple of
883
878
  Tensors.
884
879
  inputs (Union[Tensor, tuple[Tensor], list[Tensor]]): The inputs to `fn` .
885
- weights (Union[ParameterTuple, Parameter, list[Parameter]]): The parameters of the training network that need to
880
+ weights (Union[ParameterTuple, Parameter, list[Parameter]], optional):
881
+ The parameters of the training network that need to
886
882
  calculate the gradient. `weights` can be got through `weights = net.trainable_params()` .
887
883
  Default: ``None`` .
888
- has_aux (bool): If True, only the first output of `fn` contributes the gradient of `fn`, while the other outputs
884
+ has_aux (bool, optional): If True, only the first output of `fn` contributes the gradient of `fn`,
885
+ while the other outputs
889
886
  will be returned straightly. It means the `fn` must return more than one outputs in this case.
890
887
  Default: ``False``.
891
888
 
@@ -1388,6 +1385,41 @@ def stop_gradient(value):
1388
1385
  Supported Platforms:
1389
1386
  ``Ascend`` ``GPU`` ``CPU``
1390
1387
 
1388
+ Examples:
1389
+ >>> import mindspore
1390
+ >>> def f1(x):
1391
+ ... return x ** 2
1392
+ >>> x = 3.0
1393
+ >>> f1(x)
1394
+ 9.0
1395
+ >>> mindspore.ops.grad(f1)(mindspore.tensor(x))
1396
+ Tensor(shape=[], dtype=Float32, value= 6)
1397
+ >>>
1398
+ >>> # The same function with stop_gradient, return a zero gradient because x is effectively treated as a constant.
1399
+ >>> def f2(x):
1400
+ ... return mindspore.ops.stop_gradient(x) ** 2
1401
+ >>> f2(x)
1402
+ 9.0
1403
+ >>> mindspore.ops.grad(f2)(mindspore.tensor(x))
1404
+ Tensor(shape=[], dtype=Float32, value= 0)
1405
+ """
1406
+ return stop_gradient_op(value)
1407
+
1408
+
1409
+ def stop_gradient_(input):
1410
+ """
1411
+ StopGradient inplace
1412
+
1413
+ Args:
1414
+ input (Tensor): input tensor
1415
+
1416
+ Raises:
1417
+ TypeError: If `input` is not a Tensor.
1418
+ RuntimeError: If `input` is a view tensor.
1419
+
1420
+ Supported Platforms:
1421
+ ``Ascend``
1422
+
1391
1423
  Examples:
1392
1424
  >>> from mindspore import ops
1393
1425
  >>> from mindspore import Tensor
@@ -1395,7 +1427,7 @@ def stop_gradient(value):
1395
1427
  >>> def net(x, y):
1396
1428
  ... out1 = ops.MatMul()(x, y)
1397
1429
  ... out2 = ops.MatMul()(x, y)
1398
- ... out2 = ops.stop_gradient(out2)
1430
+ ... ops.stop_gradient_(out2)
1399
1431
  ... return out1, out2
1400
1432
  ...
1401
1433
  >>> x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
@@ -1406,7 +1438,7 @@ def stop_gradient(value):
1406
1438
  [[1.4100001 1.6 6.5999994]
1407
1439
  [1.4100001 1.6 6.5999994]]
1408
1440
  """
1409
- return stop_gradient_(value)
1441
+ inplace_stop_gradient_op(input)
1410
1442
 
1411
1443
 
1412
1444
  __all__ = [
@@ -1420,6 +1452,7 @@ __all__ = [
1420
1452
  'vjp',
1421
1453
  'linearize',
1422
1454
  'stop_gradient',
1455
+ 'stop_gradient_',
1423
1456
  'get_grad'
1424
1457
  ]
1425
1458
  __all__.sort()
@@ -19,7 +19,7 @@ from mindspore.ops import operations as P
19
19
  from mindspore.ops.operations import image_ops as IMG
20
20
  import mindspore.common.dtype as mstype
21
21
  from mindspore.common.tensor import Tensor
22
- from mindspore._c_expression import Tensor as Tensor_
22
+ from mindspore._c_expression import TensorPy as Tensor_
23
23
  from .._primitive_cache import _get_cache_prim
24
24
 
25
25
  check_valid_ = P.CheckValid()
@@ -252,7 +252,7 @@ def crop_and_resize(image, boxes, box_indices, crop_size, method="bilinear", ext
252
252
  >>> crop_size = (24, 24)
253
253
  >>> output = ops.crop_and_resize(Tensor(image), Tensor(boxes), Tensor(box_indices), crop_size)
254
254
  >>> print(output.shape)
255
- (5, 24, 24, 3)
255
+ (5, 24, 24, 3)
256
256
  """
257
257
  _crop_and_resize_check(image, boxes, box_indices, crop_size)
258
258
  image_shape = image.shape
@@ -23,7 +23,7 @@ from mindspore.ops import operations as P
23
23
  from mindspore.ops import functional as F
24
24
  from mindspore.ops.operations import _inner_ops as inner
25
25
  from mindspore.ops.function.math_func import _check_input_dtype, _check_attr_dtype
26
- from mindspore._c_expression import Tensor as Tensor_
26
+ from mindspore._c_expression import TensorPy as Tensor_
27
27
  from mindspore.ops.auto_generate import geqrf
28
28
 
29
29
  from ..operations import linalg_ops
@@ -37,7 +37,7 @@ slice_ = P.Slice()
37
37
 
38
38
  def cond(A, p=None):
39
39
  r"""
40
- Returns the matrix norm or vector norm of a given tensor.
40
+ Return the matrix norm or vector norm of a given tensor.
41
41
 
42
42
  `p` is the calculation mode of norm. The following norm modes are supported.
43
43
 
@@ -61,28 +61,22 @@ def cond(A, p=None):
61
61
  Currently, complex numbers are not supported.
62
62
 
63
63
  Args:
64
- A (Tensor): Tensor of shape :math:`(*, n)` or :math:`(*, m, n)`
65
- where :math:`*` is zero or more batch dimensions.
66
- p (Union[int, float, inf, -inf, 'fro', 'nuc'], optional): norm's mode. Refer to the table above for
67
- behavior. Default: ``None``.
64
+ A (Tensor): The input tensor which is zero or more batch dimensions.
65
+ p (Union[int, float, inf, -inf, 'fro', 'nuc'], optional): Norm's mode. Refer to the table above for
66
+ behavior. Default ``None``.
68
67
 
69
68
  Returns:
70
- Tensor, the result of norm calculation on the specified dimension, `dim`, has the same dtype as `A`.
71
-
72
- Raises:
73
- TypeError: If `A` is a vector and `p` is a str.
74
- ValueError: If `A` is a matrices and `p` is not in valid mode.
75
- ValueError: If `A` is a matrix and `p` is an integer that is not in [1, -1, 2, -2].
69
+ Tensor
76
70
 
77
71
  Supported Platforms:
78
72
  ``GPU`` ``CPU``
79
73
 
80
74
  Examples:
81
- >>> import mindspore as ms
82
- >>> x = ms.Tensor([[1.0, 0.0, -1.0], [0.0, 1.0, 0.0], [1.0, 0.0, 1.0]])
83
- >>> print(ms.ops.cond(x))
75
+ >>> import mindspore
76
+ >>> x = mindspore.tensor([[1.0, 0.0, -1.0], [0.0, 1.0, 0.0], [1.0, 0.0, 1.0]])
77
+ >>> print(mindspore.ops.cond(x))
84
78
  1.4142
85
- >>> print(ms.ops.cond(x, 'fro'))
79
+ >>> print(mindspore.ops.cond(x, 'fro'))
86
80
  3.1622777
87
81
  """
88
82
  matrix_inverse = _get_cache_prim(P.MatrixInverse)(adjoint=False)
@@ -119,7 +113,7 @@ def eig(A):
119
113
  Returns:
120
114
  - **eigen_values** (Tensor) - Shape :math:`(*, N)`. eigenvalues of
121
115
  the corresponding matrix. The eigenvalues may not have an order.
122
- - **eigen_vectors** (Tensor) - Shape :math:`(*, N, N)`,columns of eigen vectors represent
116
+ - **eigen_vectors** (Tensor) - Shape :math:`(*, N, N)`, columns of eigen vectors represent
123
117
  - **normalized** (unit length) eigenvectors of corresponding eigenvalues.
124
118
 
125
119
  Raises:
@@ -147,35 +141,24 @@ def eig(A):
147
141
 
148
142
  def eigvals(A):
149
143
  """
150
- Computes the eigenvalues of a square matrix(batch square matrices).
144
+ Compute the eigenvalues of a square matrix.
151
145
 
152
146
  .. warning::
153
147
  This is an experimental API that is subject to change or deletion.
154
148
 
155
149
  Args:
156
- A (Tensor): Square matrices of shape :math:`(*, N, N)`,
157
- with float32, float64, complex64 or complex128 data type.
150
+ A (Tensor): Square matrices with shape :math:`(*, N, N)` .
158
151
 
159
152
  Returns:
160
- Tensor, with shape :math:`(*, N)`. Returns the eigenvalues of
161
- the corresponding matrix, which may not have an order.
162
-
163
- Raises:
164
- TypeError: If dtype of `A` is not one of: float64, float32, complex64 or complex128.
165
- TypeError: If `A` is not a Tensor.
166
- ValueError: If `A` is not a square(batch squares).
153
+ Tensor
167
154
 
168
155
  Supported Platforms:
169
156
  ``Ascend`` ``CPU``
170
157
 
171
158
  Examples:
172
159
  >>> import mindspore
173
- >>> from mindspore import Tensor, ops
174
- >>> import numpy as np
175
- >>> input_x = Tensor(np.array([[1.0, 0.0], [0.0, 2.0]]), mindspore.float32)
176
- >>> u = ops.eigvals(input_x)
177
- >>> print(u)
178
- [1.+0.j 2.+0.j]
160
+ >>> mindspore.ops.eigvals(mindspore.tensor([[1.0, 0.0], [0.0, 2.0]]))
161
+ Tensor(shape=[2], dtype=Complex64, value= [1+0j, 2+0j])
179
162
  """
180
163
  u, _ = _get_cache_prim(P.Eig)(compute_v=False)(A)
181
164
  return u
@@ -192,45 +175,39 @@ def svd(input, full_matrices=False, compute_uv=True):
192
175
  A=U*diag(S)*V^{T}
193
176
 
194
177
  Args:
195
- input (Tensor): Tensor of the matrices to be decomposed. The shape should be :math:`(*, M, N)`,
196
- the supported dtype are float32 and float64.
197
- full_matrices (bool, optional): If true, compute full-sized :math:`U` and :math:`V`. If false, compute
178
+ input (Tensor): The input tensor, shape is :math:`(*, M, N)`.
179
+ full_matrices (bool, optional): If ``True`` , compute full-sized :math:`U` and :math:`V`. If ``False``, compute
198
180
  only the leading P singular vectors, with P is the minimum of M and N.
199
- Default: ``False`` .
200
- compute_uv (bool, optional): If true, compute the left and right singular vectors.
201
- If false, compute only the singular values. Default: ``True`` .
181
+ Default ``False`` .
182
+ compute_uv (bool, optional): If ``True`` , compute the left and right singular vectors.
183
+ If ``False``, compute only the singular values. Default ``True`` .
202
184
 
203
185
  Returns:
204
- - **s** (Tensor) - Singular values. The shape is :math:`(*, P)`.
205
- - **u** (Tensor) - Left singular vectors. If `compute_uv` is False, u will not be returned.
206
- The shape is :math:`(*, M, P)`. If `full_matrices` is True, the shape will be :math:`(*, M, M)`.
207
- - **v** (Tensor) - Right singular vectors. If `compute_uv` is False, v will not be returned.
208
- The shape is :math:`(*, N, P)`. If `full_matrices` is True, the shape will be :math:`(*, N, N)`.
186
+ If compute_uv is ``True`` , a tuple( `s` , `u` , `v` ) of tensors will be returned. Otherwise, only
187
+ a single tensor -> `s` will be returned.
209
188
 
210
- Raises:
211
- TypeError: If `full_matrices` or `compute_uv` is not the type of bool.
212
- TypeError: If the rank of input less than 2.
213
- TypeError: If the type of input is not one of the following dtype: float32, float64.
189
+ - `s` is the singular value tensor. The shape is :math:`(*, P)`.
190
+ - `u` is the left singular tensor. If `compute_uv` is ``False`` , `u` will not be returned.
191
+ The shape is :math:`(*, M, P)`. If `full_matrices` is ``True`` , the shape will be :math:`(*, M, M)`.
192
+ - `v` is the right singular tensor. If `compute_uv` is ``False`` , `v` will not be returned.
193
+ The shape is :math:`(*, N, P)`. If `full_matrices` is ``True`` , the shape will be :math:`(*, N, N)`.
214
194
 
215
195
  Supported Platforms:
216
- ``GPU`` ``CPU``
196
+ ``Ascend`` ``GPU`` ``CPU``
217
197
 
218
198
  Examples:
219
- >>> import numpy as np
220
- >>> from mindspore import Tensor, set_context
221
- >>> from mindspore import ops
222
- >>> set_context(device_target="CPU")
223
- >>> input = Tensor(np.array([[1, 2], [-4, -5], [2, 1]]).astype(np.float32))
224
- >>> s, u, v = ops.svd(input, full_matrices=True, compute_uv=True)
199
+ >>> import mindspore
200
+ >>> input = mindspore.tensor([[1, 2], [-4, -5], [2, 1]], mindspore.float32)
201
+ >>> s, u, v = mindspore.ops.svd(input, full_matrices=True, compute_uv=True)
225
202
  >>> print(s)
226
203
  [7.0652843 1.040081 ]
227
204
  >>> print(u)
228
- [[ 0.30821905 -0.48819482 0.81649697]
229
- [-0.90613353 0.11070572 0.40824813]
230
- [ 0.2896955 0.8656849 0.4082479 ]]
205
+ [[ 0.30821905 -0.48819482 0.81649697]
206
+ [-0.90613353 0.11070572 0.40824813]
207
+ [ 0.2896955 0.8656849 0.4082479 ]]
231
208
  >>> print(v)
232
- [[ 0.63863593 0.769509 ]
233
- [ 0.769509 -0.63863593]]
209
+ [[ 0.63863593 0.769509 ]
210
+ [ 0.769509 -0.63863593]]
234
211
  """
235
212
  svd_ = _get_cache_prim(linalg_ops.Svd)(full_matrices=full_matrices, compute_uv=compute_uv)
236
213
 
@@ -244,6 +221,7 @@ def svd(input, full_matrices=False, compute_uv=True):
244
221
  def pinv(x, *, atol=None, rtol=None, hermitian=False):
245
222
  r"""
246
223
  Computes the (Moore-Penrose) pseudo-inverse of a matrix.
224
+
247
225
  This function is computed using SVD. If :math:`x=U*S*V^{T}` ,Than the pseudo-inverse of x is:
248
226
  :math:`x^{+}=V*S^{+}*U^{T}` , :math:`S^{+}` is the reciprocal of each non-zero element on
249
227
  the diagonal of S, and zero remains in place.
@@ -271,32 +249,24 @@ def pinv(x, *, atol=None, rtol=None, hermitian=False):
271
249
  see the warnings in svd() and eigh().
272
250
 
273
251
  Args:
274
- x (Tensor): A matrix to be calculated. Only `float32`, `float64` are supported Tensor dtypes.
275
- shape is :math:`(*, M, N)`, * is zero or more batch dimensions.
276
-
277
- - When `hermitian` is ``True``, batch dimensions are not supported temporarily.
252
+ x (Tensor): The input tensor whose shape is :math:`(*, M, N)`, * is zero or more batch dimensions.
253
+ When `hermitian` is ``True``, batch dimensions are not supported temporarily.
278
254
 
279
255
  Keyword args:
280
- atol (float, Tensor): absolute tolerance value. Default: ``None`` .
281
- rtol (float, Tensor): relative tolerance value. Default: ``None`` .
282
- hermitian (bool): An optional bool. x is assumed to be symmetric if real. Default: ``False`` .
256
+ atol (float, Tensor): The absolute tolerance value. Default ``None`` .
257
+ rtol (float, Tensor): The relative tolerance value. Default ``None`` .
258
+ hermitian (bool): Whether `x` is assumed to be symmetric if real. Default ``False`` .
283
259
 
284
260
  Outputs:
285
- - **output** (Tensor) - same type as input. Shape is :math:`(*, N, M)`, * is zero or more batch dimensions.
286
-
287
- Raises:
288
- TypeError: If `hermitian` is not a bool.
289
- TypeError: If `x` is not a Tensor.
290
- ValueError: If the dimension of `x` is less than 2.
261
+ A tensor whose shape is :math:`(*, N, M)`, * is zero or more batch dimensions.
291
262
 
292
263
  Supported Platforms:
293
264
  ``CPU``
294
265
 
295
266
  Examples:
296
267
  >>> import mindspore
297
- >>> from mindspore import Tensor, ops
298
- >>> x = Tensor([[4., 0.], [0., 5.]], mindspore.float32)
299
- >>> output = ops.pinv(x)
268
+ >>> x = mindspore.tensor([[4., 0.], [0., 5.]], mindspore.float32)
269
+ >>> output = mindspore.ops.pinv(x)
300
270
  >>> print(output)
301
271
  [[0.25 0. ]
302
272
  [0. 0.2 ]]
@@ -327,8 +297,7 @@ def pinv(x, *, atol=None, rtol=None, hermitian=False):
327
297
 
328
298
  if not hermitian:
329
299
  s, u, v = linalg_ops.Svd()(x)
330
- max_singular_val = _narrow(s, -1, 0, 1)
331
- threshold = ops.Maximum()(atol.expand_dims(-1), rtol.expand_dims(-1) * max_singular_val)
300
+ threshold = ops.Maximum()(atol.expand_dims(-1), rtol.expand_dims(-1) * _narrow(s, -1, 0, 1))
332
301
  condition = s > threshold
333
302
  reciprocal_s_before = ops.Reciprocal()(s).broadcast_to(condition.shape)
334
303
  zero = F.zeros(condition.shape, s.dtype)