mindspore 2.2.0__cp39-cp39-win_amd64.whl → 2.2.11__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (112) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
  3. mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
  4. mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
  5. mindspore/_checkparam.py +3 -3
  6. mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
  7. mindspore/_extends/graph_kernel/splitter.py +3 -2
  8. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +83 -66
  9. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -4
  10. mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
  11. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +2 -1
  12. mindspore/_extends/parse/__init__.py +3 -2
  13. mindspore/_extends/parse/parser.py +6 -1
  14. mindspore/_extends/parse/standard_method.py +14 -11
  15. mindspore/_extends/remote/kernel_build_server.py +2 -1
  16. mindspore/common/_utils.py +16 -0
  17. mindspore/common/api.py +1 -1
  18. mindspore/common/auto_dynamic_shape.py +81 -85
  19. mindspore/common/dump.py +1 -1
  20. mindspore/common/tensor.py +3 -20
  21. mindspore/config/op_info.config +1 -1
  22. mindspore/context.py +11 -4
  23. mindspore/dataset/engine/cache_client.py +8 -5
  24. mindspore/dataset/engine/datasets_standard_format.py +5 -0
  25. mindspore/dataset/vision/transforms.py +21 -21
  26. mindspore/experimental/optim/adam.py +1 -1
  27. mindspore/gen_ops.py +1 -1
  28. mindspore/include/api/model.h +17 -0
  29. mindspore/include/api/status.h +8 -3
  30. mindspore/mindspore_backend.dll +0 -0
  31. mindspore/mindspore_common.dll +0 -0
  32. mindspore/mindspore_core.dll +0 -0
  33. mindspore/mindspore_shared_lib.dll +0 -0
  34. mindspore/nn/cell.py +0 -3
  35. mindspore/nn/layer/activation.py +4 -5
  36. mindspore/nn/layer/conv.py +39 -23
  37. mindspore/nn/layer/flash_attention.py +54 -129
  38. mindspore/nn/layer/math.py +3 -7
  39. mindspore/nn/layer/rnn_cells.py +5 -5
  40. mindspore/nn/wrap/__init__.py +4 -2
  41. mindspore/nn/wrap/cell_wrapper.py +12 -3
  42. mindspore/numpy/utils_const.py +5 -5
  43. mindspore/opencv_core452.dll +0 -0
  44. mindspore/opencv_imgcodecs452.dll +0 -0
  45. mindspore/opencv_imgproc452.dll +0 -0
  46. mindspore/ops/_grad_experimental/grad_array_ops.py +1 -1
  47. mindspore/ops/_grad_experimental/grad_implementations.py +2 -2
  48. mindspore/ops/_grad_experimental/grad_math_ops.py +19 -18
  49. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
  50. mindspore/ops/_op_impl/aicpu/add.py +3 -3
  51. mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +21 -2
  52. mindspore/ops/_utils/utils.py +2 -0
  53. mindspore/ops/composite/multitype_ops/_compile_utils.py +2 -1
  54. mindspore/ops/composite/multitype_ops/getitem_impl.py +2 -2
  55. mindspore/ops/function/array_func.py +10 -7
  56. mindspore/ops/function/grad/grad_func.py +0 -1
  57. mindspore/ops/function/nn_func.py +98 -9
  58. mindspore/ops/function/random_func.py +2 -1
  59. mindspore/ops/op_info_register.py +24 -21
  60. mindspore/ops/operations/__init__.py +6 -2
  61. mindspore/ops/operations/_grad_ops.py +25 -6
  62. mindspore/ops/operations/_inner_ops.py +155 -23
  63. mindspore/ops/operations/array_ops.py +9 -7
  64. mindspore/ops/operations/comm_ops.py +2 -2
  65. mindspore/ops/operations/custom_ops.py +85 -68
  66. mindspore/ops/operations/inner_ops.py +26 -3
  67. mindspore/ops/operations/math_ops.py +7 -6
  68. mindspore/ops/operations/nn_ops.py +193 -49
  69. mindspore/parallel/_parallel_serialization.py +10 -3
  70. mindspore/parallel/_tensor.py +4 -1
  71. mindspore/parallel/checkpoint_transform.py +13 -2
  72. mindspore/parallel/shard.py +17 -10
  73. mindspore/profiler/common/util.py +1 -0
  74. mindspore/profiler/parser/ascend_hccl_generator.py +232 -0
  75. mindspore/profiler/parser/ascend_msprof_exporter.py +86 -43
  76. mindspore/profiler/parser/ascend_msprof_generator.py +196 -9
  77. mindspore/profiler/parser/ascend_op_generator.py +1 -1
  78. mindspore/profiler/parser/ascend_timeline_generator.py +6 -182
  79. mindspore/profiler/parser/base_timeline_generator.py +1 -1
  80. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +2 -2
  81. mindspore/profiler/parser/framework_parser.py +1 -1
  82. mindspore/profiler/parser/profiler_info.py +19 -0
  83. mindspore/profiler/profiling.py +46 -24
  84. mindspore/rewrite/api/pattern_engine.py +1 -1
  85. mindspore/rewrite/parsers/for_parser.py +7 -7
  86. mindspore/rewrite/parsers/module_parser.py +4 -4
  87. mindspore/rewrite/symbol_tree.py +1 -4
  88. mindspore/run_check/_check_version.py +5 -3
  89. mindspore/safeguard/rewrite_obfuscation.py +52 -28
  90. mindspore/train/callback/_summary_collector.py +1 -1
  91. mindspore/train/dataset_helper.py +1 -0
  92. mindspore/train/model.py +2 -2
  93. mindspore/train/serialization.py +97 -11
  94. mindspore/train/summary/_summary_adapter.py +1 -1
  95. mindspore/train/summary/summary_record.py +23 -7
  96. mindspore/version.py +1 -1
  97. {mindspore-2.2.0.dist-info → mindspore-2.2.11.dist-info}/METADATA +3 -2
  98. {mindspore-2.2.0.dist-info → mindspore-2.2.11.dist-info}/RECORD +101 -112
  99. mindspore/ops/_op_impl/_custom_op/flash_attention/__init__.py +0 -0
  100. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +0 -406
  101. mindspore/ops/_op_impl/_custom_op/flash_attention/constants.py +0 -41
  102. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +0 -467
  103. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +0 -563
  104. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +0 -193
  105. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +0 -435
  106. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
  107. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +0 -45
  108. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +0 -67
  109. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +0 -62
  110. {mindspore-2.2.0.dist-info → mindspore-2.2.11.dist-info}/WHEEL +0 -0
  111. {mindspore-2.2.0.dist-info → mindspore-2.2.11.dist-info}/entry_points.txt +0 -0
  112. {mindspore-2.2.0.dist-info → mindspore-2.2.11.dist-info}/top_level.txt +0 -0
