mindspore 2.3.0__cp310-cp310-win_amd64.whl → 2.4.0__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/__init__.py +3 -1
- mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +50 -9
- mindspore/_extends/parse/compile_config.py +41 -0
- mindspore/_extends/parse/parser.py +9 -7
- mindspore/_extends/parse/standard_method.py +52 -14
- mindspore/_extends/pijit/pijit_func_white_list.py +350 -24
- mindspore/amp.py +24 -10
- mindspore/atlprov.dll +0 -0
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/common/__init__.py +6 -4
- mindspore/common/_pijit_context.py +190 -0
- mindspore/common/_register_for_tensor.py +2 -1
- mindspore/common/_tensor_overload.py +139 -0
- mindspore/common/api.py +102 -87
- mindspore/common/dump.py +5 -6
- mindspore/common/generator.py +1 -7
- mindspore/common/hook_handle.py +14 -26
- mindspore/common/mindir_util.py +2 -2
- mindspore/common/parameter.py +46 -13
- mindspore/common/recompute.py +39 -9
- mindspore/common/sparse_tensor.py +7 -3
- mindspore/common/tensor.py +209 -29
- mindspore/communication/__init__.py +1 -1
- mindspore/communication/_comm_helper.py +38 -3
- mindspore/communication/comm_func.py +310 -55
- mindspore/communication/management.py +14 -14
- mindspore/context.py +123 -22
- mindspore/dataset/__init__.py +1 -1
- mindspore/dataset/audio/__init__.py +1 -1
- mindspore/dataset/core/config.py +7 -0
- mindspore/dataset/core/validator_helpers.py +7 -0
- mindspore/dataset/engine/cache_client.py +1 -1
- mindspore/dataset/engine/datasets.py +72 -44
- mindspore/dataset/engine/datasets_audio.py +7 -7
- mindspore/dataset/engine/datasets_standard_format.py +53 -3
- mindspore/dataset/engine/datasets_text.py +20 -20
- mindspore/dataset/engine/datasets_user_defined.py +174 -104
- mindspore/dataset/engine/datasets_vision.py +33 -33
- mindspore/dataset/engine/iterators.py +29 -0
- mindspore/dataset/engine/obs/util.py +7 -0
- mindspore/dataset/engine/queue.py +114 -60
- mindspore/dataset/engine/serializer_deserializer.py +2 -2
- mindspore/dataset/engine/validators.py +34 -14
- mindspore/dataset/text/__init__.py +1 -4
- mindspore/dataset/transforms/__init__.py +0 -3
- mindspore/dataset/utils/line_reader.py +2 -0
- mindspore/dataset/vision/__init__.py +1 -4
- mindspore/dataset/vision/utils.py +1 -1
- mindspore/dataset/vision/validators.py +2 -1
- mindspore/dnnl.dll +0 -0
- mindspore/dpcmi.dll +0 -0
- mindspore/{nn/extend → experimental/es}/__init__.py +4 -11
- mindspore/experimental/es/embedding_service.py +883 -0
- mindspore/{nn/layer → experimental/es}/embedding_service_layer.py +218 -30
- mindspore/experimental/llm_boost/__init__.py +21 -0
- mindspore/{nn/extend/layer → experimental/llm_boost/atb}/__init__.py +4 -8
- mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
- mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
- mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
- mindspore/experimental/llm_boost/register.py +129 -0
- mindspore/experimental/llm_boost/utils.py +31 -0
- mindspore/experimental/optim/adamw.py +85 -0
- mindspore/experimental/optim/optimizer.py +3 -0
- mindspore/hal/__init__.py +3 -3
- mindspore/hal/contiguous_tensors_handle.py +175 -0
- mindspore/hal/stream.py +18 -0
- mindspore/include/api/model_group.h +13 -1
- mindspore/include/api/types.h +10 -10
- mindspore/include/dataset/config.h +2 -2
- mindspore/include/dataset/constants.h +2 -2
- mindspore/include/dataset/execute.h +2 -2
- mindspore/include/dataset/vision.h +4 -0
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +1 -1
- mindspore/mindrecord/filewriter.py +68 -51
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/mint/__init__.py +495 -46
- mindspore/mint/distributed/__init__.py +31 -0
- mindspore/mint/distributed/distributed.py +254 -0
- mindspore/mint/nn/__init__.py +266 -21
- mindspore/mint/nn/functional.py +125 -19
- mindspore/mint/nn/layer/__init__.py +39 -0
- mindspore/mint/nn/layer/activation.py +133 -0
- mindspore/mint/nn/layer/normalization.py +477 -0
- mindspore/mint/nn/layer/pooling.py +110 -0
- mindspore/mint/optim/adamw.py +28 -7
- mindspore/mint/special/__init__.py +63 -0
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/multiprocessing/__init__.py +2 -1
- mindspore/nn/__init__.py +0 -1
- mindspore/nn/cell.py +275 -93
- mindspore/nn/layer/activation.py +211 -44
- mindspore/nn/layer/basic.py +113 -3
- mindspore/nn/layer/embedding.py +120 -2
- mindspore/nn/layer/normalization.py +101 -5
- mindspore/nn/layer/padding.py +34 -48
- mindspore/nn/layer/pooling.py +161 -7
- mindspore/nn/layer/transformer.py +3 -3
- mindspore/nn/loss/__init__.py +2 -2
- mindspore/nn/loss/loss.py +84 -6
- mindspore/nn/optim/__init__.py +2 -1
- mindspore/nn/optim/adadelta.py +1 -1
- mindspore/nn/optim/adam.py +1 -1
- mindspore/nn/optim/lamb.py +1 -1
- mindspore/nn/optim/tft_wrapper.py +127 -0
- mindspore/nn/wrap/cell_wrapper.py +12 -23
- mindspore/nn/wrap/grad_reducer.py +5 -5
- mindspore/nn/wrap/loss_scale.py +17 -3
- mindspore/numpy/__init__.py +1 -1
- mindspore/numpy/array_creations.py +65 -68
- mindspore/numpy/array_ops.py +64 -60
- mindspore/numpy/fft.py +610 -75
- mindspore/numpy/logic_ops.py +11 -10
- mindspore/numpy/math_ops.py +85 -84
- mindspore/numpy/utils_const.py +4 -4
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +6 -4
- mindspore/ops/_grad_experimental/grad_comm_ops.py +47 -3
- mindspore/ops/_grad_experimental/grad_math_ops.py +0 -22
- mindspore/ops/_vmap/vmap_array_ops.py +2 -4
- mindspore/ops/_vmap/vmap_math_ops.py +17 -1
- mindspore/ops/_vmap/vmap_nn_ops.py +43 -2
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +85 -7
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +2 -0
- mindspore/ops/auto_generate/gen_extend_func.py +734 -13
- mindspore/ops/auto_generate/gen_ops_def.py +2420 -381
- mindspore/ops/auto_generate/gen_ops_prim.py +5196 -1659
- mindspore/ops/auto_generate/pyboost_inner_prim.py +176 -56
- mindspore/ops/composite/base.py +85 -48
- mindspore/ops/composite/multitype_ops/_compile_utils.py +1 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -2
- mindspore/ops/function/__init__.py +22 -0
- mindspore/ops/function/array_func.py +490 -153
- mindspore/ops/function/debug_func.py +113 -1
- mindspore/ops/function/fft_func.py +15 -2
- mindspore/ops/function/grad/grad_func.py +3 -2
- mindspore/ops/function/math_func.py +558 -207
- mindspore/ops/function/nn_func.py +817 -383
- mindspore/ops/function/other_func.py +3 -2
- mindspore/ops/function/random_func.py +184 -8
- mindspore/ops/function/reshard_func.py +13 -11
- mindspore/ops/function/sparse_unary_func.py +1 -1
- mindspore/ops/function/vmap_func.py +3 -2
- mindspore/ops/functional.py +24 -14
- mindspore/ops/op_info_register.py +3 -3
- mindspore/ops/operations/__init__.py +6 -1
- mindspore/ops/operations/_grad_ops.py +2 -76
- mindspore/ops/operations/_infer_ops.py +1 -1
- mindspore/ops/operations/_inner_ops.py +71 -94
- mindspore/ops/operations/array_ops.py +12 -146
- mindspore/ops/operations/comm_ops.py +42 -53
- mindspore/ops/operations/custom_ops.py +83 -19
- mindspore/ops/operations/debug_ops.py +42 -10
- mindspore/ops/operations/manually_defined/_inner.py +12 -0
- mindspore/ops/operations/manually_defined/ops_def.py +265 -10
- mindspore/ops/operations/math_ops.py +12 -223
- mindspore/ops/operations/nn_ops.py +20 -114
- mindspore/ops/operations/other_ops.py +7 -4
- mindspore/ops/operations/random_ops.py +46 -1
- mindspore/ops/primitive.py +18 -6
- mindspore/ops_generate/arg_dtype_cast.py +2 -0
- mindspore/ops_generate/gen_aclnn_implement.py +11 -11
- mindspore/ops_generate/gen_constants.py +36 -0
- mindspore/ops_generate/gen_ops.py +67 -52
- mindspore/ops_generate/gen_ops_inner_prim.py +1 -1
- mindspore/ops_generate/gen_pyboost_func.py +131 -47
- mindspore/ops_generate/op_proto.py +10 -3
- mindspore/ops_generate/pyboost_utils.py +14 -1
- mindspore/ops_generate/template.py +43 -21
- mindspore/parallel/__init__.py +3 -1
- mindspore/parallel/_auto_parallel_context.py +28 -8
- mindspore/parallel/_cell_wrapper.py +83 -0
- mindspore/parallel/_parallel_serialization.py +47 -19
- mindspore/parallel/_tensor.py +81 -11
- mindspore/parallel/_utils.py +13 -1
- mindspore/parallel/algo_parameter_config.py +5 -5
- mindspore/parallel/checkpoint_transform.py +46 -39
- mindspore/parallel/cluster/process_entity/__init__.py +1 -1
- mindspore/parallel/cluster/process_entity/_api.py +31 -23
- mindspore/parallel/cluster/process_entity/_utils.py +2 -27
- mindspore/parallel/parameter_broadcast.py +3 -4
- mindspore/parallel/shard.py +162 -31
- mindspore/parallel/transform_safetensors.py +993 -0
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/__init__.py +2 -1
- mindspore/profiler/common/constant.py +29 -0
- mindspore/profiler/common/registry.py +47 -0
- mindspore/profiler/common/util.py +28 -0
- mindspore/profiler/dynamic_profiler.py +694 -0
- mindspore/profiler/envprofiling.py +17 -19
- mindspore/profiler/parser/ascend_analysis/constant.py +18 -0
- mindspore/profiler/parser/ascend_analysis/file_manager.py +25 -4
- mindspore/profiler/parser/ascend_analysis/function_event.py +43 -19
- mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +31 -26
- mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +56 -10
- mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +55 -8
- mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
- mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +27 -20
- mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +9 -2
- mindspore/profiler/parser/ascend_msprof_exporter.py +5 -4
- mindspore/profiler/parser/ascend_timeline_generator.py +27 -25
- mindspore/profiler/parser/base_timeline_generator.py +19 -25
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +25 -12
- mindspore/profiler/parser/framework_parser.py +1 -391
- mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
- mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
- mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
- mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
- mindspore/profiler/parser/memory_usage_parser.py +0 -154
- mindspore/profiler/parser/profiler_info.py +78 -6
- mindspore/profiler/profiler.py +153 -0
- mindspore/profiler/profiling.py +280 -412
- mindspore/rewrite/__init__.py +1 -2
- mindspore/rewrite/common/namespace.py +4 -4
- mindspore/rewrite/symbol_tree/symbol_tree.py +3 -3
- mindspore/run_check/_check_version.py +36 -103
- mindspore/safeguard/rewrite_obfuscation.py +591 -247
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +4 -3
- mindspore/train/_utils.py +28 -2
- mindspore/train/amp.py +171 -53
- mindspore/train/callback/__init__.py +2 -2
- mindspore/train/callback/_callback.py +4 -4
- mindspore/train/callback/_checkpoint.py +85 -22
- mindspore/train/callback/_cluster_monitor.py +1 -1
- mindspore/train/callback/_flops_collector.py +1 -0
- mindspore/train/callback/_loss_monitor.py +3 -3
- mindspore/train/callback/_on_request_exit.py +134 -31
- mindspore/train/callback/_summary_collector.py +5 -5
- mindspore/train/callback/_tft_register.py +352 -0
- mindspore/train/dataset_helper.py +7 -3
- mindspore/train/metrics/metric.py +3 -3
- mindspore/train/metrics/roc.py +4 -4
- mindspore/train/mind_ir_pb2.py +44 -39
- mindspore/train/model.py +134 -58
- mindspore/train/serialization.py +336 -112
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +21 -0
- mindspore/utils/utils.py +60 -0
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/METADATA +6 -2
- {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/RECORD +281 -275
- mindspore/include/c_api/ms/abstract.h +0 -67
- mindspore/include/c_api/ms/attribute.h +0 -197
- mindspore/include/c_api/ms/base/handle_types.h +0 -43
- mindspore/include/c_api/ms/base/macros.h +0 -32
- mindspore/include/c_api/ms/base/status.h +0 -33
- mindspore/include/c_api/ms/base/types.h +0 -283
- mindspore/include/c_api/ms/context.h +0 -102
- mindspore/include/c_api/ms/graph.h +0 -160
- mindspore/include/c_api/ms/node.h +0 -606
- mindspore/include/c_api/ms/tensor.h +0 -161
- mindspore/include/c_api/ms/value.h +0 -84
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/nn/extend/basic.py +0 -140
- mindspore/nn/extend/embedding.py +0 -143
- mindspore/nn/extend/layer/normalization.py +0 -109
- mindspore/nn/extend/pooling.py +0 -117
- mindspore/nn/layer/embedding_service.py +0 -531
- mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +0 -93
- mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +0 -66
- mindspore/ops/extend/__init__.py +0 -53
- mindspore/ops/extend/array_func.py +0 -218
- mindspore/ops/extend/math_func.py +0 -76
- mindspore/ops/extend/nn_func.py +0 -308
- mindspore/ops/silent_check.py +0 -162
- mindspore/profiler/parser/msadvisor_analyzer.py +0 -82
- mindspore/profiler/parser/msadvisor_parser.py +0 -240
- mindspore/train/callback/_mindio_ttp.py +0 -443
- {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/WHEEL +0 -0
- {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/entry_points.txt +0 -0
- {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/top_level.txt +0 -0
mindspore/parallel/shard.py
CHANGED
|
@@ -20,7 +20,7 @@ from mindspore import log as logger
|
|
|
20
20
|
from mindspore._c_expression import Shard_
|
|
21
21
|
|
|
22
22
|
|
|
23
|
-
class Layout
|
|
23
|
+
class Layout:
|
|
24
24
|
"""
|
|
25
25
|
Parallel layout describes the detailed sharding information.
|
|
26
26
|
|
|
@@ -34,9 +34,7 @@ class Layout():
|
|
|
34
34
|
device_matrix (tuple): Describe the shape of devices arrangement, its element type is int.
|
|
35
35
|
alias_name (tuple): The alias name for each axis of device_matrix, its length shoits element type is string.
|
|
36
36
|
When using "interleaved_parallel" as an alias name, the tensor would be split into multiple
|
|
37
|
-
copies on the corresponding partition dimension on a single card.
|
|
38
|
-
of "interleaved_parallel" in device_matrix must be 2.
|
|
39
|
-
|
|
37
|
+
copies on the corresponding partition dimension on a single card.
|
|
40
38
|
Raises:
|
|
41
39
|
TypeError: `device_matrix` is not a tuple type.
|
|
42
40
|
TypeError: `alias_name` is not a tuple type.
|
|
@@ -52,7 +50,7 @@ class Layout():
|
|
|
52
50
|
>>> layout = Layout((2, 2, 2), ("dp", "sp", "mp"))
|
|
53
51
|
>>> layout0 = layout("dp", "mp")
|
|
54
52
|
>>> print(layout0.to_dict())
|
|
55
|
-
{"device_matrix": (2, 2, 2), "tensor_map": (2, 0)}
|
|
53
|
+
{"device_matrix": (2, 2, 2), "tensor_map": (2, 0), "interleaved_parallel": False}
|
|
56
54
|
>>> # Total device num is 4, but split the tensor in local device into two copies.
|
|
57
55
|
>>> layout = Layout((2, 2, 2), ("dp", "sp", "interleaved_parallel"))
|
|
58
56
|
>>> layout1 = layout(("dp", "interleaved_parallel"), "sp")
|
|
@@ -81,9 +79,6 @@ class Layout():
|
|
|
81
79
|
if inter_key in alias_name and alias_name.index(inter_key) != len(alias_name) - 1:
|
|
82
80
|
raise ValueError(f"When alias_name {alias_name} contains keyword 'interleaved_parallel',"
|
|
83
81
|
f" it should be at the last dim of alias_name, which means the virtual sharding.")
|
|
84
|
-
if inter_key in alias_name and device_matrix[alias_name.index(inter_key)] != 2:
|
|
85
|
-
raise ValueError(f"When alias_name {alias_name} contains keyword 'interleaved_parallel',"
|
|
86
|
-
f" the corresponding dim of device_matrix should be 2.")
|
|
87
82
|
self._device_shape = device_matrix
|
|
88
83
|
self._alias_name = alias_name
|
|
89
84
|
self._tensor_map = None
|
|
@@ -127,7 +122,7 @@ class Layout():
|
|
|
127
122
|
raise ValueError("The tensor_map of layout is None")
|
|
128
123
|
interleaved_parallel = "interleaved_parallel" in self._alias_name
|
|
129
124
|
return {"device_matrix": self._device_shape, "tensor_map": self._tensor_map,
|
|
130
|
-
"interleaved_parallel": interleaved_parallel}
|
|
125
|
+
"interleaved_parallel": interleaved_parallel, "alias_name": self._alias_name}
|
|
131
126
|
|
|
132
127
|
|
|
133
128
|
|
|
@@ -147,22 +142,32 @@ class Shard(Shard_):
|
|
|
147
142
|
|
|
148
143
|
def __call__(self, fn, in_strategy, out_strategy=None, parameter_plan=None, device="Ascend", level=0):
|
|
149
144
|
parallel_mode = ms.context.get_auto_parallel_context("parallel_mode")
|
|
150
|
-
if parallel_mode not in
|
|
145
|
+
if parallel_mode not in ("auto_parallel", "semi_auto_parallel"):
|
|
151
146
|
raise AssertionError(
|
|
152
147
|
f"Cell shard only supports auto parallel and semi auto parallel.")
|
|
153
|
-
if ms.context.get_context("device_target") not in
|
|
148
|
+
if ms.context.get_context("device_target") not in ("Ascend", "GPU"):
|
|
154
149
|
raise AssertionError(
|
|
155
150
|
f"'Shard' now only supports 'Ascend' and 'GPU'")
|
|
156
151
|
if parallel_mode == "auto_parallel" and \
|
|
157
152
|
ms.context.get_auto_parallel_context("search_mode") != "sharding_propagation":
|
|
158
153
|
raise AssertionError(f"'search_mode' must be 'sharding_propagation' for 'Shard' when the "
|
|
159
154
|
f"'parallel_mode' is 'auto_parallel.'")
|
|
155
|
+
|
|
160
156
|
if not isinstance(in_strategy, tuple):
|
|
161
157
|
raise TypeError(
|
|
162
|
-
f"For 'Shard', the 'in_strategy' should be a tuple, but got {type(in_strategy).__name__}")
|
|
158
|
+
f"For 'Shard', the 'in_strategy' should be a tuple, but got {type(in_strategy).__name__}.")
|
|
159
|
+
inner_type = self._check_layout_inner_type(in_strategy, "in_strategy")
|
|
160
|
+
if inner_type == "layout":
|
|
161
|
+
in_strategy = self._extract_layout_value(in_strategy, "in_strategy")
|
|
162
|
+
|
|
163
163
|
if not isinstance(out_strategy, (type(None), tuple)):
|
|
164
164
|
raise TypeError(f"For 'Shard', the 'out_strategy' should be None or tuple, "
|
|
165
|
-
f"but got {type(out_strategy).__name__}")
|
|
165
|
+
f"but got {type(out_strategy).__name__}.")
|
|
166
|
+
if not isinstance(out_strategy, type(None)):
|
|
167
|
+
logger.warning("Out_strategy is not in use currently, will be ignored in the following procedures.")
|
|
168
|
+
inner_type = self._check_layout_inner_type(out_strategy, "out_strategy")
|
|
169
|
+
if inner_type == "layout":
|
|
170
|
+
out_strategy = self._extract_layout_value(out_strategy, "out_strategy")
|
|
166
171
|
|
|
167
172
|
if not isinstance(device, str):
|
|
168
173
|
raise TypeError(f"For 'Shard', the 'device' should be a string, "
|
|
@@ -238,9 +243,9 @@ class Shard(Shard_):
|
|
|
238
243
|
f"If parameter_plan is set, type of fn must be mindspore.nn.Cell, but got {type(fn)}")
|
|
239
244
|
for k in parameter_plan.keys():
|
|
240
245
|
v = parameter_plan[k]
|
|
241
|
-
if not isinstance(k, str) or not isinstance(v, tuple):
|
|
246
|
+
if not isinstance(k, str) or not isinstance(v, (tuple, Layout)):
|
|
242
247
|
raise TypeError(f"For 'Shard', the type of each key and value in 'parameter_plan' must be str and "
|
|
243
|
-
f"tuple, but got {type(k).__name__} and {type(v).__name__}")
|
|
248
|
+
f"tuple/Layout, but got {type(k).__name__} and {type(v).__name__}")
|
|
244
249
|
else:
|
|
245
250
|
raise TypeError(f"For 'Shard', the 'parameter_plan' should be a dict or None, "
|
|
246
251
|
f"but got {type(parameter_plan).__name__}")
|
|
@@ -253,18 +258,68 @@ class Shard(Shard_):
|
|
|
253
258
|
f"{param_name} is not exist, ignored its setting.")
|
|
254
259
|
continue
|
|
255
260
|
|
|
256
|
-
|
|
257
|
-
param_name, param.shape, param_strategy)
|
|
261
|
+
has_set = None
|
|
258
262
|
if param.param_info.param_strategy:
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
+
has_set = "strategy"
|
|
264
|
+
if param.param_info.device_matrix:
|
|
265
|
+
has_set = "layout"
|
|
266
|
+
if has_set == "strategy":
|
|
267
|
+
logger.warning(f"The layout of parameter '{param_name}' has been set to "
|
|
268
|
+
f"{param.param_info.param_strategy}, current setting will be ignored.")
|
|
269
|
+
elif has_set == "layout":
|
|
270
|
+
logger.warning(f"The layout of parameter '{param_name}' has been set, "
|
|
271
|
+
f"current setting will be ignored.")
|
|
272
|
+
else:
|
|
273
|
+
if isinstance(param_strategy, tuple):
|
|
274
|
+
self._check_layout_is_valid(param_name, param.shape, param_strategy)
|
|
275
|
+
param.param_info.param_strategy = param_strategy
|
|
276
|
+
if isinstance(param_strategy, Layout):
|
|
277
|
+
param_layout = self._extract_layout_value((param_strategy,), "in_strategy")[0]
|
|
278
|
+
param.param_info.device_matrix = param_layout["device_matrix"]
|
|
279
|
+
param.param_info.tensor_map = param_layout["tensor_map"]
|
|
280
|
+
param.param_info.interleaved_parallel = param_layout["interleaved_parallel"]
|
|
281
|
+
param.param_info.alias_name = param_layout["alias_name"]
|
|
263
282
|
|
|
264
283
|
def _is_attrs_has_been_set(self, fn, in_strategy, out_strategy, device, level):
|
|
265
284
|
return self.shard_fn is not None and self.fn == fn and self.in_strategy == in_strategy and \
|
|
266
285
|
self.out_strategy == out_strategy and self.device == device and self.level == level
|
|
267
286
|
|
|
287
|
+
def _check_layout_inner_type(self, strategy, log_info):
|
|
288
|
+
"""Check inner item type of layout, should be int or ms.Layout."""
|
|
289
|
+
strategy_set = set()
|
|
290
|
+
for stra in strategy:
|
|
291
|
+
if not isinstance(stra, (tuple, Layout)):
|
|
292
|
+
raise TypeError(
|
|
293
|
+
f"The '{log_info}' should be a tuple(tuple(int)) or tuple(mindspore.Layout), "
|
|
294
|
+
f"but got {type(stra).__name__}")
|
|
295
|
+
if isinstance(stra, Layout):
|
|
296
|
+
strategy_set.add("layout")
|
|
297
|
+
elif isinstance(stra, tuple):
|
|
298
|
+
strategy_set.add("tuple")
|
|
299
|
+
self._check_tuple_strategy(stra)
|
|
300
|
+
if len(strategy_set) != 1:
|
|
301
|
+
raise TypeError(
|
|
302
|
+
f"For 'Shard', the strategy can only pass in consistent type for all dimensions.")
|
|
303
|
+
return strategy_set.pop()
|
|
304
|
+
|
|
305
|
+
def _extract_layout_value(self, layout, log_info):
|
|
306
|
+
"""Extract parallel layout value"""
|
|
307
|
+
layout_value = None
|
|
308
|
+
if layout is not None:
|
|
309
|
+
if not isinstance(layout, tuple):
|
|
310
|
+
raise TypeError(f'{log_info} must be tuple type, but got:{type(layout)}')
|
|
311
|
+
layout_value = ()
|
|
312
|
+
for in_ele in layout:
|
|
313
|
+
if not isinstance(in_ele, Layout):
|
|
314
|
+
raise TypeError(f"The {log_info} item should be a object of class Layout.")
|
|
315
|
+
layout_value += (in_ele.to_dict(),)
|
|
316
|
+
return layout_value
|
|
317
|
+
|
|
318
|
+
def _check_tuple_strategy(self, dim_strategy):
|
|
319
|
+
if not all(isinstance(x, int) for x in dim_strategy):
|
|
320
|
+
raise TypeError(
|
|
321
|
+
f"The tuple strategy for each dimension should be tuple(int).")
|
|
322
|
+
|
|
268
323
|
|
|
269
324
|
def shard(fn, in_strategy, out_strategy=None, parameter_plan=None, device="Ascend", level=0):
|
|
270
325
|
"""
|
|
@@ -288,15 +343,16 @@ def shard(fn, in_strategy, out_strategy=None, parameter_plan=None, device="Ascen
|
|
|
288
343
|
Its arguments and return value must be Tensor or Parameter.
|
|
289
344
|
If `fn` is a Cell with parameters, `fn` needs to be an instantiated object,
|
|
290
345
|
otherwise its arguments cannot be accessed.
|
|
291
|
-
in_strategy (tuple): Define the layout of inputs, each element of the tuple should be a tuple or
|
|
292
|
-
|
|
293
|
-
|
|
346
|
+
in_strategy (tuple): Define the layout of inputs, each element of the tuple should be a tuple(int) or
|
|
347
|
+
tuple(mindspore.Layout).
|
|
348
|
+
Tuple defines the layout of the corresponding input.
|
|
294
349
|
out_strategy (Union[tuple, None]): Define the layout of outputs similar with `in_strategy`.
|
|
295
350
|
It is not in use right now. Default: ``None`` .
|
|
296
351
|
parameter_plan (Union[dict, None]): Define the layout for the specified parameters. Each element in dict
|
|
297
352
|
defines the layout of the parameter like "param_name: layout".
|
|
298
353
|
The key is a parameter name of type 'str'.
|
|
299
|
-
The value is a 1-D integer tuple
|
|
354
|
+
The value is a 1-D integer tuple or a 1-D mindspore.Layout tuple,
|
|
355
|
+
indicating the corresponding layout.
|
|
300
356
|
If the parameter name is incorrect or the corresponding parameter
|
|
301
357
|
has been set, the parameter setting will be ignored.
|
|
302
358
|
Default: ``None`` .
|
|
@@ -314,9 +370,11 @@ def shard(fn, in_strategy, out_strategy=None, parameter_plan=None, device="Ascen
|
|
|
314
370
|
AssertionError: If device_target it not "Ascend" or "GPU".
|
|
315
371
|
TypeError: If `in_strategy` is not a tuple.
|
|
316
372
|
TypeError: If `out_strategy` is not a tuple or None.
|
|
373
|
+
TypeError: If any element in `in_strategy` is not a tuple(int) or tuple(mindspore.Layout).
|
|
374
|
+
TypeError: If any element in `out_strategy` is not a tuple(int) or tuple(mindspore.Layout).
|
|
317
375
|
TypeError: If `parameter_plan` is not a dict or None.
|
|
318
376
|
TypeError: If any key in `parameter_plan` is not a str.
|
|
319
|
-
TypeError: If any value in `parameter_plan` is not a tuple.
|
|
377
|
+
TypeError: If any value in `parameter_plan` is not a tuple(int) or a tuple(mindspore.Layout).
|
|
320
378
|
TypeError: If `device` is not a str.
|
|
321
379
|
TypeError: If `level` is not an integer.
|
|
322
380
|
|
|
@@ -326,23 +384,96 @@ def shard(fn, in_strategy, out_strategy=None, parameter_plan=None, device="Ascen
|
|
|
326
384
|
Examples:
|
|
327
385
|
>>> import numpy as np
|
|
328
386
|
>>> import mindspore as ms
|
|
329
|
-
>>> from mindspore import Tensor
|
|
387
|
+
>>> from mindspore import Tensor, nn
|
|
330
388
|
>>> from mindspore.communication import init
|
|
331
|
-
>>> ms.set_context(mode=ms.
|
|
389
|
+
>>> ms.set_context(mode=ms.GRAPH_MODE)
|
|
332
390
|
>>> init()
|
|
333
391
|
>>> ms.set_auto_parallel_context(parallel_mode="auto_parallel", search_mode="sharding_propagation",
|
|
334
|
-
...
|
|
392
|
+
... device_num=8)
|
|
393
|
+
>>>
|
|
394
|
+
>>> # Case 1: cell uses functional
|
|
395
|
+
>>> class BasicBlock(nn.Cell):
|
|
396
|
+
>>> def __init__(self):
|
|
397
|
+
>>> super(BasicBlock, self).__init__()
|
|
398
|
+
>>> self.dense1 = nn.Dense(64, 64)
|
|
399
|
+
>>> self.gelu = nn.GELU()
|
|
400
|
+
>>> def my_add(x, y):
|
|
401
|
+
>>> x = ops.abs(x)
|
|
402
|
+
>>> return x + y
|
|
403
|
+
>>> # shard a function with tuple(int) strategies
|
|
404
|
+
>>> self.shard_my_add = ms.shard(my_add, in_strategy=((2, 2), (1, 4)), out_strategy=((4, 1),))
|
|
405
|
+
>>>
|
|
406
|
+
>>> def construct(self, x, u):
|
|
407
|
+
>>> x = self.gelu(x)
|
|
408
|
+
>>> y = self.gelu(u)
|
|
409
|
+
>>> y = x * y
|
|
410
|
+
>>> x = self.dense1(x)
|
|
411
|
+
>>> x = self.shard_my_add(x, y)
|
|
412
|
+
>>> return x
|
|
413
|
+
>>>
|
|
414
|
+
>>> class NetForward(nn.Cell):
|
|
415
|
+
>>> def __init__(self):
|
|
416
|
+
>>> super(NetForward, self).__init__()
|
|
417
|
+
>>> self.block1 = BasicBlock()
|
|
418
|
+
>>> self.block2 = BasicBlock()
|
|
419
|
+
>>> self.matmul = ops.MatMul()
|
|
420
|
+
>>>
|
|
421
|
+
>>> def construct(self, x, y):
|
|
422
|
+
>>> x = self.matmul(x, y)
|
|
423
|
+
>>> x = self.block1(x, x)
|
|
424
|
+
>>> x = self.block2(x, x)
|
|
425
|
+
>>> return x
|
|
426
|
+
>>>
|
|
427
|
+
>>> class Net(nn.Cell):
|
|
428
|
+
>>> def __init__(self):
|
|
429
|
+
>>> super(Net, self).__init__()
|
|
430
|
+
>>> # setting cell sharding strategy and parameter_plan by tuple(int)
|
|
431
|
+
>>> self.layer_net1 = NetForward()
|
|
432
|
+
>>> self.layer_net1_shard = ms.shard(self.layer_net1, in_strategy=((4, 2), (2, 1)),
|
|
433
|
+
... parameter_plan={"self.layer_net1.block1.weight": (4, 1)})
|
|
434
|
+
>>>
|
|
435
|
+
>>> # setting cell sharding strategy and parameter_plan by tuple(ms.Layout)
|
|
436
|
+
>>> self.layer_net2 = NetForward()
|
|
437
|
+
>>> layout = Layout((4, 2, 1), ("dp", "mp", "sp"))
|
|
438
|
+
>>> in_layout = (layout("dp", "mp"), layout("mp", "sp"))
|
|
439
|
+
>>> param_layout = layout("dp", "sp")
|
|
440
|
+
>>> self.layer_net2_shard = ms.shard(self.layer_net2, in_strategy=in_layout,
|
|
441
|
+
... parameter_plan={"self.layer_net2.block2.weight": param_layout})
|
|
442
|
+
>>> self.flatten = nn.Flatten()
|
|
443
|
+
>>> self.layer1 = nn.Dense(64, 64)
|
|
444
|
+
>>> self.layer2 = nn.Dense(64, 32)
|
|
445
|
+
>>> self.add = ops.Add()
|
|
446
|
+
>>> self.matmul = ops.MatMul()
|
|
447
|
+
>>>
|
|
448
|
+
>>> def construct(self, x, y):
|
|
449
|
+
>>> x = self.flatten(x)
|
|
450
|
+
>>> y = self.flatten(y)
|
|
451
|
+
>>> x = self.layer1(x)
|
|
452
|
+
>>> x = self.layer_net1_shard(x, y)
|
|
453
|
+
>>> x = self.layer_net2_shard(x, y)
|
|
454
|
+
>>> x = self.layer2(x)
|
|
455
|
+
>>> x = self.matmul(x, Tensor(np.ones(shape=(32, 32)), dtype=ms.float32))
|
|
456
|
+
>>> return x
|
|
457
|
+
>>>
|
|
458
|
+
>>> net = Net()
|
|
459
|
+
>>> x = Tensor(np.ones(shape=(64, 1, 8, 8)), dtype=ms.float32)
|
|
460
|
+
>>> y = Tensor(np.ones(shape=(64, 1, 8, 8)), dtype=ms.float32)
|
|
461
|
+
>>> net(x, y)
|
|
462
|
+
>>>
|
|
463
|
+
>>> # Case 2: function uses functional sharding
|
|
335
464
|
>>> def test_shard(x, y):
|
|
336
465
|
... return x + y
|
|
337
466
|
>>> x = Tensor(np.ones(shape=(32, 10)), dtype=ms.float32)
|
|
338
467
|
>>> y = Tensor(np.ones(shape=(32, 10)), dtype=ms.float32)
|
|
339
|
-
>>> output = ms.shard(test_shard, in_strategy=((
|
|
468
|
+
>>> output = ms.shard(test_shard, in_strategy=((4, 2), (4, 2)))(x, y)
|
|
340
469
|
>>> print(output.shape)
|
|
341
470
|
(32, 10)
|
|
342
471
|
|
|
343
472
|
Tutorial Examples:
|
|
344
473
|
- `Functional Operator Sharding
|
|
345
|
-
<https://www.mindspore.cn/
|
|
474
|
+
<https://www.mindspore.cn/docs/en/master/model_train/parallel/shard_function_parallel.html>`_
|
|
475
|
+
- `mindspore.Layout
|
|
476
|
+
<https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.Layout.html>`_
|
|
346
477
|
"""
|
|
347
478
|
if not isinstance(fn, (ms.nn.Cell)):
|
|
348
479
|
logger.warning("'fn' is not a mindspore.nn.Cell, and its definition cannot involve Parameter; "
|