mindspore 2.2.0__cp38-cp38-macosx_11_0_arm64.whl → 2.2.11__cp38-cp38-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (115) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/_c_dataengine.cpython-38-darwin.so +0 -0
  3. mindspore/_c_expression.cpython-38-darwin.so +0 -0
  4. mindspore/_checkparam.py +3 -3
  5. mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
  6. mindspore/_extends/graph_kernel/splitter.py +3 -2
  7. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +83 -66
  8. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -4
  9. mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
  10. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +2 -1
  11. mindspore/_extends/parse/__init__.py +3 -2
  12. mindspore/_extends/parse/parser.py +6 -1
  13. mindspore/_extends/parse/standard_method.py +14 -11
  14. mindspore/_extends/remote/kernel_build_server.py +2 -1
  15. mindspore/_mindspore_offline_debug.cpython-38-darwin.so +0 -0
  16. mindspore/common/_utils.py +16 -0
  17. mindspore/common/api.py +1 -1
  18. mindspore/common/auto_dynamic_shape.py +81 -85
  19. mindspore/common/dump.py +1 -1
  20. mindspore/common/tensor.py +3 -20
  21. mindspore/config/op_info.config +1 -1
  22. mindspore/context.py +11 -4
  23. mindspore/dataset/engine/cache_client.py +8 -5
  24. mindspore/dataset/engine/datasets_standard_format.py +5 -0
  25. mindspore/dataset/vision/transforms.py +21 -21
  26. mindspore/experimental/optim/adam.py +1 -1
  27. mindspore/gen_ops.py +1 -1
  28. mindspore/include/api/model.h +17 -0
  29. mindspore/include/api/status.h +8 -3
  30. mindspore/lib/libmindspore_backend.dylib +0 -0
  31. mindspore/lib/libmindspore_common.dylib +0 -0
  32. mindspore/lib/libmindspore_core.dylib +0 -0
  33. mindspore/lib/libmindspore_shared_lib.dylib +0 -0
  34. mindspore/lib/libnnacl.dylib +0 -0
  35. mindspore/lib/libopencv_core.4.5.dylib +0 -0
  36. mindspore/lib/libopencv_imgcodecs.4.5.dylib +0 -0
  37. mindspore/nn/cell.py +0 -3
  38. mindspore/nn/layer/activation.py +4 -5
  39. mindspore/nn/layer/conv.py +39 -23
  40. mindspore/nn/layer/flash_attention.py +54 -129
  41. mindspore/nn/layer/math.py +3 -7
  42. mindspore/nn/layer/rnn_cells.py +5 -5
  43. mindspore/nn/wrap/__init__.py +4 -2
  44. mindspore/nn/wrap/cell_wrapper.py +12 -3
  45. mindspore/numpy/utils_const.py +5 -5
  46. mindspore/ops/_grad_experimental/grad_array_ops.py +1 -1
  47. mindspore/ops/_grad_experimental/grad_implementations.py +2 -2
  48. mindspore/ops/_grad_experimental/grad_math_ops.py +19 -18
  49. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
  50. mindspore/ops/_op_impl/aicpu/add.py +3 -3
  51. mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +21 -2
  52. mindspore/ops/_utils/utils.py +2 -0
  53. mindspore/ops/composite/multitype_ops/_compile_utils.py +2 -1
  54. mindspore/ops/composite/multitype_ops/getitem_impl.py +2 -2
  55. mindspore/ops/function/array_func.py +10 -7
  56. mindspore/ops/function/grad/grad_func.py +0 -1
  57. mindspore/ops/function/nn_func.py +98 -9
  58. mindspore/ops/function/random_func.py +2 -1
  59. mindspore/ops/op_info_register.py +24 -21
  60. mindspore/ops/operations/__init__.py +6 -2
  61. mindspore/ops/operations/_grad_ops.py +25 -6
  62. mindspore/ops/operations/_inner_ops.py +155 -23
  63. mindspore/ops/operations/array_ops.py +9 -7
  64. mindspore/ops/operations/comm_ops.py +2 -2
  65. mindspore/ops/operations/custom_ops.py +85 -68
  66. mindspore/ops/operations/inner_ops.py +26 -3
  67. mindspore/ops/operations/math_ops.py +7 -6
  68. mindspore/ops/operations/nn_ops.py +193 -49
  69. mindspore/parallel/_parallel_serialization.py +10 -3
  70. mindspore/parallel/_tensor.py +4 -1
  71. mindspore/parallel/checkpoint_transform.py +13 -2
  72. mindspore/parallel/shard.py +17 -10
  73. mindspore/profiler/common/util.py +1 -0
  74. mindspore/profiler/parser/ascend_hccl_generator.py +232 -0
  75. mindspore/profiler/parser/ascend_msprof_exporter.py +86 -43
  76. mindspore/profiler/parser/ascend_msprof_generator.py +196 -9
  77. mindspore/profiler/parser/ascend_op_generator.py +1 -1
  78. mindspore/profiler/parser/ascend_timeline_generator.py +6 -182
  79. mindspore/profiler/parser/base_timeline_generator.py +1 -1
  80. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +2 -2
  81. mindspore/profiler/parser/framework_parser.py +1 -1
  82. mindspore/profiler/parser/profiler_info.py +19 -0
  83. mindspore/profiler/profiling.py +46 -24
  84. mindspore/rewrite/api/pattern_engine.py +1 -1
  85. mindspore/rewrite/parsers/for_parser.py +7 -7
  86. mindspore/rewrite/parsers/module_parser.py +4 -4
  87. mindspore/rewrite/symbol_tree.py +1 -4
  88. mindspore/run_check/_check_version.py +5 -3
  89. mindspore/safeguard/rewrite_obfuscation.py +52 -28
  90. mindspore/scipy/ops.py +55 -5
  91. mindspore/scipy/optimize/__init__.py +3 -2
  92. mindspore/scipy/optimize/linear_sum_assignment.py +38 -33
  93. mindspore/train/callback/_summary_collector.py +1 -1
  94. mindspore/train/dataset_helper.py +1 -0
  95. mindspore/train/model.py +2 -2
  96. mindspore/train/serialization.py +97 -11
  97. mindspore/train/summary/_summary_adapter.py +1 -1
  98. mindspore/train/summary/summary_record.py +23 -7
  99. mindspore/version.py +1 -1
  100. {mindspore-2.2.0.dist-info → mindspore-2.2.11.dist-info}/METADATA +3 -2
  101. {mindspore-2.2.0.dist-info → mindspore-2.2.11.dist-info}/RECORD +104 -115
  102. mindspore/ops/_op_impl/_custom_op/flash_attention/__init__.py +0 -0
  103. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +0 -406
  104. mindspore/ops/_op_impl/_custom_op/flash_attention/constants.py +0 -41
  105. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +0 -467
  106. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +0 -563
  107. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +0 -193
  108. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +0 -435
  109. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
  110. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +0 -45
  111. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +0 -67
  112. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +0 -62
  113. {mindspore-2.2.0.dist-info → mindspore-2.2.11.dist-info}/WHEEL +0 -0
  114. {mindspore-2.2.0.dist-info → mindspore-2.2.11.dist-info}/entry_points.txt +0 -0
  115. {mindspore-2.2.0.dist-info → mindspore-2.2.11.dist-info}/top_level.txt +0 -0