@@ -42,6 +42,24 @@ from ._pyfunc_registry import add_pyfunc
42
42
  if platform.system() != "Windows":
43
43
  import fcntl
44
44
 
45
+ KEY_ATTR = "attr"
46
+ KEY_NAME = "name"
47
+ INPUT_NAMES = "input_names"
48
+ ATTR_NAMES = "attr_names"
49
+ AUTO_DIFF = "autodiff"
50
+ IMPLY_TYPE = "imply_type"
51
+ FUSION_TYPE = "fusion_type"
52
+ MS_KERNEL_FLAG = "ms_kernel_flag"
53
+ AKG = "AKG"
54
+ TBE = "TBE"
55
+ CUDA = "CUDA"
56
+ AICORE = "AiCore"
57
+ CPU = "CPU"
58
+ GPU = "GPU"
59
+ ASCEND = "Ascend"
60
+ HYBRID_TYPE = "hybrid"
61
+ OP_NAME = "op_name"
62
+
45
63
 
46
64
  def _get_cache_path():
47
65
  """
@@ -150,7 +168,6 @@ class Custom(ops.PrimitiveWithInfer):
150
168
 
151
169
  .. warning::
152
170
  - This is an experimental API that is subject to change.
153
- - Currently, the functionality of Custom does not support Ascend 910B.
154
171
 
155
172
  .. note::
156
173
  The supported platforms are determined by the input `func_type`. The supported platforms are as follows:
@@ -453,10 +470,10 @@ class Custom(ops.PrimitiveWithInfer):
453
470
  op_path_in_cache = [] # Save paths for op functions created in the cached.
454
471
  custom_aot_warning = True # Flag to enable warnings about custom aot path white list
455
472
 
456
- def __init__(self, func, out_shape=None, out_dtype=None, func_type="hybrid", bprop=None, reg_info=None):
457
- ops.PrimitiveWithInfer.__init__(self, "Custom")
473
+ def __init__(self, func, out_shape=None, out_dtype=None, func_type=HYBRID_TYPE, bprop=None, reg_info=None):
474
+ super().__init__("Custom")
458
475
 
459
- self.supported_targets = ["Ascend", "GPU", "CPU"]
476
+ self.supported_targets = [ASCEND, GPU, CPU]
460
477
  self.supported_func_type = ["hybrid", "akg", "tbe", "aicpu", "aot", "pyfunc", "julia"]
461
478
  self.log_prefix = "For '{}', 'func_type': {}, 'func': {}".format(self.name, func_type, func)
462
479
  self.func = func
@@ -473,7 +490,7 @@ class Custom(ops.PrimitiveWithInfer):
473
490
  self._update_func_info(reg_info)
474
491
  self.add_prim_attr("func_name", self.func_name)
475
492
  self.add_prim_attr("uniq_name", self.uniq_name)
476
- if self.func_type == "hybrid":
493
+ if self.func_type == HYBRID_TYPE:
477
494
  self.add_prim_attr("func_compile_attrs", self._func_compile_attrs)
478
495
 
479
496
  self.add_prim_attr("imply_path", self.imply_path)
@@ -502,7 +519,7 @@ class Custom(ops.PrimitiveWithInfer):
502
519
  if func_type == "akg":
503
520
  self._set_akg_kernel_type()
504
521
 
505
- if not self.bprop and self.func_type == "hybrid":
522
+ if not self.bprop and self.func_type == HYBRID_TYPE:
506
523
  self._hybrid_autodiff(func_type)
507
524
 
508
525
  self.add_prim_attr("func_type", self.func_type)
@@ -577,7 +594,7 @@ class Custom(ops.PrimitiveWithInfer):
577
594
  elif "compute" in self.func_source_str:
578
595
  self.func_type = "tvm_compute"
579
596
  else:
580
- self.func_type = "hybrid"
597
+ self.func_type = HYBRID_TYPE
581
598
  self._hybrid_func_analyser()
582
599
 
583
600
  def _check_julia_func(self):
@@ -633,18 +650,18 @@ class Custom(ops.PrimitiveWithInfer):
633
650
 
634
651
  elif self.func_type == "julia":
635
652
  self._check_julia_func()
636
- elif self.func_type == "hybrid":
637
- if not hasattr(self.func, "ms_kernel_flag"):
653
+ elif self.func_type == HYBRID_TYPE:
654
+ if not hasattr(self.func, MS_KERNEL_FLAG):
638
655
  raise TypeError("{}, 'func' must be a function decorated by kernel".format(self.log_prefix))
639
656
  self._is_ms_kernel = True
640
657
  self._func_compile_attrs = getattr(self.func, "compile_attrs", {})
641
658
  elif self.func_type == "akg":
642
- if hasattr(self.func, "ms_kernel_flag"):
659
+ if hasattr(self.func, MS_KERNEL_FLAG):
643
660
  logger.warning("{}. To have a better user experience, the mode hybrid is suggested "
644
661
  "for the input function with decorator @kernel. "
645
662
  "To enable this mode, set the 'func_type' to be \"hybrid\"".format(self.log_prefix))
646
663
  elif self.func_type == "pyfunc":
647
- if hasattr(self.func, "ms_kernel_flag"):
664
+ if hasattr(self.func, MS_KERNEL_FLAG):
648
665
  logger.warning("{}. Now you are using the function with decorator @kernel in the mode pyfunc. "
649
666
  "The kernel will be executed as a native python function, which might lead to "
650
667
  "low efficiency. To accelerate the kernel, set the 'func_type' to be \"hybrid\""
@@ -758,7 +775,7 @@ class Custom(ops.PrimitiveWithInfer):
758
775
  continue
759
776
  if isinstance(reg_info_item, str):
760
777
  reg_info_item = json.loads(reg_info_item)
761
- prefix = "_".join([prefix, reg_info_item.get("op_name", "")])
778
+ prefix = "_".join([prefix, reg_info_item.get(OP_NAME, "")])
762
779
  self.uniq_name = prefix + "_" + self.func_name
763
780
  else:
764
781
  raise TypeError("For '{}', 'func' must be of type function or str, but got {}"
@@ -768,23 +785,23 @@ class Custom(ops.PrimitiveWithInfer):
768
785
  """Update op attrs in reg_info."""
769
786
  output_name_list = []
770
787
  for _, item in enumerate(reg_info.get("outputs", [])):
771
- if isinstance(item, dict) and item.get("name"):
772
- output_name_list.append(item.get("name"))
788
+ if isinstance(item, dict) and item.get(KEY_NAME):
789
+ output_name_list.append(item.get(KEY_NAME))
773
790
  if output_name_list:
774
791
  self.add_prim_attr("output_names", output_name_list)
775
792
 
776
- if isinstance(reg_info.get("op_name"), str):
777
- self.add_prim_attr("reg_op_name", reg_info.get("op_name"))
793
+ if isinstance(reg_info.get(OP_NAME), str):
794
+ self.add_prim_attr("reg_op_name", reg_info.get(OP_NAME))
778
795
 
779
796
  if self.func_type == "aicpu":
780
- self.uniq_name = reg_info["op_name"]
797
+ self.uniq_name = reg_info[OP_NAME]
781
798
  self.add_prim_attr("uniq_name", self.uniq_name)
782
799
 
783
800
  if self.func_type in ["aot", "aicpu"]:
784
- if reg_info.get("attr") is not None and isinstance(reg_info["attr"], list):
785
- for item in reg_info["attr"]:
801
+ if reg_info.get(KEY_ATTR) is not None and isinstance(reg_info[KEY_ATTR], list):
802
+ for item in reg_info[KEY_ATTR]:
786
803
  if isinstance(item, dict) and item.get("value") is not None:
787
- self.add_prim_attr(item["name"], item["value"])
804
+ self.add_prim_attr(item[KEY_NAME], item["value"])
788
805
 
789
806
  def _register_info(self, info):
790
807
  """Register reg_info."""
@@ -802,7 +819,7 @@ class Custom(ops.PrimitiveWithInfer):
802
819
  if isinstance(reg_info, str):
803
820
  reg_info = json.loads(reg_info)
804
821
  if self.fake_output:
805
- reg_info["outputs"].append(dict({"index": 0, "name": "y", "param_type": "required"}))
822
+ reg_info["outputs"].append(dict({"index": 0, KEY_NAME: "y", "param_type": "required"}))
806
823
  new_dtype_format = []
807
824
  for i in reg_info["dtype_format"]:
808
825
  new_dtype_format.append(i + (DataType.I32_Default,))
@@ -874,16 +891,16 @@ class Custom(ops.PrimitiveWithInfer):
874
891
  "'CustomRegOp' to generate the registration information, then pass it to 'reg_info' or "
875
892
  "use 'custom_info_register' to bind it to 'func' if 'func' is a function."
876
893
  .format(self.log_prefix, reg_info, type(reg_info)))
877
- reg_info["op_name"] = self.uniq_name
878
- reg_info["imply_type"] = self._get_imply_type(reg_info, target)
879
- if not isinstance(reg_info.get("fusion_type"), str) or not reg_info["fusion_type"].strip():
880
- reg_info["fusion_type"] = "OPAQUE"
894
+ reg_info[OP_NAME] = self.uniq_name
895
+ reg_info[IMPLY_TYPE] = self._get_imply_type(reg_info, target)
896
+ if not isinstance(reg_info.get(FUSION_TYPE), str) or not reg_info[FUSION_TYPE].strip():
897
+ reg_info[FUSION_TYPE] = "OPAQUE"
881
898
  # Supplement necessary info for TBE if these information is missing in reg_info
882
- if reg_info["imply_type"] == "TBE":
883
- if reg_info.get("attr") is not None and isinstance(reg_info["attr"], list):
884
- for i, item in enumerate(reg_info["attr"]):
899
+ if reg_info[IMPLY_TYPE] == TBE:
900
+ if reg_info.get(KEY_ATTR) is not None and isinstance(reg_info[KEY_ATTR], list):
901
+ for i, item in enumerate(reg_info[KEY_ATTR]):
885
902
  if isinstance(item, dict) and item.get("value") is None:
886
- reg_info["attr"][i]["value"] = "all"
903
+ reg_info[KEY_ATTR][i]["value"] = "all"
887
904
  reg_info["async_flag"] = reg_info.get("async_flag", False)
888
905
  reg_info["binfile"] = "%s.so" % self.func_name
889
906
  reg_info["compute_cost"] = reg_info.get("compute_cost", 10)
@@ -891,8 +908,8 @@ class Custom(ops.PrimitiveWithInfer):
891
908
  reg_info["partial_flag"] = reg_info.get("partial_flag", True)
892
909
  reg_info["needCheckSupport"] = reg_info.get("need_check_supported", False)
893
910
  # Supplement necessary info for AKG if these information is missing in reg_info
894
- if reg_info["imply_type"] == "AKG":
895
- target_to_processor = {"Ascend": "AiCore", "GPU": "CUDA", "CPU": "CPU"}
911
+ if reg_info[IMPLY_TYPE] == AKG:
912
+ target_to_processor = {ASCEND: AICORE, GPU: CUDA, CPU: CPU}
896
913
  reg_info["processor"] = reg_info.get("processor", target_to_processor.get(target))
897
914
  return reg_info
898
915
 
@@ -905,15 +922,15 @@ class Custom(ops.PrimitiveWithInfer):
905
922
  # Infer target from reg_info["processor"], reg_info generated from AkgGpuRegOp or AkgAscendRegOp
906
923
  # will have the processor information.
907
924
  if target not in self.supported_targets:
908
- processor_to_target = {"AiCore": "Ascend", "CUDA": "GPU", "CPU": "CPU"}
925
+ processor_to_target = {AICORE: ASCEND, CUDA: GPU, CPU: CPU}
909
926
  target = processor_to_target.get(reg_info.get("processor"))
910
- # Infer target from reg_info["imply_type"]
927
+ # Infer target from reg_info[IMPLY_TYPE]
911
928
  if target not in self.supported_targets:
912
- imply_type_to_target = {"TBE": "Ascend", "GPU": "GPU", "CPU": "CPU"}
913
- target = imply_type_to_target.get(reg_info.get("imply_type"))
929
+ imply_type_to_target = {TBE: ASCEND, GPU: GPU, CPU: CPU}
930
+ target = imply_type_to_target.get(reg_info.get(IMPLY_TYPE))
914
931
  # Infer target from func_type
915
932
  if target not in self.supported_targets:
916
- func_type_to_target = {"tbe": "Ascend", "pyfunc": "CPU"}
933
+ func_type_to_target = {"tbe": ASCEND, "pyfunc": CPU}
917
934
  target = func_type_to_target.get(self.func_type)
918
935
  if target not in self.supported_targets:
919
936
  raise ValueError("{}, target set in registration information must be one of {}, but got {}"
@@ -922,14 +939,14 @@ class Custom(ops.PrimitiveWithInfer):
922
939
 
923
940
  def _get_imply_type(self, reg_info, target):
924
941
  """Get imply_typ information."""
925
- # Get imply_type from reg_info["imply_type"]
926
- if isinstance(reg_info, dict) and isinstance(reg_info.get("imply_type"), str) and \
927
- reg_info["imply_type"].strip():
928
- return reg_info["imply_type"]
942
+ # Get imply_type from reg_info[IMPLY_TYPE]
943
+ if isinstance(reg_info, dict) and isinstance(reg_info.get(IMPLY_TYPE), str) and \
944
+ reg_info[IMPLY_TYPE].strip():
945
+ return reg_info[IMPLY_TYPE]
929
946
  # Infer imply_type from func_type
930
- func_type_to_imply_type = {"hybrid": "AKG", "akg": "AKG", "tbe": "TBE", "aicpu": "AiCPU", "pyfunc": target,
931
- "julia": target, "aot": "BiSheng" if target == "Ascend" else target}
932
- return func_type_to_imply_type.get(self.func_type, "AKG")
947
+ func_type_to_imply_type = {"hybrid": AKG, "akg": AKG, "tbe": TBE, "aicpu": "AiCPU", "pyfunc": target,
948
+ "julia": target, "aot": "BiSheng" if target == ASCEND else target}
949
+ return func_type_to_imply_type.get(self.func_type, AKG)
933
950
 
934
951
  def _save_attr(self, reg_info):
935
952
  """Save input_names and attr_names of current func."""
@@ -943,18 +960,18 @@ class Custom(ops.PrimitiveWithInfer):
943
960
  return value
944
961
 
945
962
  tensor_inputs = _get_value_list("inputs")
946
- attr = _get_value_list("attr")
963
+ attr = _get_value_list(KEY_ATTR)
947
964
  input_names = [] # include tensor input names and attr input names
948
965
  attr_names = []
949
966
  pure_input_names = []
950
967
  for item in tensor_inputs:
951
- if isinstance(item, dict) and item.get("name") is not None:
952
- input_names.append(item["name"])
953
- pure_input_names.append(item["name"])
968
+ if isinstance(item, dict) and item.get(KEY_NAME) is not None:
969
+ input_names.append(item[KEY_NAME])
970
+ pure_input_names.append(item[KEY_NAME])
954
971
  # attr is converted from inputs only when graph mode or when inputs name is also in reg info
955
972
  attr_to_input_safe = bool(input_names) or context.get_context("mode") == ms.GRAPH_MODE
956
973
  for item in attr:
957
- if isinstance(item, dict) and item.get("name") is not None:
974
+ if isinstance(item, dict) and item.get(KEY_NAME) is not None:
958
975
  # for custom op with function tbe, we always add attrs to inputs as we don't
959
976
  # deal with attr value here and leave them to the backend process to fit the
960
977
  # usual process of tbe op compiling in mindspore
@@ -963,9 +980,9 @@ class Custom(ops.PrimitiveWithInfer):
963
980
  # add attr name to input name only when the value of attr is None in reg info
964
981
  # as we need to get values of attrs from inputs
965
982
  if attr_to_input_safe and (self.func_type == "tbe" or item.get("value", None) is None):
966
- input_names.append(item["name"])
967
- attr_names.append(item["name"])
968
- cur_attr = {"input_names": input_names, "attr_names": attr_names, "pure_input_names": pure_input_names}
983
+ input_names.append(item[KEY_NAME])
984
+ attr_names.append(item[KEY_NAME])
985
+ cur_attr = {INPUT_NAMES: input_names, ATTR_NAMES: attr_names, "pure_input_names": pure_input_names}
969
986
  # If func does not have attr, save current attr.
970
987
  # Else, check if current attr is same as previous saved one.
971
988
  prev_attr_names = attr_names
@@ -974,13 +991,13 @@ class Custom(ops.PrimitiveWithInfer):
974
991
  if not isinstance(func_attr, dict):
975
992
  setattr(self.func, "func_attr", cur_attr)
976
993
  else:
977
- prev_attr_names = func_attr.get("attr_names")
994
+ prev_attr_names = func_attr.get(ATTR_NAMES)
978
995
  elif isinstance(self.func, str):
979
996
  func_attr = Custom.attr_dict.get(self.func)
980
997
  if not isinstance(func_attr, dict):
981
998
  Custom.attr_dict[self.func] = cur_attr
982
999
  else:
983
- prev_attr_names = func_attr.get("attr_names")
1000
+ prev_attr_names = func_attr.get(ATTR_NAMES)
984
1001
  if attr_names != prev_attr_names:
985
1002
  raise ValueError("{}, attr names set in registration information must be the same as previous saved one, "
986
1003
  "but got {} vs {}".format(self.log_prefix, attr_names, prev_attr_names))
@@ -989,23 +1006,23 @@ class Custom(ops.PrimitiveWithInfer):
989
1006
  """Add primitive_target to primitive's attr."""
