mindspore 2.3.0__cp310-cp310-win_amd64.whl → 2.4.1__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (275) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +3 -1
  3. mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
  4. mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
  5. mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
  6. mindspore/_checkparam.py +50 -9
  7. mindspore/_extends/parse/compile_config.py +41 -0
  8. mindspore/_extends/parse/parser.py +9 -7
  9. mindspore/_extends/parse/standard_method.py +52 -14
  10. mindspore/_extends/pijit/pijit_func_white_list.py +350 -24
  11. mindspore/amp.py +24 -10
  12. mindspore/common/__init__.py +6 -4
  13. mindspore/common/_pijit_context.py +190 -0
  14. mindspore/common/_register_for_tensor.py +2 -1
  15. mindspore/common/_tensor_overload.py +139 -0
  16. mindspore/common/api.py +102 -87
  17. mindspore/common/dump.py +5 -6
  18. mindspore/common/generator.py +1 -7
  19. mindspore/common/hook_handle.py +14 -26
  20. mindspore/common/initializer.py +51 -15
  21. mindspore/common/mindir_util.py +2 -2
  22. mindspore/common/parameter.py +62 -15
  23. mindspore/common/recompute.py +39 -9
  24. mindspore/common/sparse_tensor.py +7 -3
  25. mindspore/common/tensor.py +183 -37
  26. mindspore/communication/__init__.py +1 -1
  27. mindspore/communication/_comm_helper.py +38 -3
  28. mindspore/communication/comm_func.py +315 -60
  29. mindspore/communication/management.py +14 -14
  30. mindspore/context.py +132 -22
  31. mindspore/dataset/__init__.py +1 -1
  32. mindspore/dataset/audio/__init__.py +1 -1
  33. mindspore/dataset/core/config.py +7 -0
  34. mindspore/dataset/core/validator_helpers.py +7 -0
  35. mindspore/dataset/engine/cache_client.py +1 -1
  36. mindspore/dataset/engine/datasets.py +72 -44
  37. mindspore/dataset/engine/datasets_audio.py +7 -7
  38. mindspore/dataset/engine/datasets_standard_format.py +53 -3
  39. mindspore/dataset/engine/datasets_text.py +20 -20
  40. mindspore/dataset/engine/datasets_user_defined.py +174 -104
  41. mindspore/dataset/engine/datasets_vision.py +33 -33
  42. mindspore/dataset/engine/iterators.py +29 -0
  43. mindspore/dataset/engine/obs/util.py +7 -0
  44. mindspore/dataset/engine/queue.py +114 -60
  45. mindspore/dataset/engine/serializer_deserializer.py +2 -2
  46. mindspore/dataset/engine/validators.py +34 -14
  47. mindspore/dataset/text/__init__.py +1 -4
  48. mindspore/dataset/transforms/__init__.py +0 -3
  49. mindspore/dataset/utils/line_reader.py +2 -0
  50. mindspore/dataset/vision/__init__.py +1 -4
  51. mindspore/dataset/vision/utils.py +1 -1
  52. mindspore/dataset/vision/validators.py +2 -1
  53. mindspore/{nn/extend → experimental/es}/__init__.py +4 -11
  54. mindspore/experimental/es/embedding_service.py +883 -0
  55. mindspore/{nn/layer → experimental/es}/embedding_service_layer.py +218 -30
  56. mindspore/experimental/llm_boost/__init__.py +21 -0
  57. mindspore/{nn/extend/layer → experimental/llm_boost/atb}/__init__.py +4 -8
  58. mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
  59. mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
  60. mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
  61. mindspore/experimental/llm_boost/register.py +129 -0
  62. mindspore/experimental/llm_boost/utils.py +31 -0
  63. mindspore/experimental/optim/adamw.py +85 -0
  64. mindspore/experimental/optim/optimizer.py +3 -0
  65. mindspore/hal/__init__.py +3 -3
  66. mindspore/hal/contiguous_tensors_handle.py +175 -0
  67. mindspore/hal/stream.py +18 -0
  68. mindspore/include/api/model_group.h +13 -1
  69. mindspore/include/api/types.h +10 -10
  70. mindspore/include/dataset/config.h +2 -2
  71. mindspore/include/dataset/constants.h +2 -2
  72. mindspore/include/dataset/execute.h +2 -2
  73. mindspore/include/dataset/vision.h +4 -0
  74. mindspore/log.py +1 -1
  75. mindspore/mindrecord/filewriter.py +68 -51
  76. mindspore/mindspore_backend.dll +0 -0
  77. mindspore/mindspore_common.dll +0 -0
  78. mindspore/mindspore_core.dll +0 -0
  79. mindspore/mindspore_np_dtype.dll +0 -0
  80. mindspore/mindspore_ops.dll +0 -0
  81. mindspore/mint/__init__.py +983 -46
  82. mindspore/mint/distributed/__init__.py +31 -0
  83. mindspore/mint/distributed/distributed.py +254 -0
  84. mindspore/mint/nn/__init__.py +268 -23
  85. mindspore/mint/nn/functional.py +125 -19
  86. mindspore/mint/nn/layer/__init__.py +39 -0
  87. mindspore/mint/nn/layer/activation.py +133 -0
  88. mindspore/mint/nn/layer/normalization.py +477 -0
  89. mindspore/mint/nn/layer/pooling.py +110 -0
  90. mindspore/mint/optim/adamw.py +26 -13
  91. mindspore/mint/special/__init__.py +63 -0
  92. mindspore/multiprocessing/__init__.py +2 -1
  93. mindspore/nn/__init__.py +0 -1
  94. mindspore/nn/cell.py +276 -96
  95. mindspore/nn/layer/activation.py +211 -44
  96. mindspore/nn/layer/basic.py +137 -10
  97. mindspore/nn/layer/embedding.py +137 -2
  98. mindspore/nn/layer/normalization.py +101 -5
  99. mindspore/nn/layer/padding.py +34 -48
  100. mindspore/nn/layer/pooling.py +161 -7
  101. mindspore/nn/layer/transformer.py +3 -3
  102. mindspore/nn/loss/__init__.py +2 -2
  103. mindspore/nn/loss/loss.py +84 -6
  104. mindspore/nn/optim/__init__.py +2 -1
  105. mindspore/nn/optim/adadelta.py +1 -1
  106. mindspore/nn/optim/adam.py +1 -1
  107. mindspore/nn/optim/lamb.py +1 -1
  108. mindspore/nn/optim/tft_wrapper.py +124 -0
  109. mindspore/nn/wrap/cell_wrapper.py +12 -23
  110. mindspore/nn/wrap/grad_reducer.py +5 -5
  111. mindspore/nn/wrap/loss_scale.py +17 -3
  112. mindspore/numpy/__init__.py +1 -1
  113. mindspore/numpy/array_creations.py +65 -68
  114. mindspore/numpy/array_ops.py +64 -60
  115. mindspore/numpy/fft.py +610 -75
  116. mindspore/numpy/logic_ops.py +11 -10
  117. mindspore/numpy/math_ops.py +85 -84
  118. mindspore/numpy/utils_const.py +4 -4
  119. mindspore/opencv_core452.dll +0 -0
  120. mindspore/opencv_imgcodecs452.dll +0 -0
  121. mindspore/opencv_imgproc452.dll +0 -0
  122. mindspore/ops/__init__.py +6 -4
  123. mindspore/ops/_grad_experimental/grad_array_ops.py +0 -11
  124. mindspore/ops/_grad_experimental/grad_comm_ops.py +67 -4
  125. mindspore/ops/_grad_experimental/grad_math_ops.py +0 -22
  126. mindspore/ops/_vmap/vmap_array_ops.py +2 -4
  127. mindspore/ops/_vmap/vmap_math_ops.py +17 -1
  128. mindspore/ops/_vmap/vmap_nn_ops.py +43 -2
  129. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +91 -7
  130. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +2 -0
  131. mindspore/ops/auto_generate/gen_extend_func.py +767 -13
  132. mindspore/ops/auto_generate/gen_ops_def.py +2452 -364
  133. mindspore/ops/auto_generate/gen_ops_prim.py +5442 -1756
  134. mindspore/ops/auto_generate/pyboost_inner_prim.py +176 -56
  135. mindspore/ops/composite/base.py +85 -48
  136. mindspore/ops/composite/multitype_ops/_compile_utils.py +1 -0
  137. mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -2
  138. mindspore/ops/function/__init__.py +22 -0
  139. mindspore/ops/function/array_func.py +492 -153
  140. mindspore/ops/function/debug_func.py +113 -1
  141. mindspore/ops/function/fft_func.py +15 -2
  142. mindspore/ops/function/grad/grad_func.py +3 -2
  143. mindspore/ops/function/math_func.py +564 -207
  144. mindspore/ops/function/nn_func.py +817 -383
  145. mindspore/ops/function/other_func.py +3 -2
  146. mindspore/ops/function/random_func.py +402 -12
  147. mindspore/ops/function/reshard_func.py +13 -11
  148. mindspore/ops/function/sparse_unary_func.py +1 -1
  149. mindspore/ops/function/vmap_func.py +3 -2
  150. mindspore/ops/functional.py +24 -14
  151. mindspore/ops/op_info_register.py +3 -3
  152. mindspore/ops/operations/__init__.py +7 -2
  153. mindspore/ops/operations/_grad_ops.py +2 -76
  154. mindspore/ops/operations/_infer_ops.py +1 -1
  155. mindspore/ops/operations/_inner_ops.py +71 -94
  156. mindspore/ops/operations/array_ops.py +14 -146
  157. mindspore/ops/operations/comm_ops.py +63 -53
  158. mindspore/ops/operations/custom_ops.py +83 -19
  159. mindspore/ops/operations/debug_ops.py +42 -10
  160. mindspore/ops/operations/manually_defined/_inner.py +12 -0
  161. mindspore/ops/operations/manually_defined/ops_def.py +273 -20
  162. mindspore/ops/operations/math_ops.py +12 -223
  163. mindspore/ops/operations/nn_ops.py +20 -114
  164. mindspore/ops/operations/other_ops.py +7 -4
  165. mindspore/ops/operations/random_ops.py +46 -1
  166. mindspore/ops/primitive.py +18 -6
  167. mindspore/ops_generate/arg_dtype_cast.py +2 -0
  168. mindspore/ops_generate/gen_aclnn_implement.py +11 -11
  169. mindspore/ops_generate/gen_constants.py +36 -0
  170. mindspore/ops_generate/gen_ops.py +67 -52
  171. mindspore/ops_generate/gen_ops_inner_prim.py +1 -1
  172. mindspore/ops_generate/gen_pyboost_func.py +131 -47
  173. mindspore/ops_generate/op_proto.py +10 -3
  174. mindspore/ops_generate/pyboost_utils.py +14 -1
  175. mindspore/ops_generate/template.py +43 -21
  176. mindspore/parallel/__init__.py +3 -1
  177. mindspore/parallel/_auto_parallel_context.py +31 -9
  178. mindspore/parallel/_cell_wrapper.py +85 -0
  179. mindspore/parallel/_parallel_serialization.py +47 -19
  180. mindspore/parallel/_tensor.py +127 -13
  181. mindspore/parallel/_utils.py +53 -22
  182. mindspore/parallel/algo_parameter_config.py +5 -5
  183. mindspore/parallel/checkpoint_transform.py +46 -39
  184. mindspore/parallel/cluster/process_entity/__init__.py +1 -1
  185. mindspore/parallel/cluster/process_entity/_api.py +31 -23
  186. mindspore/parallel/cluster/process_entity/_utils.py +2 -27
  187. mindspore/parallel/parameter_broadcast.py +3 -4
  188. mindspore/parallel/shard.py +162 -31
  189. mindspore/parallel/transform_safetensors.py +1146 -0
  190. mindspore/profiler/__init__.py +2 -1
  191. mindspore/profiler/common/constant.py +29 -0
  192. mindspore/profiler/common/registry.py +47 -0
  193. mindspore/profiler/common/util.py +28 -0
  194. mindspore/profiler/dynamic_profiler.py +694 -0
  195. mindspore/profiler/envprofiling.py +17 -19
  196. mindspore/profiler/parser/ascend_analysis/constant.py +18 -0
  197. mindspore/profiler/parser/ascend_analysis/file_manager.py +25 -4
  198. mindspore/profiler/parser/ascend_analysis/function_event.py +43 -19
  199. mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +31 -26
  200. mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +56 -10
  201. mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +55 -8
  202. mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
  203. mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +27 -20
  204. mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +9 -2
  205. mindspore/profiler/parser/ascend_msprof_exporter.py +5 -4
  206. mindspore/profiler/parser/ascend_timeline_generator.py +27 -25
  207. mindspore/profiler/parser/base_timeline_generator.py +19 -25
  208. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +25 -12
  209. mindspore/profiler/parser/framework_parser.py +1 -391
  210. mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
  211. mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
  212. mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
  213. mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
  214. mindspore/profiler/parser/memory_usage_parser.py +0 -154
  215. mindspore/profiler/parser/profiler_info.py +78 -6
  216. mindspore/profiler/profiler.py +153 -0
  217. mindspore/profiler/profiling.py +285 -413
  218. mindspore/rewrite/__init__.py +1 -2
  219. mindspore/rewrite/common/namespace.py +4 -4
  220. mindspore/rewrite/symbol_tree/symbol_tree.py +3 -3
  221. mindspore/run_check/_check_version.py +39 -104
  222. mindspore/safeguard/rewrite_obfuscation.py +591 -247
  223. mindspore/train/__init__.py +4 -3
  224. mindspore/train/_utils.py +105 -19
  225. mindspore/train/amp.py +171 -53
  226. mindspore/train/callback/__init__.py +2 -2
  227. mindspore/train/callback/_callback.py +4 -4
  228. mindspore/train/callback/_checkpoint.py +97 -31
  229. mindspore/train/callback/_cluster_monitor.py +1 -1
  230. mindspore/train/callback/_flops_collector.py +1 -0
  231. mindspore/train/callback/_loss_monitor.py +3 -3
  232. mindspore/train/callback/_on_request_exit.py +145 -31
  233. mindspore/train/callback/_summary_collector.py +5 -5
  234. mindspore/train/callback/_tft_register.py +375 -0
  235. mindspore/train/dataset_helper.py +15 -3
  236. mindspore/train/metrics/metric.py +3 -3
  237. mindspore/train/metrics/roc.py +4 -4
  238. mindspore/train/mind_ir_pb2.py +44 -39
  239. mindspore/train/model.py +154 -58
  240. mindspore/train/serialization.py +342 -128
  241. mindspore/utils/__init__.py +21 -0
  242. mindspore/utils/utils.py +60 -0
  243. mindspore/version.py +1 -1
  244. {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/METADATA +13 -7
  245. {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/RECORD +248 -242
  246. mindspore/include/c_api/ms/abstract.h +0 -67
  247. mindspore/include/c_api/ms/attribute.h +0 -197
  248. mindspore/include/c_api/ms/base/handle_types.h +0 -43
  249. mindspore/include/c_api/ms/base/macros.h +0 -32
  250. mindspore/include/c_api/ms/base/status.h +0 -33
  251. mindspore/include/c_api/ms/base/types.h +0 -283
  252. mindspore/include/c_api/ms/context.h +0 -102
  253. mindspore/include/c_api/ms/graph.h +0 -160
  254. mindspore/include/c_api/ms/node.h +0 -606
  255. mindspore/include/c_api/ms/tensor.h +0 -161
  256. mindspore/include/c_api/ms/value.h +0 -84
  257. mindspore/mindspore_shared_lib.dll +0 -0
  258. mindspore/nn/extend/basic.py +0 -140
  259. mindspore/nn/extend/embedding.py +0 -143
  260. mindspore/nn/extend/layer/normalization.py +0 -109
  261. mindspore/nn/extend/pooling.py +0 -117
  262. mindspore/nn/layer/embedding_service.py +0 -531
  263. mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +0 -93
  264. mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +0 -66
  265. mindspore/ops/extend/__init__.py +0 -53
  266. mindspore/ops/extend/array_func.py +0 -218
  267. mindspore/ops/extend/math_func.py +0 -76
  268. mindspore/ops/extend/nn_func.py +0 -308
  269. mindspore/ops/silent_check.py +0 -162
  270. mindspore/profiler/parser/msadvisor_analyzer.py +0 -82
  271. mindspore/profiler/parser/msadvisor_parser.py +0 -240
  272. mindspore/train/callback/_mindio_ttp.py +0 -443
  273. {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/WHEEL +0 -0
  274. {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/entry_points.txt +0 -0
  275. {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/top_level.txt +0 -0
@@ -20,7 +20,7 @@ from mindspore import log as logger
20
20
  from mindspore._c_expression import Shard_
21
21
 
22
22
 
23
- class Layout():
23
+ class Layout:
24
24
  """
25
25
  Parallel layout describes the detailed sharding information.
26
26
 
@@ -34,9 +34,7 @@ class Layout():
34
34
  device_matrix (tuple): Describe the shape of devices arrangement, its element type is int.
35
35
  alias_name (tuple): The alias name for each axis of device_matrix, its length shoits element type is string.
36
36
  When using "interleaved_parallel" as an alias name, the tensor would be split into multiple
37
- copies on the corresponding partition dimension on a single card. The corresponding value
38
- of "interleaved_parallel" in device_matrix must be 2.
39
-
37
+ copies on the corresponding partition dimension on a single card.
40
38
  Raises:
41
39
  TypeError: `device_matrix` is not a tuple type.
42
40
  TypeError: `alias_name` is not a tuple type.
@@ -52,7 +50,7 @@ class Layout():
52
50
  >>> layout = Layout((2, 2, 2), ("dp", "sp", "mp"))
53
51
  >>> layout0 = layout("dp", "mp")
54
52
  >>> print(layout0.to_dict())
55
- {"device_matrix": (2, 2, 2), "tensor_map": (2, 0)}
53
+ {"device_matrix": (2, 2, 2), "tensor_map": (2, 0), "interleaved_parallel": False}
56
54
  >>> # Total device num is 4, but split the tensor in local device into two copies.
57
55
  >>> layout = Layout((2, 2, 2), ("dp", "sp", "interleaved_parallel"))
58
56
  >>> layout1 = layout(("dp", "interleaved_parallel"), "sp")
@@ -81,9 +79,6 @@ class Layout():
81
79
  if inter_key in alias_name and alias_name.index(inter_key) != len(alias_name) - 1:
82
80
  raise ValueError(f"When alias_name {alias_name} contains keyword 'interleaved_parallel',"
83
81
  f" it should be at the last dim of alias_name, which means the virtual sharding.")
84
- if inter_key in alias_name and device_matrix[alias_name.index(inter_key)] != 2:
85
- raise ValueError(f"When alias_name {alias_name} contains keyword 'interleaved_parallel',"
86
- f" the corresponding dim of device_matrix should be 2.")
87
82
  self._device_shape = device_matrix
88
83
  self._alias_name = alias_name
89
84
  self._tensor_map = None
@@ -127,7 +122,7 @@ class Layout():
127
122
  raise ValueError("The tensor_map of layout is None")
128
123
  interleaved_parallel = "interleaved_parallel" in self._alias_name
129
124
  return {"device_matrix": self._device_shape, "tensor_map": self._tensor_map,
130
- "interleaved_parallel": interleaved_parallel}
125
+ "interleaved_parallel": interleaved_parallel, "alias_name": self._alias_name}
131
126
 
132
127
 
133
128
 
@@ -147,22 +142,32 @@ class Shard(Shard_):
147
142
 
148
143
  def __call__(self, fn, in_strategy, out_strategy=None, parameter_plan=None, device="Ascend", level=0):
149
144
  parallel_mode = ms.context.get_auto_parallel_context("parallel_mode")
150
- if parallel_mode not in ["auto_parallel", "semi_auto_parallel"]:
145
+ if parallel_mode not in ("auto_parallel", "semi_auto_parallel"):
151
146
  raise AssertionError(
152
147
  f"Cell shard only supports auto parallel and semi auto parallel.")
153
- if ms.context.get_context("device_target") not in ["Ascend", "GPU"]:
148
+ if ms.context.get_context("device_target") not in ("Ascend", "GPU"):
154
149
  raise AssertionError(
155
150
  f"'Shard' now only supports 'Ascend' and 'GPU'")
156
151
  if parallel_mode == "auto_parallel" and \
157
152
  ms.context.get_auto_parallel_context("search_mode") != "sharding_propagation":
158
153
  raise AssertionError(f"'search_mode' must be 'sharding_propagation' for 'Shard' when the "
159
154
  f"'parallel_mode' is 'auto_parallel.'")
155
+
160
156
  if not isinstance(in_strategy, tuple):
161
157
  raise TypeError(
162
- f"For 'Shard', the 'in_strategy' should be a tuple, but got {type(in_strategy).__name__}")
158
+ f"For 'Shard', the 'in_strategy' should be a tuple, but got {type(in_strategy).__name__}.")
159
+ inner_type = self._check_layout_inner_type(in_strategy, "in_strategy")
160
+ if inner_type == "layout":
161
+ in_strategy = self._extract_layout_value(in_strategy, "in_strategy")
162
+
163
163
  if not isinstance(out_strategy, (type(None), tuple)):
164
164
  raise TypeError(f"For 'Shard', the 'out_strategy' should be None or tuple, "
165
- f"but got {type(out_strategy).__name__}")
165
+ f"but got {type(out_strategy).__name__}.")
166
+ if not isinstance(out_strategy, type(None)):
167
+ logger.warning("Out_strategy is not in use currently, will be ignored in the following procedures.")
168
+ inner_type = self._check_layout_inner_type(out_strategy, "out_strategy")
169
+ if inner_type == "layout":
170
+ out_strategy = self._extract_layout_value(out_strategy, "out_strategy")
166
171
 
167
172
  if not isinstance(device, str):
168
173
  raise TypeError(f"For 'Shard', the 'device' should be a string, "
@@ -238,9 +243,9 @@ class Shard(Shard_):
238
243
  f"If parameter_plan is set, type of fn must be mindspore.nn.Cell, but got {type(fn)}")
239
244
  for k in parameter_plan.keys():
240
245
  v = parameter_plan[k]
241
- if not isinstance(k, str) or not isinstance(v, tuple):
246
+ if not isinstance(k, str) or not isinstance(v, (tuple, Layout)):
242
247
  raise TypeError(f"For 'Shard', the type of each key and value in 'parameter_plan' must be str and "
243
- f"tuple, but got {type(k).__name__} and {type(v).__name__}")
248
+ f"tuple/Layout, but got {type(k).__name__} and {type(v).__name__}")
244
249
  else:
245
250
  raise TypeError(f"For 'Shard', the 'parameter_plan' should be a dict or None, "
246
251
  f"but got {type(parameter_plan).__name__}")
@@ -253,18 +258,68 @@ class Shard(Shard_):
253
258
  f"{param_name} is not exist, ignored its setting.")
254
259
  continue
255
260
 
256
- self._check_layout_is_valid(
257
- param_name, param.shape, param_strategy)
261
+ has_set = None
258
262
  if param.param_info.param_strategy:
259
- logger.warning(f"The layout of parameter '{param_name}' "
260
- f"has been set to {param.param_info.param_strategy}, "
261
- f"current setting {param_strategy} will be ignored.")
262
- param.param_info.param_strategy = param_strategy
263
+ has_set = "strategy"
264
+ if param.param_info.device_matrix:
265
+ has_set = "layout"
266
+ if has_set == "strategy":
267
+ logger.warning(f"The layout of parameter '{param_name}' has been set to "
268
+ f"{param.param_info.param_strategy}, current setting will be ignored.")
269
+ elif has_set == "layout":
270
+ logger.warning(f"The layout of parameter '{param_name}' has been set, "
271
+ f"current setting will be ignored.")
272
+ else:
273
+ if isinstance(param_strategy, tuple):
274
+ self._check_layout_is_valid(param_name, param.shape, param_strategy)
275
+ param.param_info.param_strategy = param_strategy
276
+ if isinstance(param_strategy, Layout):
277
+ param_layout = self._extract_layout_value((param_strategy,), "in_strategy")[0]
278
+ param.param_info.device_matrix = param_layout["device_matrix"]
279
+ param.param_info.tensor_map = param_layout["tensor_map"]
280
+ param.param_info.interleaved_parallel = param_layout["interleaved_parallel"]
281
+ param.param_info.alias_name = param_layout["alias_name"]
263
282
 
264
283
  def _is_attrs_has_been_set(self, fn, in_strategy, out_strategy, device, level):
265
284
  return self.shard_fn is not None and self.fn == fn and self.in_strategy == in_strategy and \
266
285
  self.out_strategy == out_strategy and self.device == device and self.level == level
267
286
 
287
+ def _check_layout_inner_type(self, strategy, log_info):
288
+ """Check inner item type of layout, should be int or ms.Layout."""
289
+ strategy_set = set()
290
+ for stra in strategy:
291
+ if not isinstance(stra, (tuple, Layout)):
292
+ raise TypeError(
293
+ f"The '{log_info}' should be a tuple(tuple(int)) or tuple(mindspore.Layout), "
294
+ f"but got {type(stra).__name__}")
295
+ if isinstance(stra, Layout):
296
+ strategy_set.add("layout")
297
+ elif isinstance(stra, tuple):
298
+ strategy_set.add("tuple")
299
+ self._check_tuple_strategy(stra)
300
+ if len(strategy_set) != 1:
301
+ raise TypeError(
302
+ f"For 'Shard', the strategy can only pass in consistent type for all dimensions.")
303
+ return strategy_set.pop()
304
+
305
+ def _extract_layout_value(self, layout, log_info):
306
+ """Extract parallel layout value"""
307
+ layout_value = None
308
+ if layout is not None:
309
+ if not isinstance(layout, tuple):
310
+ raise TypeError(f'{log_info} must be tuple type, but got:{type(layout)}')
311
+ layout_value = ()
312
+ for in_ele in layout:
313
+ if not isinstance(in_ele, Layout):
314
+ raise TypeError(f"The {log_info} item should be a object of class Layout.")
315
+ layout_value += (in_ele.to_dict(),)
316
+ return layout_value
317
+
318
+ def _check_tuple_strategy(self, dim_strategy):
319
+ if not all(isinstance(x, int) for x in dim_strategy):
320
+ raise TypeError(
321
+ f"The tuple strategy for each dimension should be tuple(int).")
322
+
268
323
 
269
324
  def shard(fn, in_strategy, out_strategy=None, parameter_plan=None, device="Ascend", level=0):
270
325
  """
@@ -288,15 +343,16 @@ def shard(fn, in_strategy, out_strategy=None, parameter_plan=None, device="Ascen
288
343
  Its arguments and return value must be Tensor or Parameter.
289
344
  If `fn` is a Cell with parameters, `fn` needs to be an instantiated object,
290
345
  otherwise its arguments cannot be accessed.
291
- in_strategy (tuple): Define the layout of inputs, each element of the tuple should be a tuple or None.
292
- Tuple defines the layout of the corresponding input
293
- and None represents a data parallel strategy.
346
+ in_strategy (tuple): Define the layout of inputs, each element of the tuple should be a tuple(int) or
347
+ tuple(mindspore.Layout).
348
+ Tuple defines the layout of the corresponding input.
294
349
  out_strategy (Union[tuple, None]): Define the layout of outputs similar with `in_strategy`.
295
350
  It is not in use right now. Default: ``None`` .
296
351
  parameter_plan (Union[dict, None]): Define the layout for the specified parameters. Each element in dict
297
352
  defines the layout of the parameter like "param_name: layout".
298
353
  The key is a parameter name of type 'str'.
299
- The value is a 1-D integer tuple, indicating the corresponding layout.
354
+ The value is a 1-D integer tuple or a 1-D mindspore.Layout tuple,
355
+ indicating the corresponding layout.
300
356
  If the parameter name is incorrect or the corresponding parameter
301
357
  has been set, the parameter setting will be ignored.
302
358
  Default: ``None`` .
@@ -314,9 +370,11 @@ def shard(fn, in_strategy, out_strategy=None, parameter_plan=None, device="Ascen
314
370
  AssertionError: If device_target it not "Ascend" or "GPU".
315
371
  TypeError: If `in_strategy` is not a tuple.
316
372
  TypeError: If `out_strategy` is not a tuple or None.
373
+ TypeError: If any element in `in_strategy` is not a tuple(int) or tuple(mindspore.Layout).
374
+ TypeError: If any element in `out_strategy` is not a tuple(int) or tuple(mindspore.Layout).
317
375
  TypeError: If `parameter_plan` is not a dict or None.
318
376
  TypeError: If any key in `parameter_plan` is not a str.
319
- TypeError: If any value in `parameter_plan` is not a tuple.
377
+ TypeError: If any value in `parameter_plan` is not a tuple(int) or a tuple(mindspore.Layout).
320
378
  TypeError: If `device` is not a str.
321
379
  TypeError: If `level` is not an integer.
322
380
 
@@ -326,23 +384,96 @@ def shard(fn, in_strategy, out_strategy=None, parameter_plan=None, device="Ascen
326
384
  Examples:
327
385
  >>> import numpy as np
328
386
  >>> import mindspore as ms
329
- >>> from mindspore import Tensor
387
+ >>> from mindspore import Tensor, nn
330
388
  >>> from mindspore.communication import init
331
- >>> ms.set_context(mode=ms.PYNATIVE_MODE)
389
+ >>> ms.set_context(mode=ms.GRAPH_MODE)
332
390
  >>> init()
333
391
  >>> ms.set_auto_parallel_context(parallel_mode="auto_parallel", search_mode="sharding_propagation",
334
- ... device_num=2)
392
+ ... device_num=8)
393
+ >>>
394
+ >>> # Case 1: cell uses functional
395
+ >>> class BasicBlock(nn.Cell):
396
+ >>> def __init__(self):
397
+ >>> super(BasicBlock, self).__init__()
398
+ >>> self.dense1 = nn.Dense(64, 64)
399
+ >>> self.gelu = nn.GELU()
400
+ >>> def my_add(x, y):
401
+ >>> x = ops.abs(x)
402
+ >>> return x + y
403
+ >>> # shard a function with tuple(int) strategies
404
+ >>> self.shard_my_add = ms.shard(my_add, in_strategy=((2, 2), (1, 4)), out_strategy=((4, 1),))
405
+ >>>
406
+ >>> def construct(self, x, u):
407
+ >>> x = self.gelu(x)
408
+ >>> y = self.gelu(u)
409
+ >>> y = x * y
410
+ >>> x = self.dense1(x)
411
+ >>> x = self.shard_my_add(x, y)
412
+ >>> return x
413
+ >>>
414
+ >>> class NetForward(nn.Cell):
415
+ >>> def __init__(self):
416
+ >>> super(NetForward, self).__init__()
417
+ >>> self.block1 = BasicBlock()
418
+ >>> self.block2 = BasicBlock()
419
+ >>> self.matmul = ops.MatMul()
420
+ >>>
421
+ >>> def construct(self, x, y):
422
+ >>> x = self.matmul(x, y)
423
+ >>> x = self.block1(x, x)
424
+ >>> x = self.block2(x, x)
425
+ >>> return x
426
+ >>>
427
+ >>> class Net(nn.Cell):
428
+ >>> def __init__(self):
429
+ >>> super(Net, self).__init__()
430
+ >>> # setting cell sharding strategy and parameter_plan by tuple(int)
431
+ >>> self.layer_net1 = NetForward()
432
+ >>> self.layer_net1_shard = ms.shard(self.layer_net1, in_strategy=((4, 2), (2, 1)),
433
+ ... parameter_plan={"self.layer_net1.block1.weight": (4, 1)})
434
+ >>>
435
+ >>> # setting cell sharding strategy and parameter_plan by tuple(ms.Layout)
436
+ >>> self.layer_net2 = NetForward()
437
+ >>> layout = Layout((4, 2, 1), ("dp", "mp", "sp"))
438
+ >>> in_layout = (layout("dp", "mp"), layout("mp", "sp"))
439
+ >>> param_layout = layout("dp", "sp")
440
+ >>> self.layer_net2_shard = ms.shard(self.layer_net2, in_strategy=in_layout,
441
+ ... parameter_plan={"self.layer_net2.block2.weight": param_layout})
442
+ >>> self.flatten = nn.Flatten()
443
+ >>> self.layer1 = nn.Dense(64, 64)
444
+ >>> self.layer2 = nn.Dense(64, 32)
445
+ >>> self.add = ops.Add()
446
+ >>> self.matmul = ops.MatMul()
447
+ >>>
448
+ >>> def construct(self, x, y):
449
+ >>> x = self.flatten(x)
450
+ >>> y = self.flatten(y)
451
+ >>> x = self.layer1(x)
452
+ >>> x = self.layer_net1_shard(x, y)
453
+ >>> x = self.layer_net2_shard(x, y)
454
+ >>> x = self.layer2(x)
455
+ >>> x = self.matmul(x, Tensor(np.ones(shape=(32, 32)), dtype=ms.float32))
456
+ >>> return x
457
+ >>>
458
+ >>> net = Net()
459
+ >>> x = Tensor(np.ones(shape=(64, 1, 8, 8)), dtype=ms.float32)
460
+ >>> y = Tensor(np.ones(shape=(64, 1, 8, 8)), dtype=ms.float32)
461
+ >>> net(x, y)
462
+ >>>
463
+ >>> # Case 2: function uses functional sharding
335
464
  >>> def test_shard(x, y):
336
465
  ... return x + y
337
466
  >>> x = Tensor(np.ones(shape=(32, 10)), dtype=ms.float32)
338
467
  >>> y = Tensor(np.ones(shape=(32, 10)), dtype=ms.float32)
339
- >>> output = ms.shard(test_shard, in_strategy=((2, 1), (2, 1)))(x, y)
468
+ >>> output = ms.shard(test_shard, in_strategy=((4, 2), (4, 2)))(x, y)
340
469
  >>> print(output.shape)
341
470
  (32, 10)
342
471
 
343
472
  Tutorial Examples:
344
473
  - `Functional Operator Sharding
345
- <https://www.mindspore.cn/tutorials/experts/en/master/parallel/shard_function_parallel.html>`_
474
+ <https://www.mindspore.cn/docs/en/master/model_train/parallel/shard_function_parallel.html>`_
475
+ - `mindspore.Layout
476
+ <https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.Layout.html>`_
346
477
  """
347
478
  if not isinstance(fn, (ms.nn.Cell)):
348
479
  logger.warning("'fn' is not a mindspore.nn.Cell, and its definition cannot involve Parameter; "