mindspore/.commit_id CHANGED
@@ -1 +1 @@
1
- __commit_id__ = ''[sha1]:9390851d,[branch]:(HEAD,origin/r2.2,r2.2)''
1
+ __commit_id__ = ''[sha1]:8c390933,[branch]:(HEAD,origin/r2.2,r2.2)''
mindspore/_checkparam.py CHANGED
@@ -720,9 +720,9 @@ def check_value_type(arg_name, arg_value, valid_types, prim_name=None):
720
720
  type_names = [t.__name__ if hasattr(t, '__name__') else str(t) for t in valid_types]
721
721
  num_types = len(valid_types)
722
722
  msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
723
- type_name_msg = f'{type_names if num_types > 1 else type_names[0]}'
724
- msg = f'type of \'{arg_name}\' should be{"one of " if num_types > 1 else ""} \'{type_name_msg}\''
725
- raise TypeError(f'{msg_prefix} {msg}, but got type \'{type(arg_value).__name__}\'.')
723
+ raise TypeError(f'{msg_prefix} type of \'{arg_name}\' should be {"one of " if num_types > 1 else ""}' \
724
+ f'\'{type_names if num_types > 1 else type_names[0]}\', ' \
725
+ f'but got type \'{type(arg_value).__name__}\'.')
726
726
 