990
1007
  registered_targets = self._get_registered_targets()
991
1008
  if self.func_type == "pyfunc":
992
- self.set_device("CPU")
993
- if registered_targets and registered_targets != ["CPU"]:
1009
+ self.set_device(CPU)
1010
+ if registered_targets and registered_targets != [CPU]:
994
1011
  logger.warning("{}, only supports CPU platform, but got registered target {}. "
995
1012
  "We will run it on CPU".format(self.log_prefix, registered_targets))
996
1013
  elif self.func_type == "aot":
997
1014
  if len(registered_targets) != 1:
998
1015
  logger.info("{}, target will be set according to context.".format(self.log_prefix))
999
- elif registered_targets == ["GPU"]:
1000
- self.set_device("GPU")
1001
- elif registered_targets == ["CPU"]:
1002
- self.set_device("CPU")
1016
+ elif registered_targets == [GPU]:
1017
+ self.set_device(GPU)
1018
+ elif registered_targets == [CPU]:
1019
+ self.set_device(CPU)
1003
1020
  elif self.func_type == "julia":
1004
- self.set_device("CPU")
1021
+ self.set_device(CPU)
1005
1022
  device_target = context.get_context('device_target')
1006
- if device_target == "CPU":
1023
+ if device_target == CPU:
1007
1024
  pass
