mindspore 2.3.0__cp39-cp39-win_amd64.whl → 2.4.0__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (285) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +3 -1
  3. mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
  4. mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
  5. mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
  6. mindspore/_checkparam.py +50 -9
  7. mindspore/_extends/parse/compile_config.py +41 -0
  8. mindspore/_extends/parse/parser.py +9 -7
  9. mindspore/_extends/parse/standard_method.py +52 -14
  10. mindspore/_extends/pijit/pijit_func_white_list.py +350 -24
  11. mindspore/amp.py +24 -10
  12. mindspore/avcodec-59.dll +0 -0
  13. mindspore/avdevice-59.dll +0 -0
  14. mindspore/avfilter-8.dll +0 -0
  15. mindspore/avformat-59.dll +0 -0
  16. mindspore/avutil-57.dll +0 -0
  17. mindspore/common/__init__.py +6 -4
  18. mindspore/common/_pijit_context.py +190 -0
  19. mindspore/common/_register_for_tensor.py +2 -1
  20. mindspore/common/_tensor_overload.py +139 -0
  21. mindspore/common/api.py +102 -87
  22. mindspore/common/dump.py +5 -6
  23. mindspore/common/generator.py +1 -7
  24. mindspore/common/hook_handle.py +14 -26
  25. mindspore/common/mindir_util.py +2 -2
  26. mindspore/common/parameter.py +46 -13
  27. mindspore/common/recompute.py +39 -9
  28. mindspore/common/sparse_tensor.py +7 -3
  29. mindspore/common/tensor.py +209 -29
  30. mindspore/communication/__init__.py +1 -1
  31. mindspore/communication/_comm_helper.py +38 -3
  32. mindspore/communication/comm_func.py +310 -55
  33. mindspore/communication/management.py +14 -14
  34. mindspore/context.py +123 -22
  35. mindspore/dataset/__init__.py +1 -1
  36. mindspore/dataset/audio/__init__.py +1 -1
  37. mindspore/dataset/core/config.py +7 -0
  38. mindspore/dataset/core/validator_helpers.py +7 -0
  39. mindspore/dataset/engine/cache_client.py +1 -1
  40. mindspore/dataset/engine/datasets.py +72 -44
  41. mindspore/dataset/engine/datasets_audio.py +7 -7
  42. mindspore/dataset/engine/datasets_standard_format.py +53 -3
  43. mindspore/dataset/engine/datasets_text.py +20 -20
  44. mindspore/dataset/engine/datasets_user_defined.py +174 -104
  45. mindspore/dataset/engine/datasets_vision.py +33 -33
  46. mindspore/dataset/engine/iterators.py +29 -0
  47. mindspore/dataset/engine/obs/util.py +7 -0
  48. mindspore/dataset/engine/queue.py +114 -60
  49. mindspore/dataset/engine/serializer_deserializer.py +2 -2
  50. mindspore/dataset/engine/validators.py +34 -14
  51. mindspore/dataset/text/__init__.py +1 -4
  52. mindspore/dataset/transforms/__init__.py +0 -3
  53. mindspore/dataset/utils/line_reader.py +2 -0
  54. mindspore/dataset/vision/__init__.py +1 -4
  55. mindspore/dataset/vision/utils.py +1 -1
  56. mindspore/dataset/vision/validators.py +2 -1
  57. mindspore/dnnl.dll +0 -0
  58. mindspore/{nn/extend → experimental/es}/__init__.py +4 -11
  59. mindspore/experimental/es/embedding_service.py +883 -0
  60. mindspore/{nn/layer → experimental/es}/embedding_service_layer.py +218 -30
  61. mindspore/experimental/llm_boost/__init__.py +21 -0
  62. mindspore/{nn/extend/layer → experimental/llm_boost/atb}/__init__.py +4 -8
  63. mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
  64. mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
  65. mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
  66. mindspore/experimental/llm_boost/register.py +129 -0
  67. mindspore/experimental/llm_boost/utils.py +31 -0
  68. mindspore/experimental/optim/adamw.py +85 -0
  69. mindspore/experimental/optim/optimizer.py +3 -0
  70. mindspore/hal/__init__.py +3 -3
  71. mindspore/hal/contiguous_tensors_handle.py +175 -0
  72. mindspore/hal/stream.py +18 -0
  73. mindspore/include/api/model_group.h +13 -1
  74. mindspore/include/api/types.h +10 -10
  75. mindspore/include/dataset/config.h +2 -2
  76. mindspore/include/dataset/constants.h +2 -2
  77. mindspore/include/dataset/execute.h +2 -2
  78. mindspore/include/dataset/vision.h +4 -0
  79. mindspore/jpeg62.dll +0 -0
  80. mindspore/log.py +1 -1
  81. mindspore/mindrecord/filewriter.py +68 -51
  82. mindspore/mindspore_backend.dll +0 -0
  83. mindspore/mindspore_common.dll +0 -0
  84. mindspore/mindspore_core.dll +0 -0
  85. mindspore/mindspore_glog.dll +0 -0
  86. mindspore/mindspore_np_dtype.dll +0 -0
  87. mindspore/mindspore_ops.dll +0 -0
  88. mindspore/mint/__init__.py +495 -46
  89. mindspore/mint/distributed/__init__.py +31 -0
  90. mindspore/mint/distributed/distributed.py +254 -0
  91. mindspore/mint/nn/__init__.py +266 -21
  92. mindspore/mint/nn/functional.py +125 -19
  93. mindspore/mint/nn/layer/__init__.py +39 -0
  94. mindspore/mint/nn/layer/activation.py +133 -0
  95. mindspore/mint/nn/layer/normalization.py +477 -0
  96. mindspore/mint/nn/layer/pooling.py +110 -0
  97. mindspore/mint/optim/adamw.py +28 -7
  98. mindspore/mint/special/__init__.py +63 -0
  99. mindspore/multiprocessing/__init__.py +2 -1
  100. mindspore/nn/__init__.py +0 -1
  101. mindspore/nn/cell.py +275 -93
  102. mindspore/nn/layer/activation.py +211 -44
  103. mindspore/nn/layer/basic.py +113 -3
  104. mindspore/nn/layer/embedding.py +120 -2
  105. mindspore/nn/layer/normalization.py +101 -5
  106. mindspore/nn/layer/padding.py +34 -48
  107. mindspore/nn/layer/pooling.py +161 -7
  108. mindspore/nn/layer/transformer.py +3 -3
  109. mindspore/nn/loss/__init__.py +2 -2
  110. mindspore/nn/loss/loss.py +84 -6
  111. mindspore/nn/optim/__init__.py +2 -1
  112. mindspore/nn/optim/adadelta.py +1 -1
  113. mindspore/nn/optim/adam.py +1 -1
  114. mindspore/nn/optim/lamb.py +1 -1
  115. mindspore/nn/optim/tft_wrapper.py +127 -0
  116. mindspore/nn/wrap/cell_wrapper.py +12 -23
  117. mindspore/nn/wrap/grad_reducer.py +5 -5
  118. mindspore/nn/wrap/loss_scale.py +17 -3
  119. mindspore/numpy/__init__.py +1 -1
  120. mindspore/numpy/array_creations.py +65 -68
  121. mindspore/numpy/array_ops.py +64 -60
  122. mindspore/numpy/fft.py +610 -75
  123. mindspore/numpy/logic_ops.py +11 -10
  124. mindspore/numpy/math_ops.py +85 -84
  125. mindspore/numpy/utils_const.py +4 -4
  126. mindspore/opencv_core452.dll +0 -0
  127. mindspore/opencv_imgcodecs452.dll +0 -0
  128. mindspore/opencv_imgproc452.dll +0 -0
  129. mindspore/ops/__init__.py +6 -4
  130. mindspore/ops/_grad_experimental/grad_comm_ops.py +47 -3
  131. mindspore/ops/_grad_experimental/grad_math_ops.py +0 -22
  132. mindspore/ops/_vmap/vmap_array_ops.py +2 -4
  133. mindspore/ops/_vmap/vmap_math_ops.py +17 -1
  134. mindspore/ops/_vmap/vmap_nn_ops.py +43 -2
  135. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +85 -7
  136. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +2 -0
  137. mindspore/ops/auto_generate/gen_extend_func.py +734 -13
  138. mindspore/ops/auto_generate/gen_ops_def.py +2420 -381
  139. mindspore/ops/auto_generate/gen_ops_prim.py +5196 -1659
  140. mindspore/ops/auto_generate/pyboost_inner_prim.py +176 -56
  141. mindspore/ops/composite/base.py +85 -48
  142. mindspore/ops/composite/multitype_ops/_compile_utils.py +1 -0
  143. mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -2
  144. mindspore/ops/function/__init__.py +22 -0
  145. mindspore/ops/function/array_func.py +490 -153
  146. mindspore/ops/function/debug_func.py +113 -1
  147. mindspore/ops/function/fft_func.py +15 -2
  148. mindspore/ops/function/grad/grad_func.py +3 -2
  149. mindspore/ops/function/math_func.py +558 -207
  150. mindspore/ops/function/nn_func.py +817 -383
  151. mindspore/ops/function/other_func.py +3 -2
  152. mindspore/ops/function/random_func.py +184 -8
  153. mindspore/ops/function/reshard_func.py +13 -11
  154. mindspore/ops/function/sparse_unary_func.py +1 -1
  155. mindspore/ops/function/vmap_func.py +3 -2
  156. mindspore/ops/functional.py +24 -14
  157. mindspore/ops/op_info_register.py +3 -3
  158. mindspore/ops/operations/__init__.py +6 -1
  159. mindspore/ops/operations/_grad_ops.py +2 -76
  160. mindspore/ops/operations/_infer_ops.py +1 -1
  161. mindspore/ops/operations/_inner_ops.py +71 -94
  162. mindspore/ops/operations/array_ops.py +12 -146
  163. mindspore/ops/operations/comm_ops.py +42 -53
  164. mindspore/ops/operations/custom_ops.py +83 -19
  165. mindspore/ops/operations/debug_ops.py +42 -10
  166. mindspore/ops/operations/manually_defined/_inner.py +12 -0
  167. mindspore/ops/operations/manually_defined/ops_def.py +265 -10
  168. mindspore/ops/operations/math_ops.py +12 -223
  169. mindspore/ops/operations/nn_ops.py +20 -114
  170. mindspore/ops/operations/other_ops.py +7 -4
  171. mindspore/ops/operations/random_ops.py +46 -1
  172. mindspore/ops/primitive.py +18 -6
  173. mindspore/ops_generate/arg_dtype_cast.py +2 -0
  174. mindspore/ops_generate/gen_aclnn_implement.py +11 -11
  175. mindspore/ops_generate/gen_constants.py +36 -0
  176. mindspore/ops_generate/gen_ops.py +67 -52
  177. mindspore/ops_generate/gen_ops_inner_prim.py +1 -1
  178. mindspore/ops_generate/gen_pyboost_func.py +131 -47
  179. mindspore/ops_generate/op_proto.py +10 -3
  180. mindspore/ops_generate/pyboost_utils.py +14 -1
  181. mindspore/ops_generate/template.py +43 -21
  182. mindspore/parallel/__init__.py +3 -1
  183. mindspore/parallel/_auto_parallel_context.py +28 -8
  184. mindspore/parallel/_cell_wrapper.py +83 -0
  185. mindspore/parallel/_parallel_serialization.py +47 -19
  186. mindspore/parallel/_tensor.py +81 -11
  187. mindspore/parallel/_utils.py +13 -1
  188. mindspore/parallel/algo_parameter_config.py +5 -5
  189. mindspore/parallel/checkpoint_transform.py +46 -39
  190. mindspore/parallel/cluster/process_entity/__init__.py +1 -1
  191. mindspore/parallel/cluster/process_entity/_api.py +31 -23
  192. mindspore/parallel/cluster/process_entity/_utils.py +2 -27
  193. mindspore/parallel/parameter_broadcast.py +3 -4
  194. mindspore/parallel/shard.py +162 -31
  195. mindspore/parallel/transform_safetensors.py +993 -0
  196. mindspore/profiler/__init__.py +2 -1
  197. mindspore/profiler/common/constant.py +29 -0
  198. mindspore/profiler/common/registry.py +47 -0
  199. mindspore/profiler/common/util.py +28 -0
  200. mindspore/profiler/dynamic_profiler.py +694 -0
  201. mindspore/profiler/envprofiling.py +17 -19
  202. mindspore/profiler/parser/ascend_analysis/constant.py +18 -0
  203. mindspore/profiler/parser/ascend_analysis/file_manager.py +25 -4
  204. mindspore/profiler/parser/ascend_analysis/function_event.py +43 -19
  205. mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +31 -26
  206. mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +56 -10
  207. mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +55 -8
  208. mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
  209. mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +27 -20
  210. mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +9 -2
  211. mindspore/profiler/parser/ascend_msprof_exporter.py +5 -4
  212. mindspore/profiler/parser/ascend_timeline_generator.py +27 -25
  213. mindspore/profiler/parser/base_timeline_generator.py +19 -25
  214. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +25 -12
  215. mindspore/profiler/parser/framework_parser.py +1 -391
  216. mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
  217. mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
  218. mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
  219. mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
  220. mindspore/profiler/parser/memory_usage_parser.py +0 -154
  221. mindspore/profiler/parser/profiler_info.py +78 -6
  222. mindspore/profiler/profiler.py +153 -0
  223. mindspore/profiler/profiling.py +280 -412
  224. mindspore/rewrite/__init__.py +1 -2
  225. mindspore/rewrite/common/namespace.py +4 -4
  226. mindspore/rewrite/symbol_tree/symbol_tree.py +3 -3
  227. mindspore/run_check/_check_version.py +36 -103
  228. mindspore/safeguard/rewrite_obfuscation.py +591 -247
  229. mindspore/swresample-4.dll +0 -0
  230. mindspore/swscale-6.dll +0 -0
  231. mindspore/tinyxml2.dll +0 -0
  232. mindspore/train/__init__.py +4 -3
  233. mindspore/train/_utils.py +28 -2
  234. mindspore/train/amp.py +171 -53
  235. mindspore/train/callback/__init__.py +2 -2
  236. mindspore/train/callback/_callback.py +4 -4
  237. mindspore/train/callback/_checkpoint.py +85 -22
  238. mindspore/train/callback/_cluster_monitor.py +1 -1
  239. mindspore/train/callback/_flops_collector.py +1 -0
  240. mindspore/train/callback/_loss_monitor.py +3 -3
  241. mindspore/train/callback/_on_request_exit.py +134 -31
  242. mindspore/train/callback/_summary_collector.py +5 -5
  243. mindspore/train/callback/_tft_register.py +352 -0
  244. mindspore/train/dataset_helper.py +7 -3
  245. mindspore/train/metrics/metric.py +3 -3
  246. mindspore/train/metrics/roc.py +4 -4
  247. mindspore/train/mind_ir_pb2.py +44 -39
  248. mindspore/train/model.py +134 -58
  249. mindspore/train/serialization.py +336 -112
  250. mindspore/turbojpeg.dll +0 -0
  251. mindspore/utils/__init__.py +21 -0
  252. mindspore/utils/utils.py +60 -0
  253. mindspore/version.py +1 -1
  254. {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/METADATA +6 -2
  255. {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/RECORD +258 -252
  256. mindspore/include/c_api/ms/abstract.h +0 -67
  257. mindspore/include/c_api/ms/attribute.h +0 -197
  258. mindspore/include/c_api/ms/base/handle_types.h +0 -43
  259. mindspore/include/c_api/ms/base/macros.h +0 -32
  260. mindspore/include/c_api/ms/base/status.h +0 -33
  261. mindspore/include/c_api/ms/base/types.h +0 -283
  262. mindspore/include/c_api/ms/context.h +0 -102
  263. mindspore/include/c_api/ms/graph.h +0 -160
  264. mindspore/include/c_api/ms/node.h +0 -606
  265. mindspore/include/c_api/ms/tensor.h +0 -161
  266. mindspore/include/c_api/ms/value.h +0 -84
  267. mindspore/mindspore_shared_lib.dll +0 -0
  268. mindspore/nn/extend/basic.py +0 -140
  269. mindspore/nn/extend/embedding.py +0 -143
  270. mindspore/nn/extend/layer/normalization.py +0 -109
  271. mindspore/nn/extend/pooling.py +0 -117
  272. mindspore/nn/layer/embedding_service.py +0 -531
  273. mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +0 -93
  274. mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +0 -66
  275. mindspore/ops/extend/__init__.py +0 -53
  276. mindspore/ops/extend/array_func.py +0 -218
  277. mindspore/ops/extend/math_func.py +0 -76
  278. mindspore/ops/extend/nn_func.py +0 -308
  279. mindspore/ops/silent_check.py +0 -162
  280. mindspore/profiler/parser/msadvisor_analyzer.py +0 -82
  281. mindspore/profiler/parser/msadvisor_parser.py +0 -240
  282. mindspore/train/callback/_mindio_ttp.py +0 -443
  283. {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/WHEEL +0 -0
  284. {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/entry_points.txt +0 -0
  285. {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/top_level.txt +0 -0
@@ -26,6 +26,8 @@ from mindspore._c_expression import BroadcastToPrim_
26
26
  from mindspore._c_expression import ConcatPrim_
27
27
  from mindspore._c_expression import ConvolutionGradPrim_
28
28
  from mindspore._c_expression import ConvolutionPrim_
29
+ from mindspore._c_expression import CrossPrim_
30
+ from mindspore._c_expression import CummaxPrim_
29
31
  from mindspore._c_expression import EluExtPrim_
30
32
  from mindspore._c_expression import FFNExtPrim_
31
33
  from mindspore._c_expression import FlashAttentionScoreGradPrim_
@@ -34,20 +36,30 @@ from mindspore._c_expression import GridSampler2DGradPrim_
34
36
  from mindspore._c_expression import GridSampler2DPrim_
35
37
  from mindspore._c_expression import GridSampler3DGradPrim_
36
38
  from mindspore._c_expression import GridSampler3DPrim_
39
+ from mindspore._c_expression import HShrinkGradPrim_
40
+ from mindspore._c_expression import HShrinkPrim_
41
+ from mindspore._c_expression import IncreFlashAttentionPrim_
37
42
  from mindspore._c_expression import IsClosePrim_
43
+ from mindspore._c_expression import LogSoftmaxGradPrim_
44
+ from mindspore._c_expression import LogSoftmaxPrim_
38
45
  from mindspore._c_expression import MatMulPrim_
39
46
  from mindspore._c_expression import MaxPoolGradWithIndicesPrim_
40
47
  from mindspore._c_expression import MaxPoolGradWithMaskPrim_
41
48
  from mindspore._c_expression import MaxPoolWithIndicesPrim_
42
49
  from mindspore._c_expression import MaxPoolWithMaskPrim_
50
+ from mindspore._c_expression import NanToNumPrim_
43
51
  from mindspore._c_expression import OneHotExtPrim_
44
52
  from mindspore._c_expression import ReduceAllPrim_
45
53
  from mindspore._c_expression import ReduceAnyPrim_
46
54
  from mindspore._c_expression import ReverseV2Prim_
47
55
  from mindspore._c_expression import RmsNormPrim_
56
+ from mindspore._c_expression import RollPrim_
48
57
  from mindspore._c_expression import SearchSortedPrim_
49
58
  from mindspore._c_expression import SoftmaxPrim_
59
+ from mindspore._c_expression import SoftShrinkGradPrim_
60
+ from mindspore._c_expression import SoftShrinkPrim_
50
61
  from mindspore._c_expression import StackExtPrim_
62
+ from mindspore._c_expression import TrilExtPrim_
51
63
  from mindspore._c_expression import TriuPrim_
52
64
  from mindspore._c_expression import UpsampleTrilinear3DGradPrim_
53
65
  from mindspore._c_expression import UpsampleTrilinear3DPrim_
@@ -94,8 +106,8 @@ batch_norm_grad_ext_impl = _PyboostBatchNormGradExtPrim()
94
106
 
95
107
  class _PyboostBinaryCrossEntropyGradPrim(BinaryCrossEntropyGradPrim_):
96
108
  def __call__(self, input, target, grad_output, weight, reduction):
97
- converted_reduction = str_to_enum(reduction)
98
- return _convert_stub(super().__call__(input, target, grad_output, weight, reduction))
109
+ converted_reduction = str_to_enum('binary_cross_entropy_grad', 'reduction', reduction)
110
+ return _convert_stub(super().__call__(input, target, grad_output, weight, converted_reduction))
99
111
 
100
112
 
101
113
  binary_cross_entropy_grad_impl = _PyboostBinaryCrossEntropyGradPrim()
@@ -103,8 +115,8 @@ binary_cross_entropy_grad_impl = _PyboostBinaryCrossEntropyGradPrim()
103
115
 
104
116
  class _PyboostBinaryCrossEntropyPrim(BinaryCrossEntropyPrim_):
105
117
  def __call__(self, input, target, weight, reduction):
106
- converted_reduction = str_to_enum(reduction)
107
- return _convert_stub(super().__call__(input, target, weight, reduction))
118
+ converted_reduction = str_to_enum('binary_cross_entropy', 'reduction', reduction)
119
+ return _convert_stub(super().__call__(input, target, weight, converted_reduction))
108
120
 
109
121
 
110
122
  binary_cross_entropy_impl = _PyboostBinaryCrossEntropyPrim()
@@ -112,8 +124,8 @@ binary_cross_entropy_impl = _PyboostBinaryCrossEntropyPrim()
112
124
 
113
125
  class _PyboostBCEWithLogitsLossPrim(BCEWithLogitsLossPrim_):
114
126
  def __call__(self, input, target, weight, posWeight, reduction):
115
- converted_reduction = str_to_enum(reduction)
116
- return _convert_stub(super().__call__(input, target, weight, posWeight, reduction))
127
+ converted_reduction = str_to_enum('binary_cross_entropy_with_logits', 'reduction', reduction)
128
+ return _convert_stub(super().__call__(input, target, weight, posWeight, converted_reduction))
117
129
 
118
130
 
119
131
  binary_cross_entropy_with_logits_impl = _PyboostBCEWithLogitsLossPrim()
@@ -139,11 +151,11 @@ concat_impl = _PyboostConcatPrim()
139
151
 
140
152
  class _PyboostConvolutionGradPrim(ConvolutionGradPrim_):
141
153
  def __call__(self, dout, input, weight, bias, stride, padding, dilation, transposed, output_padding, groups, output_mask):
142
- converted_stride = to_strides(stride)
143
- converted_padding = to_2d_paddings(padding)
144
- converted_dilation = to_dilations(dilation)
145
- converted_output_padding = to_output_padding(output_padding)
146
- return _convert_stub(super().__call__(dout, input, weight, bias, stride, padding, dilation, transposed, output_padding, groups, output_mask))
154
+ converted_stride = to_strides('convolution_grad', 'stride', stride)
155
+ converted_padding = to_2d_paddings('convolution_grad', 'padding', padding)
156
+ converted_dilation = to_dilations('convolution_grad', 'dilation', dilation)
157
+ converted_output_padding = to_output_padding('convolution_grad', 'output_padding', output_padding)
158
+ return _convert_stub(super().__call__(dout, input, weight, bias, converted_stride, converted_padding, converted_dilation, transposed, converted_output_padding, groups, output_mask))
147
159
 
148
160
 
149
161
  convolution_grad_impl = _PyboostConvolutionGradPrim()
@@ -151,16 +163,34 @@ convolution_grad_impl = _PyboostConvolutionGradPrim()
151
163
 
152
164
  class _PyboostConvolutionPrim(ConvolutionPrim_):
153
165
  def __call__(self, input, weight, bias, stride, padding, dilation, transposed, output_padding, groups):
154
- converted_stride = to_strides(stride)
155
- converted_padding = to_2d_paddings(padding)
156
- converted_dilation = to_dilations(dilation)
157
- converted_output_padding = to_output_padding(output_padding)
158
- return _convert_stub(super().__call__(input, weight, bias, stride, padding, dilation, transposed, output_padding, groups))
166
+ converted_stride = to_strides('convolution', 'stride', stride)
167
+ converted_padding = to_2d_paddings('convolution', 'padding', padding)
168
+ converted_dilation = to_dilations('convolution', 'dilation', dilation)
169
+ converted_output_padding = to_output_padding('convolution', 'output_padding', output_padding)
170
+ return _convert_stub(super().__call__(input, weight, bias, converted_stride, converted_padding, converted_dilation, transposed, converted_output_padding, groups))
159
171
 
160
172
 
161
173
  convolution_impl = _PyboostConvolutionPrim()
162
174
 
163
175
 
176
+ class _PyboostCrossPrim(CrossPrim_):
177
+ def __call__(self, input, other, dim):
178
+
179
+ return _convert_stub(super().__call__(input, other, dim))
180
+
181
+
182
+ cross_impl = _PyboostCrossPrim()
183
+
184
+
185
+ class _PyboostCummaxPrim(CummaxPrim_):
186
+ def __call__(self, input, axis):
187
+
188
+ return _convert_stub(super().__call__(input, axis))
189
+
190
+
191
+ cummax_impl = _PyboostCummaxPrim()
192
+
193
+
164
194
  class _PyboostEluExtPrim(EluExtPrim_):
165
195
  def __call__(self, input, alpha):
166
196
 
@@ -172,8 +202,8 @@ elu_ext_impl = _PyboostEluExtPrim()
172
202
 
173
203
  class _PyboostFFNExtPrim(FFNExtPrim_):
174
204
  def __call__(self, x, weight1, weight2, expertTokens, bias1, bias2, scale, offset, deqScale1, deqScale2, antiquant_scale1, antiquant_scale2, antiquant_offset1, antiquant_offset2, activation, inner_precise):
175
- converted_activation = str_to_enum(activation)
176
- return _convert_stub(super().__call__(x, weight1, weight2, expertTokens, bias1, bias2, scale, offset, deqScale1, deqScale2, antiquant_scale1, antiquant_scale2, antiquant_offset1, antiquant_offset2, activation, inner_precise))
205
+ converted_activation = str_to_enum('ffn_ext', 'activation', activation)
206
+ return _convert_stub(super().__call__(x, weight1, weight2, expertTokens, bias1, bias2, scale, offset, deqScale1, deqScale2, antiquant_scale1, antiquant_scale2, antiquant_offset1, antiquant_offset2, converted_activation, inner_precise))
177
207
 
178
208
 
179
209
  ffn_ext_impl = _PyboostFFNExtPrim()
@@ -181,8 +211,8 @@ ffn_ext_impl = _PyboostFFNExtPrim()
181
211
 
182
212
  class _PyboostFlashAttentionScoreGradPrim(FlashAttentionScoreGradPrim_):
183
213
  def __call__(self, query, key, value, dy, pse_shift, drop_mask, padding_mask, atten_mask, softmax_max, softmax_sum, softmax_in, attention_in, prefix, actual_seq_qlen, actual_seq_kvlen, head_num, keep_prob, scale_value, pre_tokens, next_tokens, inner_precise, input_layout, sparse_mode):
184
- converted_input_layout = str_to_enum(input_layout)
185
- return _convert_stub(super().__call__(query, key, value, dy, pse_shift, drop_mask, padding_mask, atten_mask, softmax_max, softmax_sum, softmax_in, attention_in, prefix, actual_seq_qlen, actual_seq_kvlen, head_num, keep_prob, scale_value, pre_tokens, next_tokens, inner_precise, input_layout, sparse_mode))
214
+ converted_input_layout = str_to_enum('flash_attention_score_grad', 'input_layout', input_layout)
215
+ return _convert_stub(super().__call__(query, key, value, dy, pse_shift, drop_mask, padding_mask, atten_mask, softmax_max, softmax_sum, softmax_in, attention_in, prefix, actual_seq_qlen, actual_seq_kvlen, head_num, keep_prob, scale_value, pre_tokens, next_tokens, inner_precise, converted_input_layout, sparse_mode))
186
216
 
187
217
 
188
218
  flash_attention_score_grad_impl = _PyboostFlashAttentionScoreGradPrim()
@@ -190,8 +220,8 @@ flash_attention_score_grad_impl = _PyboostFlashAttentionScoreGradPrim()
190
220
 
191
221
  class _PyboostFlashAttentionScorePrim(FlashAttentionScorePrim_):
192
222
  def __call__(self, query, key, value, real_shift, drop_mask, padding_mask, attn_mask, prefix, actual_seq_qlen, actual_seq_kvlen, head_num, keep_prob, scale_value, pre_tokens, next_tokens, inner_precise, input_layout, sparse_mode):
193
- converted_input_layout = str_to_enum(input_layout)
194
- return _convert_stub(super().__call__(query, key, value, real_shift, drop_mask, padding_mask, attn_mask, prefix, actual_seq_qlen, actual_seq_kvlen, head_num, keep_prob, scale_value, pre_tokens, next_tokens, inner_precise, input_layout, sparse_mode))
223
+ converted_input_layout = str_to_enum('flash_attention_score', 'input_layout', input_layout)
224
+ return _convert_stub(super().__call__(query, key, value, real_shift, drop_mask, padding_mask, attn_mask, prefix, actual_seq_qlen, actual_seq_kvlen, head_num, keep_prob, scale_value, pre_tokens, next_tokens, inner_precise, converted_input_layout, sparse_mode))
195
225
 
196
226
 
197
227
  flash_attention_score_impl = _PyboostFlashAttentionScorePrim()
@@ -199,9 +229,9 @@ flash_attention_score_impl = _PyboostFlashAttentionScorePrim()
199
229
 
200
230
  class _PyboostGridSampler2DGradPrim(GridSampler2DGradPrim_):
201
231
  def __call__(self, grad, input_x, grid, interpolation_mode, padding_mode, align_corners):
202
- converted_interpolation_mode = str_to_enum(interpolation_mode)
203
- converted_padding_mode = str_to_enum(padding_mode)
204
- return _convert_stub(super().__call__(grad, input_x, grid, interpolation_mode, padding_mode, align_corners))
232
+ converted_interpolation_mode = str_to_enum('grid_sampler_2d_grad', 'interpolation_mode', interpolation_mode)
233
+ converted_padding_mode = str_to_enum('grid_sampler_2d_grad', 'padding_mode', padding_mode)
234
+ return _convert_stub(super().__call__(grad, input_x, grid, converted_interpolation_mode, converted_padding_mode, align_corners))
205
235
 
206
236
 
207
237
  grid_sampler_2d_grad_impl = _PyboostGridSampler2DGradPrim()
@@ -209,9 +239,9 @@ grid_sampler_2d_grad_impl = _PyboostGridSampler2DGradPrim()
209
239
 
210
240
  class _PyboostGridSampler2DPrim(GridSampler2DPrim_):
211
241
  def __call__(self, input_x, grid, interpolation_mode, padding_mode, align_corners):
212
- converted_interpolation_mode = str_to_enum(interpolation_mode)
213
- converted_padding_mode = str_to_enum(padding_mode)
214
- return _convert_stub(super().__call__(input_x, grid, interpolation_mode, padding_mode, align_corners))
242
+ converted_interpolation_mode = str_to_enum('grid_sampler_2d', 'interpolation_mode', interpolation_mode)
243
+ converted_padding_mode = str_to_enum('grid_sampler_2d', 'padding_mode', padding_mode)
244
+ return _convert_stub(super().__call__(input_x, grid, converted_interpolation_mode, converted_padding_mode, align_corners))
215
245
 
216
246
 
217
247
  grid_sampler_2d_impl = _PyboostGridSampler2DPrim()
@@ -219,9 +249,9 @@ grid_sampler_2d_impl = _PyboostGridSampler2DPrim()
219
249
 
220
250
  class _PyboostGridSampler3DGradPrim(GridSampler3DGradPrim_):
221
251
  def __call__(self, grad, input_x, grid, interpolation_mode, padding_mode, align_corners):
222
- converted_interpolation_mode = str_to_enum(interpolation_mode)
223
- converted_padding_mode = str_to_enum(padding_mode)
224
- return _convert_stub(super().__call__(grad, input_x, grid, interpolation_mode, padding_mode, align_corners))
252
+ converted_interpolation_mode = str_to_enum('grid_sampler_3d_grad', 'interpolation_mode', interpolation_mode)
253
+ converted_padding_mode = str_to_enum('grid_sampler_3d_grad', 'padding_mode', padding_mode)
254
+ return _convert_stub(super().__call__(grad, input_x, grid, converted_interpolation_mode, converted_padding_mode, align_corners))
225
255
 
226
256
 
227
257
  grid_sampler_3d_grad_impl = _PyboostGridSampler3DGradPrim()
@@ -229,14 +259,41 @@ grid_sampler_3d_grad_impl = _PyboostGridSampler3DGradPrim()
229
259
 
230
260
  class _PyboostGridSampler3DPrim(GridSampler3DPrim_):
231
261
  def __call__(self, input_x, grid, interpolation_mode, padding_mode, align_corners):
232
- converted_interpolation_mode = str_to_enum(interpolation_mode)
233
- converted_padding_mode = str_to_enum(padding_mode)
234
- return _convert_stub(super().__call__(input_x, grid, interpolation_mode, padding_mode, align_corners))
262
+ converted_interpolation_mode = str_to_enum('grid_sampler_3d', 'interpolation_mode', interpolation_mode)
263
+ converted_padding_mode = str_to_enum('grid_sampler_3d', 'padding_mode', padding_mode)
264
+ return _convert_stub(super().__call__(input_x, grid, converted_interpolation_mode, converted_padding_mode, align_corners))
235
265
 
236
266
 
237
267
  grid_sampler_3d_impl = _PyboostGridSampler3DPrim()
238
268
 
239
269
 
270
+ class _PyboostHShrinkGradPrim(HShrinkGradPrim_):
271
+ def __call__(self, gradients, features, lambd):
272
+
273
+ return _convert_stub(super().__call__(gradients, features, lambd))
274
+
275
+
276
+ hshrink_grad_impl = _PyboostHShrinkGradPrim()
277
+
278
+
279
+ class _PyboostHShrinkPrim(HShrinkPrim_):
280
+ def __call__(self, input, lambd):
281
+
282
+ return _convert_stub(super().__call__(input, lambd))
283
+
284
+
285
+ hshrink_impl = _PyboostHShrinkPrim()
286
+
287
+
288
+ class _PyboostIncreFlashAttentionPrim(IncreFlashAttentionPrim_):
289
+ def __call__(self, query, key, value, attn_mask, actual_seq_lengths, pse_shift, dequant_scale1, quant_scale1, dequant_scale2, quant_scale2, quant_offset2, antiquant_scale, antiquant_offset, block_table, kv_padding_size, num_heads, input_layout, scale_value, num_key_value_heads, block_size, inner_precise):
290
+ converted_input_layout = str_to_enum('incre_flash_attention', 'input_layout', input_layout)
291
+ return _convert_stub(super().__call__(query, key, value, attn_mask, actual_seq_lengths, pse_shift, dequant_scale1, quant_scale1, dequant_scale2, quant_scale2, quant_offset2, antiquant_scale, antiquant_offset, block_table, kv_padding_size, num_heads, converted_input_layout, scale_value, num_key_value_heads, block_size, inner_precise))
292
+
293
+
294
+ incre_flash_attention_impl = _PyboostIncreFlashAttentionPrim()
295
+
296
+
240
297
  class _PyboostIsClosePrim(IsClosePrim_):
241
298
  def __call__(self, input, other, rtol, atol, equal_nan):
242
299
 
@@ -246,6 +303,24 @@ class _PyboostIsClosePrim(IsClosePrim_):
246
303
  isclose_impl = _PyboostIsClosePrim()
247
304
 
248
305
 
306
+ class _PyboostLogSoftmaxGradPrim(LogSoftmaxGradPrim_):
307
+ def __call__(self, logits, grad, axis):
308
+
309
+ return _convert_stub(super().__call__(logits, grad, axis))
310
+
311
+
312
+ log_softmax_grad_impl = _PyboostLogSoftmaxGradPrim()
313
+
314
+
315
+ class _PyboostLogSoftmaxPrim(LogSoftmaxPrim_):
316
+ def __call__(self, logits, axis):
317
+
318
+ return _convert_stub(super().__call__(logits, axis))
319
+
320
+
321
+ log_softmax_impl = _PyboostLogSoftmaxPrim()
322
+
323
+
249
324
  class _PyboostMatMulPrim(MatMulPrim_):
250
325
  def __call__(self, input, mat2, transpose_a, transpose_b):
251
326
 
@@ -257,11 +332,11 @@ matmul_impl = _PyboostMatMulPrim()
257
332
 
258
333
  class _PyboostMaxPoolGradWithIndicesPrim(MaxPoolGradWithIndicesPrim_):
259
334
  def __call__(self, x, grad, argmax, kernel_size, strides, pads, dilation, ceil_mode, argmax_type):
260
- converted_kernel_size = to_kernel_size(kernel_size)
261
- converted_strides = to_strides(strides)
262
- converted_pads = to_output_padding(pads)
263
- converted_dilation = to_dilations(dilation)
264
- return _convert_stub(super().__call__(x, grad, argmax, kernel_size, strides, pads, dilation, ceil_mode, argmax_type))
335
+ converted_kernel_size = to_kernel_size('max_pool_grad_with_indices', 'kernel_size', kernel_size)
336
+ converted_strides = to_strides('max_pool_grad_with_indices', 'strides', strides)
337
+ converted_pads = to_output_padding('max_pool_grad_with_indices', 'pads', pads)
338
+ converted_dilation = to_dilations('max_pool_grad_with_indices', 'dilation', dilation)
339
+ return _convert_stub(super().__call__(x, grad, argmax, converted_kernel_size, converted_strides, converted_pads, converted_dilation, ceil_mode, argmax_type))
265
340
 
266
341
 
267
342
  max_pool_grad_with_indices_impl = _PyboostMaxPoolGradWithIndicesPrim()
@@ -269,11 +344,11 @@ max_pool_grad_with_indices_impl = _PyboostMaxPoolGradWithIndicesPrim()
269
344
 
270
345
  class _PyboostMaxPoolGradWithMaskPrim(MaxPoolGradWithMaskPrim_):
271
346
  def __call__(self, x, grad, mask, kernel_size, strides, pads, dilation, ceil_mode, argmax_type):
272
- converted_kernel_size = to_kernel_size(kernel_size)
273
- converted_strides = to_strides(strides)
274
- converted_pads = to_output_padding(pads)
275
- converted_dilation = to_dilations(dilation)
276
- return _convert_stub(super().__call__(x, grad, mask, kernel_size, strides, pads, dilation, ceil_mode, argmax_type))
347
+ converted_kernel_size = to_kernel_size('max_pool_grad_with_mask', 'kernel_size', kernel_size)
348
+ converted_strides = to_strides('max_pool_grad_with_mask', 'strides', strides)
349
+ converted_pads = to_output_padding('max_pool_grad_with_mask', 'pads', pads)
350
+ converted_dilation = to_dilations('max_pool_grad_with_mask', 'dilation', dilation)
351
+ return _convert_stub(super().__call__(x, grad, mask, converted_kernel_size, converted_strides, converted_pads, converted_dilation, ceil_mode, argmax_type))
277
352
 
278
353
 
279
354
  max_pool_grad_with_mask_impl = _PyboostMaxPoolGradWithMaskPrim()
@@ -281,11 +356,11 @@ max_pool_grad_with_mask_impl = _PyboostMaxPoolGradWithMaskPrim()
281
356
 
282
357
  class _PyboostMaxPoolWithIndicesPrim(MaxPoolWithIndicesPrim_):
283
358
  def __call__(self, x, kernel_size, strides, pads, dilation, ceil_mode, argmax_type):
284
- converted_kernel_size = to_kernel_size(kernel_size)
285
- converted_strides = to_strides(strides)
286
- converted_pads = to_output_padding(pads)
287
- converted_dilation = to_dilations(dilation)
288
- return _convert_stub(super().__call__(x, kernel_size, strides, pads, dilation, ceil_mode, argmax_type))
359
+ converted_kernel_size = to_kernel_size('max_pool_with_indices', 'kernel_size', kernel_size)
360
+ converted_strides = to_strides('max_pool_with_indices', 'strides', strides)
361
+ converted_pads = to_output_padding('max_pool_with_indices', 'pads', pads)
362
+ converted_dilation = to_dilations('max_pool_with_indices', 'dilation', dilation)
363
+ return _convert_stub(super().__call__(x, converted_kernel_size, converted_strides, converted_pads, converted_dilation, ceil_mode, argmax_type))
289
364
 
290
365
 
291
366
  max_pool_with_indices_impl = _PyboostMaxPoolWithIndicesPrim()
@@ -293,16 +368,25 @@ max_pool_with_indices_impl = _PyboostMaxPoolWithIndicesPrim()
293
368
 
294
369
  class _PyboostMaxPoolWithMaskPrim(MaxPoolWithMaskPrim_):
295
370
  def __call__(self, x, kernel_size, strides, pads, dilation, ceil_mode, argmax_type):
296
- converted_kernel_size = to_kernel_size(kernel_size)
297
- converted_strides = to_strides(strides)
298
- converted_pads = to_output_padding(pads)
299
- converted_dilation = to_dilations(dilation)
300
- return _convert_stub(super().__call__(x, kernel_size, strides, pads, dilation, ceil_mode, argmax_type))
371
+ converted_kernel_size = to_kernel_size('max_pool_with_mask', 'kernel_size', kernel_size)
372
+ converted_strides = to_strides('max_pool_with_mask', 'strides', strides)
373
+ converted_pads = to_output_padding('max_pool_with_mask', 'pads', pads)
374
+ converted_dilation = to_dilations('max_pool_with_mask', 'dilation', dilation)
375
+ return _convert_stub(super().__call__(x, converted_kernel_size, converted_strides, converted_pads, converted_dilation, ceil_mode, argmax_type))
301
376
 
302
377
 
303
378
  max_pool_with_mask_impl = _PyboostMaxPoolWithMaskPrim()
304
379
 
305
380
 
381
+ class _PyboostNanToNumPrim(NanToNumPrim_):
382
+ def __call__(self, input, nan, posinf, neginf):
383
+
384
+ return _convert_stub(super().__call__(input, nan, posinf, neginf))
385
+
386
+
387
+ nan_to_num_impl = _PyboostNanToNumPrim()
388
+
389
+
306
390
  class _PyboostOneHotExtPrim(OneHotExtPrim_):
307
391
  def __call__(self, tensor, num_classes, on_value, off_value, axis):
308
392
 
@@ -348,6 +432,15 @@ class _PyboostRmsNormPrim(RmsNormPrim_):
348
432
  rms_norm_impl = _PyboostRmsNormPrim()
349
433
 
350
434
 
435
+ class _PyboostRollPrim(RollPrim_):
436
+ def __call__(self, input, shift, axis):
437
+
438
+ return _convert_stub(super().__call__(input, shift, axis))
439
+
440
+
441
+ roll_impl = _PyboostRollPrim()
442
+
443
+
351
444
  class _PyboostSearchSortedPrim(SearchSortedPrim_):
352
445
  def __call__(self, sorted_sequence, values, sorter, dtype, right):
353
446
 
@@ -366,6 +459,24 @@ class _PyboostSoftmaxPrim(SoftmaxPrim_):
366
459
  softmax_impl = _PyboostSoftmaxPrim()
367
460
 
368
461
 
462
+ class _PyboostSoftShrinkGradPrim(SoftShrinkGradPrim_):
463
+ def __call__(self, input_grad, input_x, lambd):
464
+
465
+ return _convert_stub(super().__call__(input_grad, input_x, lambd))
466
+
467
+
468
+ softshrink_grad_impl = _PyboostSoftShrinkGradPrim()
469
+
470
+
471
+ class _PyboostSoftShrinkPrim(SoftShrinkPrim_):
472
+ def __call__(self, input, lambd):
473
+
474
+ return _convert_stub(super().__call__(input, lambd))
475
+
476
+
477
+ softshrink_impl = _PyboostSoftShrinkPrim()
478
+
479
+
369
480
  class _PyboostStackExtPrim(StackExtPrim_):
370
481
  def __call__(self, tensors, dim):
371
482
 
@@ -375,6 +486,15 @@ class _PyboostStackExtPrim(StackExtPrim_):
375
486
  stack_ext_impl = _PyboostStackExtPrim()
376
487
 
377
488
 
489
+ class _PyboostTrilExtPrim(TrilExtPrim_):
490
+ def __call__(self, input, diagonal):
491
+
492
+ return _convert_stub(super().__call__(input, diagonal))
493
+
494
+
495
+ tril_ext_impl = _PyboostTrilExtPrim()
496
+
497
+
378
498
  class _PyboostTriuPrim(TriuPrim_):
379
499
  def __call__(self, input, diagonal):
380
500
 
@@ -412,9 +532,9 @@ grouped_matmul_impl = _PyboostGroupedMatmulPrim()
412
532
 
413
533
 
414
534
  class _PyboostQuantBatchMatmulPrim(QuantBatchMatmulPrim_):
415
- def __call__(self, x1, x2, scale, offset, bias, transpose_x1, transpose_x2, dtype):
535
+ def __call__(self, x1, x2, scale, offset, bias, pertokenScaleOptional, transpose_x1, transpose_x2, dtype):
416
536
 
417
- return _convert_stub(super().__call__(x1, x2, scale, offset, bias, transpose_x1, transpose_x2, dtype))
537
+ return _convert_stub(super().__call__(x1, x2, scale, offset, bias, pertokenScaleOptional, transpose_x1, transpose_x2, dtype))
418
538
 
419
539
 
420
540
  quant_batch_matmul_impl = _PyboostQuantBatchMatmulPrim()
@@ -30,7 +30,7 @@ from mindspore._c_expression import GradOperation_, HyperMap_, Map_, MultitypeFu
30
30
  SequenceSliceGetItem_, ListSliceSetItem_, VmapOperation_, TaylorOperation_, ListPop_, \
31
31
  ListClear_, ListReverse_, ListExtend_, DictClear_, DictHasKey_, DictUpdate_, DictFromKeys_, \
32
32
  ZerosLike_, TensorIndexGetitem_, TensorIndexSetitem_, ListAdd_, DictSetItem_, \
33
- HandleBoolTensor_, PreSetitemByTuple_, StarredGetItem_,\
33
+ HandleBoolTensor_, PreSetitemByTuple_, StarredGetItem_, \
34
34
  StarredUnpack_, StarredUnpackMerge_, IterConverter_, HasNext_, Next_, MSContext
35
35
  from mindspore.common import dtype as mstype
36
36
  from mindspore.common.api import jit, _pynative_executor, _wrap_func
@@ -346,9 +346,11 @@ class GradOperation(GradOperation_):
346
346
  self.grad_position = (0,)
347
347
 
348
348
  def __call__(self, fn, weights=None):
349
- weights_id = _get_grad_weights_id(weights)
350
- if self.grad_fn is not None and self.fn == fn and self.weights_id == weights_id:
351
- return self.grad_fn
349
+ weights_id = ''
350
+ if context.get_context("mode") == context.GRAPH_MODE:
351
+ weights_id = _get_grad_weights_id(weights)
352
+ if self.grad_fn is not None and self.fn == fn and self.weights_id == weights_id:
353
+ return self.grad_fn
352
354
  grad_ = GradOperation(self.get_all, self.get_by_list, self.sens_param)
353
355
  # If calling Grad in GRAPH_MODE or calling Grad in functions decorated with 'jit', do grad in GRAPH_MODE
354
356
  # If calling Grad in pure PYNATIVE_MODE do grad in PYNATIVE_MODE
@@ -374,8 +376,8 @@ class GradOperation(GradOperation_):
374
376
 
375
377
  @_wrap_func
376
378
  def after_grad(*args, **kwargs):
377
- self._pynative_forward_run(fn, grad_, weights, args, kwargs)
378
- out = _pynative_executor.grad(fn, grad_, weights, self.grad_position, *args, **kwargs)
379
+ run_args = self._pynative_forward_run(fn, grad_, weights, *args, **kwargs)
380
+ out = _pynative_executor.grad(fn, grad_, weights, self.grad_position, *run_args)
379
381
  out = _grads_divided_by_device_num_if_recomputation(out)
380
382
  return out
381
383
  else:
@@ -396,26 +398,39 @@ class GradOperation(GradOperation_):
396
398
  self.weights_id = weights_id
397
399
  return self.grad_fn
398
400
 
399
- def _pynative_forward_run(self, fn, grad, weights, args, kwargs):
400
- """ Pynative forward run to build grad graph. """
401
- new_kwargs = kwargs
401
+ def _pynative_forward_run(self, fn, grad, weights, *args, **kwargs):
402
+ """ PyNative forward run to build grad graph. """
403
+ sens = None
402
404
  if self.sens_param:
403
- if 'sens' not in kwargs.keys():
404
- args = args[:-1]
405
+ if 'sens' in kwargs.keys():
406
+ sens = kwargs.pop('sens')
405
407
  else:
406
- new_kwargs = kwargs.copy()
407
- new_kwargs.pop('sens')
408
+ # default use args last elem as sens
409
+ sens = args[-1]
410
+ args = args[:-1]
411
+ run_args = args
412
+ if kwargs:
413
+ run_args = args + tuple(kwargs.values())
414
+
415
+ # check run exclude sens
408
416
  if isinstance(fn, (FunctionType, MethodType)):
409
- if not _pynative_executor.check_run(grad, fn, weights, None, *args, **new_kwargs):
417
+ if not _pynative_executor.check_run(grad, fn, weights, None, *run_args):
410
418
  _pynative_executor.set_grad_flag(True)
411
- _pynative_executor.new_graph(fn, *args, **new_kwargs)
412
- output = fn(*args, **new_kwargs)
413
- _pynative_executor.end_graph(fn, output, *args, **new_kwargs)
419
+ _pynative_executor.new_graph(fn, *args, **kwargs)
420
+ output = fn(*args, **kwargs)
421
+ _pynative_executor.end_graph(fn, output, *args, **kwargs)
414
422
  else:
415
- # Check if fn have run already
416
- if not _pynative_executor.check_run(grad, fn, weights, None, *args, **new_kwargs):
417
- _pynative_executor.set_grad_flag(True)
418
- fn(*args, **new_kwargs)
423
+ # Check if fn has run already
424
+ if not _pynative_executor.check_run(grad, fn, weights, None, *run_args):
425
+ requires_grad = fn.requires_grad
426
+ fn.requires_grad = True
427
+ fn(*args, **kwargs)
428
+ fn.requires_grad = requires_grad
429
+
430
+ # If it has sens, keep sens as the last element
431
+ if sens is not None:
432
+ run_args += (sens,) if sens is not isinstance(run_args, tuple) else sens
433
+ return run_args
419
434
 
420
435
 
421
436
  class _TaylorOperation(TaylorOperation_):
@@ -552,13 +567,15 @@ class _Grad(GradOperation_):
552
567
  self.weights_id = None
553
568
 
554
569
  def __call__(self, fn, weights=None, grad_position=0):
555
- weights_id = _get_grad_weights_id(weights)
556
- if self.grad_fn is not None and self.fn == fn and self.grad_position == grad_position and \
557
- self.weights_id == weights_id:
558
- return self.grad_fn
570
+ weights_id = ''
571
+ if context.get_context("mode") == context.GRAPH_MODE:
572
+ weights_id = _get_grad_weights_id(weights)
573
+ if self.grad_fn is not None and self.fn == fn and self.grad_position == grad_position and \
574
+ self.weights_id == weights_id:
575
+ return self.grad_fn
559
576
 
560
- def aux_fn(*args):
561
- outputs = fn(*args)
577
+ def aux_fn(*args, **kwargs):
578
+ outputs = fn(*args, **kwargs)
562
579
  if not isinstance(outputs, tuple) or len(outputs) < 2:
563
580
  raise ValueError("When has_aux is True, origin fn requires more than one outputs.")
564
581
  res = (outputs[0],)
@@ -597,8 +614,8 @@ class _Grad(GradOperation_):
597
614
 
598
615
  @_wrap_func
599
616
  def after_grad(*args, **kwargs):
600
- res = self._pynative_forward_run(fn, grad_, weights, args, kwargs)
601
- out = _pynative_executor.grad(fn, grad_, weights, grad_position, *args, **kwargs)
617
+ run_args, res = self._pynative_forward_run(fn, grad_, weights, *args, **kwargs)
618
+ out = _pynative_executor.grad(fn, grad_, weights, grad_position, *run_args)
602
619
  out = _grads_divided_by_device_num_if_recomputation(out)
603
620
  if self.return_ids and out:
604
621
  out = _combine_with_ids(grad_position, weights, out)
@@ -633,32 +650,49 @@ class _Grad(GradOperation_):
633
650
  self.weights_id = weights_id
634
651
  return self.grad_fn
635
652
 
636
- def _pynative_forward_run(self, fn, grad, weights, args, kwargs):
637
- """ Pynative forward runs to build grad graph. """
638
- new_kwargs = kwargs
639
- outputs = ()
653
+ def _pynative_forward_run(self, fn, grad, weights, *args, **kwargs):
654
+ """ PyNative forward runs to build grad graph. """
655
+ sens = None
640
656
  if self.sens_param:
641
657
  if 'sens' in kwargs.keys():
642
- new_kwargs = kwargs.copy()
643
- new_kwargs.pop('sens')
658
+ sens = kwargs.pop('sens')
644
659
  else:
660
+ # default use args last elem as sens
661
+ sens = args[-1]
645
662
  args = args[:-1]
663
+ run_args = args
664
+ if kwargs:
665
+ run_args = args + tuple(kwargs.values())
666
+
667
+ # check run exclude sens
668
+ outputs = ()
669
+ run_forward = False
646
670
  if isinstance(fn, (FunctionType, MethodType)):
647
- if not _pynative_executor.check_run(grad, fn, weights, self.grad_position, *args, **new_kwargs):
671
+ if not _pynative_executor.check_run(grad, fn, weights, self.grad_position, *run_args):
648
672
  _pynative_executor.set_grad_flag(True)
649
- _pynative_executor.new_graph(fn, *args, **new_kwargs)
650
- outputs = fn(*args, **new_kwargs)
651
- _pynative_executor.end_graph(fn, outputs, *args, **new_kwargs)
652
- return outputs
673
+ _pynative_executor.new_graph(fn, *args, **kwargs)
674
+ outputs = fn(*args, **kwargs)
675
+ _pynative_executor.end_graph(fn, outputs, *args, **kwargs)
676
+ run_forward = True
653
677
  else:
654
678
  # Check if fn has run already.
655
- if not _pynative_executor.check_run(grad, fn, weights, self.grad_position, *args, **new_kwargs):
656
- _pynative_executor.set_grad_flag(True)
657
- outputs = fn(*args, **new_kwargs)
658
- return outputs
679
+ if not _pynative_executor.check_run(grad, fn, weights, self.grad_position, *run_args):
680
+ requires_grad = fn.requires_grad
681
+ fn.requires_grad = True
682
+ outputs = fn(*args, **kwargs)
683
+ fn.requires_grad = requires_grad
684
+ run_forward = True
685
+ # If it has sens, keep sens as the last element
686
+ if sens is not None:
687
+ run_args += (sens,) if sens is not isinstance(run_args, tuple) else sens
688
+
689
+ # Normal run grad
690
+ if run_forward:
691
+ return run_args, outputs
692
+
659
693
  if (self.get_value or self.has_aux) and not outputs:
660
- outputs = fn(*args, **new_kwargs)
661
- return outputs
694
+ outputs = fn(*args, **kwargs)
695
+ return run_args, outputs
662
696
 
663
697
 
664
698
  class _Vmap(VmapOperation_):
@@ -806,10 +840,12 @@ class MultitypeFuncGraph(MultitypeFuncGraph_):
806
840
 
807
841
  class HyperMap(HyperMap_):
808
842
  """
809
- Hypermap will apply the set operation to input sequences.
843
+ HyperMap will apply the set operation to input sequences.
810
844
 
811
845
  Apply the operations to every element of the sequence or nested sequence. Different
812
- from `mindspore.ops.Map`, the `HyperMap` supports to apply on nested structure.
846
+ from `mindspore.ops.Map`, the `HyperMap` supports to apply on nested structure. The
847
+ `HyperMap` also supports dynamic sequences as input, but it does not extend this
848
+ support to nested dynamic sequences.
813
849
 
814
850
  Args:
815
851
  ops (Union[MultitypeFuncGraph, None], optional): `ops` is the operation to apply. If `ops` is `None`,
@@ -959,6 +995,7 @@ class _ListAppend(ListAppend_):
959
995
  Args:
960
996
  name (str): The name of the metafuncgraph object.
961
997
  """
998
+
962
999
  # `__init__` method removed entirely
963
1000
  def __call__(self, *args):
964
1001
  pass
@@ -483,6 +483,7 @@ def format_index_tensor(index, arg):
483
483
  index[format_idx] = F.select(index_tensor < 0, index_tensor + format_dim, index_tensor)
484
484
  return index
485
485
  index = Tensor(index)
486
+ format_dims = Tensor(format_dims)
486
487
  return F.select(index < 0, index + format_dims, index)
487
488
 
488
489
 
@@ -41,7 +41,7 @@ def _number_not_in_tuple(x, y):
41
41
  Returns:
42
42
  bool, if x not in y return true, x in y return false.
43
43
  """
44
- if F.is_sequence_shape_unknown(y) or not F.isconstant(x):
44
+ if F.is_sequence_value_unknown(y) or not F.isconstant(x):
45
45
  return not InSequence()(x, y)
46
46
  return not const_utils.scalar_in_sequence(x, y)
47
47
 
@@ -58,7 +58,7 @@ def _number_not_in_list(x, y):
58
58
  Returns:
59
59
  bool, if x not in y return true, x in y return false.
60
60
  """
61
- if F.is_sequence_shape_unknown(y) or not F.isconstant(x):
61
+ if F.is_sequence_value_unknown(y) or not F.isconstant(x):
62
62
  return not InSequence()(x, y)
63
63
  return not const_utils.scalar_in_sequence(x, y)
64
64