727
727
  # Notice: bool is subclass of int, so `check_value_type('x', True, [int])` will check fail, and
728
728
  # `check_value_type('x', True, [bool, int])` will check pass
@@ -83,23 +83,23 @@ class CommonPattern:
83
83
  def reshape(dom):
84
84
  """fuse strategy for reshape dom"""
85
85
  if dom.pattern != PrimLib.RESHAPE:
86
- return []
86
+ return [], False
87
87
  min_area, forward_fuse = None, False
88
88
  for a, _ in dom.out_relations.items():
89
- if a.pattern <= PrimLib.BROADCAST and dom.check_acyclic(a) and \
90
- (min_area is None or a.pattern < min_area.pattern):
91
- min_area = a
89
+ if a.pattern <= PrimLib.BROADCAST and dom.check_acyclic(a):
90
+ if min_area is None or a.pattern < min_area.pattern:
91
+ min_area = a
92
92
  for a, _ in dom.in_relations.items():
93
- if a.pattern <= PrimLib.BROADCAST and a.check_acyclic(dom) and \
94
- (min_area is None or a.pattern < min_area.pattern):
95
- min_area, forward_fuse = a, True
96
- return ([min_area], forward_fuse) if min_area else []
93
+ if a.pattern <= PrimLib.BROADCAST and a.check_acyclic(dom):
94
+ if min_area is None or a.pattern < min_area.pattern:
95
+ min_area, forward_fuse = a, True
96
+ return ([min_area], forward_fuse) if min_area else ([], False)
97
97
 
98
98
  @staticmethod
99
99
  def isolate_reshape(dom):
100
100
  """fuse strategy for isolate reshape dom"""
101
101
  if dom.pattern != PrimLib.RESHAPE or len(dom.ops) != 1:
102
- return []
102
+ return [], False
103
103
  for a, _ in dom.out_relations.items():
104
104
  if a.mode == GraphSplitByPattern.Area.MODE_COMPOSITE and dom.check_acyclic(a):
105
105
  return [a], False
@@ -107,59 +107,61 @@ class CommonPattern:
107
107
  if a.mode == GraphSplitByPattern.Area.MODE_COMPOSITE and a.pattern <= PrimLib.BROADCAST and \
108
108
  a.check_acyclic(dom):
109
109
  return [a], True
110
- return []
110
+ return [], False
111
111
 
112
112
  @staticmethod
113
113
  def elemwise_depth(dom):
114
114
  """fuse strategy in depth for elemwise dom"""
115
115
  if dom.pattern != PrimLib.ELEMWISE or len(dom.in_relations) != 1:
116
- return []
116
+ return [], False
117
117
  a, r = list(dom.in_relations.items())[0]
118
- if a.pattern > PrimLib.ELEMWISE or len(a.out_relations) != 1 or r > PrimLib.ELEMWISE or \
119
- tensor_size(a.dom_op().output) != tensor_size(dom.dom_op().output):
120
- return []
118
+ if a.pattern > PrimLib.ELEMWISE or len(a.out_relations) != 1 or r > PrimLib.ELEMWISE:
119
+ return [], False
120
+ if tensor_size(a.dom_op().output) != tensor_size(dom.dom_op().output):
121
+ return [], False
121
122
  return [a], True
122
123
 
123
124
  @staticmethod
124
125
  def elemwise_width(dom):
125
126
  """fuse strategy in width for elemwise dom"""
126
127
  if dom.pattern != PrimLib.ELEMWISE:
127
- return []
128
+ return [], False
128
129
  fused = []
129
130
  for a, r in dom.in_relations.items():
130
- if a.pattern <= PrimLib.ELEMWISE and r <= PrimLib.ELEMWISE and a.check_acyclic(dom) and \
131
- tensor_size(a.dom_op().output) == tensor_size(dom.dom_op().output):
132
- fused.append(a)
131
+ if a.pattern <= PrimLib.ELEMWISE and r <= PrimLib.ELEMWISE and a.check_acyclic(dom):
132
+ if tensor_size(a.dom_op().output) == tensor_size(dom.dom_op().output):
133
+ fused.append(a)
133
134
  return fused, True
134
135
 
135
136
  @staticmethod
136
137
  def broadcast_depth(dom):
137
138
  """fuse strategy in depth for broadcast dom"""
138
139
  if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or len(dom.in_relations) != 1:
139
- return []
140
+ return [], False
140
141
  a, r = list(dom.in_relations.items())[0]
141
- if a.pattern > PrimLib.BROADCAST or len(a.out_relations) != 1 or r > PrimLib.ELEMWISE or \
142
- tensor_size(a.dom_op().output) != tensor_size(dom.dom_op().output):
143
- return []
142
+ if a.pattern > PrimLib.BROADCAST or len(a.out_relations) != 1 or r > PrimLib.ELEMWISE:
143
+ return [], False
144
+ if tensor_size(a.dom_op().output) != tensor_size(dom.dom_op().output):
145
+ return [], False
144
146
  return [a], True
145
147
 
146
148
  @staticmethod
147
149
  def broadcast_width(dom):
148
150
  """fuse strategy in width for broadcast dom"""
149
151
  if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST):
150
- return []
152
+ return [], False
151
153
  fused = []
152
154
  for a, r in dom.in_relations.items():
153
- if a.pattern <= PrimLib.BROADCAST and r <= PrimLib.ELEMWISE and a.check_acyclic(dom) and \
154
- tensor_size(a.dom_op().output) == tensor_size(dom.dom_op().output):
155
- fused.append(a)
155
+ if a.pattern <= PrimLib.BROADCAST and r <= PrimLib.ELEMWISE and a.check_acyclic(dom):
156
+ if tensor_size(a.dom_op().output) == tensor_size(dom.dom_op().output):
157
+ fused.append(a)
156
158
  return fused, True
157
159
 
158
160
  @staticmethod
159
161
  def assign(dom):
160
162
  """fuse strategy for assign dom"""
161
163
  if len(dom.ops) != 1 or dom.dom_op().prim != "Assign":
162
- return []
164
+ return [], False
163
165
  fused = []
164
166
  for a, _ in dom.in_relations.items():
165
167
  fused.append(a)
@@ -711,8 +713,9 @@ class GraphSplitByPattern:
711
713
  for i in range(len(areas) - 1):
712
714
  dom = areas[i]
713
715
  for a in areas[i + 1:]:
714
- if dom.check_acyclic(a) and a.check_acyclic(dom) and \
715
- selector(dom, a) and self.limit_area_size(dom, [a], 64) and dom.fuse_confirm(a):
716
+ can_fuse = dom.check_acyclic(a) and a.check_acyclic(dom) and selector(dom, a) \
717
+ and self.limit_area_size(dom, [a], 64) and dom.fuse_confirm(a)
718
+ if can_fuse:
716
719
  dom.fuse(a)
717
720
  self.set_area_map(a.ops, dom)
718
721
  self.areas.remove(a)
@@ -844,7 +847,7 @@ class GraphSplitByPattern:
844
847
  while stack:
845
848
  op = stack.pop()
846
849
  if len(op.inputs) > 1 or PrimLib.iter_type(op) > PrimLib.BROADCAST or len(ops) > max_weight:
847
- return []
850
+ return [], []
848
851
  ops.append(op)
849
852
  for t in op.inputs:
850
853
  if t.op in area.ops:
@@ -878,8 +881,8 @@ class GraphSplitByPattern:
878
881
  return []
879
882
  result = []
880
883
  for op in borders:
881
- if prods[op]:
882
- prod_ops, inputs = prods[op]
884
+ prod_ops, inputs = prods[op]
885
+ if prod_ops:
883
886
  if sum([t.get_size() for t in inputs]) <= op.output.get_size():
884
887
  pred = self.area_map.get(inputs[0].op) if inputs and inputs[0].op else None
885
888
  result.append([pred, prod_ops[::-1]])
@@ -938,23 +941,25 @@ class GraphSplitGpu(GraphSplitByPattern):
938
941
  return a.pattern > PrimLib.REDUCE or r > PrimLib.BROADCAST
939
942
 
940
943
  def _broadcast_bwd_depth(dom):