1008
- elif device_target == "GPU" and registered_targets and registered_targets == ["CPU"]:
1025
+ elif device_target == GPU and registered_targets and registered_targets == [CPU]:
1009
1026
  logger.warning("{}, only supports CPU platform, but got registered target {}. "
1010
1027
  "We will run it on CPU".format(self.log_prefix, registered_targets))
1011
1028
  else:
@@ -1028,15 +1045,15 @@ class Custom(ops.PrimitiveWithInfer):
1028
1045
  elif isinstance(self.func, str):
1029
1046
  func_attr = Custom.attr_dict.get(self.func)
1030
1047
  if isinstance(func_attr, dict):
1031
- _add_prim_attr("input_names")
1032
- _add_prim_attr("attr_names")
1048
+ _add_prim_attr(INPUT_NAMES)
1049
+ _add_prim_attr(ATTR_NAMES)
1033
1050
  _add_prim_attr("pure_input_names")
1034
1051
  self._add_prim_target()
1035
1052
  if callable(self.func) and callable(self.out_shape):
1036
- if hasattr(self.out_shape, "type") and getattr(self.out_shape, "type") == "autodiff":
1037
- self.add_prim_attr("autodiff", True)
1053
+ if hasattr(self.out_shape, "type") and getattr(self.out_shape, "type") == AUTO_DIFF:
1054
+ self.add_prim_attr(AUTO_DIFF, True)
1038
1055
  else:
