mindspore 2.3.0__cp39-cp39-win_amd64.whl → 2.4.1__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (287) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +3 -1
  3. mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
  4. mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
  5. mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
  6. mindspore/_checkparam.py +50 -9
  7. mindspore/_extends/parse/compile_config.py +41 -0
  8. mindspore/_extends/parse/parser.py +9 -7
  9. mindspore/_extends/parse/standard_method.py +52 -14
  10. mindspore/_extends/pijit/pijit_func_white_list.py +350 -24
  11. mindspore/amp.py +24 -10
  12. mindspore/avcodec-59.dll +0 -0
  13. mindspore/avdevice-59.dll +0 -0
  14. mindspore/avfilter-8.dll +0 -0
  15. mindspore/avformat-59.dll +0 -0
  16. mindspore/avutil-57.dll +0 -0
  17. mindspore/common/__init__.py +6 -4
  18. mindspore/common/_pijit_context.py +190 -0
  19. mindspore/common/_register_for_tensor.py +2 -1
  20. mindspore/common/_tensor_overload.py +139 -0
  21. mindspore/common/api.py +102 -87
  22. mindspore/common/dump.py +5 -6
  23. mindspore/common/generator.py +1 -7
  24. mindspore/common/hook_handle.py +14 -26
  25. mindspore/common/initializer.py +51 -15
  26. mindspore/common/mindir_util.py +2 -2
  27. mindspore/common/parameter.py +62 -15
  28. mindspore/common/recompute.py +39 -9
  29. mindspore/common/sparse_tensor.py +7 -3
  30. mindspore/common/tensor.py +183 -37
  31. mindspore/communication/__init__.py +1 -1
  32. mindspore/communication/_comm_helper.py +38 -3
  33. mindspore/communication/comm_func.py +315 -60
  34. mindspore/communication/management.py +14 -14
  35. mindspore/context.py +132 -22
  36. mindspore/dataset/__init__.py +1 -1
  37. mindspore/dataset/audio/__init__.py +1 -1
  38. mindspore/dataset/core/config.py +7 -0
  39. mindspore/dataset/core/validator_helpers.py +7 -0
  40. mindspore/dataset/engine/cache_client.py +1 -1
  41. mindspore/dataset/engine/datasets.py +72 -44
  42. mindspore/dataset/engine/datasets_audio.py +7 -7
  43. mindspore/dataset/engine/datasets_standard_format.py +53 -3
  44. mindspore/dataset/engine/datasets_text.py +20 -20
  45. mindspore/dataset/engine/datasets_user_defined.py +174 -104
  46. mindspore/dataset/engine/datasets_vision.py +33 -33
  47. mindspore/dataset/engine/iterators.py +29 -0
  48. mindspore/dataset/engine/obs/util.py +7 -0
  49. mindspore/dataset/engine/queue.py +114 -60
  50. mindspore/dataset/engine/serializer_deserializer.py +2 -2
  51. mindspore/dataset/engine/validators.py +34 -14
  52. mindspore/dataset/text/__init__.py +1 -4
  53. mindspore/dataset/transforms/__init__.py +0 -3
  54. mindspore/dataset/utils/line_reader.py +2 -0
  55. mindspore/dataset/vision/__init__.py +1 -4
  56. mindspore/dataset/vision/utils.py +1 -1
  57. mindspore/dataset/vision/validators.py +2 -1
  58. mindspore/dnnl.dll +0 -0
  59. mindspore/{nn/extend → experimental/es}/__init__.py +4 -11
  60. mindspore/experimental/es/embedding_service.py +883 -0
  61. mindspore/{nn/layer → experimental/es}/embedding_service_layer.py +218 -30
  62. mindspore/experimental/llm_boost/__init__.py +21 -0
  63. mindspore/{nn/extend/layer → experimental/llm_boost/atb}/__init__.py +4 -8
  64. mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
  65. mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
  66. mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
  67. mindspore/experimental/llm_boost/register.py +129 -0
  68. mindspore/experimental/llm_boost/utils.py +31 -0
  69. mindspore/experimental/optim/adamw.py +85 -0
  70. mindspore/experimental/optim/optimizer.py +3 -0
  71. mindspore/hal/__init__.py +3 -3
  72. mindspore/hal/contiguous_tensors_handle.py +175 -0
  73. mindspore/hal/stream.py +18 -0
  74. mindspore/include/api/model_group.h +13 -1
  75. mindspore/include/api/types.h +10 -10
  76. mindspore/include/dataset/config.h +2 -2
  77. mindspore/include/dataset/constants.h +2 -2
  78. mindspore/include/dataset/execute.h +2 -2
  79. mindspore/include/dataset/vision.h +4 -0
  80. mindspore/jpeg62.dll +0 -0
  81. mindspore/log.py +1 -1
  82. mindspore/mindrecord/filewriter.py +68 -51
  83. mindspore/mindspore_backend.dll +0 -0
  84. mindspore/mindspore_common.dll +0 -0
  85. mindspore/mindspore_core.dll +0 -0
  86. mindspore/mindspore_glog.dll +0 -0
  87. mindspore/mindspore_np_dtype.dll +0 -0
  88. mindspore/mindspore_ops.dll +0 -0
  89. mindspore/mint/__init__.py +983 -46
  90. mindspore/mint/distributed/__init__.py +31 -0
  91. mindspore/mint/distributed/distributed.py +254 -0
  92. mindspore/mint/nn/__init__.py +268 -23
  93. mindspore/mint/nn/functional.py +125 -19
  94. mindspore/mint/nn/layer/__init__.py +39 -0
  95. mindspore/mint/nn/layer/activation.py +133 -0
  96. mindspore/mint/nn/layer/normalization.py +477 -0
  97. mindspore/mint/nn/layer/pooling.py +110 -0
  98. mindspore/mint/optim/adamw.py +26 -13
  99. mindspore/mint/special/__init__.py +63 -0
  100. mindspore/multiprocessing/__init__.py +2 -1
  101. mindspore/nn/__init__.py +0 -1
  102. mindspore/nn/cell.py +276 -96
  103. mindspore/nn/layer/activation.py +211 -44
  104. mindspore/nn/layer/basic.py +137 -10
  105. mindspore/nn/layer/embedding.py +137 -2
  106. mindspore/nn/layer/normalization.py +101 -5
  107. mindspore/nn/layer/padding.py +34 -48
  108. mindspore/nn/layer/pooling.py +161 -7
  109. mindspore/nn/layer/transformer.py +3 -3
  110. mindspore/nn/loss/__init__.py +2 -2
  111. mindspore/nn/loss/loss.py +84 -6
  112. mindspore/nn/optim/__init__.py +2 -1
  113. mindspore/nn/optim/adadelta.py +1 -1
  114. mindspore/nn/optim/adam.py +1 -1
  115. mindspore/nn/optim/lamb.py +1 -1
  116. mindspore/nn/optim/tft_wrapper.py +124 -0
  117. mindspore/nn/wrap/cell_wrapper.py +12 -23
  118. mindspore/nn/wrap/grad_reducer.py +5 -5
  119. mindspore/nn/wrap/loss_scale.py +17 -3
  120. mindspore/numpy/__init__.py +1 -1
  121. mindspore/numpy/array_creations.py +65 -68
  122. mindspore/numpy/array_ops.py +64 -60
  123. mindspore/numpy/fft.py +610 -75
  124. mindspore/numpy/logic_ops.py +11 -10
  125. mindspore/numpy/math_ops.py +85 -84
  126. mindspore/numpy/utils_const.py +4 -4
  127. mindspore/opencv_core452.dll +0 -0
  128. mindspore/opencv_imgcodecs452.dll +0 -0
  129. mindspore/opencv_imgproc452.dll +0 -0
  130. mindspore/ops/__init__.py +6 -4
  131. mindspore/ops/_grad_experimental/grad_array_ops.py +0 -11
  132. mindspore/ops/_grad_experimental/grad_comm_ops.py +67 -4
  133. mindspore/ops/_grad_experimental/grad_math_ops.py +0 -22
  134. mindspore/ops/_vmap/vmap_array_ops.py +2 -4
  135. mindspore/ops/_vmap/vmap_math_ops.py +17 -1
  136. mindspore/ops/_vmap/vmap_nn_ops.py +43 -2
  137. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +91 -7
  138. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +2 -0
  139. mindspore/ops/auto_generate/gen_extend_func.py +767 -13
  140. mindspore/ops/auto_generate/gen_ops_def.py +2452 -364
  141. mindspore/ops/auto_generate/gen_ops_prim.py +5442 -1756
  142. mindspore/ops/auto_generate/pyboost_inner_prim.py +176 -56
  143. mindspore/ops/composite/base.py +85 -48
  144. mindspore/ops/composite/multitype_ops/_compile_utils.py +1 -0
  145. mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -2
  146. mindspore/ops/function/__init__.py +22 -0
  147. mindspore/ops/function/array_func.py +492 -153
  148. mindspore/ops/function/debug_func.py +113 -1
  149. mindspore/ops/function/fft_func.py +15 -2
  150. mindspore/ops/function/grad/grad_func.py +3 -2
  151. mindspore/ops/function/math_func.py +564 -207
  152. mindspore/ops/function/nn_func.py +817 -383
  153. mindspore/ops/function/other_func.py +3 -2
  154. mindspore/ops/function/random_func.py +402 -12
  155. mindspore/ops/function/reshard_func.py +13 -11
  156. mindspore/ops/function/sparse_unary_func.py +1 -1
  157. mindspore/ops/function/vmap_func.py +3 -2
  158. mindspore/ops/functional.py +24 -14
  159. mindspore/ops/op_info_register.py +3 -3
  160. mindspore/ops/operations/__init__.py +7 -2
  161. mindspore/ops/operations/_grad_ops.py +2 -76
  162. mindspore/ops/operations/_infer_ops.py +1 -1
  163. mindspore/ops/operations/_inner_ops.py +71 -94
  164. mindspore/ops/operations/array_ops.py +14 -146
  165. mindspore/ops/operations/comm_ops.py +63 -53
  166. mindspore/ops/operations/custom_ops.py +83 -19
  167. mindspore/ops/operations/debug_ops.py +42 -10
  168. mindspore/ops/operations/manually_defined/_inner.py +12 -0
  169. mindspore/ops/operations/manually_defined/ops_def.py +273 -20
  170. mindspore/ops/operations/math_ops.py +12 -223
  171. mindspore/ops/operations/nn_ops.py +20 -114
  172. mindspore/ops/operations/other_ops.py +7 -4
  173. mindspore/ops/operations/random_ops.py +46 -1
  174. mindspore/ops/primitive.py +18 -6
  175. mindspore/ops_generate/arg_dtype_cast.py +2 -0
  176. mindspore/ops_generate/gen_aclnn_implement.py +11 -11
  177. mindspore/ops_generate/gen_constants.py +36 -0
  178. mindspore/ops_generate/gen_ops.py +67 -52
  179. mindspore/ops_generate/gen_ops_inner_prim.py +1 -1
  180. mindspore/ops_generate/gen_pyboost_func.py +131 -47
  181. mindspore/ops_generate/op_proto.py +10 -3
  182. mindspore/ops_generate/pyboost_utils.py +14 -1
  183. mindspore/ops_generate/template.py +43 -21
  184. mindspore/parallel/__init__.py +3 -1
  185. mindspore/parallel/_auto_parallel_context.py +31 -9
  186. mindspore/parallel/_cell_wrapper.py +85 -0
  187. mindspore/parallel/_parallel_serialization.py +47 -19
  188. mindspore/parallel/_tensor.py +127 -13
  189. mindspore/parallel/_utils.py +53 -22
  190. mindspore/parallel/algo_parameter_config.py +5 -5
  191. mindspore/parallel/checkpoint_transform.py +46 -39
  192. mindspore/parallel/cluster/process_entity/__init__.py +1 -1
  193. mindspore/parallel/cluster/process_entity/_api.py +31 -23
  194. mindspore/parallel/cluster/process_entity/_utils.py +2 -27
  195. mindspore/parallel/parameter_broadcast.py +3 -4
  196. mindspore/parallel/shard.py +162 -31
  197. mindspore/parallel/transform_safetensors.py +1146 -0
  198. mindspore/profiler/__init__.py +2 -1
  199. mindspore/profiler/common/constant.py +29 -0
  200. mindspore/profiler/common/registry.py +47 -0
  201. mindspore/profiler/common/util.py +28 -0
  202. mindspore/profiler/dynamic_profiler.py +694 -0
  203. mindspore/profiler/envprofiling.py +17 -19
  204. mindspore/profiler/parser/ascend_analysis/constant.py +18 -0
  205. mindspore/profiler/parser/ascend_analysis/file_manager.py +25 -4
  206. mindspore/profiler/parser/ascend_analysis/function_event.py +43 -19
  207. mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +31 -26
  208. mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +56 -10
  209. mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +55 -8
  210. mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
  211. mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +27 -20
  212. mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +9 -2
  213. mindspore/profiler/parser/ascend_msprof_exporter.py +5 -4
  214. mindspore/profiler/parser/ascend_timeline_generator.py +27 -25
  215. mindspore/profiler/parser/base_timeline_generator.py +19 -25
  216. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +25 -12
  217. mindspore/profiler/parser/framework_parser.py +1 -391
  218. mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
  219. mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
  220. mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
  221. mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
  222. mindspore/profiler/parser/memory_usage_parser.py +0 -154
  223. mindspore/profiler/parser/profiler_info.py +78 -6
  224. mindspore/profiler/profiler.py +153 -0
  225. mindspore/profiler/profiling.py +285 -413
  226. mindspore/rewrite/__init__.py +1 -2
  227. mindspore/rewrite/common/namespace.py +4 -4
  228. mindspore/rewrite/symbol_tree/symbol_tree.py +3 -3
  229. mindspore/run_check/_check_version.py +39 -104
  230. mindspore/safeguard/rewrite_obfuscation.py +591 -247
  231. mindspore/swresample-4.dll +0 -0
  232. mindspore/swscale-6.dll +0 -0
  233. mindspore/tinyxml2.dll +0 -0
  234. mindspore/train/__init__.py +4 -3
  235. mindspore/train/_utils.py +105 -19
  236. mindspore/train/amp.py +171 -53
  237. mindspore/train/callback/__init__.py +2 -2
  238. mindspore/train/callback/_callback.py +4 -4
  239. mindspore/train/callback/_checkpoint.py +97 -31
  240. mindspore/train/callback/_cluster_monitor.py +1 -1
  241. mindspore/train/callback/_flops_collector.py +1 -0
  242. mindspore/train/callback/_loss_monitor.py +3 -3
  243. mindspore/train/callback/_on_request_exit.py +145 -31
  244. mindspore/train/callback/_summary_collector.py +5 -5
  245. mindspore/train/callback/_tft_register.py +375 -0
  246. mindspore/train/dataset_helper.py +15 -3
  247. mindspore/train/metrics/metric.py +3 -3
  248. mindspore/train/metrics/roc.py +4 -4
  249. mindspore/train/mind_ir_pb2.py +44 -39
  250. mindspore/train/model.py +154 -58
  251. mindspore/train/serialization.py +342 -128
  252. mindspore/turbojpeg.dll +0 -0
  253. mindspore/utils/__init__.py +21 -0
  254. mindspore/utils/utils.py +60 -0
  255. mindspore/version.py +1 -1
  256. {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/METADATA +13 -7
  257. {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/RECORD +260 -254
  258. {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/WHEEL +1 -1
  259. mindspore/include/c_api/ms/abstract.h +0 -67
  260. mindspore/include/c_api/ms/attribute.h +0 -197
  261. mindspore/include/c_api/ms/base/handle_types.h +0 -43
  262. mindspore/include/c_api/ms/base/macros.h +0 -32
  263. mindspore/include/c_api/ms/base/status.h +0 -33
  264. mindspore/include/c_api/ms/base/types.h +0 -283
  265. mindspore/include/c_api/ms/context.h +0 -102
  266. mindspore/include/c_api/ms/graph.h +0 -160
  267. mindspore/include/c_api/ms/node.h +0 -606
  268. mindspore/include/c_api/ms/tensor.h +0 -161
  269. mindspore/include/c_api/ms/value.h +0 -84
  270. mindspore/mindspore_shared_lib.dll +0 -0
  271. mindspore/nn/extend/basic.py +0 -140
  272. mindspore/nn/extend/embedding.py +0 -143
  273. mindspore/nn/extend/layer/normalization.py +0 -109
  274. mindspore/nn/extend/pooling.py +0 -117
  275. mindspore/nn/layer/embedding_service.py +0 -531
  276. mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +0 -93
  277. mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +0 -66
  278. mindspore/ops/extend/__init__.py +0 -53
  279. mindspore/ops/extend/array_func.py +0 -218
  280. mindspore/ops/extend/math_func.py +0 -76
  281. mindspore/ops/extend/nn_func.py +0 -308
  282. mindspore/ops/silent_check.py +0 -162
  283. mindspore/profiler/parser/msadvisor_analyzer.py +0 -82
  284. mindspore/profiler/parser/msadvisor_parser.py +0 -240
  285. mindspore/train/callback/_mindio_ttp.py +0 -443
  286. {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/entry_points.txt +0 -0
  287. {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/top_level.txt +0 -0
@@ -22,6 +22,7 @@ import os
22
22
  import sys
23
23
  import math
24
24
  import numbers
25
+ from contextlib import contextmanager
25
26
  import numpy as np
26
27
  from mindspore import log as logger
27
28
  from mindspore.log import _LogActionOnce
@@ -41,6 +42,8 @@ from mindspore.parallel._ps_context import _is_role_worker, _is_role_pserver, _i
41
42
  _is_ps_mode
42
43
  from mindspore.parallel._ps_context import _reinsert_hash_table_size, _insert_accumu_init_info, _cache_enable
43
44
  from mindspore.common._decorator import deprecated
45
+ from mindspore.communication._comm_helper import _is_initialized
46
+ from mindspore.communication import get_group_size
44
47
  import mindspore.common._monad as monad
45
48
 
46
49
  __all__ = ['Parameter', 'ParameterTuple']
@@ -52,11 +55,32 @@ PARAMETER_NAME_PREFIX_MAX_LEN = 1024
52
55
  _GLOBAL_PARAMETER_KEY = -1
53
56
 
54
57
 
55
- def _is_in_parallel_mode():
58
+ @contextmanager
59
+ def no_init_parameters():
60
+ init_class = globals()["Parameter"]
61
+ setattr(init_class, "init_param", False)
62
+ try:
63
+ yield
64
+ finally:
65
+ setattr(init_class, "init_param", True)
66
+
67
+
68
+ def _is_in_auto_parallel_mode():
56
69
  """Get parallel mode."""
57
70
  return auto_parallel_context().get_parallel_mode() in ["semi_auto_parallel", "auto_parallel"]
58
71
 
59
72
 
73
+ def _is_parallel_mode():
74
+ """ Whether is parallel mode """
75
+ if not _is_initialized() or context.get_context('mode') == context.PYNATIVE_MODE:
76
+ return False
77
+ if os.getenv("RUN_MODE") != "predict":
78
+ return False
79
+ if get_group_size() > 1 and _get_parallel_mode() == "stand_alone":
80
+ return True
81
+ return False
82
+
83
+
60
84
  def init_to_value(init):
61
85
  """
62
86
  Get value of initializer.
@@ -91,6 +115,15 @@ def _get_unique_parameter_key():
91
115
  return _GLOBAL_PARAMETER_KEY
92
116
 
93
117
 
118
+ def _gen_offload_file_path(offload_dir):
119
+ offload_dir = os.path.relpath(offload_dir)
120
+ if not os.path.exists(offload_dir):
121
+ os.makedirs(offload_dir, mode=0o700, exist_ok=True)
122
+ offload_file_path = offload_dir + "/" + str(_get_global_rank()) + "_" + str(
123
+ _get_unique_parameter_key()) + "_" + str(time.time()) + ".data"
124
+ return offload_file_path
125
+
126
+
94
127
  def _offload_if_config(data):
95
128
  """
96
129
  Offload parameter(data size > 512) to file when enable memory offload and offload parameter to disk.
@@ -111,11 +144,7 @@ def _offload_if_config(data):
111
144
  offload_file_path = data.offload_file_path()
112
145
  if offload_file_path is None or offload_file_path == "":
113
146
  offload_dir = offload_context.get("offload_path", "./offload")
114
- offload_dir = os.path.relpath(offload_dir)
115
- if not os.path.exists(offload_dir):
116
- os.makedirs(offload_dir)
117
- offload_file_path = offload_dir + "/" + str(_get_global_rank()) + "_" + str(
118
- _get_unique_parameter_key()) + "_" + str(time.time()) + ".data"
147
+ offload_file_path = _gen_offload_file_path(offload_dir)
119
148
  data.offload(offload_file_path)
120
149
 
121
150
 
@@ -191,6 +220,12 @@ class Parameter(Tensor_):
191
220
  storage_format (str): Only Ascend device target is supported. It is used to specify the format of the weight
192
221
  loaded to the device. By default, the format is not changed. The optional values are ``"FRACTAL_NZ"`` ,
193
222
  ``"NC1HWC0"`` , ``"FRACTAL_Z"`` , etc. Default: ``""`` .
223
+ device(str): Only Ascend device target is supported. It is used to specify the device which the parameter is
224
+ stored. By default, the parameter will be stored on NPU while computing. When the device is specified as
225
+ ``"CPU"``, the parameter will be loaded into the device when it needs to be used, and unloaded to the CPU
226
+ after use. It takes effext only when `memory_offload` is ``"ON"``, `jit_level` is not ``"O2"`` and
227
+ `memory_optimize_level` is ``O0`` in `mindspore.set_context()`. Less device memory is needed when device is
228
+ specified as ``"CPU"``.
194
229
 
195
230
  Examples:
196
231
  >>> import numpy as np
@@ -219,7 +254,8 @@ class Parameter(Tensor_):
219
254
  def __new__(cls, default_input, *args, **kwargs):
220
255
  init_data_flag = bool(isinstance(default_input, Tensor) and default_input.has_init)
221
256
  rc = sys.getrefcount(default_input)
222
- input_class, *class_init_args = Parameter._get_parameter_new_args(default_input, rc)
257
+ init_param = getattr(cls, "init_param", True)
258
+ input_class, *class_init_args = Parameter._get_parameter_new_args(default_input, rc, init_param)
223
259
  new_type = Parameter._get_base_class(input_class)
224
260
  obj = input_class.__new__(new_type)
225
261
  input_class.__init__(obj, *class_init_args)
@@ -244,7 +280,7 @@ class Parameter(Tensor_):
244
280
  Parameter, (data, self.name, self.requires_grad, self.layerwise_parallel))
245
281
 
246
282
  def __init__(self, default_input, name=None, requires_grad=True, layerwise_parallel=False, parallel_optimizer=True,
247
- storage_format=""):
283
+ storage_format="", device=None):
248
284
  self.param_info = ParamInfo()
249
285
  self.init_in_server = False
250
286
  self.name = name
@@ -263,7 +299,7 @@ class Parameter(Tensor_):
263
299
  self.requires_aggr = True
264
300
  self._cast_type = None
265
301
  self._unique = False
266
- self.is_in_parallel = _is_in_parallel_mode()
302
+ self.is_in_parallel = _is_in_auto_parallel_mode()
267
303
  self.is_in_shard = False
268
304
  self._pipeline_stage_list = []
269
305
  self.slice_num = 1
@@ -296,6 +332,10 @@ class Parameter(Tensor_):
296
332
  f" 'numpy.ndarray', 'list']. But got type {type(default_input)}.")
297
333
  self.param_info.parameter_shape = self.shape
298
334
  self.param_info.storage_format = storage_format
335
+ if device is not None:
336
+ if device != "CPU":
337
+ raise ValueError(f"Only 'CPU' is supported for device, but got ${device}.")
338
+ self._set_user_data("parameter_device", device)
299
339
 
300
340
  import mindspore.ops.operations.other_ops as other_ops
301
341
  self.load = other_ops.Load()
@@ -327,7 +367,7 @@ class Parameter(Tensor_):
327
367
  return new_type
328
368
 
329
369
  @staticmethod
330
- def _get_parameter_new_args(data, rc):
370
+ def _get_parameter_new_args(data, rc, init_param=True):
331
371
  """Set `set_data` of current `Parameter`."""
332
372
  if isinstance(data, bool):
333
373
  raise ValueError('Parameter data can not be `bool`')
@@ -342,7 +382,8 @@ class Parameter(Tensor_):
342
382
  return (Tensor, data.asnumpy(), mstype.qint4x2)
343
383
  return (Tensor, data.asnumpy())
344
384
 
345
- not_init_data = _is_role_sched() or (_is_role_pserver() and _cache_enable()) or _is_in_parallel_mode()
385
+ not_init_data = not init_param or _is_role_sched() or (_is_role_pserver() and _cache_enable()) \
386
+ or _is_in_auto_parallel_mode() or _is_parallel_mode()
346
387
  if not_init_data:
347
388
  # do not init data while in auto parallel.
348
389
  return (Tensor, None, data.dtype, get_slice_shape(data.dtype, data.shape), data.init)
@@ -368,7 +409,7 @@ class Parameter(Tensor_):
368
409
 
369
410
  Tutorial Examples:
370
411
  - `Parameter Server Mode
371
- <https://www.mindspore.cn/tutorials/experts/en/master/parallel/parameter_server_training.html>`_
412
+ <https://www.mindspore.cn/docs/en/master/model_train/parallel/parameter_server_training.html>`_
372
413
  """
373
414
  if not _is_ps_mode() or not (_is_role_worker() or _is_role_pserver() or _is_role_sched()):
374
415
  raise RuntimeError("Must complete following two steps before calling set_param_ps: \n"
@@ -616,6 +657,9 @@ class Parameter(Tensor_):
616
657
  shape = self.shape if self.slice_num == 1 else self.param_info.origin_shape
617
658
  dtype = self.dtype
618
659
  x.set_data(initializer(init, shape=shape, dtype=dtype))
660
+ device = self._get_user_data("parameter_device")
661
+ if device is not None:
662
+ x._set_user_data("parameter_device", device)
619
663
  return x
620
664
 
621
665
  @property
@@ -942,8 +986,10 @@ class Parameter(Tensor_):
942
986
  >>> x = Parameter(Tensor(np.array([[1, 2], [3, 4]], dtype=np.float32)), name="param")
943
987
  >>> x.init_data()
944
988
  """
945
- if self.is_default_input_init and self.is_in_parallel != _is_in_parallel_mode():
989
+ if self.is_default_input_init and self.is_in_parallel != _is_in_auto_parallel_mode():
946
990
  raise RuntimeError("Must set or change parallel mode before any initializer Tensor created.")
991
+ if hasattr(self, "init_param") and self.init_param:
992
+ return self
947
993
  if self.init_mode is None:
948
994
  return self
949
995
  if self.inited_param is not None:
@@ -1026,8 +1072,9 @@ class ParameterTuple(tuple):
1026
1072
  Tuple, the new Parameter tuple.
1027
1073
 
1028
1074
  Tutorial Examples:
1029
- - `Cell and Parameter - Parameter Tuple
1030
- <https://mindspore.cn/tutorials/en/master/advanced/modules/layer.html#parameter-tuple>`_
1075
+ - `Tensor and Parameter - Parameter Tuple
1076
+ <https://mindspore.cn/docs/en/master/model_train/model_building/tensor_and_parameter.html
1077
+ #parameter-tuple>`_
1031
1078
  """
1032
1079
  Validator.check_str_by_regular(prefix)
1033
1080
  new = []
@@ -23,8 +23,10 @@ from mindspore.common.tensor import Tensor
23
23
  from mindspore import ops
24
24
  from mindspore.ops.composite import GradOperation
25
25
  from mindspore.common._register_for_recompute import recompute_registry
26
- from mindspore.common.api import _pynative_executor
26
+ from mindspore.common.api import _pynative_executor, _no_grad
27
27
  from mindspore.common.generator import get_rng_state, set_rng_state
28
+ from mindspore.train.amp import amp_decorator
29
+ from mindspore._c_expression.amp import get_curr_amp_strategy
28
30
 
29
31
 
30
32
  class _WrapCell(Cell):
@@ -34,7 +36,7 @@ class _WrapCell(Cell):
34
36
  """
35
37
 
36
38
  def __init__(self, function):
37
- super(_WrapCell, self).__init__()
39
+ super(_WrapCell, self).__init__(auto_prefix=False)
38
40
  self.function = function
39
41
 
40
42
  def construct(self, *args, **kwargs):
@@ -56,6 +58,7 @@ class _RecomputeCell(Cell):
56
58
  self.args = []
57
59
  self.kwargs = []
58
60
  self.wrap_cell = _WrapCell(block)
61
+ self.wrap_cell.set_inputs()
59
62
 
60
63
  self.net = block
61
64
  self.internal_params = []
@@ -64,15 +67,18 @@ class _RecomputeCell(Cell):
64
67
  self._add_attr("is_cell_recompute", "True")
65
68
  self.grad = GradOperation(get_all=True, get_by_list=True, sens_param=True)
66
69
  self.init_mixed_precision_type(block)
70
+ self.amp_strategy = None
67
71
 
68
72
  def construct(self, *args, **kwargs):
69
- _check_input_args_validate(self.net, args)
73
+ _check_input_args_validate(self.net, args, kwargs)
70
74
  self.args.append(args)
71
75
  self.kwargs.append(kwargs)
72
76
  self.save_rng_state = kwargs.pop("save_rng_state", True)
73
77
  if self.save_rng_state:
74
78
  self.cpu_rng_state = get_rng_state()
75
- return self.net(*args, **kwargs)
79
+ self.amp_strategy = get_curr_amp_strategy()
80
+ with _no_grad():
81
+ return self.net(*args, **kwargs)
76
82
 
77
83
  def bprop(self, *args):
78
84
  """
@@ -86,14 +92,23 @@ class _RecomputeCell(Cell):
86
92
  self.args.pop()
87
93
  self.kwargs.pop()
88
94
  if kwargs:
89
- input_args = list(input_args) + list(kwargs.values())
95
+ input_args_for_check = list(input_args) + list(kwargs.values())
96
+ else:
97
+ input_args_for_check = list(input_args)
90
98
  # To detach inputs to avoid erasing auto grad meta info of origin inputs.
91
99
  input_args = _detach_input(input_args)
100
+ kwargs = _detach_input(kwargs)
101
+ kwargs['sens'] = grad_input
92
102
  try:
93
103
  pre_rng_state = get_rng_state()
94
104
  set_rng_state(self.cpu_rng_state)
95
105
  _pynative_executor.set_is_run_recompute(True)
96
- grads = self.grad(self.net, self.internal_params)(*input_args, grad_input)
106
+ if self.amp_strategy:
107
+ with amp_decorator(self.amp_strategy.get_amp_level(), self.amp_strategy.get_amp_dtype(),
108
+ self.amp_strategy.get_white_list(), self.amp_strategy.get_black_list()):
109
+ grads = self.grad(self.net, self.internal_params)(*input_args, **kwargs)
110
+ else:
111
+ grads = self.grad(self.net, self.internal_params)(*input_args, **kwargs)
97
112
  _pynative_executor.set_is_run_recompute(False)
98
113
  set_rng_state(pre_rng_state)
99
114
  except Exception as err:
@@ -101,7 +116,7 @@ class _RecomputeCell(Cell):
101
116
  raise err
102
117
  weights = OrderedDict()
103
118
  input_grads = list(grads[0])
104
- _padding_input_grads(input_args, input_grads)
119
+ _padding_input_grads(input_args_for_check, input_grads)
105
120
  for i, param in enumerate(self.internal_params):
106
121
  weights[param] = grads[1][i]
107
122
  return tuple(input_grads), weights
@@ -121,6 +136,7 @@ class _RecomputeCell(Cell):
121
136
  # To avoid sub cell same name
122
137
  block.__self__.check_names_and_refresh_name()
123
138
  self.internal_params = block.__self__.trainable_params()
139
+ self.wrap_cell.mixed_precision_type = block.__self__.get_mixed_precision_type()
124
140
  self.wrap_cell.set_mixed_precision_type(block.__self__.get_mixed_precision_type())
125
141
  self.net = self.wrap_cell
126
142
  else:
@@ -128,13 +144,14 @@ class _RecomputeCell(Cell):
128
144
  "only support Cell object or MethodType function!")
129
145
 
130
146
 
131
- def _check_input_args_validate(block, args):
147
+ def _check_input_args_validate(block, args, kwargs):
132
148
  """
133
149
  Check recompute input args validate
134
150
  :param args:
135
151
  :return:
136
152
  """
137
- if not any([isinstance(arg, Tensor) for arg in args]):
153
+ if not (any([isinstance(arg, Tensor) for arg in args]) or \
154
+ any([isinstance(arg, Tensor) for arg in kwargs.values()])):
138
155
  logger.warning("None of the inputs of function are tensors, which not need use recompute!")
139
156
  for arg in args:
140
157
  if isinstance(arg, (tuple, list)):
@@ -168,6 +185,11 @@ def _padding_input_grads(args, input_grads):
168
185
 
169
186
 
170
187
  def _detach_input(input_arg):
188
+ """
189
+ Detach input
190
+ :param input_arg:
191
+ :return: detach output
192
+ """
171
193
  if isinstance(input_arg, Tensor):
172
194
  return ops.stop_gradient(input_arg)
173
195
  if isinstance(input_arg, (list, tuple)):
@@ -175,6 +197,14 @@ def _detach_input(input_arg):
175
197
  for arg in input_arg:
176
198
  detach_inputs.append(_detach_input(arg))
177
199
  return detach_inputs if isinstance(input_arg, list) else tuple(detach_inputs)
200
+ if isinstance(input_arg, dict):
201
+ detach_inputs = {}
202
+ for key, val in input_arg.items():
203
+ if isinstance(val, Tensor):
204
+ detach_inputs[key] = ops.stop_gradient(val)
205
+ else:
206
+ detach_inputs[key] = val
207
+ return detach_inputs
178
208
  return input_arg
179
209
 
180
210
 
@@ -97,7 +97,8 @@ class RowTensor(RowTensorInner):
97
97
  [0, 0]]
98
98
 
99
99
  .. warning::
100
- This is an experimental API that is subjected to change or deletion.
100
+ - This is an experimental API that is subjected to change or deletion.
101
+ - If use PyNative mode, set "export MS_PYNATIVE_CONFIG_STATIC_SHAPE=1".
101
102
 
102
103
  Args:
103
104
  indices (Tensor): A 1-D integer Tensor of shape :math:`(d_0)` . Default: ``None``.
@@ -226,10 +227,11 @@ class COOTensor(COOTensor_):
226
227
 
227
228
  Common arithmetic operations include: addition (+), subtraction (-), multiplication (*),
228
229
  and division (/). For details about operations supported by `COOTensor`, see
229
- `operators <https://www.mindspore.cn/docs/en/master/note/static_graph_syntax_support.html#operators>`_.
230
+ `operators <https://www.mindspore.cn/docs/en/master/model_train/program_form/static_graph.html#operators>`_.
230
231
 
231
232
  .. warning::
232
233
  - This is an experimental API that is subject to change or deletion.
234
+ - If use PyNative mode, set "export MS_PYNATIVE_CONFIG_STATIC_SHAPE=1".
233
235
  - Currently, duplicate coordinates in the indices will not be coalesced.
234
236
  If the indices contain out-of-bound values, the result will be undefined.
235
237
 
@@ -646,6 +648,7 @@ class CSRTensor(CSRTensor_):
646
648
  [1., 2., 3., 4., 5., 6.], shape is (3, 5), then the dense representation of the sparse tensor will be:
647
649
 
648
650
  .. code-block::
651
+
649
652
  [[1., 0., 0., 2., 0.],
650
653
  [0., 3., 4., 0., 5.],
651
654
  [0., 0., 6., 0., 0.]]
@@ -668,10 +671,11 @@ class CSRTensor(CSRTensor_):
668
671
 
669
672
  Common arithmetic operations include: addition (+), subtraction (-), multiplication (*),
670
673
  and division (/). For details about operations supported by `CSRTensor`, see
671
- `operators <https://www.mindspore.cn/docs/en/master/note/static_graph_syntax_support.html#operators>`_.
674
+ `operators <https://www.mindspore.cn/docs/en/master/model_train/program_form/static_graph.html#operators>`_.
672
675
 
673
676
  .. warning::
674
677
  - This is an experimental API that is subjected to change.
678
+ - If use PyNative mode, set "export MS_PYNATIVE_CONFIG_STATIC_SHAPE=1".
675
679
  - If the values given by `indptr` or `indices` are invalid, the results may be undefined. Invalid values include
676
680
  when the length of `values` or `indices` exceeds the range indicated by `indptr`, and when the columns
677
681
  indicated by `indices` are repeated on the same row.