mindspore 2.3.0__cp39-cp39-win_amd64.whl → 2.4.1__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +3 -1
- mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +50 -9
- mindspore/_extends/parse/compile_config.py +41 -0
- mindspore/_extends/parse/parser.py +9 -7
- mindspore/_extends/parse/standard_method.py +52 -14
- mindspore/_extends/pijit/pijit_func_white_list.py +350 -24
- mindspore/amp.py +24 -10
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/common/__init__.py +6 -4
- mindspore/common/_pijit_context.py +190 -0
- mindspore/common/_register_for_tensor.py +2 -1
- mindspore/common/_tensor_overload.py +139 -0
- mindspore/common/api.py +102 -87
- mindspore/common/dump.py +5 -6
- mindspore/common/generator.py +1 -7
- mindspore/common/hook_handle.py +14 -26
- mindspore/common/initializer.py +51 -15
- mindspore/common/mindir_util.py +2 -2
- mindspore/common/parameter.py +62 -15
- mindspore/common/recompute.py +39 -9
- mindspore/common/sparse_tensor.py +7 -3
- mindspore/common/tensor.py +183 -37
- mindspore/communication/__init__.py +1 -1
- mindspore/communication/_comm_helper.py +38 -3
- mindspore/communication/comm_func.py +315 -60
- mindspore/communication/management.py +14 -14
- mindspore/context.py +132 -22
- mindspore/dataset/__init__.py +1 -1
- mindspore/dataset/audio/__init__.py +1 -1
- mindspore/dataset/core/config.py +7 -0
- mindspore/dataset/core/validator_helpers.py +7 -0
- mindspore/dataset/engine/cache_client.py +1 -1
- mindspore/dataset/engine/datasets.py +72 -44
- mindspore/dataset/engine/datasets_audio.py +7 -7
- mindspore/dataset/engine/datasets_standard_format.py +53 -3
- mindspore/dataset/engine/datasets_text.py +20 -20
- mindspore/dataset/engine/datasets_user_defined.py +174 -104
- mindspore/dataset/engine/datasets_vision.py +33 -33
- mindspore/dataset/engine/iterators.py +29 -0
- mindspore/dataset/engine/obs/util.py +7 -0
- mindspore/dataset/engine/queue.py +114 -60
- mindspore/dataset/engine/serializer_deserializer.py +2 -2
- mindspore/dataset/engine/validators.py +34 -14
- mindspore/dataset/text/__init__.py +1 -4
- mindspore/dataset/transforms/__init__.py +0 -3
- mindspore/dataset/utils/line_reader.py +2 -0
- mindspore/dataset/vision/__init__.py +1 -4
- mindspore/dataset/vision/utils.py +1 -1
- mindspore/dataset/vision/validators.py +2 -1
- mindspore/dnnl.dll +0 -0
- mindspore/{nn/extend → experimental/es}/__init__.py +4 -11
- mindspore/experimental/es/embedding_service.py +883 -0
- mindspore/{nn/layer → experimental/es}/embedding_service_layer.py +218 -30
- mindspore/experimental/llm_boost/__init__.py +21 -0
- mindspore/{nn/extend/layer → experimental/llm_boost/atb}/__init__.py +4 -8
- mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
- mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
- mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
- mindspore/experimental/llm_boost/register.py +129 -0
- mindspore/experimental/llm_boost/utils.py +31 -0
- mindspore/experimental/optim/adamw.py +85 -0
- mindspore/experimental/optim/optimizer.py +3 -0
- mindspore/hal/__init__.py +3 -3
- mindspore/hal/contiguous_tensors_handle.py +175 -0
- mindspore/hal/stream.py +18 -0
- mindspore/include/api/model_group.h +13 -1
- mindspore/include/api/types.h +10 -10
- mindspore/include/dataset/config.h +2 -2
- mindspore/include/dataset/constants.h +2 -2
- mindspore/include/dataset/execute.h +2 -2
- mindspore/include/dataset/vision.h +4 -0
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +1 -1
- mindspore/mindrecord/filewriter.py +68 -51
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/mint/__init__.py +983 -46
- mindspore/mint/distributed/__init__.py +31 -0
- mindspore/mint/distributed/distributed.py +254 -0
- mindspore/mint/nn/__init__.py +268 -23
- mindspore/mint/nn/functional.py +125 -19
- mindspore/mint/nn/layer/__init__.py +39 -0
- mindspore/mint/nn/layer/activation.py +133 -0
- mindspore/mint/nn/layer/normalization.py +477 -0
- mindspore/mint/nn/layer/pooling.py +110 -0
- mindspore/mint/optim/adamw.py +26 -13
- mindspore/mint/special/__init__.py +63 -0
- mindspore/multiprocessing/__init__.py +2 -1
- mindspore/nn/__init__.py +0 -1
- mindspore/nn/cell.py +276 -96
- mindspore/nn/layer/activation.py +211 -44
- mindspore/nn/layer/basic.py +137 -10
- mindspore/nn/layer/embedding.py +137 -2
- mindspore/nn/layer/normalization.py +101 -5
- mindspore/nn/layer/padding.py +34 -48
- mindspore/nn/layer/pooling.py +161 -7
- mindspore/nn/layer/transformer.py +3 -3
- mindspore/nn/loss/__init__.py +2 -2
- mindspore/nn/loss/loss.py +84 -6
- mindspore/nn/optim/__init__.py +2 -1
- mindspore/nn/optim/adadelta.py +1 -1
- mindspore/nn/optim/adam.py +1 -1
- mindspore/nn/optim/lamb.py +1 -1
- mindspore/nn/optim/tft_wrapper.py +124 -0
- mindspore/nn/wrap/cell_wrapper.py +12 -23
- mindspore/nn/wrap/grad_reducer.py +5 -5
- mindspore/nn/wrap/loss_scale.py +17 -3
- mindspore/numpy/__init__.py +1 -1
- mindspore/numpy/array_creations.py +65 -68
- mindspore/numpy/array_ops.py +64 -60
- mindspore/numpy/fft.py +610 -75
- mindspore/numpy/logic_ops.py +11 -10
- mindspore/numpy/math_ops.py +85 -84
- mindspore/numpy/utils_const.py +4 -4
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +6 -4
- mindspore/ops/_grad_experimental/grad_array_ops.py +0 -11
- mindspore/ops/_grad_experimental/grad_comm_ops.py +67 -4
- mindspore/ops/_grad_experimental/grad_math_ops.py +0 -22
- mindspore/ops/_vmap/vmap_array_ops.py +2 -4
- mindspore/ops/_vmap/vmap_math_ops.py +17 -1
- mindspore/ops/_vmap/vmap_nn_ops.py +43 -2
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +91 -7
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +2 -0
- mindspore/ops/auto_generate/gen_extend_func.py +767 -13
- mindspore/ops/auto_generate/gen_ops_def.py +2452 -364
- mindspore/ops/auto_generate/gen_ops_prim.py +5442 -1756
- mindspore/ops/auto_generate/pyboost_inner_prim.py +176 -56
- mindspore/ops/composite/base.py +85 -48
- mindspore/ops/composite/multitype_ops/_compile_utils.py +1 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -2
- mindspore/ops/function/__init__.py +22 -0
- mindspore/ops/function/array_func.py +492 -153
- mindspore/ops/function/debug_func.py +113 -1
- mindspore/ops/function/fft_func.py +15 -2
- mindspore/ops/function/grad/grad_func.py +3 -2
- mindspore/ops/function/math_func.py +564 -207
- mindspore/ops/function/nn_func.py +817 -383
- mindspore/ops/function/other_func.py +3 -2
- mindspore/ops/function/random_func.py +402 -12
- mindspore/ops/function/reshard_func.py +13 -11
- mindspore/ops/function/sparse_unary_func.py +1 -1
- mindspore/ops/function/vmap_func.py +3 -2
- mindspore/ops/functional.py +24 -14
- mindspore/ops/op_info_register.py +3 -3
- mindspore/ops/operations/__init__.py +7 -2
- mindspore/ops/operations/_grad_ops.py +2 -76
- mindspore/ops/operations/_infer_ops.py +1 -1
- mindspore/ops/operations/_inner_ops.py +71 -94
- mindspore/ops/operations/array_ops.py +14 -146
- mindspore/ops/operations/comm_ops.py +63 -53
- mindspore/ops/operations/custom_ops.py +83 -19
- mindspore/ops/operations/debug_ops.py +42 -10
- mindspore/ops/operations/manually_defined/_inner.py +12 -0
- mindspore/ops/operations/manually_defined/ops_def.py +273 -20
- mindspore/ops/operations/math_ops.py +12 -223
- mindspore/ops/operations/nn_ops.py +20 -114
- mindspore/ops/operations/other_ops.py +7 -4
- mindspore/ops/operations/random_ops.py +46 -1
- mindspore/ops/primitive.py +18 -6
- mindspore/ops_generate/arg_dtype_cast.py +2 -0
- mindspore/ops_generate/gen_aclnn_implement.py +11 -11
- mindspore/ops_generate/gen_constants.py +36 -0
- mindspore/ops_generate/gen_ops.py +67 -52
- mindspore/ops_generate/gen_ops_inner_prim.py +1 -1
- mindspore/ops_generate/gen_pyboost_func.py +131 -47
- mindspore/ops_generate/op_proto.py +10 -3
- mindspore/ops_generate/pyboost_utils.py +14 -1
- mindspore/ops_generate/template.py +43 -21
- mindspore/parallel/__init__.py +3 -1
- mindspore/parallel/_auto_parallel_context.py +31 -9
- mindspore/parallel/_cell_wrapper.py +85 -0
- mindspore/parallel/_parallel_serialization.py +47 -19
- mindspore/parallel/_tensor.py +127 -13
- mindspore/parallel/_utils.py +53 -22
- mindspore/parallel/algo_parameter_config.py +5 -5
- mindspore/parallel/checkpoint_transform.py +46 -39
- mindspore/parallel/cluster/process_entity/__init__.py +1 -1
- mindspore/parallel/cluster/process_entity/_api.py +31 -23
- mindspore/parallel/cluster/process_entity/_utils.py +2 -27
- mindspore/parallel/parameter_broadcast.py +3 -4
- mindspore/parallel/shard.py +162 -31
- mindspore/parallel/transform_safetensors.py +1146 -0
- mindspore/profiler/__init__.py +2 -1
- mindspore/profiler/common/constant.py +29 -0
- mindspore/profiler/common/registry.py +47 -0
- mindspore/profiler/common/util.py +28 -0
- mindspore/profiler/dynamic_profiler.py +694 -0
- mindspore/profiler/envprofiling.py +17 -19
- mindspore/profiler/parser/ascend_analysis/constant.py +18 -0
- mindspore/profiler/parser/ascend_analysis/file_manager.py +25 -4
- mindspore/profiler/parser/ascend_analysis/function_event.py +43 -19
- mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +31 -26
- mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +56 -10
- mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +55 -8
- mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
- mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +27 -20
- mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +9 -2
- mindspore/profiler/parser/ascend_msprof_exporter.py +5 -4
- mindspore/profiler/parser/ascend_timeline_generator.py +27 -25
- mindspore/profiler/parser/base_timeline_generator.py +19 -25
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +25 -12
- mindspore/profiler/parser/framework_parser.py +1 -391
- mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
- mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
- mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
- mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
- mindspore/profiler/parser/memory_usage_parser.py +0 -154
- mindspore/profiler/parser/profiler_info.py +78 -6
- mindspore/profiler/profiler.py +153 -0
- mindspore/profiler/profiling.py +285 -413
- mindspore/rewrite/__init__.py +1 -2
- mindspore/rewrite/common/namespace.py +4 -4
- mindspore/rewrite/symbol_tree/symbol_tree.py +3 -3
- mindspore/run_check/_check_version.py +39 -104
- mindspore/safeguard/rewrite_obfuscation.py +591 -247
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +4 -3
- mindspore/train/_utils.py +105 -19
- mindspore/train/amp.py +171 -53
- mindspore/train/callback/__init__.py +2 -2
- mindspore/train/callback/_callback.py +4 -4
- mindspore/train/callback/_checkpoint.py +97 -31
- mindspore/train/callback/_cluster_monitor.py +1 -1
- mindspore/train/callback/_flops_collector.py +1 -0
- mindspore/train/callback/_loss_monitor.py +3 -3
- mindspore/train/callback/_on_request_exit.py +145 -31
- mindspore/train/callback/_summary_collector.py +5 -5
- mindspore/train/callback/_tft_register.py +375 -0
- mindspore/train/dataset_helper.py +15 -3
- mindspore/train/metrics/metric.py +3 -3
- mindspore/train/metrics/roc.py +4 -4
- mindspore/train/mind_ir_pb2.py +44 -39
- mindspore/train/model.py +154 -58
- mindspore/train/serialization.py +342 -128
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +21 -0
- mindspore/utils/utils.py +60 -0
- mindspore/version.py +1 -1
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/METADATA +13 -7
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/RECORD +260 -254
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/WHEEL +1 -1
- mindspore/include/c_api/ms/abstract.h +0 -67
- mindspore/include/c_api/ms/attribute.h +0 -197
- mindspore/include/c_api/ms/base/handle_types.h +0 -43
- mindspore/include/c_api/ms/base/macros.h +0 -32
- mindspore/include/c_api/ms/base/status.h +0 -33
- mindspore/include/c_api/ms/base/types.h +0 -283
- mindspore/include/c_api/ms/context.h +0 -102
- mindspore/include/c_api/ms/graph.h +0 -160
- mindspore/include/c_api/ms/node.h +0 -606
- mindspore/include/c_api/ms/tensor.h +0 -161
- mindspore/include/c_api/ms/value.h +0 -84
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/nn/extend/basic.py +0 -140
- mindspore/nn/extend/embedding.py +0 -143
- mindspore/nn/extend/layer/normalization.py +0 -109
- mindspore/nn/extend/pooling.py +0 -117
- mindspore/nn/layer/embedding_service.py +0 -531
- mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +0 -93
- mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +0 -66
- mindspore/ops/extend/__init__.py +0 -53
- mindspore/ops/extend/array_func.py +0 -218
- mindspore/ops/extend/math_func.py +0 -76
- mindspore/ops/extend/nn_func.py +0 -308
- mindspore/ops/silent_check.py +0 -162
- mindspore/profiler/parser/msadvisor_analyzer.py +0 -82
- mindspore/profiler/parser/msadvisor_parser.py +0 -240
- mindspore/train/callback/_mindio_ttp.py +0 -443
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/entry_points.txt +0 -0
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/top_level.txt +0 -0
mindspore/common/parameter.py
CHANGED
|
@@ -22,6 +22,7 @@ import os
|
|
|
22
22
|
import sys
|
|
23
23
|
import math
|
|
24
24
|
import numbers
|
|
25
|
+
from contextlib import contextmanager
|
|
25
26
|
import numpy as np
|
|
26
27
|
from mindspore import log as logger
|
|
27
28
|
from mindspore.log import _LogActionOnce
|
|
@@ -41,6 +42,8 @@ from mindspore.parallel._ps_context import _is_role_worker, _is_role_pserver, _i
|
|
|
41
42
|
_is_ps_mode
|
|
42
43
|
from mindspore.parallel._ps_context import _reinsert_hash_table_size, _insert_accumu_init_info, _cache_enable
|
|
43
44
|
from mindspore.common._decorator import deprecated
|
|
45
|
+
from mindspore.communication._comm_helper import _is_initialized
|
|
46
|
+
from mindspore.communication import get_group_size
|
|
44
47
|
import mindspore.common._monad as monad
|
|
45
48
|
|
|
46
49
|
__all__ = ['Parameter', 'ParameterTuple']
|
|
@@ -52,11 +55,32 @@ PARAMETER_NAME_PREFIX_MAX_LEN = 1024
|
|
|
52
55
|
_GLOBAL_PARAMETER_KEY = -1
|
|
53
56
|
|
|
54
57
|
|
|
55
|
-
|
|
58
|
+
@contextmanager
|
|
59
|
+
def no_init_parameters():
|
|
60
|
+
init_class = globals()["Parameter"]
|
|
61
|
+
setattr(init_class, "init_param", False)
|
|
62
|
+
try:
|
|
63
|
+
yield
|
|
64
|
+
finally:
|
|
65
|
+
setattr(init_class, "init_param", True)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def _is_in_auto_parallel_mode():
|
|
56
69
|
"""Get parallel mode."""
|
|
57
70
|
return auto_parallel_context().get_parallel_mode() in ["semi_auto_parallel", "auto_parallel"]
|
|
58
71
|
|
|
59
72
|
|
|
73
|
+
def _is_parallel_mode():
|
|
74
|
+
""" Whether is parallel mode """
|
|
75
|
+
if not _is_initialized() or context.get_context('mode') == context.PYNATIVE_MODE:
|
|
76
|
+
return False
|
|
77
|
+
if os.getenv("RUN_MODE") != "predict":
|
|
78
|
+
return False
|
|
79
|
+
if get_group_size() > 1 and _get_parallel_mode() == "stand_alone":
|
|
80
|
+
return True
|
|
81
|
+
return False
|
|
82
|
+
|
|
83
|
+
|
|
60
84
|
def init_to_value(init):
|
|
61
85
|
"""
|
|
62
86
|
Get value of initializer.
|
|
@@ -91,6 +115,15 @@ def _get_unique_parameter_key():
|
|
|
91
115
|
return _GLOBAL_PARAMETER_KEY
|
|
92
116
|
|
|
93
117
|
|
|
118
|
+
def _gen_offload_file_path(offload_dir):
|
|
119
|
+
offload_dir = os.path.relpath(offload_dir)
|
|
120
|
+
if not os.path.exists(offload_dir):
|
|
121
|
+
os.makedirs(offload_dir, mode=0o700, exist_ok=True)
|
|
122
|
+
offload_file_path = offload_dir + "/" + str(_get_global_rank()) + "_" + str(
|
|
123
|
+
_get_unique_parameter_key()) + "_" + str(time.time()) + ".data"
|
|
124
|
+
return offload_file_path
|
|
125
|
+
|
|
126
|
+
|
|
94
127
|
def _offload_if_config(data):
|
|
95
128
|
"""
|
|
96
129
|
Offload parameter(data size > 512) to file when enable memory offload and offload parameter to disk.
|
|
@@ -111,11 +144,7 @@ def _offload_if_config(data):
|
|
|
111
144
|
offload_file_path = data.offload_file_path()
|
|
112
145
|
if offload_file_path is None or offload_file_path == "":
|
|
113
146
|
offload_dir = offload_context.get("offload_path", "./offload")
|
|
114
|
-
|
|
115
|
-
if not os.path.exists(offload_dir):
|
|
116
|
-
os.makedirs(offload_dir)
|
|
117
|
-
offload_file_path = offload_dir + "/" + str(_get_global_rank()) + "_" + str(
|
|
118
|
-
_get_unique_parameter_key()) + "_" + str(time.time()) + ".data"
|
|
147
|
+
offload_file_path = _gen_offload_file_path(offload_dir)
|
|
119
148
|
data.offload(offload_file_path)
|
|
120
149
|
|
|
121
150
|
|
|
@@ -191,6 +220,12 @@ class Parameter(Tensor_):
|
|
|
191
220
|
storage_format (str): Only Ascend device target is supported. It is used to specify the format of the weight
|
|
192
221
|
loaded to the device. By default, the format is not changed. The optional values are ``"FRACTAL_NZ"`` ,
|
|
193
222
|
``"NC1HWC0"`` , ``"FRACTAL_Z"`` , etc. Default: ``""`` .
|
|
223
|
+
device(str): Only Ascend device target is supported. It is used to specify the device which the parameter is
|
|
224
|
+
stored. By default, the parameter will be stored on NPU while computing. When the device is specified as
|
|
225
|
+
``"CPU"``, the parameter will be loaded into the device when it needs to be used, and unloaded to the CPU
|
|
226
|
+
after use. It takes effext only when `memory_offload` is ``"ON"``, `jit_level` is not ``"O2"`` and
|
|
227
|
+
`memory_optimize_level` is ``O0`` in `mindspore.set_context()`. Less device memory is needed when device is
|
|
228
|
+
specified as ``"CPU"``.
|
|
194
229
|
|
|
195
230
|
Examples:
|
|
196
231
|
>>> import numpy as np
|
|
@@ -219,7 +254,8 @@ class Parameter(Tensor_):
|
|
|
219
254
|
def __new__(cls, default_input, *args, **kwargs):
|
|
220
255
|
init_data_flag = bool(isinstance(default_input, Tensor) and default_input.has_init)
|
|
221
256
|
rc = sys.getrefcount(default_input)
|
|
222
|
-
|
|
257
|
+
init_param = getattr(cls, "init_param", True)
|
|
258
|
+
input_class, *class_init_args = Parameter._get_parameter_new_args(default_input, rc, init_param)
|
|
223
259
|
new_type = Parameter._get_base_class(input_class)
|
|
224
260
|
obj = input_class.__new__(new_type)
|
|
225
261
|
input_class.__init__(obj, *class_init_args)
|
|
@@ -244,7 +280,7 @@ class Parameter(Tensor_):
|
|
|
244
280
|
Parameter, (data, self.name, self.requires_grad, self.layerwise_parallel))
|
|
245
281
|
|
|
246
282
|
def __init__(self, default_input, name=None, requires_grad=True, layerwise_parallel=False, parallel_optimizer=True,
|
|
247
|
-
storage_format=""):
|
|
283
|
+
storage_format="", device=None):
|
|
248
284
|
self.param_info = ParamInfo()
|
|
249
285
|
self.init_in_server = False
|
|
250
286
|
self.name = name
|
|
@@ -263,7 +299,7 @@ class Parameter(Tensor_):
|
|
|
263
299
|
self.requires_aggr = True
|
|
264
300
|
self._cast_type = None
|
|
265
301
|
self._unique = False
|
|
266
|
-
self.is_in_parallel =
|
|
302
|
+
self.is_in_parallel = _is_in_auto_parallel_mode()
|
|
267
303
|
self.is_in_shard = False
|
|
268
304
|
self._pipeline_stage_list = []
|
|
269
305
|
self.slice_num = 1
|
|
@@ -296,6 +332,10 @@ class Parameter(Tensor_):
|
|
|
296
332
|
f" 'numpy.ndarray', 'list']. But got type {type(default_input)}.")
|
|
297
333
|
self.param_info.parameter_shape = self.shape
|
|
298
334
|
self.param_info.storage_format = storage_format
|
|
335
|
+
if device is not None:
|
|
336
|
+
if device != "CPU":
|
|
337
|
+
raise ValueError(f"Only 'CPU' is supported for device, but got ${device}.")
|
|
338
|
+
self._set_user_data("parameter_device", device)
|
|
299
339
|
|
|
300
340
|
import mindspore.ops.operations.other_ops as other_ops
|
|
301
341
|
self.load = other_ops.Load()
|
|
@@ -327,7 +367,7 @@ class Parameter(Tensor_):
|
|
|
327
367
|
return new_type
|
|
328
368
|
|
|
329
369
|
@staticmethod
|
|
330
|
-
def _get_parameter_new_args(data, rc):
|
|
370
|
+
def _get_parameter_new_args(data, rc, init_param=True):
|
|
331
371
|
"""Set `set_data` of current `Parameter`."""
|
|
332
372
|
if isinstance(data, bool):
|
|
333
373
|
raise ValueError('Parameter data can not be `bool`')
|
|
@@ -342,7 +382,8 @@ class Parameter(Tensor_):
|
|
|
342
382
|
return (Tensor, data.asnumpy(), mstype.qint4x2)
|
|
343
383
|
return (Tensor, data.asnumpy())
|
|
344
384
|
|
|
345
|
-
not_init_data = _is_role_sched() or (_is_role_pserver() and _cache_enable())
|
|
385
|
+
not_init_data = not init_param or _is_role_sched() or (_is_role_pserver() and _cache_enable()) \
|
|
386
|
+
or _is_in_auto_parallel_mode() or _is_parallel_mode()
|
|
346
387
|
if not_init_data:
|
|
347
388
|
# do not init data while in auto parallel.
|
|
348
389
|
return (Tensor, None, data.dtype, get_slice_shape(data.dtype, data.shape), data.init)
|
|
@@ -368,7 +409,7 @@ class Parameter(Tensor_):
|
|
|
368
409
|
|
|
369
410
|
Tutorial Examples:
|
|
370
411
|
- `Parameter Server Mode
|
|
371
|
-
<https://www.mindspore.cn/
|
|
412
|
+
<https://www.mindspore.cn/docs/en/master/model_train/parallel/parameter_server_training.html>`_
|
|
372
413
|
"""
|
|
373
414
|
if not _is_ps_mode() or not (_is_role_worker() or _is_role_pserver() or _is_role_sched()):
|
|
374
415
|
raise RuntimeError("Must complete following two steps before calling set_param_ps: \n"
|
|
@@ -616,6 +657,9 @@ class Parameter(Tensor_):
|
|
|
616
657
|
shape = self.shape if self.slice_num == 1 else self.param_info.origin_shape
|
|
617
658
|
dtype = self.dtype
|
|
618
659
|
x.set_data(initializer(init, shape=shape, dtype=dtype))
|
|
660
|
+
device = self._get_user_data("parameter_device")
|
|
661
|
+
if device is not None:
|
|
662
|
+
x._set_user_data("parameter_device", device)
|
|
619
663
|
return x
|
|
620
664
|
|
|
621
665
|
@property
|
|
@@ -942,8 +986,10 @@ class Parameter(Tensor_):
|
|
|
942
986
|
>>> x = Parameter(Tensor(np.array([[1, 2], [3, 4]], dtype=np.float32)), name="param")
|
|
943
987
|
>>> x.init_data()
|
|
944
988
|
"""
|
|
945
|
-
if self.is_default_input_init and self.is_in_parallel !=
|
|
989
|
+
if self.is_default_input_init and self.is_in_parallel != _is_in_auto_parallel_mode():
|
|
946
990
|
raise RuntimeError("Must set or change parallel mode before any initializer Tensor created.")
|
|
991
|
+
if hasattr(self, "init_param") and self.init_param:
|
|
992
|
+
return self
|
|
947
993
|
if self.init_mode is None:
|
|
948
994
|
return self
|
|
949
995
|
if self.inited_param is not None:
|
|
@@ -1026,8 +1072,9 @@ class ParameterTuple(tuple):
|
|
|
1026
1072
|
Tuple, the new Parameter tuple.
|
|
1027
1073
|
|
|
1028
1074
|
Tutorial Examples:
|
|
1029
|
-
- `
|
|
1030
|
-
<https://mindspore.cn/
|
|
1075
|
+
- `Tensor and Parameter - Parameter Tuple
|
|
1076
|
+
<https://mindspore.cn/docs/en/master/model_train/model_building/tensor_and_parameter.html
|
|
1077
|
+
#parameter-tuple>`_
|
|
1031
1078
|
"""
|
|
1032
1079
|
Validator.check_str_by_regular(prefix)
|
|
1033
1080
|
new = []
|
mindspore/common/recompute.py
CHANGED
|
@@ -23,8 +23,10 @@ from mindspore.common.tensor import Tensor
|
|
|
23
23
|
from mindspore import ops
|
|
24
24
|
from mindspore.ops.composite import GradOperation
|
|
25
25
|
from mindspore.common._register_for_recompute import recompute_registry
|
|
26
|
-
from mindspore.common.api import _pynative_executor
|
|
26
|
+
from mindspore.common.api import _pynative_executor, _no_grad
|
|
27
27
|
from mindspore.common.generator import get_rng_state, set_rng_state
|
|
28
|
+
from mindspore.train.amp import amp_decorator
|
|
29
|
+
from mindspore._c_expression.amp import get_curr_amp_strategy
|
|
28
30
|
|
|
29
31
|
|
|
30
32
|
class _WrapCell(Cell):
|
|
@@ -34,7 +36,7 @@ class _WrapCell(Cell):
|
|
|
34
36
|
"""
|
|
35
37
|
|
|
36
38
|
def __init__(self, function):
|
|
37
|
-
super(_WrapCell, self).__init__()
|
|
39
|
+
super(_WrapCell, self).__init__(auto_prefix=False)
|
|
38
40
|
self.function = function
|
|
39
41
|
|
|
40
42
|
def construct(self, *args, **kwargs):
|
|
@@ -56,6 +58,7 @@ class _RecomputeCell(Cell):
|
|
|
56
58
|
self.args = []
|
|
57
59
|
self.kwargs = []
|
|
58
60
|
self.wrap_cell = _WrapCell(block)
|
|
61
|
+
self.wrap_cell.set_inputs()
|
|
59
62
|
|
|
60
63
|
self.net = block
|
|
61
64
|
self.internal_params = []
|
|
@@ -64,15 +67,18 @@ class _RecomputeCell(Cell):
|
|
|
64
67
|
self._add_attr("is_cell_recompute", "True")
|
|
65
68
|
self.grad = GradOperation(get_all=True, get_by_list=True, sens_param=True)
|
|
66
69
|
self.init_mixed_precision_type(block)
|
|
70
|
+
self.amp_strategy = None
|
|
67
71
|
|
|
68
72
|
def construct(self, *args, **kwargs):
|
|
69
|
-
_check_input_args_validate(self.net, args)
|
|
73
|
+
_check_input_args_validate(self.net, args, kwargs)
|
|
70
74
|
self.args.append(args)
|
|
71
75
|
self.kwargs.append(kwargs)
|
|
72
76
|
self.save_rng_state = kwargs.pop("save_rng_state", True)
|
|
73
77
|
if self.save_rng_state:
|
|
74
78
|
self.cpu_rng_state = get_rng_state()
|
|
75
|
-
|
|
79
|
+
self.amp_strategy = get_curr_amp_strategy()
|
|
80
|
+
with _no_grad():
|
|
81
|
+
return self.net(*args, **kwargs)
|
|
76
82
|
|
|
77
83
|
def bprop(self, *args):
|
|
78
84
|
"""
|
|
@@ -86,14 +92,23 @@ class _RecomputeCell(Cell):
|
|
|
86
92
|
self.args.pop()
|
|
87
93
|
self.kwargs.pop()
|
|
88
94
|
if kwargs:
|
|
89
|
-
|
|
95
|
+
input_args_for_check = list(input_args) + list(kwargs.values())
|
|
96
|
+
else:
|
|
97
|
+
input_args_for_check = list(input_args)
|
|
90
98
|
# To detach inputs to avoid erasing auto grad meta info of origin inputs.
|
|
91
99
|
input_args = _detach_input(input_args)
|
|
100
|
+
kwargs = _detach_input(kwargs)
|
|
101
|
+
kwargs['sens'] = grad_input
|
|
92
102
|
try:
|
|
93
103
|
pre_rng_state = get_rng_state()
|
|
94
104
|
set_rng_state(self.cpu_rng_state)
|
|
95
105
|
_pynative_executor.set_is_run_recompute(True)
|
|
96
|
-
|
|
106
|
+
if self.amp_strategy:
|
|
107
|
+
with amp_decorator(self.amp_strategy.get_amp_level(), self.amp_strategy.get_amp_dtype(),
|
|
108
|
+
self.amp_strategy.get_white_list(), self.amp_strategy.get_black_list()):
|
|
109
|
+
grads = self.grad(self.net, self.internal_params)(*input_args, **kwargs)
|
|
110
|
+
else:
|
|
111
|
+
grads = self.grad(self.net, self.internal_params)(*input_args, **kwargs)
|
|
97
112
|
_pynative_executor.set_is_run_recompute(False)
|
|
98
113
|
set_rng_state(pre_rng_state)
|
|
99
114
|
except Exception as err:
|
|
@@ -101,7 +116,7 @@ class _RecomputeCell(Cell):
|
|
|
101
116
|
raise err
|
|
102
117
|
weights = OrderedDict()
|
|
103
118
|
input_grads = list(grads[0])
|
|
104
|
-
_padding_input_grads(
|
|
119
|
+
_padding_input_grads(input_args_for_check, input_grads)
|
|
105
120
|
for i, param in enumerate(self.internal_params):
|
|
106
121
|
weights[param] = grads[1][i]
|
|
107
122
|
return tuple(input_grads), weights
|
|
@@ -121,6 +136,7 @@ class _RecomputeCell(Cell):
|
|
|
121
136
|
# To avoid sub cell same name
|
|
122
137
|
block.__self__.check_names_and_refresh_name()
|
|
123
138
|
self.internal_params = block.__self__.trainable_params()
|
|
139
|
+
self.wrap_cell.mixed_precision_type = block.__self__.get_mixed_precision_type()
|
|
124
140
|
self.wrap_cell.set_mixed_precision_type(block.__self__.get_mixed_precision_type())
|
|
125
141
|
self.net = self.wrap_cell
|
|
126
142
|
else:
|
|
@@ -128,13 +144,14 @@ class _RecomputeCell(Cell):
|
|
|
128
144
|
"only support Cell object or MethodType function!")
|
|
129
145
|
|
|
130
146
|
|
|
131
|
-
def _check_input_args_validate(block, args):
|
|
147
|
+
def _check_input_args_validate(block, args, kwargs):
|
|
132
148
|
"""
|
|
133
149
|
Check recompute input args validate
|
|
134
150
|
:param args:
|
|
135
151
|
:return:
|
|
136
152
|
"""
|
|
137
|
-
if not any([isinstance(arg, Tensor) for arg in args])
|
|
153
|
+
if not (any([isinstance(arg, Tensor) for arg in args]) or \
|
|
154
|
+
any([isinstance(arg, Tensor) for arg in kwargs.values()])):
|
|
138
155
|
logger.warning("None of the inputs of function are tensors, which not need use recompute!")
|
|
139
156
|
for arg in args:
|
|
140
157
|
if isinstance(arg, (tuple, list)):
|
|
@@ -168,6 +185,11 @@ def _padding_input_grads(args, input_grads):
|
|
|
168
185
|
|
|
169
186
|
|
|
170
187
|
def _detach_input(input_arg):
|
|
188
|
+
"""
|
|
189
|
+
Detach input
|
|
190
|
+
:param input_arg:
|
|
191
|
+
:return: detach output
|
|
192
|
+
"""
|
|
171
193
|
if isinstance(input_arg, Tensor):
|
|
172
194
|
return ops.stop_gradient(input_arg)
|
|
173
195
|
if isinstance(input_arg, (list, tuple)):
|
|
@@ -175,6 +197,14 @@ def _detach_input(input_arg):
|
|
|
175
197
|
for arg in input_arg:
|
|
176
198
|
detach_inputs.append(_detach_input(arg))
|
|
177
199
|
return detach_inputs if isinstance(input_arg, list) else tuple(detach_inputs)
|
|
200
|
+
if isinstance(input_arg, dict):
|
|
201
|
+
detach_inputs = {}
|
|
202
|
+
for key, val in input_arg.items():
|
|
203
|
+
if isinstance(val, Tensor):
|
|
204
|
+
detach_inputs[key] = ops.stop_gradient(val)
|
|
205
|
+
else:
|
|
206
|
+
detach_inputs[key] = val
|
|
207
|
+
return detach_inputs
|
|
178
208
|
return input_arg
|
|
179
209
|
|
|
180
210
|
|
|
@@ -97,7 +97,8 @@ class RowTensor(RowTensorInner):
|
|
|
97
97
|
[0, 0]]
|
|
98
98
|
|
|
99
99
|
.. warning::
|
|
100
|
-
This is an experimental API that is subjected to change or deletion.
|
|
100
|
+
- This is an experimental API that is subjected to change or deletion.
|
|
101
|
+
- If use PyNative mode, set "export MS_PYNATIVE_CONFIG_STATIC_SHAPE=1".
|
|
101
102
|
|
|
102
103
|
Args:
|
|
103
104
|
indices (Tensor): A 1-D integer Tensor of shape :math:`(d_0)` . Default: ``None``.
|
|
@@ -226,10 +227,11 @@ class COOTensor(COOTensor_):
|
|
|
226
227
|
|
|
227
228
|
Common arithmetic operations include: addition (+), subtraction (-), multiplication (*),
|
|
228
229
|
and division (/). For details about operations supported by `COOTensor`, see
|
|
229
|
-
`operators <https://www.mindspore.cn/docs/en/master/
|
|
230
|
+
`operators <https://www.mindspore.cn/docs/en/master/model_train/program_form/static_graph.html#operators>`_.
|
|
230
231
|
|
|
231
232
|
.. warning::
|
|
232
233
|
- This is an experimental API that is subject to change or deletion.
|
|
234
|
+
- If use PyNative mode, set "export MS_PYNATIVE_CONFIG_STATIC_SHAPE=1".
|
|
233
235
|
- Currently, duplicate coordinates in the indices will not be coalesced.
|
|
234
236
|
If the indices contain out-of-bound values, the result will be undefined.
|
|
235
237
|
|
|
@@ -646,6 +648,7 @@ class CSRTensor(CSRTensor_):
|
|
|
646
648
|
[1., 2., 3., 4., 5., 6.], shape is (3, 5), then the dense representation of the sparse tensor will be:
|
|
647
649
|
|
|
648
650
|
.. code-block::
|
|
651
|
+
|
|
649
652
|
[[1., 0., 0., 2., 0.],
|
|
650
653
|
[0., 3., 4., 0., 5.],
|
|
651
654
|
[0., 0., 6., 0., 0.]]
|
|
@@ -668,10 +671,11 @@ class CSRTensor(CSRTensor_):
|
|
|
668
671
|
|
|
669
672
|
Common arithmetic operations include: addition (+), subtraction (-), multiplication (*),
|
|
670
673
|
and division (/). For details about operations supported by `CSRTensor`, see
|
|
671
|
-
`operators <https://www.mindspore.cn/docs/en/master/
|
|
674
|
+
`operators <https://www.mindspore.cn/docs/en/master/model_train/program_form/static_graph.html#operators>`_.
|
|
672
675
|
|
|
673
676
|
.. warning::
|
|
674
677
|
- This is an experimental API that is subjected to change.
|
|
678
|
+
- If use PyNative mode, set "export MS_PYNATIVE_CONFIG_STATIC_SHAPE=1".
|
|
675
679
|
- If the values given by `indptr` or `indices` are invalid, the results may be undefined. Invalid values include
|
|
676
680
|
when the length of `values` or `indices` exceeds the range indicated by `indptr`, and when the columns
|
|
677
681
|
indicated by `indices` are repeated on the same row.
|