1039
- self.add_prim_attr("autodiff", False)
1056
+ self.add_prim_attr(AUTO_DIFF, False)
1040
1057
 
1041
1058
  def _hybrid_autodiff(self, input_func_type):
1042
1059
  """generate backward op for a custom hybrid op"""
@@ -1052,7 +1069,7 @@ class Custom(ops.PrimitiveWithInfer):
1052
1069
  def infer_func(*args):
1053
1070
  return args[:inputs_num]
1054
1071
 
1055
- setattr(infer_func, "type", "autodiff")
1072
+ setattr(infer_func, "type", AUTO_DIFF)
1056
1073
  op = Custom(func=self.func, out_shape=infer_func, out_dtype=infer_func,
1057
1074
  func_type=input_func_type, bprop=True)
1058
1075
  self.bprop = grad_func(op)
@@ -238,13 +238,14 @@ class LambApplyOptimizerAssign(PrimitiveWithInfer):
238
238
  @prim_attr_register
239
239
  def __init__(self):
240
240
  """Initialize LambApplyOptimizerAssign"""
241
+ self.var_shape = "var_shape"
241
242
  self.add_prim_attr('side_effect_mem', True)
242
243
 
243
244
  def infer_shape(self, grad_shape, v_shape, m_shape, var_shape, beta1_shape, sub1_shape,
244
245
  beta2_shape, sub2_shape, eps_shape, steps_shape, use_weight_shape, weight_decay_shape):