941
- if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or len(dom.out_relations) != 1 or \
942
- dom.is_output or len(dom.ops) > self.BROADCAST_FUSE_DEPTH:
943
- return []
944
+ if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or len(dom.out_relations) != 1:
945
+ return [], False
946
+ if dom.is_output or len(dom.ops) > self.BROADCAST_FUSE_DEPTH:
947
+ return [], False
944
948
  a, r = list(dom.out_relations.items())[0]
945
949
  if _broadcast_pat_exclude(dom, a, r) or len(a.in_relations) != 1:
946
- return []
950
+ return [], False
947
951
  return [a], False
948
952
 
949
953
  def _broadcast_bwd_width(dom):
950
954
  if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or \
951
955
  dom.is_output or len(dom.ops) > self.BROADCAST_FUSE_DEPTH:
952
- return []
956
+ return [], False
953
957
  fused = []
954
958
  for a, r in dom.out_relations.items():
955
- if _broadcast_pat_exclude(dom, a, r) or not dom.check_acyclic(a) or \
956
- (fused and tensor_size(fused[0].dom_op().output) != tensor_size(a.dom_op().output)):
957
- return []
959
+ if _broadcast_pat_exclude(dom, a, r) or not dom.check_acyclic(a):
960
+ return [], False
961
+ if fused and tensor_size(fused[0].dom_op().output) != tensor_size(a.dom_op().output):
962
+ return [], False
958
963
  fused.append(a)
959
964
  return fused, False
960
965
 
@@ -965,25 +970,25 @@ class GraphSplitGpu(GraphSplitByPattern):
965
970
 
966
971
  def _reduce_depth(dom):
967
972
  if dom.pattern != PrimLib.REDUCE or len(dom.in_relations) != 1:
968
- return []
973
+ return [], False
969
974
  a, r = list(dom.in_relations.items())[0]
970
- if dom.ops[0].inputs[0].dtype == "float16" and a.is_output and len(a.ops) >= 10 and \
971
- _is_atomic_add_available(dom):
972
- # to evade the precision problem.
973
- return []
975
+ if dom.ops[0].inputs[0].dtype == "float16" and a.is_output:
976
+ if len(a.ops) >= 10 and _is_atomic_add_available(dom):
977
+ # to evade the precision problem.
978
+ return [], False
974
979
  if _reduce_pat_exclude(dom, a, r) or len(a.out_relations) != 1:
975
980
  return []
976
981
  return [a], True
977
982
 
978
983
  def _reduce_width(dom):
979
984
  if dom.pattern != PrimLib.REDUCE:
980
- return []
985
+ return [], False
981
986
  fused = []
982
987
  for a, r in dom.in_relations.items():
983
- if dom.ops[0].inputs[0].dtype == "float16" and a.is_output and len(a.ops) >= 10 and \
984
- _is_atomic_add_available(dom):
985
- # to evade the precision problem.
986
- continue
988
+ if dom.ops[0].inputs[0].dtype == "float16" and a.is_output:
989
+ if len(a.ops) >= 10 and _is_atomic_add_available(dom):
990
+ # to evade the precision problem.
991
+ continue
987
992
  if not _reduce_pat_exclude(dom, a, r) and a.check_acyclic(dom):
988
993
  fused.append(a)
989
994
  return fused, True
@@ -1016,15 +1021,15 @@ class GraphSplitGpu(GraphSplitByPattern):
1016
1021
 
1017
1022
  def _reduce_output(dom):
1018
1023
  if dom.pattern != PrimLib.REDUCE:
1019
- return []
1024
+ return [], False
1020
1025
  if _may_multi_filter(dom.ops):
1021
- return []
1026
+ return [], False
1022
1027
  if _is_atomic_add_available(dom):
1023
- return []
1028
+ return [], False
1024
1029
  is_all_reduce = tensor_size(dom.ops[0].output) == 1
1025
1030
  # excluded large size all reduce
1026
1031
  if is_all_reduce and dom.ops[0].inputs and tensor_size(dom.ops[0].inputs[0]) > 1024 * 12:
1027
- return []
1032
+ return [], False
1028
1033
 
1029
1034
  fused = []
1030
1035
  for a, r in dom.out_relations.items():
@@ -1034,11 +1039,11 @@ class GraphSplitGpu(GraphSplitByPattern):
1034
1039
 
1035
1040
  def _reduce_stitch(dom):
1036
1041
  if dom.pattern != PrimLib.REDUCE:
1037
- return []
1042
+ return [], False
1038
1043
  if tensor_size(dom.ops[0].output) == 1:
1039
- return []
1044
+ return [], False
1040
1045
  if tensor_size(dom.ops[0].inputs[0]) < 1024 * 12:
1041
- return []
1046
+ return [], False
1042
1047
 
1043
1048
  fused = []
1044
1049
  for a, r in dom.out_relations.items():
@@ -1055,7 +1060,7 @@ class GraphSplitGpu(GraphSplitByPattern):
1055
1060
 
1056
1061
  def _transpose(dom):
1057
1062
  if len(dom.ops) != 1 or dom.ops[0].prim != "Transpose":
1058
- return []
1063
+ return [], False
1059
1064
  fused = []
1060
1065
  for a, _ in dom.in_relations.items():
1061
1066
  if a.pattern <= PrimLib.BROADCAST and a.check_acyclic(dom) and len(a.ops) <= self.TRANSPOSE_FUSE_DEPTH:
@@ -1064,7 +1069,7 @@ class GraphSplitGpu(GraphSplitByPattern):
1064
1069
 
1065
1070
  def _strided_slice(dom):
1066
1071
  if dom.dom_op().prim != "StridedSlice":
1067
- return []
1072
+ return [], False
1068
1073
  fused = []
1069
1074
  for a, _ in dom.in_relations.items():
1070
1075
  if a.pattern <= PrimLib.BROADCAST and a.check_acyclic(dom) and \
@@ -1075,7 +1080,7 @@ class GraphSplitGpu(GraphSplitByPattern):
1075
1080
  def _gather_output(dom, reduce_fusion=False):
1076
1081
  gather_prims = ("Gather", "GatherNd", "CSRGather")
1077
1082
  if not dom.dom_op().prim in gather_prims:
1078
- return []
1083
+ return [], False
1079
1084
 
1080
1085
  def _reduce_exclude(op, axis_list):
