mindspore 2.3.0__cp310-cp310-win_amd64.whl → 2.4.0__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (308) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  3. mindspore/Newtonsoft.Json.dll +0 -0
  4. mindspore/__init__.py +3 -1
  5. mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
  6. mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
  7. mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
  8. mindspore/_checkparam.py +50 -9
  9. mindspore/_extends/parse/compile_config.py +41 -0
  10. mindspore/_extends/parse/parser.py +9 -7
  11. mindspore/_extends/parse/standard_method.py +52 -14
  12. mindspore/_extends/pijit/pijit_func_white_list.py +350 -24
  13. mindspore/amp.py +24 -10
  14. mindspore/atlprov.dll +0 -0
  15. mindspore/avcodec-59.dll +0 -0
  16. mindspore/avdevice-59.dll +0 -0
  17. mindspore/avfilter-8.dll +0 -0
  18. mindspore/avformat-59.dll +0 -0
  19. mindspore/avutil-57.dll +0 -0
  20. mindspore/c1.dll +0 -0
  21. mindspore/c1xx.dll +0 -0
  22. mindspore/c2.dll +0 -0
  23. mindspore/common/__init__.py +6 -4
  24. mindspore/common/_pijit_context.py +190 -0
  25. mindspore/common/_register_for_tensor.py +2 -1
  26. mindspore/common/_tensor_overload.py +139 -0
  27. mindspore/common/api.py +102 -87
  28. mindspore/common/dump.py +5 -6
  29. mindspore/common/generator.py +1 -7
  30. mindspore/common/hook_handle.py +14 -26
  31. mindspore/common/mindir_util.py +2 -2
  32. mindspore/common/parameter.py +46 -13
  33. mindspore/common/recompute.py +39 -9
  34. mindspore/common/sparse_tensor.py +7 -3
  35. mindspore/common/tensor.py +209 -29
  36. mindspore/communication/__init__.py +1 -1
  37. mindspore/communication/_comm_helper.py +38 -3
  38. mindspore/communication/comm_func.py +310 -55
  39. mindspore/communication/management.py +14 -14
  40. mindspore/context.py +123 -22
  41. mindspore/dataset/__init__.py +1 -1
  42. mindspore/dataset/audio/__init__.py +1 -1
  43. mindspore/dataset/core/config.py +7 -0
  44. mindspore/dataset/core/validator_helpers.py +7 -0
  45. mindspore/dataset/engine/cache_client.py +1 -1
  46. mindspore/dataset/engine/datasets.py +72 -44
  47. mindspore/dataset/engine/datasets_audio.py +7 -7
  48. mindspore/dataset/engine/datasets_standard_format.py +53 -3
  49. mindspore/dataset/engine/datasets_text.py +20 -20
  50. mindspore/dataset/engine/datasets_user_defined.py +174 -104
  51. mindspore/dataset/engine/datasets_vision.py +33 -33
  52. mindspore/dataset/engine/iterators.py +29 -0
  53. mindspore/dataset/engine/obs/util.py +7 -0
  54. mindspore/dataset/engine/queue.py +114 -60
  55. mindspore/dataset/engine/serializer_deserializer.py +2 -2
  56. mindspore/dataset/engine/validators.py +34 -14
  57. mindspore/dataset/text/__init__.py +1 -4
  58. mindspore/dataset/transforms/__init__.py +0 -3
  59. mindspore/dataset/utils/line_reader.py +2 -0
  60. mindspore/dataset/vision/__init__.py +1 -4
  61. mindspore/dataset/vision/utils.py +1 -1
  62. mindspore/dataset/vision/validators.py +2 -1
  63. mindspore/dnnl.dll +0 -0
  64. mindspore/dpcmi.dll +0 -0
  65. mindspore/{nn/extend → experimental/es}/__init__.py +4 -11
  66. mindspore/experimental/es/embedding_service.py +883 -0
  67. mindspore/{nn/layer → experimental/es}/embedding_service_layer.py +218 -30
  68. mindspore/experimental/llm_boost/__init__.py +21 -0
  69. mindspore/{nn/extend/layer → experimental/llm_boost/atb}/__init__.py +4 -8
  70. mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
  71. mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
  72. mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
  73. mindspore/experimental/llm_boost/register.py +129 -0
  74. mindspore/experimental/llm_boost/utils.py +31 -0
  75. mindspore/experimental/optim/adamw.py +85 -0
  76. mindspore/experimental/optim/optimizer.py +3 -0
  77. mindspore/hal/__init__.py +3 -3
  78. mindspore/hal/contiguous_tensors_handle.py +175 -0
  79. mindspore/hal/stream.py +18 -0
  80. mindspore/include/api/model_group.h +13 -1
  81. mindspore/include/api/types.h +10 -10
  82. mindspore/include/dataset/config.h +2 -2
  83. mindspore/include/dataset/constants.h +2 -2
  84. mindspore/include/dataset/execute.h +2 -2
  85. mindspore/include/dataset/vision.h +4 -0
  86. mindspore/jpeg62.dll +0 -0
  87. mindspore/log.py +1 -1
  88. mindspore/mindrecord/filewriter.py +68 -51
  89. mindspore/mindspore_backend.dll +0 -0
  90. mindspore/mindspore_common.dll +0 -0
  91. mindspore/mindspore_core.dll +0 -0
  92. mindspore/mindspore_glog.dll +0 -0
  93. mindspore/mindspore_np_dtype.dll +0 -0
  94. mindspore/mindspore_ops.dll +0 -0
  95. mindspore/mint/__init__.py +495 -46
  96. mindspore/mint/distributed/__init__.py +31 -0
  97. mindspore/mint/distributed/distributed.py +254 -0
  98. mindspore/mint/nn/__init__.py +266 -21
  99. mindspore/mint/nn/functional.py +125 -19
  100. mindspore/mint/nn/layer/__init__.py +39 -0
  101. mindspore/mint/nn/layer/activation.py +133 -0
  102. mindspore/mint/nn/layer/normalization.py +477 -0
  103. mindspore/mint/nn/layer/pooling.py +110 -0
  104. mindspore/mint/optim/adamw.py +28 -7
  105. mindspore/mint/special/__init__.py +63 -0
  106. mindspore/msobj140.dll +0 -0
  107. mindspore/mspdb140.dll +0 -0
  108. mindspore/mspdbcore.dll +0 -0
  109. mindspore/mspdbst.dll +0 -0
  110. mindspore/mspft140.dll +0 -0
  111. mindspore/msvcdis140.dll +0 -0
  112. mindspore/msvcp140_1.dll +0 -0
  113. mindspore/msvcp140_2.dll +0 -0
  114. mindspore/msvcp140_atomic_wait.dll +0 -0
  115. mindspore/msvcp140_codecvt_ids.dll +0 -0
  116. mindspore/multiprocessing/__init__.py +2 -1
  117. mindspore/nn/__init__.py +0 -1
  118. mindspore/nn/cell.py +275 -93
  119. mindspore/nn/layer/activation.py +211 -44
  120. mindspore/nn/layer/basic.py +113 -3
  121. mindspore/nn/layer/embedding.py +120 -2
  122. mindspore/nn/layer/normalization.py +101 -5
  123. mindspore/nn/layer/padding.py +34 -48
  124. mindspore/nn/layer/pooling.py +161 -7
  125. mindspore/nn/layer/transformer.py +3 -3
  126. mindspore/nn/loss/__init__.py +2 -2
  127. mindspore/nn/loss/loss.py +84 -6
  128. mindspore/nn/optim/__init__.py +2 -1
  129. mindspore/nn/optim/adadelta.py +1 -1
  130. mindspore/nn/optim/adam.py +1 -1
  131. mindspore/nn/optim/lamb.py +1 -1
  132. mindspore/nn/optim/tft_wrapper.py +127 -0
  133. mindspore/nn/wrap/cell_wrapper.py +12 -23
  134. mindspore/nn/wrap/grad_reducer.py +5 -5
  135. mindspore/nn/wrap/loss_scale.py +17 -3
  136. mindspore/numpy/__init__.py +1 -1
  137. mindspore/numpy/array_creations.py +65 -68
  138. mindspore/numpy/array_ops.py +64 -60
  139. mindspore/numpy/fft.py +610 -75
  140. mindspore/numpy/logic_ops.py +11 -10
  141. mindspore/numpy/math_ops.py +85 -84
  142. mindspore/numpy/utils_const.py +4 -4
  143. mindspore/opencv_core452.dll +0 -0
  144. mindspore/opencv_imgcodecs452.dll +0 -0
  145. mindspore/opencv_imgproc452.dll +0 -0
  146. mindspore/ops/__init__.py +6 -4
  147. mindspore/ops/_grad_experimental/grad_comm_ops.py +47 -3
  148. mindspore/ops/_grad_experimental/grad_math_ops.py +0 -22
  149. mindspore/ops/_vmap/vmap_array_ops.py +2 -4
  150. mindspore/ops/_vmap/vmap_math_ops.py +17 -1
  151. mindspore/ops/_vmap/vmap_nn_ops.py +43 -2
  152. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +85 -7
  153. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +2 -0
  154. mindspore/ops/auto_generate/gen_extend_func.py +734 -13
  155. mindspore/ops/auto_generate/gen_ops_def.py +2420 -381
  156. mindspore/ops/auto_generate/gen_ops_prim.py +5196 -1659
  157. mindspore/ops/auto_generate/pyboost_inner_prim.py +176 -56
  158. mindspore/ops/composite/base.py +85 -48
  159. mindspore/ops/composite/multitype_ops/_compile_utils.py +1 -0
  160. mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -2
  161. mindspore/ops/function/__init__.py +22 -0
  162. mindspore/ops/function/array_func.py +490 -153
  163. mindspore/ops/function/debug_func.py +113 -1
  164. mindspore/ops/function/fft_func.py +15 -2
  165. mindspore/ops/function/grad/grad_func.py +3 -2
  166. mindspore/ops/function/math_func.py +558 -207
  167. mindspore/ops/function/nn_func.py +817 -383
  168. mindspore/ops/function/other_func.py +3 -2
  169. mindspore/ops/function/random_func.py +184 -8
  170. mindspore/ops/function/reshard_func.py +13 -11
  171. mindspore/ops/function/sparse_unary_func.py +1 -1
  172. mindspore/ops/function/vmap_func.py +3 -2
  173. mindspore/ops/functional.py +24 -14
  174. mindspore/ops/op_info_register.py +3 -3
  175. mindspore/ops/operations/__init__.py +6 -1
  176. mindspore/ops/operations/_grad_ops.py +2 -76
  177. mindspore/ops/operations/_infer_ops.py +1 -1
  178. mindspore/ops/operations/_inner_ops.py +71 -94
  179. mindspore/ops/operations/array_ops.py +12 -146
  180. mindspore/ops/operations/comm_ops.py +42 -53
  181. mindspore/ops/operations/custom_ops.py +83 -19
  182. mindspore/ops/operations/debug_ops.py +42 -10
  183. mindspore/ops/operations/manually_defined/_inner.py +12 -0
  184. mindspore/ops/operations/manually_defined/ops_def.py +265 -10
  185. mindspore/ops/operations/math_ops.py +12 -223
  186. mindspore/ops/operations/nn_ops.py +20 -114
  187. mindspore/ops/operations/other_ops.py +7 -4
  188. mindspore/ops/operations/random_ops.py +46 -1
  189. mindspore/ops/primitive.py +18 -6
  190. mindspore/ops_generate/arg_dtype_cast.py +2 -0
  191. mindspore/ops_generate/gen_aclnn_implement.py +11 -11
  192. mindspore/ops_generate/gen_constants.py +36 -0
  193. mindspore/ops_generate/gen_ops.py +67 -52
  194. mindspore/ops_generate/gen_ops_inner_prim.py +1 -1
  195. mindspore/ops_generate/gen_pyboost_func.py +131 -47
  196. mindspore/ops_generate/op_proto.py +10 -3
  197. mindspore/ops_generate/pyboost_utils.py +14 -1
  198. mindspore/ops_generate/template.py +43 -21
  199. mindspore/parallel/__init__.py +3 -1
  200. mindspore/parallel/_auto_parallel_context.py +28 -8
  201. mindspore/parallel/_cell_wrapper.py +83 -0
  202. mindspore/parallel/_parallel_serialization.py +47 -19
  203. mindspore/parallel/_tensor.py +81 -11
  204. mindspore/parallel/_utils.py +13 -1
  205. mindspore/parallel/algo_parameter_config.py +5 -5
  206. mindspore/parallel/checkpoint_transform.py +46 -39
  207. mindspore/parallel/cluster/process_entity/__init__.py +1 -1
  208. mindspore/parallel/cluster/process_entity/_api.py +31 -23
  209. mindspore/parallel/cluster/process_entity/_utils.py +2 -27
  210. mindspore/parallel/parameter_broadcast.py +3 -4
  211. mindspore/parallel/shard.py +162 -31
  212. mindspore/parallel/transform_safetensors.py +993 -0
  213. mindspore/pgodb140.dll +0 -0
  214. mindspore/pgort140.dll +0 -0
  215. mindspore/profiler/__init__.py +2 -1
  216. mindspore/profiler/common/constant.py +29 -0
  217. mindspore/profiler/common/registry.py +47 -0
  218. mindspore/profiler/common/util.py +28 -0
  219. mindspore/profiler/dynamic_profiler.py +694 -0
  220. mindspore/profiler/envprofiling.py +17 -19
  221. mindspore/profiler/parser/ascend_analysis/constant.py +18 -0
  222. mindspore/profiler/parser/ascend_analysis/file_manager.py +25 -4
  223. mindspore/profiler/parser/ascend_analysis/function_event.py +43 -19
  224. mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +31 -26
  225. mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +56 -10
  226. mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +55 -8
  227. mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
  228. mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +27 -20
  229. mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +9 -2
  230. mindspore/profiler/parser/ascend_msprof_exporter.py +5 -4
  231. mindspore/profiler/parser/ascend_timeline_generator.py +27 -25
  232. mindspore/profiler/parser/base_timeline_generator.py +19 -25
  233. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +25 -12
  234. mindspore/profiler/parser/framework_parser.py +1 -391
  235. mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
  236. mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
  237. mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
  238. mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
  239. mindspore/profiler/parser/memory_usage_parser.py +0 -154
  240. mindspore/profiler/parser/profiler_info.py +78 -6
  241. mindspore/profiler/profiler.py +153 -0
  242. mindspore/profiler/profiling.py +280 -412
  243. mindspore/rewrite/__init__.py +1 -2
  244. mindspore/rewrite/common/namespace.py +4 -4
  245. mindspore/rewrite/symbol_tree/symbol_tree.py +3 -3
  246. mindspore/run_check/_check_version.py +36 -103
  247. mindspore/safeguard/rewrite_obfuscation.py +591 -247
  248. mindspore/swresample-4.dll +0 -0
  249. mindspore/swscale-6.dll +0 -0
  250. mindspore/tbbmalloc.dll +0 -0
  251. mindspore/tinyxml2.dll +0 -0
  252. mindspore/train/__init__.py +4 -3
  253. mindspore/train/_utils.py +28 -2
  254. mindspore/train/amp.py +171 -53
  255. mindspore/train/callback/__init__.py +2 -2
  256. mindspore/train/callback/_callback.py +4 -4
  257. mindspore/train/callback/_checkpoint.py +85 -22
  258. mindspore/train/callback/_cluster_monitor.py +1 -1
  259. mindspore/train/callback/_flops_collector.py +1 -0
  260. mindspore/train/callback/_loss_monitor.py +3 -3
  261. mindspore/train/callback/_on_request_exit.py +134 -31
  262. mindspore/train/callback/_summary_collector.py +5 -5
  263. mindspore/train/callback/_tft_register.py +352 -0
  264. mindspore/train/dataset_helper.py +7 -3
  265. mindspore/train/metrics/metric.py +3 -3
  266. mindspore/train/metrics/roc.py +4 -4
  267. mindspore/train/mind_ir_pb2.py +44 -39
  268. mindspore/train/model.py +134 -58
  269. mindspore/train/serialization.py +336 -112
  270. mindspore/turbojpeg.dll +0 -0
  271. mindspore/utils/__init__.py +21 -0
  272. mindspore/utils/utils.py +60 -0
  273. mindspore/vcmeta.dll +0 -0
  274. mindspore/vcruntime140.dll +0 -0
  275. mindspore/vcruntime140_1.dll +0 -0
  276. mindspore/version.py +1 -1
  277. {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/METADATA +6 -2
  278. {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/RECORD +281 -275
  279. mindspore/include/c_api/ms/abstract.h +0 -67
  280. mindspore/include/c_api/ms/attribute.h +0 -197
  281. mindspore/include/c_api/ms/base/handle_types.h +0 -43
  282. mindspore/include/c_api/ms/base/macros.h +0 -32
  283. mindspore/include/c_api/ms/base/status.h +0 -33
  284. mindspore/include/c_api/ms/base/types.h +0 -283
  285. mindspore/include/c_api/ms/context.h +0 -102
  286. mindspore/include/c_api/ms/graph.h +0 -160
  287. mindspore/include/c_api/ms/node.h +0 -606
  288. mindspore/include/c_api/ms/tensor.h +0 -161
  289. mindspore/include/c_api/ms/value.h +0 -84
  290. mindspore/mindspore_shared_lib.dll +0 -0
  291. mindspore/nn/extend/basic.py +0 -140
  292. mindspore/nn/extend/embedding.py +0 -143
  293. mindspore/nn/extend/layer/normalization.py +0 -109
  294. mindspore/nn/extend/pooling.py +0 -117
  295. mindspore/nn/layer/embedding_service.py +0 -531
  296. mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +0 -93
  297. mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +0 -66
  298. mindspore/ops/extend/__init__.py +0 -53
  299. mindspore/ops/extend/array_func.py +0 -218
  300. mindspore/ops/extend/math_func.py +0 -76
  301. mindspore/ops/extend/nn_func.py +0 -308
  302. mindspore/ops/silent_check.py +0 -162
  303. mindspore/profiler/parser/msadvisor_analyzer.py +0 -82
  304. mindspore/profiler/parser/msadvisor_parser.py +0 -240
  305. mindspore/train/callback/_mindio_ttp.py +0 -443
  306. {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/WHEEL +0 -0
  307. {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/entry_points.txt +0 -0
  308. {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/top_level.txt +0 -0
mindspore/ops/__init__.py CHANGED
@@ -29,13 +29,14 @@ from mindspore.ops.vm_impl_registry import get_vm_impl_fn, vm_impl_registry
29
29
  from mindspore.ops.op_info_register import op_info_register, custom_info_register, AkgGpuRegOp, AkgAscendRegOp, \
30
30
  AiCPURegOp, TBERegOp, CpuRegOp, CustomRegOp, DataType
31
31
  from mindspore.ops.primitive import constexpr
32
- from mindspore.ops import composite, operations, functional, function, auto_generate, extend
32
+ from mindspore.ops import composite, operations, functional, function
33
33
  from mindspore.ops import signature
34
+ from mindspore.ops.auto_generate import cpp_create_prim_instance_helper, gen_arg_dtype_cast, gen_arg_handler, \
35
+ gen_extend_func, gen_ops_def, gen_ops_prim, pyboost_inner_prim
34
36
  from mindspore.ops.composite import *
35
37
  from mindspore.ops.operations import *
36
38
  from mindspore.ops.function import *
37
39
  from mindspore.ops.functional import *
38
- from mindspore.ops.silent_check import _silent_check
39
40
 
40
41
  __primitive__ = [
41
42
  "prim_attr_register", "prim_arg_register", "Primitive", "PrimitiveWithInfer", "PrimitiveWithCheck", "signature"
@@ -44,11 +45,12 @@ __primitive__ = [
44
45
  __all__ = ["get_vm_impl_fn", "vm_impl_registry",
45
46
  "op_info_register", "custom_info_register", "AkgGpuRegOp", "AkgAscendRegOp", "AiCPURegOp", "TBERegOp",
46
47
  "CpuRegOp", "CustomRegOp", "DataType",
47
- "constexpr", "reshard"]
48
+ "constexpr", "reshard",
49
+ "cpp_create_prim_instance_helper", "gen_arg_dtype_cast", "gen_arg_handler", "gen_extend_func", "gen_ops_def",
50
+ "gen_ops_prim", "pyboost_inner_prim"]
48
51
  __all__.extend(__primitive__)
49
52
  __all__.extend(composite.__all__)
50
53
  __all__.extend(operations.__all__)
51
54
  __all__.extend(functional.__all__)
52
55
  __all__.extend(function.__all__)
53
56
  __all__.extend(auto_generate.__all__)
54
- _silent_check()
@@ -34,6 +34,7 @@ from mindspore.ops.operations.comm_ops import (AllGather, _MiniStepAllGather, _H
34
34
  _MicroStepAllGather, Reduce, CollectiveGather, CollectiveScatter)
35
35
  from mindspore.ops._grad_experimental.grad_base import bprop_getters
36
36
  from mindspore.ops.operations import _grad_ops as G
37
+ import mindspore as ms
37
38
 
38
39
 
39
40
  @bprop_getters.register(AllReduce)
@@ -95,6 +96,12 @@ def get_bprop_send(self):
95
96
  dtype = self.get_attr_dict()["dtype"]
96
97
  tag = self.get_attr_dict()["sr_tag"]
97
98
  send_grad = Receive(tag, self.rank, shape, dtype, self.group_back)
99
+ if "dst_global_rank" in self.get_attr_dict():
100
+ dst_global_rank = self.get_attr_dict().get("dst_global_rank")
101
+ send_grad.add_prim_attr("src_global_rank", dst_global_rank)
102
+ if "RING_ATTENTION_INDEX" in self.get_attr_dict():
103
+ ringattention = self.get_attr_dict().get("RING_ATTENTION_INDEX")
104
+ send_grad.add_prim_attr("RING_ATTENTION_INDEX", ringattention)
98
105
  virtual_input = Tensor(0.0, dtype)
99
106
 
100
107
  def bprop(x, out, dout):
@@ -108,8 +115,16 @@ def get_bprop_send(self):
108
115
  def get_bprop_receive(self):
109
116
  """Generate bprop for Receive."""
110
117
  tag = self.get_attr_dict()["sr_tag"]
118
+ flash_tag = self.get_attr_dict().get("flash_tag")
111
119
  receive_grad = Send(tag, self.rank, self.group_back)
112
- receive_grad.add_prim_attr("shape", self.shape)
120
+ shape = self.get_attr_dict()["shape"]
121
+ receive_grad.add_prim_attr("shape", shape)
122
+ if "src_global_rank" in self.get_attr_dict():
123
+ src_global_rank = self.get_attr_dict().get("src_global_rank")
124
+ receive_grad.add_prim_attr("dst_global_rank", src_global_rank)
125
+ if "RING_ATTENTION_INDEX" in self.get_attr_dict():
126
+ ringattention = self.get_attr_dict().get("RING_ATTENTION_INDEX")
127
+ receive_grad.add_prim_attr("RING_ATTENTION_INDEX", ringattention)
113
128
  depend = P.Depend()
114
129
  cast = P.Cast()
115
130
  out_tensor = Tensor(0.0, mstype.float16)
@@ -117,7 +132,7 @@ def get_bprop_receive(self):
117
132
 
118
133
  def bprop(x, out, dout):
119
134
  send_out = receive_grad(dout)
120
- if is_opt_shard:
135
+ if is_opt_shard or (flash_tag == "True"):
121
136
  dx = depend(F.zeros_like(x), send_out)
122
137
  else:
123
138
  dx = depend(cast(out_tensor, F.dtype(x)), send_out)
@@ -186,6 +201,9 @@ def get_bprop_mirror_micro_step_operator(self):
186
201
  group = self.group
187
202
  dev_num = self.dev_num
188
203
  mean_flag = self.mean_flag
204
+ param_name = " "
205
+ if 'mirror_user_id' in self.get_attr_dict():
206
+ param_name = self.get_attr_dict()['mirror_user_id']
189
207
  scale = 1 / dev_num
190
208
 
191
209
  all_reduce = AllReduce(group=group)
@@ -196,7 +214,6 @@ def get_bprop_mirror_micro_step_operator(self):
196
214
  if hasattr(self, 'parameter'):
197
215
  parameter = self.parameter
198
216
  all_reduce.add_prim_attr("parameter", parameter)
199
-
200
217
  if self.instance_name:
201
218
  instance_name = "grad_mirror" + self.instance_name
202
219
  all_reduce.set_prim_instance_name(instance_name)
@@ -207,8 +224,14 @@ def get_bprop_mirror_micro_step_operator(self):
207
224
  assign.add_prim_attr("parameter_micro", 0)
208
225
  out_tensor = Tensor(1.0, mstype.float16)
209
226
  opt_shard = _get_enable_parallel_optimizer()
227
+ ln_print = P.Print()
228
+ reduce_sum = P.ReduceSum(keep_dims=False)
229
+ square = P.Square()
230
+ dump_local_norm = ms.get_auto_parallel_context("dump_local_norm")
210
231
 
211
232
  def bprop(x, z, out, dout):
233
+ if dump_local_norm:
234
+ z = F.depend(z, ln_print("dump local norm: ", param_name, reduce_sum(square((z)))))
212
235
  real_grad = z
213
236
  assign_out = dout
214
237
  if issubclass_(F.typeof(dout), mstype.tensor_type):
@@ -309,6 +332,9 @@ def get_bprop_micro_step_all_gather(self):
309
332
  """Generate bprop for _MicroStepAllGather"""
310
333
  fusion = self.get_attr_dict()["fusion"]
311
334
  mean_flag = self.get_attr_dict()["mean_flag"]
335
+ param_name = " "
336
+ if 'mirror_user_id' in self.get_attr_dict():
337
+ param_name = self.get_attr_dict()['mirror_user_id']
312
338
  do_mirror = False
313
339
  if self.group != "":
314
340
  do_mirror = self.get_attr_dict()["do_mirror"]
@@ -324,6 +350,10 @@ def get_bprop_micro_step_all_gather(self):
324
350
  dtype = P.DType()
325
351
  out_tensor = Tensor(1.0, mstype.float16)
326
352
  with_mirror_operator = self.get_attr_dict()["with_mirror_operator"]
353
+ ln_print = P.Print()
354
+ reduce_sum = P.ReduceSum(keep_dims=False)
355
+ square = P.Square()
356
+ dump_local_norm = ms.get_auto_parallel_context("dump_local_norm")
327
357
 
328
358
  def bprop(x, z, out, dout):
329
359
  if with_mirror_operator:
@@ -334,6 +364,8 @@ def get_bprop_micro_step_all_gather(self):
334
364
  real_grad = F.tensor_mul(real_grad, scale)
335
365
  return (real_grad, cast(out_tensor, dtype(z)))
336
366
  z = F.depend(z, dout)
367
+ if dump_local_norm:
368
+ z = F.depend(z, ln_print("dump local norm: ", param_name, reduce_sum(square((z)))))
337
369
  if not do_mirror:
338
370
  return (z, cast(out_tensor, dtype(z)))
339
371
  real_grad = reduce_scatter(z)
@@ -529,16 +561,25 @@ def get_bprop_mirror_operator(self):
529
561
  group = self.get_attr_dict()['group']
530
562
  dev_num = self.get_attr_dict()['dev_num']
531
563
  mean_flag = self.get_attr_dict()['mean_flag']
564
+ param_name = " "
565
+ if 'mirror_user_id' in self.get_attr_dict():
566
+ param_name = self.get_attr_dict()['mirror_user_id']
567
+
532
568
  dev_num_r = 1.0
569
+ dump_local_norm = ms.get_auto_parallel_context("dump_local_norm")
533
570
  if dev_num > 1:
534
571
  dev_num_r = 1.0 / dev_num
535
572
  all_reduce = AllReduce(group=group)
536
573
  all_gather = AllGather(group=group)
537
574
  mul = P.Mul()
538
575
  cast = P.Cast()
576
+ ln_print = P.Print()
577
+ reduce_sum = P.ReduceSum(keep_dims=False)
578
+ square = P.Square()
539
579
 
540
580
  fusion = self.get_attr_dict()["fusion"]
541
581
  all_reduce.add_prim_attr("fusion", fusion)
582
+ parameter = " "
542
583
  if hasattr(self, 'parameter'):
543
584
  parameter = self.parameter
544
585
  all_reduce.add_prim_attr("parameter", parameter)
@@ -548,6 +589,9 @@ def get_bprop_mirror_operator(self):
548
589
  all_reduce.set_prim_instance_name(instance_name)
549
590
 
550
591
  def bprop(x, out, dout):
592
+ if dump_local_norm:
593
+ dout = F.depend(dout, ln_print("dump local norm: ", param_name, reduce_sum(square((dout)))))
594
+
551
595
  if dev_num == 1:
552
596
  return (dout,)
553
597
  if mean_flag:
@@ -18,12 +18,9 @@
18
18
  import numpy as np
19
19
  import mindspore.numpy as mnp
20
20
  from mindspore.common import dtype as mstype
21
- import mindspore.ops as ops
22
21
  from mindspore.ops import functional as F
23
22
  from mindspore.ops import operations as P
24
23
  from mindspore import Tensor
25
- from mindspore.ops.operations.math_ops import SilentCheck
26
- from mindspore.ops.operations._inner_ops import _MirrorSilentCheck
27
24
  from mindspore.ops.operations.math_ops import CumulativeLogsumexp
28
25
  from mindspore.ops.operations.math_ops import MatrixSolve
29
26
  from mindspore.ops.operations.math_ops import MatrixSolveLs
@@ -803,22 +800,3 @@ def get_bprop_tensor_add(self):
803
800
  return binop_grad_common(x, y, dout, dout)
804
801
 
805
802
  return bprop
806
-
807
-
808
- @bprop_getters.register(_MirrorSilentCheck)
809
- def get_bprop_mirror_silent_check(self):
810
- """Grad definition for '_MirrorSilentCheck' op"""
811
- silent_check = SilentCheck(self.min_steps, self.thresh_l1, self.coeff_l1, self.thresh_l2, self.coeff_l2)
812
- out_tensor = Tensor([0.0], mstype.float32)
813
-
814
- def bporp(x, pre_val, min_val, max_val, n_step, loss_scale, out, dout):
815
- if dout.dtype == mstype.float16:
816
- return (dout, out_tensor, out_tensor, out_tensor, out_tensor, out_tensor)
817
- if loss_scale is not None:
818
- gnorm = ops.norm(dout / loss_scale)
819
- else:
820
- gnorm = ops.norm(dout)
821
- dx, _, _, _, _ = silent_check(gnorm, dout, pre_val, min_val, max_val, n_step)
822
- return (dx, out_tensor, out_tensor, out_tensor, out_tensor, out_tensor)
823
-
824
- return bporp
@@ -2113,6 +2113,7 @@ def get_split_vmap_rule(prim, axis_size):
2113
2113
 
2114
2114
  return vmap_rule
2115
2115
 
2116
+
2116
2117
  @vmap_rules_getters.register(P.SearchSorted)
2117
2118
  def get_searchsorted_vmap_rule(prim, axis_size):
2118
2119
  """VmapRule for `SearchSorted`."""
@@ -2131,10 +2132,7 @@ def get_searchsorted_vmap_rule(prim, axis_size):
2131
2132
  if sorter is not None and sorter_dim is not None:
2132
2133
  sorter = _bdim_at_front(sorter, sorter_dim, axis_size)
2133
2134
 
2134
- dtype, _ = dtype_bdim
2135
- right, _ = right_bdim
2136
-
2137
- outputs = prim(sequence, values, sorter, dtype, right)
2135
+ outputs = prim(sequence, values, sorter, dtype_bdim[0], right_bdim[0])
2138
2136
 
2139
2137
  return outputs, 0
2140
2138
 
@@ -916,6 +916,23 @@ def get_isclose_vmap_rule(prim, axis_size):
916
916
 
917
917
  return vmap_rule
918
918
 
919
+
920
+ @vmap_rules_getters.register(P.Round)
921
+ def get_round_vmap_rule(prim, axis_size):
922
+ """VmapRule for round."""
923
+ if isinstance(prim, str):
924
+ prim = Primitive(prim)
925
+
926
+ def vmap_rule(x_bdim, decimal_bdim):
927
+ var, x_dim = x_bdim
928
+ decimal_var, decimal_dim = decimal_bdim
929
+ if decimal_dim is not None:
930
+ _raise_value_error("For vmap, the batch axis of decimal must be none.")
931
+ out = prim(var, decimal_var)
932
+ return out, x_dim
933
+
934
+ return vmap_rule
935
+
919
936
  get_assign_vmap_rule = vmap_rules_getters.register(P.AssignAdd)(get_assign_vmap_rule)
920
937
  get_assign_vmap_rule = vmap_rules_getters.register(P.AssignSub)(get_assign_vmap_rule)
921
938
 
@@ -949,7 +966,6 @@ get_unop_vmap_rule = vmap_rules_getters.register(P.Reciprocal)(get_unop_vmap_rul
949
966
  get_unop_vmap_rule = vmap_rules_getters.register(P.Inv)(get_unop_vmap_rule)
950
967
  get_unop_vmap_rule = vmap_rules_getters.register(P.Invert)(get_unop_vmap_rule)
951
968
  get_unop_vmap_rule = vmap_rules_getters.register(P.Rint)(get_unop_vmap_rule)
952
- get_unop_vmap_rule = vmap_rules_getters.register(P.Round)(get_unop_vmap_rule)
953
969
  get_unop_vmap_rule = vmap_rules_getters.register(P.Rsqrt)(get_unop_vmap_rule)
954
970
  get_unop_vmap_rule = vmap_rules_getters.register("Sigmoid")(get_unop_vmap_rule)
955
971
  get_unop_vmap_rule = vmap_rules_getters.register(P.Sqrt)(get_unop_vmap_rule)
@@ -517,7 +517,6 @@ def get_in_top_k_vmap_rule(prim, axis_size):
517
517
 
518
518
  @vmap_rules_getters.register(G.FastGeLUGrad)
519
519
  @vmap_rules_getters.register(G.HSwishGrad)
520
- @vmap_rules_getters.register(G.SoftShrinkGrad)
521
520
  def get_common_activation_grad_vmap_rule(prim, axis_size):
522
521
  """VmapRule for common activation grad operation."""
523
522
  prim_name = prim.name
@@ -547,6 +546,49 @@ def get_common_activation_grad_vmap_rule(prim, axis_size):
547
546
  return vmap_rule
548
547
 
549
548
 
549
+ @vmap_rules_getters.register("SoftShrink")
550
+ def get_softshrink_vmap_rule(prim, axis_size):
551
+ """VmapRule for `SoftShrink`."""
552
+ def vmap_rule(x_bdim, lambd_bdim):
553
+ var, dim = x_bdim
554
+ lambd, _ = lambd_bdim
555
+ out = prim(var, lambd)
556
+ return out, dim
557
+
558
+ return vmap_rule
559
+
560
+
561
+ @vmap_rules_getters.register("SoftShrinkGrad")
562
+ def get_softshrink_grad_vmap_rule(prim, axis_size):
563
+ """VmapRule for `SoftShrinkGrad`."""
564
+ prim_name = prim.name
565
+
566
+ def vmap_rule(dy_bdim, x_bdim, lambd_bdim):
567
+ x, x_dim = x_bdim
568
+ lambd, _ = lambd_bdim
569
+ dy, dy_dim = dy_bdim
570
+ x_shape = F.shape(x)
571
+ dy_shape = F.shape(dy)
572
+ if x_dim == dy_dim and x_shape == dy_shape:
573
+ out = prim(dy, x, lambd)
574
+ return out, x_dim
575
+
576
+ if F.rank(x):
577
+ x = _bdim_at_front(x, x_dim, 1)
578
+ if F.rank(dy):
579
+ dy = _bdim_at_front(dy, dy_dim, 1)
580
+ x_shape = F.shape(x)
581
+ dy_shape = F.shape(dy)
582
+ if x_shape != dy_shape:
583
+ raise RuntimeError("For {} vmap, input x shape is supposed to be the same as input dy shape "
584
+ "after batch transforming, but got x_shape {}, dy_shape {}"
585
+ .format(prim_name, x_shape, dy_shape))
586
+ out = prim(dy, x, lambd)
587
+ return out, 0
588
+
589
+ return vmap_rule
590
+
591
+
550
592
  @vmap_rules_getters.register("HShrink")
551
593
  def get_hshrink_vmap_rule(prim, axis_size):
552
594
  """VmapRule for `HShrink`."""
@@ -2196,7 +2238,6 @@ get_unop_vmap_rule = vmap_rules_getters.register(P.SeLU)(get_unop_vmap_rule)
2196
2238
  get_unop_vmap_rule = vmap_rules_getters.register(P.HSigmoid)(get_unop_vmap_rule)
2197
2239
  get_unop_vmap_rule = vmap_rules_getters.register(P.Softplus)(get_unop_vmap_rule)
2198
2240
  get_unop_vmap_rule = vmap_rules_getters.register(P.Softsign)(get_unop_vmap_rule)
2199
- get_unop_vmap_rule = vmap_rules_getters.register(P.SoftShrink)(get_unop_vmap_rule)
2200
2241
  get_unop_vmap_rule = vmap_rules_getters.register(P.GeLU)(get_unop_vmap_rule)
2201
2242
  get_unop_vmap_rule = vmap_rules_getters.register(P.FastGeLU)(get_unop_vmap_rule)
2202
2243
  get_unop_vmap_rule = vmap_rules_getters.register(P.HSwish)(get_unop_vmap_rule)
@@ -21,6 +21,7 @@ op_args_default_value = {
21
21
  "AdamW": {"amsgrad": False, "maximize": False},
22
22
  "AddExt": {"alpha": 1},
23
23
  "AddLayerNormV2": {"epsilon": 1e-5, "additionalOut": False},
24
+ "ApplyAdamW": {"max_grad_norm": None, "amsgrad": False, "maximize": False},
24
25
  "ApplyCamePart2": {"sum_r": None, "global_shape": None},
25
26
  "ApplyCamePart3": {"global_shape": None, "use_first_moment": False},
26
27
  "ApplyCamePart4": {"global_shape": None},
@@ -29,6 +30,7 @@ op_args_default_value = {
29
30
  "ArgMaxExt": {"dim": None, "keepdim": False},
30
31
  "Argmax": {"axis": -1, "output_type": mstype.int32},
31
32
  "ArgMaxWithValue": {"axis": 0, "keep_dims": False},
33
+ "ArgMinExt": {"dim": None, "keepdim": False},
32
34
  "Argmin": {"axis": -1, "output_type": mstype.int32},
33
35
  "ArgMinWithValue": {"axis": 0, "keep_dims": False},
34
36
  "AvgPool2DGrad": {"padding": 0, "ceil_mode": False, "count_include_pad": True, "divisor_override": None},
@@ -36,8 +38,8 @@ op_args_default_value = {
36
38
  "AvgPoolGrad": {"kernel_size": 1, "strides": 1, "pad_mode": 'VALID', "data_format": 'NCHW'},
37
39
  "AvgPool": {"kernel_size": 1, "strides": 1, "pad_mode": 'VALID', "data_format": 'NCHW'},
38
40
  "BatchMatMul": {"transpose_a": False, "transpose_b": False},
39
- "BatchNormExt": {"training": False, "momentum": 0.1, "epsilon": 1e-5},
40
- "BatchNormGradExt": {"training": False, "eps": 1e-5},
41
+ "BatchNormExt": {"running_mean": None, "runnning_var": None, "training": False, "momentum": 0.1, "epsilon": 1e-5},
42
+ "BatchNormGradExt": {"running_mean": None, "running_var": None, "saved_mean": None, "saved_rstd": None, "training": False, "eps": 1e-5},
41
43
  "BatchNormGradGrad": {"is_training": False, "epsilon": 1e-5, "data_format": 'NCHW'},
42
44
  "BatchNormGrad": {"is_training": False, "epsilon": 1e-5, "data_format": 'NCHW'},
43
45
  "BatchNormGradWithActivation": {"is_training": False, "epsilon": 1e-5, "data_format": 'NCHW'},
@@ -63,10 +65,12 @@ op_args_default_value = {
63
65
  "ConvolutionGrad": {"bias": None, "stride": 1, "padding": 0, "dilation": 1, "transposed": False, "output_padding": 0, "groups": 1, "output_mask": ()},
64
66
  "Convolution": {"bias": None, "stride": 1, "padding": 0, "dilation": 1, "transposed": False, "output_padding": 0, "groups": 1},
65
67
  "Correlate": {"mode": 'valid'},
68
+ "Cross": {"dim": -65530},
66
69
  "CumProd": {"exclusive": False, "reverse": False},
67
70
  "CumSum": {"exclusive": False, "reverse": False},
68
71
  "CumsumExt": {"dtype": None},
69
- "DCT": {"axis": -1, "norm": 'BACKWARD', "forward": True, "grad": False},
72
+ "DCT": {"type": 2, "n": None, "axis": -1, "norm": None},
73
+ "DCTN": {"type": 2, "s": None, "axes": None, "norm": None},
70
74
  "Dense": {"bias": None},
71
75
  "Diagonal": {"offset": 0, "dim1": 0, "dim2": 1},
72
76
  "DivMod": {"rounding_mode": None},
@@ -75,13 +79,25 @@ op_args_default_value = {
75
79
  "EluExt": {"alpha": 1.0},
76
80
  "EluGradExt": {"alpha": 1.0},
77
81
  "Elu": {"alpha": 1.0},
82
+ "EmbeddingApplyAdaGrad": {"mask_zero": (0,), "padding_key": (0,), "padding_key_mask": (1,), "completion_key": (0,), "completion_key_mask": (1,), "_embedding_dim": 1, "_max_key_num": 1},
83
+ "EmbeddingApplyAdam": {"mask_zero": (0,), "padding_key": (0,), "padding_key_mask": (1,), "completion_key": (0,), "completion_key_mask": (1,), "_embedding_dim": 1, "_max_key_num": 1},
84
+ "EmbeddingApplyAdamW": {"ams_grad": (0,), "mask_zero": (0,), "padding_key": (0,), "padding_key_mask": (1,), "completion_key": (0,), "completion_key_mask": (1,), "_embedding_dim": 1, "_max_key_num": 1},
85
+ "EmbeddingApplyFtrl": {"mask_zero": (0,), "padding_key": (0,), "padding_key_mask": (1,), "completion_key": (0,), "completion_key_mask": (1,), "_embedding_dim": 1, "_max_key_num": 1},
86
+ "EmbeddingApplyRmsprop": {"mask_zero": (0,), "padding_key": (0,), "padding_key_mask": (1,), "completion_key": (0,), "completion_key_mask": (1,), "_embedding_dim": 1, "_max_key_num": 1},
87
+ "EmbeddingApplySgd": {"mask_zero": (0,), "padding_key": (0,), "padding_key_mask": (1,), "completion_key": (0,), "completion_key_mask": (1,), "_embedding_dim": 1, "_max_key_num": 1},
78
88
  "EmbeddingDenseBackward": {"padding_idx": None, "scale_grad_by_freq": False},
89
+ "EmbeddingFeatureMappingFileSize": {"only_offset_flag": True},
90
+ "EmbeddingFeatureMappingFind": {"num": 1},
91
+ "EmbeddingFeatureMappingImport": {"only_offset_flag": True, "num": 1},
79
92
  "Embedding": {"padding_idx": None, "max_norm": None, "norm_type": 2.0, "scale_grad_by_freq": False},
93
+ "EmbeddingTableEvict": {"steps_to_live": 0},
80
94
  "ExtractImagePatches": {"padding": 'VALID'},
81
95
  "FFNExt": {"expertTokens": None, "bias1": None, "bias2": None, "scale": None, "offset": None, "deqScale1": None, "deqScale2": None, "antiquant_scale1": None, "antiquant_scale2": None, "antiquant_offset1": None, "antiquant_offset2": None, "activation": 'fastgelu', "inner_precise": 0},
82
96
  "FFT2": {"s": None, "dim": (-2, -1), "norm": None},
83
97
  "FFT": {"n": None, "dim": -1, "norm": None},
98
+ "FFTOrtho": {"axes": None, "forward": True},
84
99
  "FFTWithSize": {"norm": 'backward', "onesided": True, "signal_sizes": ()},
100
+ "FFTFreq": {"d": 1.0, "dtype": None},
85
101
  "FFTN": {"s": None, "dim": None, "norm": None},
86
102
  "FFTShift": {"dim": None},
87
103
  "FillScalar": {"dtype": None},
@@ -90,23 +106,42 @@ op_args_default_value = {
90
106
  "FlashAttentionScore": {"real_shift": None, "drop_mask": None, "padding_mask": None, "attn_mask": None, "prefix": None, "actual_seq_qlen": None, "actual_seq_kvlen": None, "keep_prob": 1.0, "scale_value": 1.0, "pre_tokens": 2147483647, "next_tokens": 2147483647, "inner_precise": 0, "input_layout": 'BSH', "sparse_mode": 0},
91
107
  "FlattenExt": {"start_dim": 0, "end_dim": -1},
92
108
  "Gather": {"batch_dims": 0},
109
+ "GenerateEodMaskV2": {"start": 0, "steps": 1, "error_mode": 'cycle', "flip_mode": 'bitflip', "multiply_factor": 0.0, "bit_pos": 0, "flip_probability": 0.0},
93
110
  "GridSampler2DGrad": {"interpolation_mode": 'bilinear', "padding_mode": 'zeros', "align_corners": False},
94
111
  "GridSampler2D": {"interpolation_mode": 'bilinear', "padding_mode": 'zeros', "align_corners": False},
95
112
  "GridSampler3DGrad": {"interpolation_mode": 'bilinear', "padding_mode": 'zeros', "align_corners": False},
96
113
  "GridSampler3D": {"interpolation_mode": 'bilinear', "padding_mode": 'zeros', "align_corners": False},
97
114
  "GroupNormGrad": {"dx_is_require": True, "dgamma_is_require": True, "dbeta_is_require": True},
98
115
  "GroupNorm": {"weight": None, "bias": None, "eps": 1e-5},
116
+ "HFFT2": {"s": None, "dim": (-2, -1), "norm": None},
117
+ "HFFT": {"n": None, "dim": -1, "norm": None},
118
+ "HFFTN": {"s": None, "dim": None, "norm": None},
119
+ "HistcExt": {"bins": 100, "min": 0, "max": 0},
99
120
  "HShrinkGrad": {"lambd": 0.5},
100
121
  "HShrink": {"lambd": 0.5},
122
+ "IDCT": {"type": 2, "n": None, "axis": -1, "norm": None},
123
+ "IDCTN": {"type": 2, "s": None, "axes": None, "norm": None},
101
124
  "IFFT2": {"s": None, "dim": (-2, -1), "norm": None},
102
125
  "IFFT": {"n": None, "dim": -1, "norm": None},
103
126
  "IFFTN": {"s": None, "dim": None, "norm": None},
104
127
  "IFFTShift": {"dim": None},
128
+ "IHFFT2": {"s": None, "dim": (-2, -1), "norm": None},
129
+ "IHFFT": {"n": None, "dim": -1, "norm": None},
130
+ "IHFFTN": {"s": None, "dim": None, "norm": None},
105
131
  "Im2ColExt": {"dilation": 1, "padding": 0, "stride": 1},
132
+ "IncreFlashAttention": {"attn_mask": None, "actual_seq_lengths": None, "pse_shift": None, "dequant_scale1": None, "quant_scale1": None, "dequant_scale2": None, "quant_scale2": None, "quant_offset2": None, "antiquant_scale": None, "antiquant_offset": None, "block_table": None, "kv_padding_size": None, "num_heads": 1, "input_layout": 'BSH', "scale_value": 1.0, "num_key_value_heads": 0, "block_size": 0, "inner_precise": 1},
106
133
  "IndexAddExt": {"alpha": 1},
107
- "IRFFTGrad": {"n": None, "dim": -1, "norm": None},
134
+ "InplaceAddExt": {"alpha": 1},
135
+ "InplaceAddmm": {"beta": 1, "alpha": 1},
136
+ "InplaceAddsExt": {"alpha": 1},
137
+ "InsertGemV2InBackward": {"start": 0, "steps": 1, "error_mode": 'cycle', "flip_mode": 'bitflip', "multiply_factor": 0.0, "bit_pos": 0, "flip_probability": 0.0},
138
+ "IRFFT2": {"s": None, "dim": (-2, -1), "norm": None},
139
+ "IRFFTDouble": {"dim": -1},
108
140
  "IRFFT": {"n": None, "dim": -1, "norm": None},
141
+ "IRFFTN": {"s": None, "dim": None, "norm": None},
109
142
  "IsClose": {"rtol": 1e-05, "atol": 1e-08, "equal_nan": True},
143
+ "L1LossBackwardExt": {"reduction": 'mean'},
144
+ "L1LossExt": {"reduction": 'mean'},
110
145
  "LayerNormExt": {"weight": None, "bias": None, "eps": 1e-5},
111
146
  "LayerNormGradGrad": {"begin_norm_axis": 1, "begin_params_axis": 1},
112
147
  "LayerNormGrad": {"begin_norm_axis": 1, "begin_params_axis": 1},
@@ -116,10 +151,13 @@ op_args_default_value = {
116
151
  "LeakyReLUExt": {"negative_slope": 0.01},
117
152
  "LeakyReLUGradExt": {"negative_slope": 0.01, "is_result": False},
118
153
  "LinSpaceExt": {"dtype": None},
154
+ "LogSoftmaxExt": {"dim": None, "dtype": None},
119
155
  "LogSoftmaxGrad": {"axis": -1},
120
156
  "LogSoftmax": {"axis": -1},
121
157
  "LogitGrad": {"eps": -1.0},
122
158
  "Logit": {"eps": -1.0},
159
+ "LpNormV2": {"p": 2.0, "dim": None, "keepdim": False, "epsilon": 1e-12},
160
+ "LstsqV2": {"driver": None},
123
161
  "MatMul": {"transpose_a": False, "transpose_b": False},
124
162
  "MaxPoolGradWithIndices": {"strides": None, "pads": 0, "dilation": (1, 1), "ceil_mode": False, "argmax_type": mstype.int64},
125
163
  "MaxPoolGradWithMask": {"strides": None, "pads": 0, "dilation": (1, 1), "ceil_mode": False, "argmax_type": mstype.int64},
@@ -128,15 +166,20 @@ op_args_default_value = {
128
166
  "MaximumGradGrad": {"grad_x": True, "grad_y": True},
129
167
  "MaximumGrad": {"grad_x": True, "grad_y": True},
130
168
  "MeanExt": {"axis": None, "keep_dims": False, "dtype": None},
169
+ "MedianDim": {"dim": -1, "keepdim": False},
131
170
  "MinimumGrad": {"grad_x": True, "grad_y": True},
171
+ "MSELossExt": {"reduction": 'mean'},
172
+ "MSELossGradExt": {"reduction": 'mean'},
132
173
  "NanToNum": {"nan": None, "posinf": None, "neginf": None},
133
174
  "NLLLossGrad": {"reduction": 'mean', "ignore_index": -100},
134
175
  "NLLLoss": {"reduction": 'mean', "ignore_index": -100},
135
- "Norm": {"ord": None, "dim": None, "keepdim": False, "dtype": None},
176
+ "Norm": {"p": 2.0, "dim": None, "keepdim": False, "dtype": None},
136
177
  "OneHotExt": {"axis": -1},
137
178
  "OneHot": {"axis": -1},
138
179
  "OnesLikeExt": {"dtype": None},
139
180
  "Ones": {"dtype": None},
181
+ "PagedAttentionMask": {"antiquant_scale": None, "antiquant_offset": None, "alibi_mask": None, "kv_cache_quant_mode": 'DEFAULT'},
182
+ "PagedAttention": {"antiquant_scale": None, "antiquant_offset": None, "attn_mask": None, "q_seq_lens": None, "kv_cache_quant_mode": 'DEFAULT'},
140
183
  "ProdExt": {"axis": None, "keep_dims": False, "dtype": None},
141
184
  "PromptKVCache": {"align_mode": 'LEFT'},
142
185
  "Qr": {"full_matrices": False},
@@ -165,16 +208,27 @@ op_args_default_value = {
165
208
  "ResizeNearestNeighbor": {"align_corners": False, "half_pixel_centers": False},
166
209
  "ResizeNearestNeighborV2Grad": {"align_corners": False, "half_pixel_centers": False},
167
210
  "ResizeNearestNeighborV2": {"align_corners": False, "half_pixel_centers": False},
168
- "RFFTGrad": {"n": None, "dim": -1, "norm": None},
211
+ "RFFT2": {"s": None, "dim": (-2, -1), "norm": None},
169
212
  "RFFT": {"n": None, "dim": -1, "norm": None},
213
+ "RFFTFreq": {"d": 1.0, "dtype": None},
214
+ "RFFTN": {"s": None, "dim": None, "norm": None},
170
215
  "RmsNorm": {"epsilon": 1e-6},
216
+ "Roll": {"axis": None},
217
+ "RotaryPositionEmbeddingGrad": {"dx": None, "mode": 0},
218
+ "RotaryPositionEmbedding": {"mode": 0},
219
+ "Round": {"decimals": 0},
171
220
  "ScalarToTensor": {"dtype": None},
221
+ "Scatter": {"reduce": 'none'},
222
+ "ScatterValue": {"reduce": 'none'},
172
223
  "SearchSorted": {"sorter": None, "dtype": mstype.int64, "right": False},
173
224
  "SequenceConcat": {"axis": 0},
225
+ "SilentCheckV2": {"c_min_steps": 7, "c_thresh_l1": 1000000.0, "c_coeff_l1": 100000.0, "c_thresh_l2": 10000.0, "c_coeff_l2": 5000.0, "npu_asd_detect": 1},
174
226
  "SoftmaxBackward": {"dim": -1},
175
227
  "Softmax": {"axis": -1},
176
228
  "SoftplusExt": {"beta": 1, "threshold": 20},
177
229
  "SoftplusGradExt": {"beta": 1, "threshold": 20},
230
+ "SoftShrinkGrad": {"lambd": 0.5},
231
+ "SoftShrink": {"lambd": 0.5},
178
232
  "SolveTriangular": {"trans": 0, "lower": False, "unit_diagonal": False},
179
233
  "SortExt": {"dim": -1, "descending": False, "stable": False},
180
234
  "Split": {"axis": 0, "output_num": 1},
@@ -184,11 +238,20 @@ op_args_default_value = {
184
238
  "StridedSlice": {"begin_mask": 0, "end_mask": 0, "ellipsis_mask": 0, "new_axis_mask": 0, "shrink_axis_mask": 0},
185
239
  "SubExt": {"alpha": 1},
186
240
  "SumExt": {"dim": None, "keepdim": False, "dtype": None},
241
+ "SwigluGrad": {"dim": -1},
242
+ "Swiglu": {"dim": -1},
243
+ "TensorScatterElements": {"axis": 0, "reduce": 'none'},
187
244
  "TopkExt": {"dim": -1, "largest": True, "sorted": True},
245
+ "TopKRouter": {"drop_type": 0},
246
+ "TraceV2Grad": {"offset": 0, "axis1": 1, "axis2": 0},
247
+ "TraceV2": {"offset": 0, "axis1": 1, "axis2": 0, "dtype": None},
248
+ "TrilExt": {"diagonal": 0},
188
249
  "Triu": {"diagonal": 0},
189
250
  "TupleToTensor": {"dtype": None},
190
251
  "Unique2": {"sorted": True, "return_inverse": False, "return_counts": False},
191
252
  "UnstackExt": {"axis": 0},
253
+ "UpsampleBicubic2DGrad": {"output_size": None, "scales": None, "align_corners": False},
254
+ "UpsampleBicubic2D": {"output_size": None, "scales": None, "align_corners": False},
192
255
  "UpsampleBilinear2DGrad": {"output_size": None, "scales": None, "align_corners": False},
193
256
  "UpsampleBilinear2D": {"output_size": None, "scales": None, "align_corners": False},
194
257
  "UpsampleLinear1DGrad": {"output_size": None, "scales": None, "align_corners": False},
@@ -203,12 +266,13 @@ op_args_default_value = {
203
266
  "UpsampleTrilinear3D": {"output_size": None, "scales": None, "align_corners": False},
204
267
  "ZerosLikeExt": {"dtype": None},
205
268
  "Zeros": {"dtype": None},
269
+ "AddRmsNormQuantV2": {"epsilon": 1e-5},
206
270
  "DynamicQuantExt": {"smooth_scales": None},
207
271
  "FusedInferAttentionScore": {"pse_shift": None, "attn_mask": None, "actual_seq_lengths": None, "actual_seq_lengths_kv": None, "dequant_scale1": None, "quant_scale1": None, "dequant_scale2": None, "quant_scale2": None, "quant_offset2": None, "antiquant_scale": None, "antiquant_offset": None, "block_table": None, "query_padding_size": None, "kv_padding_size": None, "scale_value": 1.0, "pre_tokens": 2147483647, "next_tokens": 2147483647, "input_layout": 'BSH', "num_key_value_heads": 0, "sparse_mode": 0, "inner_precise": 1, "block_size": 0, "antiquant_mode": 0, "softmax_lse_flag": False},
208
272
  "GroupedMatmul": {"bias": None, "scale": None, "offset": None, "antiquant_scale": None, "antiquant_offset": None, "group_list": None, "split_item": 0, "group_type": -1},
209
273
  "KVCacheScatterUpdate": {"reduce": 'none'},
210
274
  "MoeFinalizeRouting": {"x2": None, "bias": None, "scales": None, "expanded_row_idx": None, "expanded_expert_idx": None},
211
- "QuantBatchMatmul": {"offset": None, "bias": None, "transpose_x1": False, "transpose_x2": False, "dtype": mstype.float16},
275
+ "QuantBatchMatmul": {"offset": None, "bias": None, "pertokenScaleOptional": None, "transpose_x1": False, "transpose_x2": False, "dtype": mstype.float16},
212
276
  "QuantV2": {"sqrt_mode": False, "rounding_mode": 'ROUND', "dst_type": mstype.int8},
213
277
  "WeightQuantBatchMatmul": {"antiquant_offset": None, "quant_scale": None, "quant_offset": None, "bias": None, "transpose_x": False, "transpose_weight": False, "antiquant_group_size": 0},
214
278
  }
@@ -216,16 +280,30 @@ op_args_default_value = {
216
280
  op_labels = {
217
281
  "AdamWeightDecay": {"side_effect_mem": True},
218
282
  "AdamW": {"side_effect_mem": True},
283
+ "ApplyAdamW": {"side_effect_mem": True},
219
284
  "AssignAdd": {"side_effect_mem": True},
220
285
  "Assign": {"side_effect_mem": True},
286
+ "CopyExt": {"side_effect_mem": True},
221
287
  "DecoderKVCache": {"side_effect_mem": True},
222
288
  "DropoutExt": {"side_effect_hidden": True},
223
289
  "DropoutGenMaskExt": {"side_effect_hidden": True},
224
290
  "Dropout": {"side_effect_hidden": True},
291
+ "EmbeddingApplyAdaGrad": {"_process_node_engine_id": 'PS'},
292
+ "EmbeddingApplyAdam": {"_process_node_engine_id": 'PS'},
293
+ "EmbeddingApplyAdamW": {"_process_node_engine_id": 'PS'},
294
+ "EmbeddingApplyFtrl": {"_process_node_engine_id": 'PS'},
295
+ "EmbeddingApplyRmsprop": {"_process_node_engine_id": 'PS'},
296
+ "EmbeddingApplySgd": {"_process_node_engine_id": 'PS'},
225
297
  "Embedding": {"side_effect_mem": True},
298
+ "EmbeddingTableEvict": {"_process_node_engine_id": 'PS'},
226
299
  "Generator": {"side_effect_mem": True},
300
+ "InplaceAddExt": {"side_effect_mem": True},
301
+ "InplaceAddmm": {"side_effect_mem": True},
302
+ "InplaceAddsExt": {"side_effect_mem": True},
227
303
  "Log": {"cust_aicpu": 'Log', "base": -1.0, "scale": 1.0, "shift": 0.0},
228
304
  "PromptKVCache": {"side_effect_mem": True},
229
305
  "ReshapeAndCache": {"side_effect_mem": True},
230
306
  "ResizeD": {"mode": 'linear'},
307
+ "SilentCheckV2": {"side_effect_mem": True},
308
+ "KVCacheScatterUpdate": {"side_effect_mem": True},
231
309
  }
@@ -238,6 +238,8 @@ def type_it(op_name, arg_name, data, src_type, dst_type):
238
238
  """
239
239
  cast operator argument data type.
240
240
  """
241
+ if isinstance(data, type(None)):
242
+ return data
241
243
  if not isinstance(src_type, tuple):
242
244
  src_type = int(src_type)
243
245
  else: