mindspore 2.2.0__cp39-cp39-win_amd64.whl → 2.2.11__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +3 -3
- mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
- mindspore/_extends/graph_kernel/splitter.py +3 -2
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +83 -66
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -4
- mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +2 -1
- mindspore/_extends/parse/__init__.py +3 -2
- mindspore/_extends/parse/parser.py +6 -1
- mindspore/_extends/parse/standard_method.py +14 -11
- mindspore/_extends/remote/kernel_build_server.py +2 -1
- mindspore/common/_utils.py +16 -0
- mindspore/common/api.py +1 -1
- mindspore/common/auto_dynamic_shape.py +81 -85
- mindspore/common/dump.py +1 -1
- mindspore/common/tensor.py +3 -20
- mindspore/config/op_info.config +1 -1
- mindspore/context.py +11 -4
- mindspore/dataset/engine/cache_client.py +8 -5
- mindspore/dataset/engine/datasets_standard_format.py +5 -0
- mindspore/dataset/vision/transforms.py +21 -21
- mindspore/experimental/optim/adam.py +1 -1
- mindspore/gen_ops.py +1 -1
- mindspore/include/api/model.h +17 -0
- mindspore/include/api/status.h +8 -3
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/nn/cell.py +0 -3
- mindspore/nn/layer/activation.py +4 -5
- mindspore/nn/layer/conv.py +39 -23
- mindspore/nn/layer/flash_attention.py +54 -129
- mindspore/nn/layer/math.py +3 -7
- mindspore/nn/layer/rnn_cells.py +5 -5
- mindspore/nn/wrap/__init__.py +4 -2
- mindspore/nn/wrap/cell_wrapper.py +12 -3
- mindspore/numpy/utils_const.py +5 -5
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/_grad_experimental/grad_array_ops.py +1 -1
- mindspore/ops/_grad_experimental/grad_implementations.py +2 -2
- mindspore/ops/_grad_experimental/grad_math_ops.py +19 -18
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
- mindspore/ops/_op_impl/aicpu/add.py +3 -3
- mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +21 -2
- mindspore/ops/_utils/utils.py +2 -0
- mindspore/ops/composite/multitype_ops/_compile_utils.py +2 -1
- mindspore/ops/composite/multitype_ops/getitem_impl.py +2 -2
- mindspore/ops/function/array_func.py +10 -7
- mindspore/ops/function/grad/grad_func.py +0 -1
- mindspore/ops/function/nn_func.py +98 -9
- mindspore/ops/function/random_func.py +2 -1
- mindspore/ops/op_info_register.py +24 -21
- mindspore/ops/operations/__init__.py +6 -2
- mindspore/ops/operations/_grad_ops.py +25 -6
- mindspore/ops/operations/_inner_ops.py +155 -23
- mindspore/ops/operations/array_ops.py +9 -7
- mindspore/ops/operations/comm_ops.py +2 -2
- mindspore/ops/operations/custom_ops.py +85 -68
- mindspore/ops/operations/inner_ops.py +26 -3
- mindspore/ops/operations/math_ops.py +7 -6
- mindspore/ops/operations/nn_ops.py +193 -49
- mindspore/parallel/_parallel_serialization.py +10 -3
- mindspore/parallel/_tensor.py +4 -1
- mindspore/parallel/checkpoint_transform.py +13 -2
- mindspore/parallel/shard.py +17 -10
- mindspore/profiler/common/util.py +1 -0
- mindspore/profiler/parser/ascend_hccl_generator.py +232 -0
- mindspore/profiler/parser/ascend_msprof_exporter.py +86 -43
- mindspore/profiler/parser/ascend_msprof_generator.py +196 -9
- mindspore/profiler/parser/ascend_op_generator.py +1 -1
- mindspore/profiler/parser/ascend_timeline_generator.py +6 -182
- mindspore/profiler/parser/base_timeline_generator.py +1 -1
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +2 -2
- mindspore/profiler/parser/framework_parser.py +1 -1
- mindspore/profiler/parser/profiler_info.py +19 -0
- mindspore/profiler/profiling.py +46 -24
- mindspore/rewrite/api/pattern_engine.py +1 -1
- mindspore/rewrite/parsers/for_parser.py +7 -7
- mindspore/rewrite/parsers/module_parser.py +4 -4
- mindspore/rewrite/symbol_tree.py +1 -4
- mindspore/run_check/_check_version.py +5 -3
- mindspore/safeguard/rewrite_obfuscation.py +52 -28
- mindspore/train/callback/_summary_collector.py +1 -1
- mindspore/train/dataset_helper.py +1 -0
- mindspore/train/model.py +2 -2
- mindspore/train/serialization.py +97 -11
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/summary_record.py +23 -7
- mindspore/version.py +1 -1
- {mindspore-2.2.0.dist-info → mindspore-2.2.11.dist-info}/METADATA +3 -2
- {mindspore-2.2.0.dist-info → mindspore-2.2.11.dist-info}/RECORD +101 -112
- mindspore/ops/_op_impl/_custom_op/flash_attention/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +0 -406
- mindspore/ops/_op_impl/_custom_op/flash_attention/constants.py +0 -41
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +0 -467
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +0 -563
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +0 -193
- mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +0 -435
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +0 -45
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +0 -67
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +0 -62
- {mindspore-2.2.0.dist-info → mindspore-2.2.11.dist-info}/WHEEL +0 -0
- {mindspore-2.2.0.dist-info → mindspore-2.2.11.dist-info}/entry_points.txt +0 -0
- {mindspore-2.2.0.dist-info → mindspore-2.2.11.dist-info}/top_level.txt +0 -0
|
@@ -42,6 +42,24 @@ from ._pyfunc_registry import add_pyfunc
|
|
|
42
42
|
if platform.system() != "Windows":
|
|
43
43
|
import fcntl
|
|
44
44
|
|
|
45
|
+
KEY_ATTR = "attr"
|
|
46
|
+
KEY_NAME = "name"
|
|
47
|
+
INPUT_NAMES = "input_names"
|
|
48
|
+
ATTR_NAMES = "attr_names"
|
|
49
|
+
AUTO_DIFF = "autodiff"
|
|
50
|
+
IMPLY_TYPE = "imply_type"
|
|
51
|
+
FUSION_TYPE = "fusion_type"
|
|
52
|
+
MS_KERNEL_FLAG = "ms_kernel_flag"
|
|
53
|
+
AKG = "AKG"
|
|
54
|
+
TBE = "TBE"
|
|
55
|
+
CUDA = "CUDA"
|
|
56
|
+
AICORE = "AiCore"
|
|
57
|
+
CPU = "CPU"
|
|
58
|
+
GPU = "GPU"
|
|
59
|
+
ASCEND = "Ascend"
|
|
60
|
+
HYBRID_TYPE = "hybrid"
|
|
61
|
+
OP_NAME = "op_name"
|
|
62
|
+
|
|
45
63
|
|
|
46
64
|
def _get_cache_path():
|
|
47
65
|
"""
|
|
@@ -150,7 +168,6 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
150
168
|
|
|
151
169
|
.. warning::
|
|
152
170
|
- This is an experimental API that is subject to change.
|
|
153
|
-
- Currently, the functionality of Custom does not support Ascend 910B.
|
|
154
171
|
|
|
155
172
|
.. note::
|
|
156
173
|
The supported platforms are determined by the input `func_type`. The supported platforms are as follows:
|
|
@@ -453,10 +470,10 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
453
470
|
op_path_in_cache = [] # Save paths for op functions created in the cached.
|
|
454
471
|
custom_aot_warning = True # Flag to enable warnings about custom aot path white list
|
|
455
472
|
|
|
456
|
-
def __init__(self, func, out_shape=None, out_dtype=None, func_type=
|
|
457
|
-
|
|
473
|
+
def __init__(self, func, out_shape=None, out_dtype=None, func_type=HYBRID_TYPE, bprop=None, reg_info=None):
|
|
474
|
+
super().__init__("Custom")
|
|
458
475
|
|
|
459
|
-
self.supported_targets = [
|
|
476
|
+
self.supported_targets = [ASCEND, GPU, CPU]
|
|
460
477
|
self.supported_func_type = ["hybrid", "akg", "tbe", "aicpu", "aot", "pyfunc", "julia"]
|
|
461
478
|
self.log_prefix = "For '{}', 'func_type': {}, 'func': {}".format(self.name, func_type, func)
|
|
462
479
|
self.func = func
|
|
@@ -473,7 +490,7 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
473
490
|
self._update_func_info(reg_info)
|
|
474
491
|
self.add_prim_attr("func_name", self.func_name)
|
|
475
492
|
self.add_prim_attr("uniq_name", self.uniq_name)
|
|
476
|
-
if self.func_type ==
|
|
493
|
+
if self.func_type == HYBRID_TYPE:
|
|
477
494
|
self.add_prim_attr("func_compile_attrs", self._func_compile_attrs)
|
|
478
495
|
|
|
479
496
|
self.add_prim_attr("imply_path", self.imply_path)
|
|
@@ -502,7 +519,7 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
502
519
|
if func_type == "akg":
|
|
503
520
|
self._set_akg_kernel_type()
|
|
504
521
|
|
|
505
|
-
if not self.bprop and self.func_type ==
|
|
522
|
+
if not self.bprop and self.func_type == HYBRID_TYPE:
|
|
506
523
|
self._hybrid_autodiff(func_type)
|
|
507
524
|
|
|
508
525
|
self.add_prim_attr("func_type", self.func_type)
|
|
@@ -577,7 +594,7 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
577
594
|
elif "compute" in self.func_source_str:
|
|
578
595
|
self.func_type = "tvm_compute"
|
|
579
596
|
else:
|
|
580
|
-
self.func_type =
|
|
597
|
+
self.func_type = HYBRID_TYPE
|
|
581
598
|
self._hybrid_func_analyser()
|
|
582
599
|
|
|
583
600
|
def _check_julia_func(self):
|
|
@@ -633,18 +650,18 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
633
650
|
|
|
634
651
|
elif self.func_type == "julia":
|
|
635
652
|
self._check_julia_func()
|
|
636
|
-
elif self.func_type ==
|
|
637
|
-
if not hasattr(self.func,
|
|
653
|
+
elif self.func_type == HYBRID_TYPE:
|
|
654
|
+
if not hasattr(self.func, MS_KERNEL_FLAG):
|
|
638
655
|
raise TypeError("{}, 'func' must be a function decorated by kernel".format(self.log_prefix))
|
|
639
656
|
self._is_ms_kernel = True
|
|
640
657
|
self._func_compile_attrs = getattr(self.func, "compile_attrs", {})
|
|
641
658
|
elif self.func_type == "akg":
|
|
642
|
-
if hasattr(self.func,
|
|
659
|
+
if hasattr(self.func, MS_KERNEL_FLAG):
|
|
643
660
|
logger.warning("{}. To have a better user experience, the mode hybrid is suggested "
|
|
644
661
|
"for the input function with decorator @kernel. "
|
|
645
662
|
"To enable this mode, set the 'func_type' to be \"hybrid\"".format(self.log_prefix))
|
|
646
663
|
elif self.func_type == "pyfunc":
|
|
647
|
-
if hasattr(self.func,
|
|
664
|
+
if hasattr(self.func, MS_KERNEL_FLAG):
|
|
648
665
|
logger.warning("{}. Now you are using the function with decorator @kernel in the mode pyfunc. "
|
|
649
666
|
"The kernel will be executed as a native python function, which might lead to "
|
|
650
667
|
"low efficiency. To accelerate the kernel, set the 'func_type' to be \"hybrid\""
|
|
@@ -758,7 +775,7 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
758
775
|
continue
|
|
759
776
|
if isinstance(reg_info_item, str):
|
|
760
777
|
reg_info_item = json.loads(reg_info_item)
|
|
761
|
-
prefix = "_".join([prefix, reg_info_item.get(
|
|
778
|
+
prefix = "_".join([prefix, reg_info_item.get(OP_NAME, "")])
|
|
762
779
|
self.uniq_name = prefix + "_" + self.func_name
|
|
763
780
|
else:
|
|
764
781
|
raise TypeError("For '{}', 'func' must be of type function or str, but got {}"
|
|
@@ -768,23 +785,23 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
768
785
|
"""Update op attrs in reg_info."""
|
|
769
786
|
output_name_list = []
|
|
770
787
|
for _, item in enumerate(reg_info.get("outputs", [])):
|
|
771
|
-
if isinstance(item, dict) and item.get(
|
|
772
|
-
output_name_list.append(item.get(
|
|
788
|
+
if isinstance(item, dict) and item.get(KEY_NAME):
|
|
789
|
+
output_name_list.append(item.get(KEY_NAME))
|
|
773
790
|
if output_name_list:
|
|
774
791
|
self.add_prim_attr("output_names", output_name_list)
|
|
775
792
|
|
|
776
|
-
if isinstance(reg_info.get(
|
|
777
|
-
self.add_prim_attr("reg_op_name", reg_info.get(
|
|
793
|
+
if isinstance(reg_info.get(OP_NAME), str):
|
|
794
|
+
self.add_prim_attr("reg_op_name", reg_info.get(OP_NAME))
|
|
778
795
|
|
|
779
796
|
if self.func_type == "aicpu":
|
|
780
|
-
self.uniq_name = reg_info[
|
|
797
|
+
self.uniq_name = reg_info[OP_NAME]
|
|
781
798
|
self.add_prim_attr("uniq_name", self.uniq_name)
|
|
782
799
|
|
|
783
800
|
if self.func_type in ["aot", "aicpu"]:
|
|
784
|
-
if reg_info.get(
|
|
785
|
-
for item in reg_info[
|
|
801
|
+
if reg_info.get(KEY_ATTR) is not None and isinstance(reg_info[KEY_ATTR], list):
|
|
802
|
+
for item in reg_info[KEY_ATTR]:
|
|
786
803
|
if isinstance(item, dict) and item.get("value") is not None:
|
|
787
|
-
self.add_prim_attr(item[
|
|
804
|
+
self.add_prim_attr(item[KEY_NAME], item["value"])
|
|
788
805
|
|
|
789
806
|
def _register_info(self, info):
|
|
790
807
|
"""Register reg_info."""
|
|
@@ -802,7 +819,7 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
802
819
|
if isinstance(reg_info, str):
|
|
803
820
|
reg_info = json.loads(reg_info)
|
|
804
821
|
if self.fake_output:
|
|
805
|
-
reg_info["outputs"].append(dict({"index": 0,
|
|
822
|
+
reg_info["outputs"].append(dict({"index": 0, KEY_NAME: "y", "param_type": "required"}))
|
|
806
823
|
new_dtype_format = []
|
|
807
824
|
for i in reg_info["dtype_format"]:
|
|
808
825
|
new_dtype_format.append(i + (DataType.I32_Default,))
|
|
@@ -874,16 +891,16 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
874
891
|
"'CustomRegOp' to generate the registration information, then pass it to 'reg_info' or "
|
|
875
892
|
"use 'custom_info_register' to bind it to 'func' if 'func' is a function."
|
|
876
893
|
.format(self.log_prefix, reg_info, type(reg_info)))
|
|
877
|
-
reg_info[
|
|
878
|
-
reg_info[
|
|
879
|
-
if not isinstance(reg_info.get(
|
|
880
|
-
reg_info[
|
|
894
|
+
reg_info[OP_NAME] = self.uniq_name
|
|
895
|
+
reg_info[IMPLY_TYPE] = self._get_imply_type(reg_info, target)
|
|
896
|
+
if not isinstance(reg_info.get(FUSION_TYPE), str) or not reg_info[FUSION_TYPE].strip():
|
|
897
|
+
reg_info[FUSION_TYPE] = "OPAQUE"
|
|
881
898
|
# Supplement necessary info for TBE if these information is missing in reg_info
|
|
882
|
-
if reg_info[
|
|
883
|
-
if reg_info.get(
|
|
884
|
-
for i, item in enumerate(reg_info[
|
|
899
|
+
if reg_info[IMPLY_TYPE] == TBE:
|
|
900
|
+
if reg_info.get(KEY_ATTR) is not None and isinstance(reg_info[KEY_ATTR], list):
|
|
901
|
+
for i, item in enumerate(reg_info[KEY_ATTR]):
|
|
885
902
|
if isinstance(item, dict) and item.get("value") is None:
|
|
886
|
-
reg_info[
|
|
903
|
+
reg_info[KEY_ATTR][i]["value"] = "all"
|
|
887
904
|
reg_info["async_flag"] = reg_info.get("async_flag", False)
|
|
888
905
|
reg_info["binfile"] = "%s.so" % self.func_name
|
|
889
906
|
reg_info["compute_cost"] = reg_info.get("compute_cost", 10)
|
|
@@ -891,8 +908,8 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
891
908
|
reg_info["partial_flag"] = reg_info.get("partial_flag", True)
|
|
892
909
|
reg_info["needCheckSupport"] = reg_info.get("need_check_supported", False)
|
|
893
910
|
# Supplement necessary info for AKG if these information is missing in reg_info
|
|
894
|
-
if reg_info[
|
|
895
|
-
target_to_processor = {
|
|
911
|
+
if reg_info[IMPLY_TYPE] == AKG:
|
|
912
|
+
target_to_processor = {ASCEND: AICORE, GPU: CUDA, CPU: CPU}
|
|
896
913
|
reg_info["processor"] = reg_info.get("processor", target_to_processor.get(target))
|
|
897
914
|
return reg_info
|
|
898
915
|
|
|
@@ -905,15 +922,15 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
905
922
|
# Infer target from reg_info["processor"], reg_info generated from AkgGpuRegOp or AkgAscendRegOp
|
|
906
923
|
# will have the processor information.
|
|
907
924
|
if target not in self.supported_targets:
|
|
908
|
-
processor_to_target = {
|
|
925
|
+
processor_to_target = {AICORE: ASCEND, CUDA: GPU, CPU: CPU}
|
|
909
926
|
target = processor_to_target.get(reg_info.get("processor"))
|
|
910
|
-
# Infer target from reg_info[
|
|
927
|
+
# Infer target from reg_info[IMPLY_TYPE]
|
|
911
928
|
if target not in self.supported_targets:
|
|
912
|
-
imply_type_to_target = {
|
|
913
|
-
target = imply_type_to_target.get(reg_info.get(
|
|
929
|
+
imply_type_to_target = {TBE: ASCEND, GPU: GPU, CPU: CPU}
|
|
930
|
+
target = imply_type_to_target.get(reg_info.get(IMPLY_TYPE))
|
|
914
931
|
# Infer target from func_type
|
|
915
932
|
if target not in self.supported_targets:
|
|
916
|
-
func_type_to_target = {"tbe":
|
|
933
|
+
func_type_to_target = {"tbe": ASCEND, "pyfunc": CPU}
|
|
917
934
|
target = func_type_to_target.get(self.func_type)
|
|
918
935
|
if target not in self.supported_targets:
|
|
919
936
|
raise ValueError("{}, target set in registration information must be one of {}, but got {}"
|
|
@@ -922,14 +939,14 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
922
939
|
|
|
923
940
|
def _get_imply_type(self, reg_info, target):
|
|
924
941
|
"""Get imply_typ information."""
|
|
925
|
-
# Get imply_type from reg_info[
|
|
926
|
-
if isinstance(reg_info, dict) and isinstance(reg_info.get(
|
|
927
|
-
reg_info[
|
|
928
|
-
return reg_info[
|
|
942
|
+
# Get imply_type from reg_info[IMPLY_TYPE]
|
|
943
|
+
if isinstance(reg_info, dict) and isinstance(reg_info.get(IMPLY_TYPE), str) and \
|
|
944
|
+
reg_info[IMPLY_TYPE].strip():
|
|
945
|
+
return reg_info[IMPLY_TYPE]
|
|
929
946
|
# Infer imply_type from func_type
|
|
930
|
-
func_type_to_imply_type = {"hybrid":
|
|
931
|
-
"julia": target, "aot": "BiSheng" if target ==
|
|
932
|
-
return func_type_to_imply_type.get(self.func_type,
|
|
947
|
+
func_type_to_imply_type = {"hybrid": AKG, "akg": AKG, "tbe": TBE, "aicpu": "AiCPU", "pyfunc": target,
|
|
948
|
+
"julia": target, "aot": "BiSheng" if target == ASCEND else target}
|
|
949
|
+
return func_type_to_imply_type.get(self.func_type, AKG)
|
|
933
950
|
|
|
934
951
|
def _save_attr(self, reg_info):
|
|
935
952
|
"""Save input_names and attr_names of current func."""
|
|
@@ -943,18 +960,18 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
943
960
|
return value
|
|
944
961
|
|
|
945
962
|
tensor_inputs = _get_value_list("inputs")
|
|
946
|
-
attr = _get_value_list(
|
|
963
|
+
attr = _get_value_list(KEY_ATTR)
|
|
947
964
|
input_names = [] # include tensor input names and attr input names
|
|
948
965
|
attr_names = []
|
|
949
966
|
pure_input_names = []
|
|
950
967
|
for item in tensor_inputs:
|
|
951
|
-
if isinstance(item, dict) and item.get(
|
|
952
|
-
input_names.append(item[
|
|
953
|
-
pure_input_names.append(item[
|
|
968
|
+
if isinstance(item, dict) and item.get(KEY_NAME) is not None:
|
|
969
|
+
input_names.append(item[KEY_NAME])
|
|
970
|
+
pure_input_names.append(item[KEY_NAME])
|
|
954
971
|
# attr is converted from inputs only when graph mode or when inputs name is also in reg info
|
|
955
972
|
attr_to_input_safe = bool(input_names) or context.get_context("mode") == ms.GRAPH_MODE
|
|
956
973
|
for item in attr:
|
|
957
|
-
if isinstance(item, dict) and item.get(
|
|
974
|
+
if isinstance(item, dict) and item.get(KEY_NAME) is not None:
|
|
958
975
|
# for custom op with function tbe, we always add attrs to inputs as we don't
|
|
959
976
|
# deal with attr value here and leave them to the backend process to fit the
|
|
960
977
|
# usual process of tbe op compiling in mindspore
|
|
@@ -963,9 +980,9 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
963
980
|
# add attr name to input name only when the value of attr is None in reg info
|
|
964
981
|
# as we need to get values of attrs from inputs
|
|
965
982
|
if attr_to_input_safe and (self.func_type == "tbe" or item.get("value", None) is None):
|
|
966
|
-
input_names.append(item[
|
|
967
|
-
attr_names.append(item[
|
|
968
|
-
cur_attr = {
|
|
983
|
+
input_names.append(item[KEY_NAME])
|
|
984
|
+
attr_names.append(item[KEY_NAME])
|
|
985
|
+
cur_attr = {INPUT_NAMES: input_names, ATTR_NAMES: attr_names, "pure_input_names": pure_input_names}
|
|
969
986
|
# If func does not have attr, save current attr.
|
|
970
987
|
# Else, check if current attr is same as previous saved one.
|
|
971
988
|
prev_attr_names = attr_names
|
|
@@ -974,13 +991,13 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
974
991
|
if not isinstance(func_attr, dict):
|
|
975
992
|
setattr(self.func, "func_attr", cur_attr)
|
|
976
993
|
else:
|
|
977
|
-
prev_attr_names = func_attr.get(
|
|
994
|
+
prev_attr_names = func_attr.get(ATTR_NAMES)
|
|
978
995
|
elif isinstance(self.func, str):
|
|
979
996
|
func_attr = Custom.attr_dict.get(self.func)
|
|
980
997
|
if not isinstance(func_attr, dict):
|
|
981
998
|
Custom.attr_dict[self.func] = cur_attr
|
|
982
999
|
else:
|
|
983
|
-
prev_attr_names = func_attr.get(
|
|
1000
|
+
prev_attr_names = func_attr.get(ATTR_NAMES)
|
|
984
1001
|
if attr_names != prev_attr_names:
|
|
985
1002
|
raise ValueError("{}, attr names set in registration information must be the same as previous saved one, "
|
|
986
1003
|
"but got {} vs {}".format(self.log_prefix, attr_names, prev_attr_names))
|
|
@@ -989,23 +1006,23 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
989
1006
|
"""Add primitive_target to primitive's attr."""
|
|
990
1007
|
registered_targets = self._get_registered_targets()
|
|
991
1008
|
if self.func_type == "pyfunc":
|
|
992
|
-
self.set_device(
|
|
993
|
-
if registered_targets and registered_targets != [
|
|
1009
|
+
self.set_device(CPU)
|
|
1010
|
+
if registered_targets and registered_targets != [CPU]:
|
|
994
1011
|
logger.warning("{}, only supports CPU platform, but got registered target {}. "
|
|
995
1012
|
"We will run it on CPU".format(self.log_prefix, registered_targets))
|
|
996
1013
|
elif self.func_type == "aot":
|
|
997
1014
|
if len(registered_targets) != 1:
|
|
998
1015
|
logger.info("{}, target will be set according to context.".format(self.log_prefix))
|
|
999
|
-
elif registered_targets == [
|
|
1000
|
-
self.set_device(
|
|
1001
|
-
elif registered_targets == [
|
|
1002
|
-
self.set_device(
|
|
1016
|
+
elif registered_targets == [GPU]:
|
|
1017
|
+
self.set_device(GPU)
|
|
1018
|
+
elif registered_targets == [CPU]:
|
|
1019
|
+
self.set_device(CPU)
|
|
1003
1020
|
elif self.func_type == "julia":
|
|
1004
|
-
self.set_device(
|
|
1021
|
+
self.set_device(CPU)
|
|
1005
1022
|
device_target = context.get_context('device_target')
|
|
1006
|
-
if device_target ==
|
|
1023
|
+
if device_target == CPU:
|
|
1007
1024
|
pass
|
|
1008
|
-
elif device_target ==
|
|
1025
|
+
elif device_target == GPU and registered_targets and registered_targets == [CPU]:
|
|
1009
1026
|
logger.warning("{}, only supports CPU platform, but got registered target {}. "
|
|
1010
1027
|
"We will run it on CPU".format(self.log_prefix, registered_targets))
|
|
1011
1028
|
else:
|
|
@@ -1028,15 +1045,15 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
1028
1045
|
elif isinstance(self.func, str):
|
|
1029
1046
|
func_attr = Custom.attr_dict.get(self.func)
|
|
1030
1047
|
if isinstance(func_attr, dict):
|
|
1031
|
-
_add_prim_attr(
|
|
1032
|
-
_add_prim_attr(
|
|
1048
|
+
_add_prim_attr(INPUT_NAMES)
|
|
1049
|
+
_add_prim_attr(ATTR_NAMES)
|
|
1033
1050
|
_add_prim_attr("pure_input_names")
|
|
1034
1051
|
self._add_prim_target()
|
|
1035
1052
|
if callable(self.func) and callable(self.out_shape):
|
|
1036
|
-
if hasattr(self.out_shape, "type") and getattr(self.out_shape, "type") ==
|
|
1037
|
-
self.add_prim_attr(
|
|
1053
|
+
if hasattr(self.out_shape, "type") and getattr(self.out_shape, "type") == AUTO_DIFF:
|
|
1054
|
+
self.add_prim_attr(AUTO_DIFF, True)
|
|
1038
1055
|
else:
|
|
1039
|
-
self.add_prim_attr(
|
|
1056
|
+
self.add_prim_attr(AUTO_DIFF, False)
|
|
1040
1057
|
|
|
1041
1058
|
def _hybrid_autodiff(self, input_func_type):
|
|
1042
1059
|
"""generate backward op for a custom hybrid op"""
|
|
@@ -1052,7 +1069,7 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
1052
1069
|
def infer_func(*args):
|
|
1053
1070
|
return args[:inputs_num]
|
|
1054
1071
|
|
|
1055
|
-
setattr(infer_func, "type",
|
|
1072
|
+
setattr(infer_func, "type", AUTO_DIFF)
|
|
1056
1073
|
op = Custom(func=self.func, out_shape=infer_func, out_dtype=infer_func,
|
|
1057
1074
|
func_type=input_func_type, bprop=True)
|
|
1058
1075
|
self.bprop = grad_func(op)
|
|
@@ -238,13 +238,14 @@ class LambApplyOptimizerAssign(PrimitiveWithInfer):
|
|
|
238
238
|
@prim_attr_register
|
|
239
239
|
def __init__(self):
|
|
240
240
|
"""Initialize LambApplyOptimizerAssign"""
|
|
241
|
+
self.var_shape = "var_shape"
|
|
241
242
|
self.add_prim_attr('side_effect_mem', True)
|
|
242
243
|
|
|
243
244
|
def infer_shape(self, grad_shape, v_shape, m_shape, var_shape, beta1_shape, sub1_shape,
|
|
244
245
|
beta2_shape, sub2_shape, eps_shape, steps_shape, use_weight_shape, weight_decay_shape):
|
|
245
|
-
validator.check(
|
|
246
|
-
validator.check(
|
|
247
|
-
validator.check(
|
|
246
|
+
validator.check(self.var_shape, var_shape, "m_shape", m_shape, validator.EQ, self.name)
|
|
247
|
+
validator.check(self.var_shape, var_shape, "v_shape", v_shape, validator.EQ, self.name)
|
|
248
|
+
validator.check(self.var_shape, var_shape, "grad_shape", grad_shape, validator.EQ, self.name)
|
|
248
249
|
return m_shape, v_shape, m_shape
|
|
249
250
|
|
|
250
251
|
def infer_dtype(self, grad_dtype, v_dtype, m_dtype, var_dtype, beta1_dtype, sub1_dtype,
|
|
@@ -658,3 +659,25 @@ class ScaleGrad(PrimitiveWithInfer):
|
|
|
658
659
|
@prim_attr_register
|
|
659
660
|
def __init__(self):
|
|
660
661
|
"""Initialize ScaleGrad"""
|
|
662
|
+
|
|
663
|
+
|
|
664
|
+
class KVCacheMgr(Primitive):
|
|
665
|
+
"""
|
|
666
|
+
Update past with cur and index along sequence axis.
|
|
667
|
+
|
|
668
|
+
Inputs:
|
|
669
|
+
- **past** (Parameter) - 4-D tensor with shape: :math:`(batch_size, num_head, seq_len, hidden_size)`.
|
|
670
|
+
- **cur** (Tensor) - 4-D tensor with shape: :math:`(batch_size, num_head, 1, hidden_size)`.
|
|
671
|
+
- **index** (Tensor) - 1-D tensor with shape: :math:`(batch_size,)`.
|
|
672
|
+
|
|
673
|
+
Outputs:
|
|
674
|
+
Tensor, has the same data type and shape as original `past`.
|
|
675
|
+
|
|
676
|
+
Supported Platforms:
|
|
677
|
+
``Ascend``
|
|
678
|
+
"""
|
|
679
|
+
|
|
680
|
+
@prim_attr_register
|
|
681
|
+
def __init__(self):
|
|
682
|
+
self.init_prim_io_names(inputs=['past', 'cur', 'index'], outputs=['past'])
|
|
683
|
+
self.add_prim_attr('side_effect_mem', True)
|
|
@@ -1536,9 +1536,8 @@ class LpNorm(Primitive):
|
|
|
1536
1536
|
"""
|
|
1537
1537
|
|
|
1538
1538
|
@prim_attr_register
|
|
1539
|
-
def __init__(self, axis, p=2, keep_dims=False, epsilon=1e-12):
|
|
1539
|
+
def __init__(self, axis=(), p=2, keep_dims=False, epsilon=1e-12):
|
|
1540
1540
|
"""Initialize LpNorm"""
|
|
1541
|
-
super().__init__("LpNorm")
|
|
1542
1541
|
validator.check_value_type("p", p, [int], self.name)
|
|
1543
1542
|
validator.check_value_type("axis", axis, [int, tuple, list], self.name)
|
|
1544
1543
|
validator.check_value_type("keep_dims", keep_dims, [bool], self.name)
|
|
@@ -2494,6 +2493,7 @@ class Reciprocal(PrimitiveWithCheck):
|
|
|
2494
2493
|
self.init_prim_io_names(inputs=['x'], outputs=['y'])
|
|
2495
2494
|
|
|
2496
2495
|
def infer_value(self, x):
|
|
2496
|
+
"""Infer value for Reciprocal"""
|
|
2497
2497
|
if x is not None:
|
|
2498
2498
|
x = x.asnumpy()
|
|
2499
2499
|
out = 1.0 / x
|
|
@@ -2551,6 +2551,7 @@ class Pow(Primitive):
|
|
|
2551
2551
|
self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['y'])
|
|
2552
2552
|
|
|
2553
2553
|
def infer_value(self, x, power):
|
|
2554
|
+
"""infer value for _BinaryOp"""
|
|
2554
2555
|
if x is not None and power is not None:
|
|
2555
2556
|
x = x.asnumpy()
|
|
2556
2557
|
power = power.asnumpy()
|
|
@@ -2931,7 +2932,7 @@ class Histogram(Primitive):
|
|
|
2931
2932
|
"""
|
|
2932
2933
|
|
|
2933
2934
|
@prim_attr_register
|
|
2934
|
-
def __init__(self, bins=100, min=0.0, max=0.0):
|
|
2935
|
+
def __init__(self, bins=100, min=0.0, max=0.0):
|
|
2935
2936
|
"""Initialize Histogram."""
|
|
2936
2937
|
self.init_prim_io_names(inputs=['x'], outputs=['y'])
|
|
2937
2938
|
validator.check_value_type("bins", bins, [int], self.name)
|
|
@@ -6568,9 +6569,9 @@ class LinSpace(Primitive):
|
|
|
6568
6569
|
|
|
6569
6570
|
Inputs:
|
|
6570
6571
|
- **start** (Tensor) - Start value of interval, 0-D Tensor with dtype float32 or float64.
|
|
6571
|
-
- **stop** (Tensor) - Last value of interval, 0-D Tensor with dtype
|
|
6572
|
-
- **num** (int) - Number of ticks in the interval, inclusive of `start` and `stop`.
|
|
6573
|
-
|
|
6572
|
+
- **stop** (Tensor) - Last value of interval, 0-D Tensor with dtype float32 or float64.
|
|
6573
|
+
- **num** (Union[int, Tensor]) - Number of ticks in the interval, inclusive of `start` and `stop`.
|
|
6574
|
+
Must be a positive integer. When the input is Tensor, it must be a 0-D Tensor with dtype int32 or int64.
|
|
6574
6575
|
|
|
6575
6576
|
Outputs:
|
|
6576
6577
|
Tensor, has the same shape and dtype as `start`.
|