1081
1086
  """ Whether this operator should be excluded.
@@ -1173,7 +1178,7 @@ class GraphSplitGpu(GraphSplitByPattern):
1173
1178
  for a, _ in dom.out_relations.items():
1174
1179
  if _shape_consistent(gather_prims, appected_areas, dom, a) and dom.check_acyclic(a):
1175
1180
  return [a], False
1176
- return []
1181
+ return [], False
1177
1182
 
1178
1183
  def _broadcast_tot(dom):
1179
1184
  """Fuse rule for TensorScatterAdd and UnsortedSegmentSum."""
@@ -1182,13 +1187,13 @@ class GraphSplitGpu(GraphSplitByPattern):
1182
1187
  return bool(set(op1.inputs) & set(op2.inputs))
1183
1188
 
1184
1189
  if len(dom.ops) != 1:
1185
- return []
1190
+ return [], False
1186
1191
 
1187
1192
  # Only fuse the first input for `TensorScatterAdd`` and the first and second input for `UnsortedSegmentSum`.
1188
1193
  fuse_arg = {"TensorScatterAdd": slice(1, None), "UnsortedSegmentSum": slice(0, 2)}
1189
1194
  arg_idx = fuse_arg.get(dom.dom_op().prim, -1)
1190
1195
  if arg_idx == -1:
1191
- return []
1196
+ return [], False
1192
1197
  fuse_tensor = dom.dom_op().inputs[arg_idx]
1193
1198
 
1194
1199
  for a, _ in dom.in_relations.items():
@@ -1200,27 +1205,30 @@ class GraphSplitGpu(GraphSplitByPattern):
1200
1205
  # Rule 2: Fuse op(reshape/elementwise/broadcast) in specified position inputs.
1201
1206
  if a.pattern <= PrimLib.BROADCAST and any((op.output in fuse_tensor for op in a.ops)):
1202
1207
  return [a], True
1203
- return []
1208
+ return [], False
1204
1209
 
1205
1210
  def _broadcast_onehot(dom, fwd=True):
1206
1211
  """Fuse rule for OneHot."""
1207
1212
  if dom.dom_op().prim != "OneHot":
1208
- return []
1213
+ return [], False
1209
1214
 
1210
1215
  fused = []
1211
1216
  neighbours = dom.in_relations.items() if fwd else dom.out_relations.items()
1212
1217
  for a, _ in neighbours:
1213
1218
  if a.pattern <= PrimLib.BROADCAST:
1214
- if (fwd and a.check_acyclic(dom) and len(a.out_relations) == 1 and not a.is_output) or \
1215
- (not fwd and dom.check_acyclic(a)):
1216
- fused.append(a)
1219
+ if fwd:
1220
+ if a.check_acyclic(dom) and len(a.out_relations) == 1 and not a.is_output:
1221
+ fused.append(a)
1222
+ else:
1223
+ if dom.check_acyclic(a):
1224
+ fused.append(a)
1217
1225
 
1218
1226
  return fused, fwd
1219
1227
 
1220
1228
  def _elemwise_elemany(dom):
1221
1229
  """Fuse rule for elemany."""
1222
1230
  if dom.dom_op().prim != "ElemAny":
1223
- return []
1231
+ return [], False
1224
1232
 
1225
1233
  fused = []
1226
1234
  for a, r in dom.in_relations.items():
@@ -1233,21 +1241,21 @@ class GraphSplitGpu(GraphSplitByPattern):
1233
1241
  """Fuse rule for injective """
1234
1242
  injective_ops = {"Transpose", "StridedSlice"}
1235
1243
  if dom.dom_op().prim not in injective_ops:
1236
- return []
1244
+ return [], False
1237
1245
  to_ops = dom.dom_op().output.to_ops
1238
1246
  if dom.is_output or len(to_ops) != 1 or len(dom.out_relations) != 1:
1239
- return []
1247
+ return [], False
1240
1248
  to_area = list(dom.out_relations.keys())[0]
1241
1249
  if (to_area.pattern >= PrimLib.REDUCE and to_area.dom_op().prim not in injective_ops) or \
1242
1250
  to_ops[0] not in to_area.ops:
1243
- return []
1251
+ return [], False
1244
1252
  if len(to_area.ops) > self.TRANSPOSE_FUSE_DEPTH:
1245
- return []
1253
+ return [], False
1246
1254
  return [to_area], False
1247
1255
 
1248
1256
  def _h_broadcast(dom, a):
1249
1257
  if dom.pattern > PrimLib.BROADCAST:
1250
- return []
1258
+ return [], False
1251
1259
  return a.pattern <= PrimLib.BROADCAST and dom.ops[0].output.shape == a.ops[0].output.shape
1252
1260
 
1253
1261
  def _h_reduce(dom, a):
@@ -1274,7 +1282,7 @@ class GraphSplitGpu(GraphSplitByPattern):
1274
1282
  fuse_arg = {"CSRReduceSum": slice(1, 3), "CSRGather": slice(2, 3)}
1275
1283
  arg_idx = fuse_arg.get(dom.dom_op().prim, -1)
1276
1284
  if arg_idx == -1:
1277
- return []
1285
+ return [], False
1278
1286
  fuse_tensor = dom.dom_op().inputs[arg_idx]
1279
1287
  for a, _ in dom.in_relations.items():
1280
1288
  if (a.dom_op().prim == "CSRGather" and a.dom_op().prim == dom.dom_op().prim and
@@ -1283,7 +1291,7 @@ class GraphSplitGpu(GraphSplitByPattern):
1283
1291
  if a.pattern <= PrimLib.BROADCAST and dom.check_acyclic(a) and \
1284
1292
  any([op.output in fuse_tensor for op in a.ops]):
1285
1293
  return [a], True
1286
- return []
1294
+ return [], False
1287
1295
 
1288
1296
  def _fuse_loop():
1289
1297
  self.fuse(CommonPattern.reshape)
@@ -50,8 +50,9 @@ def split_with_json(json_str, flags_str):
50
50
  def _load_repository(graph, flags):
51
51
  """Load repository if exists"""
52
52
  def check_repo(op, best_split, op_desc):
53
- if not isinstance(best_split, dict) or "group_num" not in best_split or "graph_mode" not in best_split \
54
- or "split_result" not in best_split:
53
+ if not isinstance(best_split, dict):
54
+ return False
55
+ if "group_num" not in best_split or "graph_mode" not in best_split or "split_result" not in best_split:
55
56
  logger.warning("The graph split repository of {} should be a dict which contains 'group_num', 'graph_mode' "
56
57
  "and 'split_result' field, but got {}".format(op, best_split))
57
58
  return False