245
- validator.check("var_shape", var_shape, "m_shape", m_shape, validator.EQ, self.name)
246
- validator.check("var_shape", var_shape, "v_shape", v_shape, validator.EQ, self.name)
247
- validator.check("var_shape", var_shape, "grad_shape", grad_shape, validator.EQ, self.name)
246
+ validator.check(self.var_shape, var_shape, "m_shape", m_shape, validator.EQ, self.name)
247
+ validator.check(self.var_shape, var_shape, "v_shape", v_shape, validator.EQ, self.name)
248
+ validator.check(self.var_shape, var_shape, "grad_shape", grad_shape, validator.EQ, self.name)
248
249
  return m_shape, v_shape, m_shape
249
250
 
250
251
  def infer_dtype(self, grad_dtype, v_dtype, m_dtype, var_dtype, beta1_dtype, sub1_dtype,
@@ -658,3 +659,25 @@ class ScaleGrad(PrimitiveWithInfer):
658
659
  @prim_attr_register
659
660
  def __init__(self):
660
661
  """Initialize ScaleGrad"""
662
+
663
+
664
+ class KVCacheMgr(Primitive):
665
+ """
666
+ Update past with cur and index along sequence axis.
667
+
668
+ Inputs:
669
+ - **past** (Parameter) - 4-D tensor with shape: :math:`(batch_size, num_head, seq_len, hidden_size)`.
670
+ - **cur** (Tensor) - 4-D tensor with shape: :math:`(batch_size, num_head, 1, hidden_size)`.
671
+ - **index** (Tensor) - 1-D tensor with shape: :math:`(batch_size,)`.
672
+
673
+ Outputs:
674
+ Tensor, has the same data type and shape as original `past`.
675
+
676
+ Supported Platforms:
677
+ ``Ascend``
678
+ """
679
+
680
+ @prim_attr_register
681
+ def __init__(self):
682
+ self.init_prim_io_names(inputs=['past', 'cur', 'index'], outputs=['past'])
683
+ self.add_prim_attr('side_effect_mem', True)
@@ -1536,9 +1536,8 @@ class LpNorm(Primitive):
1536
1536
  """
1537
1537
 
1538
1538
  @prim_attr_register
1539
- def __init__(self, axis, p=2, keep_dims=False, epsilon=1e-12):
1539
+ def __init__(self, axis=(), p=2, keep_dims=False, epsilon=1e-12):
1540
1540
  """Initialize LpNorm"""
1541
- super().__init__("LpNorm")
1542
1541
  validator.check_value_type("p", p, [int], self.name)
1543
1542
  validator.check_value_type("axis", axis, [int, tuple, list], self.name)
1544
1543
  validator.check_value_type("keep_dims", keep_dims, [bool], self.name)
@@ -2494,6 +2493,7 @@ class Reciprocal(PrimitiveWithCheck):
2494
2493
  self.init_prim_io_names(inputs=['x'], outputs=['y'])
2495
2494
 
2496
2495
  def infer_value(self, x):
2496
+ """Infer value for Reciprocal"""
2497
2497
  if x is not None:
2498
2498
  x = x.asnumpy()
2499
2499
  out = 1.0 / x
@@ -2551,6 +2551,7 @@ class Pow(Primitive):
2551
2551
  self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['y'])
2552
2552
 
2553
2553
  def infer_value(self, x, power):
2554
+ """infer value for _BinaryOp"""
2554
2555
  if x is not None and power is not None:
2555
2556
  x = x.asnumpy()
2556
2557
  power = power.asnumpy()
@@ -2931,7 +2932,7 @@ class Histogram(Primitive):
2931
2932
  """
2932
2933
 
2933
2934
  @prim_attr_register
2934
- def __init__(self, bins=100, min=0.0, max=0.0): # pylint: disable=W0622
2935
+ def __init__(self, bins=100, min=0.0, max=0.0):
2935
2936
  """Initialize Histogram."""
2936
2937
  self.init_prim_io_names(inputs=['x'], outputs=['y'])
2937
2938
  validator.check_value_type("bins", bins, [int], self.name)
@@ -6568,9 +6569,9 @@ class LinSpace(Primitive):
6568
6569
 
6569
6570
  Inputs:
6570
6571
  - **start** (Tensor) - Start value of interval, 0-D Tensor with dtype float32 or float64.
6571
- - **stop** (Tensor) - Last value of interval, 0-D Tensor with dtype float32 or float64.
6572
- - **num** (int) - Number of ticks in the interval, inclusive of `start` and `stop`.
6573
- Supported dtypes: int32, int64.
6572
+ - **stop** (Tensor) - Last value of interval, 0-D Tensor with dtype float32 or float64.
6573
+ - **num** (Union[int, Tensor]) - Number of ticks in the interval, inclusive of `start` and `stop`.
6574
+ Must be a positive integer. When the input is Tensor, it must be a 0-D Tensor with dtype int32 or int64.
6574
6575
 
6575
6576
  Outputs:
6576
6577
  Tensor, has the same shape and dtype as `start`.