mindspore 2.2.0__cp39-cp39-macosx_11_0_arm64.whl → 2.2.11__cp39-cp39-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/_c_dataengine.cpython-39-darwin.so +0 -0
- mindspore/_c_expression.cpython-39-darwin.so +0 -0
- mindspore/_checkparam.py +3 -3
- mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
- mindspore/_extends/graph_kernel/splitter.py +3 -2
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +83 -66
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -4
- mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +2 -1
- mindspore/_extends/parse/__init__.py +3 -2
- mindspore/_extends/parse/parser.py +6 -1
- mindspore/_extends/parse/standard_method.py +14 -11
- mindspore/_extends/remote/kernel_build_server.py +2 -1
- mindspore/_mindspore_offline_debug.cpython-39-darwin.so +0 -0
- mindspore/common/_utils.py +16 -0
- mindspore/common/api.py +1 -1
- mindspore/common/auto_dynamic_shape.py +81 -85
- mindspore/common/dump.py +1 -1
- mindspore/common/tensor.py +3 -20
- mindspore/config/op_info.config +1 -1
- mindspore/context.py +11 -4
- mindspore/dataset/engine/cache_client.py +8 -5
- mindspore/dataset/engine/datasets_standard_format.py +5 -0
- mindspore/dataset/vision/transforms.py +21 -21
- mindspore/experimental/optim/adam.py +1 -1
- mindspore/gen_ops.py +1 -1
- mindspore/include/api/model.h +17 -0
- mindspore/include/api/status.h +8 -3
- mindspore/lib/libmindspore_backend.dylib +0 -0
- mindspore/lib/libmindspore_common.dylib +0 -0
- mindspore/lib/libmindspore_core.dylib +0 -0
- mindspore/lib/libmindspore_shared_lib.dylib +0 -0
- mindspore/lib/libnnacl.dylib +0 -0
- mindspore/lib/libopencv_core.4.5.dylib +0 -0
- mindspore/lib/libopencv_imgcodecs.4.5.dylib +0 -0
- mindspore/lib/libopencv_imgproc.4.5.dylib +0 -0
- mindspore/nn/cell.py +0 -3
- mindspore/nn/layer/activation.py +4 -5
- mindspore/nn/layer/conv.py +39 -23
- mindspore/nn/layer/flash_attention.py +54 -129
- mindspore/nn/layer/math.py +3 -7
- mindspore/nn/layer/rnn_cells.py +5 -5
- mindspore/nn/wrap/__init__.py +4 -2
- mindspore/nn/wrap/cell_wrapper.py +12 -3
- mindspore/numpy/utils_const.py +5 -5
- mindspore/ops/_grad_experimental/grad_array_ops.py +1 -1
- mindspore/ops/_grad_experimental/grad_implementations.py +2 -2
- mindspore/ops/_grad_experimental/grad_math_ops.py +19 -18
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
- mindspore/ops/_op_impl/aicpu/add.py +3 -3
- mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +21 -2
- mindspore/ops/_utils/utils.py +2 -0
- mindspore/ops/composite/multitype_ops/_compile_utils.py +2 -1
- mindspore/ops/composite/multitype_ops/getitem_impl.py +2 -2
- mindspore/ops/function/array_func.py +10 -7
- mindspore/ops/function/grad/grad_func.py +0 -1
- mindspore/ops/function/nn_func.py +98 -9
- mindspore/ops/function/random_func.py +2 -1
- mindspore/ops/op_info_register.py +24 -21
- mindspore/ops/operations/__init__.py +6 -2
- mindspore/ops/operations/_grad_ops.py +25 -6
- mindspore/ops/operations/_inner_ops.py +155 -23
- mindspore/ops/operations/array_ops.py +9 -7
- mindspore/ops/operations/comm_ops.py +2 -2
- mindspore/ops/operations/custom_ops.py +85 -68
- mindspore/ops/operations/inner_ops.py +26 -3
- mindspore/ops/operations/math_ops.py +7 -6
- mindspore/ops/operations/nn_ops.py +193 -49
- mindspore/parallel/_parallel_serialization.py +10 -3
- mindspore/parallel/_tensor.py +4 -1
- mindspore/parallel/checkpoint_transform.py +13 -2
- mindspore/parallel/shard.py +17 -10
- mindspore/profiler/common/util.py +1 -0
- mindspore/profiler/parser/ascend_hccl_generator.py +232 -0
- mindspore/profiler/parser/ascend_msprof_exporter.py +86 -43
- mindspore/profiler/parser/ascend_msprof_generator.py +196 -9
- mindspore/profiler/parser/ascend_op_generator.py +1 -1
- mindspore/profiler/parser/ascend_timeline_generator.py +6 -182
- mindspore/profiler/parser/base_timeline_generator.py +1 -1
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +2 -2
- mindspore/profiler/parser/framework_parser.py +1 -1
- mindspore/profiler/parser/profiler_info.py +19 -0
- mindspore/profiler/profiling.py +46 -24
- mindspore/rewrite/api/pattern_engine.py +1 -1
- mindspore/rewrite/parsers/for_parser.py +7 -7
- mindspore/rewrite/parsers/module_parser.py +4 -4
- mindspore/rewrite/symbol_tree.py +1 -4
- mindspore/run_check/_check_version.py +5 -3
- mindspore/safeguard/rewrite_obfuscation.py +52 -28
- mindspore/scipy/ops.py +55 -5
- mindspore/scipy/optimize/__init__.py +3 -2
- mindspore/scipy/optimize/linear_sum_assignment.py +38 -33
- mindspore/train/callback/_summary_collector.py +1 -1
- mindspore/train/dataset_helper.py +1 -0
- mindspore/train/model.py +2 -2
- mindspore/train/serialization.py +97 -11
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/summary_record.py +23 -7
- mindspore/version.py +1 -1
- {mindspore-2.2.0.dist-info → mindspore-2.2.11.dist-info}/METADATA +3 -2
- {mindspore-2.2.0.dist-info → mindspore-2.2.11.dist-info}/RECORD +105 -116
- mindspore/ops/_op_impl/_custom_op/flash_attention/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +0 -406
- mindspore/ops/_op_impl/_custom_op/flash_attention/constants.py +0 -41
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +0 -467
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +0 -563
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +0 -193
- mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +0 -435
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +0 -45
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +0 -67
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +0 -62
- {mindspore-2.2.0.dist-info → mindspore-2.2.11.dist-info}/WHEEL +0 -0
- {mindspore-2.2.0.dist-info → mindspore-2.2.11.dist-info}/entry_points.txt +0 -0
- {mindspore-2.2.0.dist-info → mindspore-2.2.11.dist-info}/top_level.txt +0 -0
mindspore/.commit_id
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
__commit_id__ = ''[sha1]:
|
|
1
|
+
__commit_id__ = ''[sha1]:8c390933,[branch]:(HEAD,origin/r2.2,r2.2)''
|
|
Binary file
|
|
Binary file
|
mindspore/_checkparam.py
CHANGED
|
@@ -720,9 +720,9 @@ def check_value_type(arg_name, arg_value, valid_types, prim_name=None):
|
|
|
720
720
|
type_names = [t.__name__ if hasattr(t, '__name__') else str(t) for t in valid_types]
|
|
721
721
|
num_types = len(valid_types)
|
|
722
722
|
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
|
723
|
-
|
|
724
|
-
|
|
725
|
-
|
|
723
|
+
raise TypeError(f'{msg_prefix} type of \'{arg_name}\' should be {"one of " if num_types > 1 else ""}' \
|
|
724
|
+
f'\'{type_names if num_types > 1 else type_names[0]}\', ' \
|
|
725
|
+
f'but got type \'{type(arg_value).__name__}\'.')
|
|
726
726
|
|
|
727
727
|
# Notice: bool is subclass of int, so `check_value_type('x', True, [int])` will check fail, and
|
|
728
728
|
# `check_value_type('x', True, [bool, int])` will check pass
|
|
@@ -83,23 +83,23 @@ class CommonPattern:
|
|
|
83
83
|
def reshape(dom):
|
|
84
84
|
"""fuse strategy for reshape dom"""
|
|
85
85
|
if dom.pattern != PrimLib.RESHAPE:
|
|
86
|
-
return []
|
|
86
|
+
return [], False
|
|
87
87
|
min_area, forward_fuse = None, False
|
|
88
88
|
for a, _ in dom.out_relations.items():
|
|
89
|
-
if a.pattern <= PrimLib.BROADCAST and dom.check_acyclic(a)
|
|
90
|
-
|
|
91
|
-
|
|
89
|
+
if a.pattern <= PrimLib.BROADCAST and dom.check_acyclic(a):
|
|
90
|
+
if min_area is None or a.pattern < min_area.pattern:
|
|
91
|
+
min_area = a
|
|
92
92
|
for a, _ in dom.in_relations.items():
|
|
93
|
-
if a.pattern <= PrimLib.BROADCAST and a.check_acyclic(dom)
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
return ([min_area], forward_fuse) if min_area else []
|
|
93
|
+
if a.pattern <= PrimLib.BROADCAST and a.check_acyclic(dom):
|
|
94
|
+
if min_area is None or a.pattern < min_area.pattern:
|
|
95
|
+
min_area, forward_fuse = a, True
|
|
96
|
+
return ([min_area], forward_fuse) if min_area else ([], False)
|
|
97
97
|
|
|
98
98
|
@staticmethod
|
|
99
99
|
def isolate_reshape(dom):
|
|
100
100
|
"""fuse strategy for isolate reshape dom"""
|
|
101
101
|
if dom.pattern != PrimLib.RESHAPE or len(dom.ops) != 1:
|
|
102
|
-
return []
|
|
102
|
+
return [], False
|
|
103
103
|
for a, _ in dom.out_relations.items():
|
|
104
104
|
if a.mode == GraphSplitByPattern.Area.MODE_COMPOSITE and dom.check_acyclic(a):
|
|
105
105
|
return [a], False
|
|
@@ -107,59 +107,61 @@ class CommonPattern:
|
|
|
107
107
|
if a.mode == GraphSplitByPattern.Area.MODE_COMPOSITE and a.pattern <= PrimLib.BROADCAST and \
|
|
108
108
|
a.check_acyclic(dom):
|
|
109
109
|
return [a], True
|
|
110
|
-
return []
|
|
110
|
+
return [], False
|
|
111
111
|
|
|
112
112
|
@staticmethod
|
|
113
113
|
def elemwise_depth(dom):
|
|
114
114
|
"""fuse strategy in depth for elemwise dom"""
|
|
115
115
|
if dom.pattern != PrimLib.ELEMWISE or len(dom.in_relations) != 1:
|
|
116
|
-
return []
|
|
116
|
+
return [], False
|
|
117
117
|
a, r = list(dom.in_relations.items())[0]
|
|
118
|
-
if a.pattern > PrimLib.ELEMWISE or len(a.out_relations) != 1 or r > PrimLib.ELEMWISE
|
|
119
|
-
|
|
120
|
-
|
|
118
|
+
if a.pattern > PrimLib.ELEMWISE or len(a.out_relations) != 1 or r > PrimLib.ELEMWISE:
|
|
119
|
+
return [], False
|
|
120
|
+
if tensor_size(a.dom_op().output) != tensor_size(dom.dom_op().output):
|
|
121
|
+
return [], False
|
|
121
122
|
return [a], True
|
|
122
123
|
|
|
123
124
|
@staticmethod
|
|
124
125
|
def elemwise_width(dom):
|
|
125
126
|
"""fuse strategy in width for elemwise dom"""
|
|
126
127
|
if dom.pattern != PrimLib.ELEMWISE:
|
|
127
|
-
return []
|
|
128
|
+
return [], False
|
|
128
129
|
fused = []
|
|
129
130
|
for a, r in dom.in_relations.items():
|
|
130
|
-
if a.pattern <= PrimLib.ELEMWISE and r <= PrimLib.ELEMWISE and a.check_acyclic(dom)
|
|
131
|
-
|
|
132
|
-
|
|
131
|
+
if a.pattern <= PrimLib.ELEMWISE and r <= PrimLib.ELEMWISE and a.check_acyclic(dom):
|
|
132
|
+
if tensor_size(a.dom_op().output) == tensor_size(dom.dom_op().output):
|
|
133
|
+
fused.append(a)
|
|
133
134
|
return fused, True
|
|
134
135
|
|
|
135
136
|
@staticmethod
|
|
136
137
|
def broadcast_depth(dom):
|
|
137
138
|
"""fuse strategy in depth for broadcast dom"""
|
|
138
139
|
if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or len(dom.in_relations) != 1:
|
|
139
|
-
return []
|
|
140
|
+
return [], False
|
|
140
141
|
a, r = list(dom.in_relations.items())[0]
|
|
141
|
-
if a.pattern > PrimLib.BROADCAST or len(a.out_relations) != 1 or r > PrimLib.ELEMWISE
|
|
142
|
-
|
|
143
|
-
|
|
142
|
+
if a.pattern > PrimLib.BROADCAST or len(a.out_relations) != 1 or r > PrimLib.ELEMWISE:
|
|
143
|
+
return [], False
|
|
144
|
+
if tensor_size(a.dom_op().output) != tensor_size(dom.dom_op().output):
|
|
145
|
+
return [], False
|
|
144
146
|
return [a], True
|
|
145
147
|
|
|
146
148
|
@staticmethod
|
|
147
149
|
def broadcast_width(dom):
|
|
148
150
|
"""fuse strategy in width for broadcast dom"""
|
|
149
151
|
if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST):
|
|
150
|
-
return []
|
|
152
|
+
return [], False
|
|
151
153
|
fused = []
|
|
152
154
|
for a, r in dom.in_relations.items():
|
|
153
|
-
if a.pattern <= PrimLib.BROADCAST and r <= PrimLib.ELEMWISE and a.check_acyclic(dom)
|
|
154
|
-
|
|
155
|
-
|
|
155
|
+
if a.pattern <= PrimLib.BROADCAST and r <= PrimLib.ELEMWISE and a.check_acyclic(dom):
|
|
156
|
+
if tensor_size(a.dom_op().output) == tensor_size(dom.dom_op().output):
|
|
157
|
+
fused.append(a)
|
|
156
158
|
return fused, True
|
|
157
159
|
|
|
158
160
|
@staticmethod
|
|
159
161
|
def assign(dom):
|
|
160
162
|
"""fuse strategy for assign dom"""
|
|
161
163
|
if len(dom.ops) != 1 or dom.dom_op().prim != "Assign":
|
|
162
|
-
return []
|
|
164
|
+
return [], False
|
|
163
165
|
fused = []
|
|
164
166
|
for a, _ in dom.in_relations.items():
|
|
165
167
|
fused.append(a)
|
|
@@ -711,8 +713,9 @@ class GraphSplitByPattern:
|
|
|
711
713
|
for i in range(len(areas) - 1):
|
|
712
714
|
dom = areas[i]
|
|
713
715
|
for a in areas[i + 1:]:
|
|
714
|
-
|
|
715
|
-
|
|
716
|
+
can_fuse = dom.check_acyclic(a) and a.check_acyclic(dom) and selector(dom, a) \
|
|
717
|
+
and self.limit_area_size(dom, [a], 64) and dom.fuse_confirm(a)
|
|
718
|
+
if can_fuse:
|
|
716
719
|
dom.fuse(a)
|
|
717
720
|
self.set_area_map(a.ops, dom)
|
|
718
721
|
self.areas.remove(a)
|
|
@@ -844,7 +847,7 @@ class GraphSplitByPattern:
|
|
|
844
847
|
while stack:
|
|
845
848
|
op = stack.pop()
|
|
846
849
|
if len(op.inputs) > 1 or PrimLib.iter_type(op) > PrimLib.BROADCAST or len(ops) > max_weight:
|
|
847
|
-
return []
|
|
850
|
+
return [], []
|
|
848
851
|
ops.append(op)
|
|
849
852
|
for t in op.inputs:
|
|
850
853
|
if t.op in area.ops:
|
|
@@ -878,8 +881,8 @@ class GraphSplitByPattern:
|
|
|
878
881
|
return []
|
|
879
882
|
result = []
|
|
880
883
|
for op in borders:
|
|
881
|
-
|
|
882
|
-
|
|
884
|
+
prod_ops, inputs = prods[op]
|
|
885
|
+
if prod_ops:
|
|
883
886
|
if sum([t.get_size() for t in inputs]) <= op.output.get_size():
|
|
884
887
|
pred = self.area_map.get(inputs[0].op) if inputs and inputs[0].op else None
|
|
885
888
|
result.append([pred, prod_ops[::-1]])
|
|
@@ -938,23 +941,25 @@ class GraphSplitGpu(GraphSplitByPattern):
|
|
|
938
941
|
return a.pattern > PrimLib.REDUCE or r > PrimLib.BROADCAST
|
|
939
942
|
|
|
940
943
|
def _broadcast_bwd_depth(dom):
|
|
941
|
-
if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or len(dom.out_relations) != 1
|
|
942
|
-
|
|
943
|
-
|
|
944
|
+
if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or len(dom.out_relations) != 1:
|
|
945
|
+
return [], False
|
|
946
|
+
if dom.is_output or len(dom.ops) > self.BROADCAST_FUSE_DEPTH:
|
|
947
|
+
return [], False
|
|
944
948
|
a, r = list(dom.out_relations.items())[0]
|
|
945
949
|
if _broadcast_pat_exclude(dom, a, r) or len(a.in_relations) != 1:
|
|
946
|
-
return []
|
|
950
|
+
return [], False
|
|
947
951
|
return [a], False
|
|
948
952
|
|
|
949
953
|
def _broadcast_bwd_width(dom):
|
|
950
954
|
if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or \
|
|
951
955
|
dom.is_output or len(dom.ops) > self.BROADCAST_FUSE_DEPTH:
|
|
952
|
-
return []
|
|
956
|
+
return [], False
|
|
953
957
|
fused = []
|
|
954
958
|
for a, r in dom.out_relations.items():
|
|
955
|
-
if _broadcast_pat_exclude(dom, a, r) or not dom.check_acyclic(a)
|
|
956
|
-
|
|
957
|
-
|
|
959
|
+
if _broadcast_pat_exclude(dom, a, r) or not dom.check_acyclic(a):
|
|
960
|
+
return [], False
|
|
961
|
+
if fused and tensor_size(fused[0].dom_op().output) != tensor_size(a.dom_op().output):
|
|
962
|
+
return [], False
|
|
958
963
|
fused.append(a)
|
|
959
964
|
return fused, False
|
|
960
965
|
|
|
@@ -965,25 +970,25 @@ class GraphSplitGpu(GraphSplitByPattern):
|
|
|
965
970
|
|
|
966
971
|
def _reduce_depth(dom):
|
|
967
972
|
if dom.pattern != PrimLib.REDUCE or len(dom.in_relations) != 1:
|
|
968
|
-
return []
|
|
973
|
+
return [], False
|
|
969
974
|
a, r = list(dom.in_relations.items())[0]
|
|
970
|
-
if dom.ops[0].inputs[0].dtype == "float16" and a.is_output
|
|
971
|
-
|
|
972
|
-
|
|
973
|
-
|
|
975
|
+
if dom.ops[0].inputs[0].dtype == "float16" and a.is_output:
|
|
976
|
+
if len(a.ops) >= 10 and _is_atomic_add_available(dom):
|
|
977
|
+
# to evade the precision problem.
|
|
978
|
+
return [], False
|
|
974
979
|
if _reduce_pat_exclude(dom, a, r) or len(a.out_relations) != 1:
|
|
975
980
|
return []
|
|
976
981
|
return [a], True
|
|
977
982
|
|
|
978
983
|
def _reduce_width(dom):
|
|
979
984
|
if dom.pattern != PrimLib.REDUCE:
|
|
980
|
-
return []
|
|
985
|
+
return [], False
|
|
981
986
|
fused = []
|
|
982
987
|
for a, r in dom.in_relations.items():
|
|
983
|
-
if dom.ops[0].inputs[0].dtype == "float16" and a.is_output
|
|
984
|
-
|
|
985
|
-
|
|
986
|
-
|
|
988
|
+
if dom.ops[0].inputs[0].dtype == "float16" and a.is_output:
|
|
989
|
+
if len(a.ops) >= 10 and _is_atomic_add_available(dom):
|
|
990
|
+
# to evade the precision problem.
|
|
991
|
+
continue
|
|
987
992
|
if not _reduce_pat_exclude(dom, a, r) and a.check_acyclic(dom):
|
|
988
993
|
fused.append(a)
|
|
989
994
|
return fused, True
|
|
@@ -1016,15 +1021,15 @@ class GraphSplitGpu(GraphSplitByPattern):
|
|
|
1016
1021
|
|
|
1017
1022
|
def _reduce_output(dom):
|
|
1018
1023
|
if dom.pattern != PrimLib.REDUCE:
|
|
1019
|
-
return []
|
|
1024
|
+
return [], False
|
|
1020
1025
|
if _may_multi_filter(dom.ops):
|
|
1021
|
-
return []
|
|
1026
|
+
return [], False
|
|
1022
1027
|
if _is_atomic_add_available(dom):
|
|
1023
|
-
return []
|
|
1028
|
+
return [], False
|
|
1024
1029
|
is_all_reduce = tensor_size(dom.ops[0].output) == 1
|
|
1025
1030
|
# excluded large size all reduce
|
|
1026
1031
|
if is_all_reduce and dom.ops[0].inputs and tensor_size(dom.ops[0].inputs[0]) > 1024 * 12:
|
|
1027
|
-
return []
|
|
1032
|
+
return [], False
|
|
1028
1033
|
|
|
1029
1034
|
fused = []
|
|
1030
1035
|
for a, r in dom.out_relations.items():
|
|
@@ -1034,11 +1039,11 @@ class GraphSplitGpu(GraphSplitByPattern):
|
|
|
1034
1039
|
|
|
1035
1040
|
def _reduce_stitch(dom):
|
|
1036
1041
|
if dom.pattern != PrimLib.REDUCE:
|
|
1037
|
-
return []
|
|
1042
|
+
return [], False
|
|
1038
1043
|
if tensor_size(dom.ops[0].output) == 1:
|
|
1039
|
-
return []
|
|
1044
|
+
return [], False
|
|
1040
1045
|
if tensor_size(dom.ops[0].inputs[0]) < 1024 * 12:
|
|
1041
|
-
return []
|
|
1046
|
+
return [], False
|
|
1042
1047
|
|
|
1043
1048
|
fused = []
|
|
1044
1049
|
for a, r in dom.out_relations.items():
|
|
@@ -1055,7 +1060,7 @@ class GraphSplitGpu(GraphSplitByPattern):
|
|
|
1055
1060
|
|
|
1056
1061
|
def _transpose(dom):
|
|
1057
1062
|
if len(dom.ops) != 1 or dom.ops[0].prim != "Transpose":
|
|
1058
|
-
return []
|
|
1063
|
+
return [], False
|
|
1059
1064
|
fused = []
|
|
1060
1065
|
for a, _ in dom.in_relations.items():
|
|
1061
1066
|
if a.pattern <= PrimLib.BROADCAST and a.check_acyclic(dom) and len(a.ops) <= self.TRANSPOSE_FUSE_DEPTH:
|
|
@@ -1064,7 +1069,7 @@ class GraphSplitGpu(GraphSplitByPattern):
|
|
|
1064
1069
|
|
|
1065
1070
|
def _strided_slice(dom):
|
|
1066
1071
|
if dom.dom_op().prim != "StridedSlice":
|
|
1067
|
-
return []
|
|
1072
|
+
return [], False
|
|
1068
1073
|
fused = []
|
|
1069
1074
|
for a, _ in dom.in_relations.items():
|
|
1070
1075
|
if a.pattern <= PrimLib.BROADCAST and a.check_acyclic(dom) and \
|
|
@@ -1075,7 +1080,7 @@ class GraphSplitGpu(GraphSplitByPattern):
|
|
|
1075
1080
|
def _gather_output(dom, reduce_fusion=False):
|
|
1076
1081
|
gather_prims = ("Gather", "GatherNd", "CSRGather")
|
|
1077
1082
|
if not dom.dom_op().prim in gather_prims:
|
|
1078
|
-
return []
|
|
1083
|
+
return [], False
|
|
1079
1084
|
|
|
1080
1085
|
def _reduce_exclude(op, axis_list):
|
|
1081
1086
|
""" Whether this operator should be excluded.
|
|
@@ -1173,7 +1178,7 @@ class GraphSplitGpu(GraphSplitByPattern):
|
|
|
1173
1178
|
for a, _ in dom.out_relations.items():
|
|
1174
1179
|
if _shape_consistent(gather_prims, appected_areas, dom, a) and dom.check_acyclic(a):
|
|
1175
1180
|
return [a], False
|
|
1176
|
-
return []
|
|
1181
|
+
return [], False
|
|
1177
1182
|
|
|
1178
1183
|
def _broadcast_tot(dom):
|
|
1179
1184
|
"""Fuse rule for TensorScatterAdd and UnsortedSegmentSum."""
|
|
@@ -1182,13 +1187,13 @@ class GraphSplitGpu(GraphSplitByPattern):
|
|
|
1182
1187
|
return bool(set(op1.inputs) & set(op2.inputs))
|
|
1183
1188
|
|
|
1184
1189
|
if len(dom.ops) != 1:
|
|
1185
|
-
return []
|
|
1190
|
+
return [], False
|
|
1186
1191
|
|
|
1187
1192
|
# Only fuse the first input for `TensorScatterAdd`` and the first and second input for `UnsortedSegmentSum`.
|
|
1188
1193
|
fuse_arg = {"TensorScatterAdd": slice(1, None), "UnsortedSegmentSum": slice(0, 2)}
|
|
1189
1194
|
arg_idx = fuse_arg.get(dom.dom_op().prim, -1)
|
|
1190
1195
|
if arg_idx == -1:
|
|
1191
|
-
return []
|
|
1196
|
+
return [], False
|
|
1192
1197
|
fuse_tensor = dom.dom_op().inputs[arg_idx]
|
|
1193
1198
|
|
|
1194
1199
|
for a, _ in dom.in_relations.items():
|
|
@@ -1200,27 +1205,30 @@ class GraphSplitGpu(GraphSplitByPattern):
|
|
|
1200
1205
|
# Rule 2: Fuse op(reshape/elementwise/broadcast) in specified position inputs.
|
|
1201
1206
|
if a.pattern <= PrimLib.BROADCAST and any((op.output in fuse_tensor for op in a.ops)):
|
|
1202
1207
|
return [a], True
|
|
1203
|
-
return []
|
|
1208
|
+
return [], False
|
|
1204
1209
|
|
|
1205
1210
|
def _broadcast_onehot(dom, fwd=True):
|
|
1206
1211
|
"""Fuse rule for OneHot."""
|
|
1207
1212
|
if dom.dom_op().prim != "OneHot":
|
|
1208
|
-
return []
|
|
1213
|
+
return [], False
|
|
1209
1214
|
|
|
1210
1215
|
fused = []
|
|
1211
1216
|
neighbours = dom.in_relations.items() if fwd else dom.out_relations.items()
|
|
1212
1217
|
for a, _ in neighbours:
|
|
1213
1218
|
if a.pattern <= PrimLib.BROADCAST:
|
|
1214
|
-
if
|
|
1215
|
-
|
|
1216
|
-
|
|
1219
|
+
if fwd:
|
|
1220
|
+
if a.check_acyclic(dom) and len(a.out_relations) == 1 and not a.is_output:
|
|
1221
|
+
fused.append(a)
|
|
1222
|
+
else:
|
|
1223
|
+
if dom.check_acyclic(a):
|
|
1224
|
+
fused.append(a)
|
|
1217
1225
|
|
|
1218
1226
|
return fused, fwd
|
|
1219
1227
|
|
|
1220
1228
|
def _elemwise_elemany(dom):
|
|
1221
1229
|
"""Fuse rule for elemany."""
|
|
1222
1230
|
if dom.dom_op().prim != "ElemAny":
|
|
1223
|
-
return []
|
|
1231
|
+
return [], False
|
|
1224
1232
|
|
|
1225
1233
|
fused = []
|
|
1226
1234
|
for a, r in dom.in_relations.items():
|
|
@@ -1233,21 +1241,21 @@ class GraphSplitGpu(GraphSplitByPattern):
|
|
|
1233
1241
|
"""Fuse rule for injective """
|
|
1234
1242
|
injective_ops = {"Transpose", "StridedSlice"}
|
|
1235
1243
|
if dom.dom_op().prim not in injective_ops:
|
|
1236
|
-
return []
|
|
1244
|
+
return [], False
|
|
1237
1245
|
to_ops = dom.dom_op().output.to_ops
|
|
1238
1246
|
if dom.is_output or len(to_ops) != 1 or len(dom.out_relations) != 1:
|
|
1239
|
-
return []
|
|
1247
|
+
return [], False
|
|
1240
1248
|
to_area = list(dom.out_relations.keys())[0]
|
|
1241
1249
|
if (to_area.pattern >= PrimLib.REDUCE and to_area.dom_op().prim not in injective_ops) or \
|
|
1242
1250
|
to_ops[0] not in to_area.ops:
|
|
1243
|
-
return []
|
|
1251
|
+
return [], False
|
|
1244
1252
|
if len(to_area.ops) > self.TRANSPOSE_FUSE_DEPTH:
|
|
1245
|
-
return []
|
|
1253
|
+
return [], False
|
|
1246
1254
|
return [to_area], False
|
|
1247
1255
|
|
|
1248
1256
|
def _h_broadcast(dom, a):
|
|
1249
1257
|
if dom.pattern > PrimLib.BROADCAST:
|
|
1250
|
-
return []
|
|
1258
|
+
return [], False
|
|
1251
1259
|
return a.pattern <= PrimLib.BROADCAST and dom.ops[0].output.shape == a.ops[0].output.shape
|
|
1252
1260
|
|
|
1253
1261
|
def _h_reduce(dom, a):
|
|
@@ -1274,7 +1282,7 @@ class GraphSplitGpu(GraphSplitByPattern):
|
|
|
1274
1282
|
fuse_arg = {"CSRReduceSum": slice(1, 3), "CSRGather": slice(2, 3)}
|
|
1275
1283
|
arg_idx = fuse_arg.get(dom.dom_op().prim, -1)
|
|
1276
1284
|
if arg_idx == -1:
|
|
1277
|
-
return []
|
|
1285
|
+
return [], False
|
|
1278
1286
|
fuse_tensor = dom.dom_op().inputs[arg_idx]
|
|
1279
1287
|
for a, _ in dom.in_relations.items():
|
|
1280
1288
|
if (a.dom_op().prim == "CSRGather" and a.dom_op().prim == dom.dom_op().prim and
|
|
@@ -1283,7 +1291,7 @@ class GraphSplitGpu(GraphSplitByPattern):
|
|
|
1283
1291
|
if a.pattern <= PrimLib.BROADCAST and dom.check_acyclic(a) and \
|
|
1284
1292
|
any([op.output in fuse_tensor for op in a.ops]):
|
|
1285
1293
|
return [a], True
|
|
1286
|
-
return []
|
|
1294
|
+
return [], False
|
|
1287
1295
|
|
|
1288
1296
|
def _fuse_loop():
|
|
1289
1297
|
self.fuse(CommonPattern.reshape)
|
|
@@ -50,8 +50,9 @@ def split_with_json(json_str, flags_str):
|
|
|
50
50
|
def _load_repository(graph, flags):
|
|
51
51
|
"""Load repository if exists"""
|
|
52
52
|
def check_repo(op, best_split, op_desc):
|
|
53
|
-
if not isinstance(best_split, dict)
|
|
54
|
-
|
|
53
|
+
if not isinstance(best_split, dict):
|
|
54
|
+
return False
|
|
55
|
+
if "group_num" not in best_split or "graph_mode" not in best_split or "split_result" not in best_split:
|
|
55
56
|
logger.warning("The graph split repository of {} should be a dict which contains 'group_num', 'graph_mode' "
|
|
56
57
|
"and 'split_result' field, but got {}".format(op, best_split))
|
|
57
58
|
return False
|