mindspore 2.3.0__cp39-cp39-win_amd64.whl → 2.4.0__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (285) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +3 -1
  3. mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
  4. mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
  5. mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
  6. mindspore/_checkparam.py +50 -9
  7. mindspore/_extends/parse/compile_config.py +41 -0
  8. mindspore/_extends/parse/parser.py +9 -7
  9. mindspore/_extends/parse/standard_method.py +52 -14
  10. mindspore/_extends/pijit/pijit_func_white_list.py +350 -24
  11. mindspore/amp.py +24 -10
  12. mindspore/avcodec-59.dll +0 -0
  13. mindspore/avdevice-59.dll +0 -0
  14. mindspore/avfilter-8.dll +0 -0
  15. mindspore/avformat-59.dll +0 -0
  16. mindspore/avutil-57.dll +0 -0
  17. mindspore/common/__init__.py +6 -4
  18. mindspore/common/_pijit_context.py +190 -0
  19. mindspore/common/_register_for_tensor.py +2 -1
  20. mindspore/common/_tensor_overload.py +139 -0
  21. mindspore/common/api.py +102 -87
  22. mindspore/common/dump.py +5 -6
  23. mindspore/common/generator.py +1 -7
  24. mindspore/common/hook_handle.py +14 -26
  25. mindspore/common/mindir_util.py +2 -2
  26. mindspore/common/parameter.py +46 -13
  27. mindspore/common/recompute.py +39 -9
  28. mindspore/common/sparse_tensor.py +7 -3
  29. mindspore/common/tensor.py +209 -29
  30. mindspore/communication/__init__.py +1 -1
  31. mindspore/communication/_comm_helper.py +38 -3
  32. mindspore/communication/comm_func.py +310 -55
  33. mindspore/communication/management.py +14 -14
  34. mindspore/context.py +123 -22
  35. mindspore/dataset/__init__.py +1 -1
  36. mindspore/dataset/audio/__init__.py +1 -1
  37. mindspore/dataset/core/config.py +7 -0
  38. mindspore/dataset/core/validator_helpers.py +7 -0
  39. mindspore/dataset/engine/cache_client.py +1 -1
  40. mindspore/dataset/engine/datasets.py +72 -44
  41. mindspore/dataset/engine/datasets_audio.py +7 -7
  42. mindspore/dataset/engine/datasets_standard_format.py +53 -3
  43. mindspore/dataset/engine/datasets_text.py +20 -20
  44. mindspore/dataset/engine/datasets_user_defined.py +174 -104
  45. mindspore/dataset/engine/datasets_vision.py +33 -33
  46. mindspore/dataset/engine/iterators.py +29 -0
  47. mindspore/dataset/engine/obs/util.py +7 -0
  48. mindspore/dataset/engine/queue.py +114 -60
  49. mindspore/dataset/engine/serializer_deserializer.py +2 -2
  50. mindspore/dataset/engine/validators.py +34 -14
  51. mindspore/dataset/text/__init__.py +1 -4
  52. mindspore/dataset/transforms/__init__.py +0 -3
  53. mindspore/dataset/utils/line_reader.py +2 -0
  54. mindspore/dataset/vision/__init__.py +1 -4
  55. mindspore/dataset/vision/utils.py +1 -1
  56. mindspore/dataset/vision/validators.py +2 -1
  57. mindspore/dnnl.dll +0 -0
  58. mindspore/{nn/extend → experimental/es}/__init__.py +4 -11
  59. mindspore/experimental/es/embedding_service.py +883 -0
  60. mindspore/{nn/layer → experimental/es}/embedding_service_layer.py +218 -30
  61. mindspore/experimental/llm_boost/__init__.py +21 -0
  62. mindspore/{nn/extend/layer → experimental/llm_boost/atb}/__init__.py +4 -8
  63. mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
  64. mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
  65. mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
  66. mindspore/experimental/llm_boost/register.py +129 -0
  67. mindspore/experimental/llm_boost/utils.py +31 -0
  68. mindspore/experimental/optim/adamw.py +85 -0
  69. mindspore/experimental/optim/optimizer.py +3 -0
  70. mindspore/hal/__init__.py +3 -3
  71. mindspore/hal/contiguous_tensors_handle.py +175 -0
  72. mindspore/hal/stream.py +18 -0
  73. mindspore/include/api/model_group.h +13 -1
  74. mindspore/include/api/types.h +10 -10
  75. mindspore/include/dataset/config.h +2 -2
  76. mindspore/include/dataset/constants.h +2 -2
  77. mindspore/include/dataset/execute.h +2 -2
  78. mindspore/include/dataset/vision.h +4 -0
  79. mindspore/jpeg62.dll +0 -0
  80. mindspore/log.py +1 -1
  81. mindspore/mindrecord/filewriter.py +68 -51
  82. mindspore/mindspore_backend.dll +0 -0
  83. mindspore/mindspore_common.dll +0 -0
  84. mindspore/mindspore_core.dll +0 -0
  85. mindspore/mindspore_glog.dll +0 -0
  86. mindspore/mindspore_np_dtype.dll +0 -0
  87. mindspore/mindspore_ops.dll +0 -0
  88. mindspore/mint/__init__.py +495 -46
  89. mindspore/mint/distributed/__init__.py +31 -0
  90. mindspore/mint/distributed/distributed.py +254 -0
  91. mindspore/mint/nn/__init__.py +266 -21
  92. mindspore/mint/nn/functional.py +125 -19
  93. mindspore/mint/nn/layer/__init__.py +39 -0
  94. mindspore/mint/nn/layer/activation.py +133 -0
  95. mindspore/mint/nn/layer/normalization.py +477 -0
  96. mindspore/mint/nn/layer/pooling.py +110 -0
  97. mindspore/mint/optim/adamw.py +28 -7
  98. mindspore/mint/special/__init__.py +63 -0
  99. mindspore/multiprocessing/__init__.py +2 -1
  100. mindspore/nn/__init__.py +0 -1
  101. mindspore/nn/cell.py +275 -93
  102. mindspore/nn/layer/activation.py +211 -44
  103. mindspore/nn/layer/basic.py +113 -3
  104. mindspore/nn/layer/embedding.py +120 -2
  105. mindspore/nn/layer/normalization.py +101 -5
  106. mindspore/nn/layer/padding.py +34 -48
  107. mindspore/nn/layer/pooling.py +161 -7
  108. mindspore/nn/layer/transformer.py +3 -3
  109. mindspore/nn/loss/__init__.py +2 -2
  110. mindspore/nn/loss/loss.py +84 -6
  111. mindspore/nn/optim/__init__.py +2 -1
  112. mindspore/nn/optim/adadelta.py +1 -1
  113. mindspore/nn/optim/adam.py +1 -1
  114. mindspore/nn/optim/lamb.py +1 -1
  115. mindspore/nn/optim/tft_wrapper.py +127 -0
  116. mindspore/nn/wrap/cell_wrapper.py +12 -23
  117. mindspore/nn/wrap/grad_reducer.py +5 -5
  118. mindspore/nn/wrap/loss_scale.py +17 -3
  119. mindspore/numpy/__init__.py +1 -1
  120. mindspore/numpy/array_creations.py +65 -68
  121. mindspore/numpy/array_ops.py +64 -60
  122. mindspore/numpy/fft.py +610 -75
  123. mindspore/numpy/logic_ops.py +11 -10
  124. mindspore/numpy/math_ops.py +85 -84
  125. mindspore/numpy/utils_const.py +4 -4
  126. mindspore/opencv_core452.dll +0 -0
  127. mindspore/opencv_imgcodecs452.dll +0 -0
  128. mindspore/opencv_imgproc452.dll +0 -0
  129. mindspore/ops/__init__.py +6 -4
  130. mindspore/ops/_grad_experimental/grad_comm_ops.py +47 -3
  131. mindspore/ops/_grad_experimental/grad_math_ops.py +0 -22
  132. mindspore/ops/_vmap/vmap_array_ops.py +2 -4
  133. mindspore/ops/_vmap/vmap_math_ops.py +17 -1
  134. mindspore/ops/_vmap/vmap_nn_ops.py +43 -2
  135. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +85 -7
  136. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +2 -0
  137. mindspore/ops/auto_generate/gen_extend_func.py +734 -13
  138. mindspore/ops/auto_generate/gen_ops_def.py +2420 -381
  139. mindspore/ops/auto_generate/gen_ops_prim.py +5196 -1659
  140. mindspore/ops/auto_generate/pyboost_inner_prim.py +176 -56
  141. mindspore/ops/composite/base.py +85 -48
  142. mindspore/ops/composite/multitype_ops/_compile_utils.py +1 -0
  143. mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -2
  144. mindspore/ops/function/__init__.py +22 -0
  145. mindspore/ops/function/array_func.py +490 -153
  146. mindspore/ops/function/debug_func.py +113 -1
  147. mindspore/ops/function/fft_func.py +15 -2
  148. mindspore/ops/function/grad/grad_func.py +3 -2
  149. mindspore/ops/function/math_func.py +558 -207
  150. mindspore/ops/function/nn_func.py +817 -383
  151. mindspore/ops/function/other_func.py +3 -2
  152. mindspore/ops/function/random_func.py +184 -8
  153. mindspore/ops/function/reshard_func.py +13 -11
  154. mindspore/ops/function/sparse_unary_func.py +1 -1
  155. mindspore/ops/function/vmap_func.py +3 -2
  156. mindspore/ops/functional.py +24 -14
  157. mindspore/ops/op_info_register.py +3 -3
  158. mindspore/ops/operations/__init__.py +6 -1
  159. mindspore/ops/operations/_grad_ops.py +2 -76
  160. mindspore/ops/operations/_infer_ops.py +1 -1
  161. mindspore/ops/operations/_inner_ops.py +71 -94
  162. mindspore/ops/operations/array_ops.py +12 -146
  163. mindspore/ops/operations/comm_ops.py +42 -53
  164. mindspore/ops/operations/custom_ops.py +83 -19
  165. mindspore/ops/operations/debug_ops.py +42 -10
  166. mindspore/ops/operations/manually_defined/_inner.py +12 -0
  167. mindspore/ops/operations/manually_defined/ops_def.py +265 -10
  168. mindspore/ops/operations/math_ops.py +12 -223
  169. mindspore/ops/operations/nn_ops.py +20 -114
  170. mindspore/ops/operations/other_ops.py +7 -4
  171. mindspore/ops/operations/random_ops.py +46 -1
  172. mindspore/ops/primitive.py +18 -6
  173. mindspore/ops_generate/arg_dtype_cast.py +2 -0
  174. mindspore/ops_generate/gen_aclnn_implement.py +11 -11
  175. mindspore/ops_generate/gen_constants.py +36 -0
  176. mindspore/ops_generate/gen_ops.py +67 -52
  177. mindspore/ops_generate/gen_ops_inner_prim.py +1 -1
  178. mindspore/ops_generate/gen_pyboost_func.py +131 -47
  179. mindspore/ops_generate/op_proto.py +10 -3
  180. mindspore/ops_generate/pyboost_utils.py +14 -1
  181. mindspore/ops_generate/template.py +43 -21
  182. mindspore/parallel/__init__.py +3 -1
  183. mindspore/parallel/_auto_parallel_context.py +28 -8
  184. mindspore/parallel/_cell_wrapper.py +83 -0
  185. mindspore/parallel/_parallel_serialization.py +47 -19
  186. mindspore/parallel/_tensor.py +81 -11
  187. mindspore/parallel/_utils.py +13 -1
  188. mindspore/parallel/algo_parameter_config.py +5 -5
  189. mindspore/parallel/checkpoint_transform.py +46 -39
  190. mindspore/parallel/cluster/process_entity/__init__.py +1 -1
  191. mindspore/parallel/cluster/process_entity/_api.py +31 -23
  192. mindspore/parallel/cluster/process_entity/_utils.py +2 -27
  193. mindspore/parallel/parameter_broadcast.py +3 -4
  194. mindspore/parallel/shard.py +162 -31
  195. mindspore/parallel/transform_safetensors.py +993 -0
  196. mindspore/profiler/__init__.py +2 -1
  197. mindspore/profiler/common/constant.py +29 -0
  198. mindspore/profiler/common/registry.py +47 -0
  199. mindspore/profiler/common/util.py +28 -0
  200. mindspore/profiler/dynamic_profiler.py +694 -0
  201. mindspore/profiler/envprofiling.py +17 -19
  202. mindspore/profiler/parser/ascend_analysis/constant.py +18 -0
  203. mindspore/profiler/parser/ascend_analysis/file_manager.py +25 -4
  204. mindspore/profiler/parser/ascend_analysis/function_event.py +43 -19
  205. mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +31 -26
  206. mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +56 -10
  207. mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +55 -8
  208. mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
  209. mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +27 -20
  210. mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +9 -2
  211. mindspore/profiler/parser/ascend_msprof_exporter.py +5 -4
  212. mindspore/profiler/parser/ascend_timeline_generator.py +27 -25
  213. mindspore/profiler/parser/base_timeline_generator.py +19 -25
  214. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +25 -12
  215. mindspore/profiler/parser/framework_parser.py +1 -391
  216. mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
  217. mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
  218. mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
  219. mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
  220. mindspore/profiler/parser/memory_usage_parser.py +0 -154
  221. mindspore/profiler/parser/profiler_info.py +78 -6
  222. mindspore/profiler/profiler.py +153 -0
  223. mindspore/profiler/profiling.py +280 -412
  224. mindspore/rewrite/__init__.py +1 -2
  225. mindspore/rewrite/common/namespace.py +4 -4
  226. mindspore/rewrite/symbol_tree/symbol_tree.py +3 -3
  227. mindspore/run_check/_check_version.py +36 -103
  228. mindspore/safeguard/rewrite_obfuscation.py +591 -247
  229. mindspore/swresample-4.dll +0 -0
  230. mindspore/swscale-6.dll +0 -0
  231. mindspore/tinyxml2.dll +0 -0
  232. mindspore/train/__init__.py +4 -3
  233. mindspore/train/_utils.py +28 -2
  234. mindspore/train/amp.py +171 -53
  235. mindspore/train/callback/__init__.py +2 -2
  236. mindspore/train/callback/_callback.py +4 -4
  237. mindspore/train/callback/_checkpoint.py +85 -22
  238. mindspore/train/callback/_cluster_monitor.py +1 -1
  239. mindspore/train/callback/_flops_collector.py +1 -0
  240. mindspore/train/callback/_loss_monitor.py +3 -3
  241. mindspore/train/callback/_on_request_exit.py +134 -31
  242. mindspore/train/callback/_summary_collector.py +5 -5
  243. mindspore/train/callback/_tft_register.py +352 -0
  244. mindspore/train/dataset_helper.py +7 -3
  245. mindspore/train/metrics/metric.py +3 -3
  246. mindspore/train/metrics/roc.py +4 -4
  247. mindspore/train/mind_ir_pb2.py +44 -39
  248. mindspore/train/model.py +134 -58
  249. mindspore/train/serialization.py +336 -112
  250. mindspore/turbojpeg.dll +0 -0
  251. mindspore/utils/__init__.py +21 -0
  252. mindspore/utils/utils.py +60 -0
  253. mindspore/version.py +1 -1
  254. {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/METADATA +6 -2
  255. {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/RECORD +258 -252
  256. mindspore/include/c_api/ms/abstract.h +0 -67
  257. mindspore/include/c_api/ms/attribute.h +0 -197
  258. mindspore/include/c_api/ms/base/handle_types.h +0 -43
  259. mindspore/include/c_api/ms/base/macros.h +0 -32
  260. mindspore/include/c_api/ms/base/status.h +0 -33
  261. mindspore/include/c_api/ms/base/types.h +0 -283
  262. mindspore/include/c_api/ms/context.h +0 -102
  263. mindspore/include/c_api/ms/graph.h +0 -160
  264. mindspore/include/c_api/ms/node.h +0 -606
  265. mindspore/include/c_api/ms/tensor.h +0 -161
  266. mindspore/include/c_api/ms/value.h +0 -84
  267. mindspore/mindspore_shared_lib.dll +0 -0
  268. mindspore/nn/extend/basic.py +0 -140
  269. mindspore/nn/extend/embedding.py +0 -143
  270. mindspore/nn/extend/layer/normalization.py +0 -109
  271. mindspore/nn/extend/pooling.py +0 -117
  272. mindspore/nn/layer/embedding_service.py +0 -531
  273. mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +0 -93
  274. mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +0 -66
  275. mindspore/ops/extend/__init__.py +0 -53
  276. mindspore/ops/extend/array_func.py +0 -218
  277. mindspore/ops/extend/math_func.py +0 -76
  278. mindspore/ops/extend/nn_func.py +0 -308
  279. mindspore/ops/silent_check.py +0 -162
  280. mindspore/profiler/parser/msadvisor_analyzer.py +0 -82
  281. mindspore/profiler/parser/msadvisor_parser.py +0 -240
  282. mindspore/train/callback/_mindio_ttp.py +0 -443
  283. {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/WHEEL +0 -0
  284. {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/entry_points.txt +0 -0
  285. {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/top_level.txt +0 -0
mindspore/common/api.py CHANGED
@@ -38,12 +38,13 @@ from mindspore.common.tensor import Tensor as PythonTensor
38
38
  from mindspore.common.sparse_tensor import CSRTensor as PythonCSRTensor
39
39
  from mindspore.common.sparse_tensor import COOTensor as PythonCOOTensor
40
40
  from mindspore.common.sparse_tensor import RowTensor as PythonRowTensor
41
+ from mindspore._c_expression.amp import get_curr_amp_strategy
41
42
  from mindspore._c_expression import GraphExecutor_, Tensor, CSRTensor, RowTensor, COOTensor, \
42
43
  PyNativeExecutor_, verify_inputs_signature, init_exec_dataset, _set_dataset_mode_config, init_pipeline, \
43
- _ms_memory_recycle, _bind_device_ctx, jit_mode_pi_enable, jit_mode_pi_compile
44
+ _ms_memory_recycle, _bind_device_ctx
44
45
  from mindspore.parallel._ps_context import _is_role_sched
45
46
  from mindspore.parallel._utils import _check_full_batch, _get_parameter_broadcast, _is_pynative_parallel, \
46
- _is_in_auto_parallel_mode
47
+ _is_in_auto_parallel_mode, _is_parallel_mode
47
48
  from mindspore import _checkparam as Validator
48
49
  from mindspore._checkparam import is_stub_tensor
49
50
  from mindspore.common._utils import is_shape_unknown
@@ -51,6 +52,8 @@ from mindspore.common.mutable import mutable
51
52
  from mindspore.common._register_for_adapter import ms_adapter_registry
52
53
  from mindspore.common.auto_dynamic_shape import get_auto_dynamic_shape_args, update_auto_dynamic_shape_phase, \
53
54
  get_auto_dynamic_shape_args_with_check_input_signature, update_auto_dynamic_shape_phase_with_check_input_signature
55
+ from mindspore.common._pijit_context import PIJitCaptureContext
56
+ from mindspore.common.parameter import Parameter
54
57
 
55
58
  # Store ms_function class compiled pipeline cache.
56
59
  ms_compile_cache = set()
@@ -513,6 +516,19 @@ def _generate_dyn_compile_args(compile_args, dyn_args):
513
516
  return tuple(new_compile_args)
514
517
 
515
518
 
519
+ def _get_parameter_ids(args, kwargs):
520
+ """Get the ids of parameters."""
521
+ parameter_ids = ""
522
+ for arg in args:
523
+ if isinstance(arg, Parameter):
524
+ parameter_ids += str(id(arg))
525
+ for _, value in kwargs.items():
526
+ # The type of key is usually String type.
527
+ if isinstance(value, Parameter):
528
+ parameter_ids += str(id(value))
529
+ return parameter_ids
530
+
531
+
516
532
  class _MindsporeFunctionExecutor:
517
533
  """
518
534
  Represents a function compiled by graph compiler.
@@ -625,6 +641,10 @@ class _MindsporeFunctionExecutor:
625
641
 
626
642
  self._graph_executor.set_enable_tuple_broaden(self.enable_tuple_broaden)
627
643
  key = self._graph_executor.generate_arguments_key(self.fn, compile_args, kwargs, self.enable_tuple_broaden)
644
+
645
+ parameter_ids = _get_parameter_ids(args, kwargs)
646
+ if parameter_ids != "":
647
+ key = str(key) + '.' + parameter_ids
628
648
  phase = generate_name + '.' + str(key)
629
649
 
630
650
  update_auto_dynamic_shape_phase_with_check_input_signature(compile_args, key_id, phase, self.input_signature)
@@ -783,31 +803,28 @@ def _get_jit_hash(hash_input):
783
803
  return _get_obj_id(hash_input)
784
804
 
785
805
 
786
- def _update_graph_executor_config(jit_config):
787
- """Update GraphExecutor jit_config"""
788
- if isinstance(jit_config, JitConfig):
789
- jit_config = jit_config.jit_config_dict
790
- if not isinstance(jit_config, dict):
791
- return
792
- valid_config = dict()
793
- for k, v in jit_config.items():
794
- valid_config[str(k)] = str(v)
795
- GraphExecutor_.get_instance().set_jit_config(JitConfig(**valid_config).jit_config_dict)
796
-
797
-
798
806
  def jit(fn=None, mode="PSJit", input_signature=None, hash_args=None, jit_config=None, compile_once=False):
799
807
  """
800
808
  Create a callable MindSpore graph from a Python function.
801
809
 
802
810
  This allows the MindSpore runtime to apply optimizations based on graph.
803
811
 
812
+ Note:
813
+ - If `input_signature` is specified, each input of `fn` must be a Tensor. And the input arguments for `fn`
814
+ will not accept `**kwargs`.
815
+ - It is not supported to run a function with decoration @jit(mode=“PIJit”)
816
+ in static graph mode, in which case the decoration @jit(mode=“PIJit”) is considered invalid.
817
+ - Calls to functions with decorated @jit(mode=“PIJit”) inside functions
818
+ decorated with @jit(mode=“PIJit”) are not supported,
819
+ and the decoration @jit(mode=“PIJit”) is considered invalid.
820
+
804
821
  Args:
805
822
  fn (Function): The Python function that will be run as a graph. Default: ``None`` .
806
823
  mode (str): The type of jit used, the value of mode should be ``PIJit`` or ``PSJit``. Default: ``PSJit`` .
807
824
 
808
- - `PSJit <https://www.mindspore.cn/docs/en/master/note/static_graph_syntax_support.html>`_ :
825
+ - `PSJit <https://www.mindspore.cn/docs/en/master/model_train/program_form/static_graph.html>`_ :
809
826
  Parse python ast to build graph.
810
- - `PIJit <https://www.mindspore.cn/docs/en/master/design/dynamic_graph_and_static_graph.html>`_ :
827
+ - `PIJit <https://www.mindspore.cn/docs/en/master/model_train/program_form/pynative.html#pijit>`_ :
811
828
  Parse python bytecode to build graph at runtime.
812
829
 
813
830
  input_signature (Union[Tuple, List, Dict, Tensor]): The Tensor which describes the input arguments. The
@@ -831,10 +848,6 @@ def jit(fn=None, mode="PSJit", input_signature=None, hash_args=None, jit_config=
831
848
  it was created again.
832
849
  Default: ``False`` .
833
850
 
834
- Note:
835
- If `input_signature` is specified, each input of `fn` must be a Tensor. And the input arguments for `fn`
836
- will not accept `**kwargs`.
837
-
838
851
  Returns:
839
852
  Function, if `fn` is not None, returns a callable function that will execute the compiled function; If `fn` is
840
853
  None, returns a decorator and when this decorator invokes with a single `fn` argument, the callable function is
@@ -938,45 +951,20 @@ def jit(fn=None, mode="PSJit", input_signature=None, hash_args=None, jit_config=
938
951
  # only the function or cell instance wrapped by shard will fall into this branch
939
952
  if _is_pynative_parallel() and func.__name__ == _PYNATIVE_PARALLEL_FUNC_NAME:
940
953
  process_obj = hash_args
954
+ # Handle auto mixed precision strategy.
955
+ if not hasattr(func, "amp_strategy"):
956
+ if isinstance(func, types.MethodType):
957
+ setattr(func.__func__, "amp_strategy", get_curr_amp_strategy())
958
+ else:
959
+ setattr(func, "amp_strategy", get_curr_amp_strategy())
941
960
  out = _MindsporeFunctionExecutor(func, hash_obj, dyn_args, process_obj, jit_config)(*args, **kwargs)
942
961
  return out
943
962
 
944
963
  return staging_specialize
945
964
 
946
- def pi_wrap_mindspore(decorated):
947
- func = decorated
948
- if isinstance(func, ms.nn.Cell):
949
- func = func.construct
950
- if isinstance(func, type) and issubclass(func, ms.nn.Cell):
951
- func = func.construct
952
- if isinstance(func, types.MethodType):
953
- func = func.__func__
954
- if not isinstance(func, types.FunctionType):
955
- logger.warning("only support function and mindspore.nn.Cell instance")
956
- return decorated
957
-
958
- # generator, coroutine, awaitable and a function that return them is unsupported
959
- UNSUPPORTED_CODE_TYPE = (inspect.CO_GENERATOR | inspect.CO_COROUTINE |
960
- inspect.CO_ASYNC_GENERATOR | inspect.CO_ITERABLE_COROUTINE)
961
- if func.__code__.co_flags & UNSUPPORTED_CODE_TYPE:
962
- return decorated
963
-
964
- _update_graph_executor_config(jit_config)
965
- config = dict()
966
- if isinstance(jit_config, JitConfig):
967
- config.update(jit_config.jit_config_dict)
968
- elif jit_config is not None:
969
- config.update(jit_config)
970
- jit_mode_pi_enable()
971
-
972
- if jit_mode_pi_compile(func, config, input_signature) is False:
973
- logger.warning('add fn {} to compile failed '.format(func))
974
-
975
- return decorated
976
-
977
965
  wrap_func = wrap_mindspore
978
966
  if mode == "PIJit":
979
- wrap_func = pi_wrap_mindspore
967
+ wrap_func = PIJitCaptureContext(jit_config, input_signature)
980
968
 
981
969
  if fn is not None:
982
970
  return wrap_func(fn)
@@ -1272,7 +1260,7 @@ def jit_class(cls):
1272
1260
  if not inspect.isclass(cls):
1273
1261
  raise TypeError(f'Decorator jit_class can only be used for class type, but got {cls}.')
1274
1262
  # Check if cls is nn.Cell.
1275
- if issubclass(cls, nn.Cell):
1263
+ if issubclass(cls, nn.cell.Cell):
1276
1264
  raise TypeError(f"Decorator jit_class is used for user-defined classes and cannot be used for nn.Cell: {cls}.")
1277
1265
  setattr(cls, '__ms_class__', True)
1278
1266
  return cls
@@ -1463,23 +1451,22 @@ class _PyNativeExecutor:
1463
1451
  """
1464
1452
  self._executor.end_graph(obj, output, *args, *(kwargs.values()))
1465
1453
 
1466
- def check_run(self, grad, obj, weights, grad_hash_id, *args, **kwargs):
1454
+ def check_run(self, grad, obj, weights, grad_hash_id, *args):
1467
1455
  """
1468
1456
  Whether the forward graph need to construct.
1469
1457
 
1470
1458
  Args:
1471
1459
  grad (GradOperation): The gradoperation object.
1472
1460
  obj (Function/Cell): The function or cell instance.
1473
- grad_hash_id (tuple): The id of objects which contribute to cache of compiled graph in pynative mode.
1461
+ grad_hash_id (tuple): The id of objects, which contributes to cache of compiled graph in pynative mode.
1474
1462
  args (tuple): Function or cell input arguments.
1475
- kwargs (dict): keyword arguments.
1476
1463
 
1477
1464
  Return:
1478
- bool, specifies whether the forward graph need to construct.
1465
+ bool, specifies whether the forward graph needs to construct.
1479
1466
  """
1480
- return self._executor.check_run(grad, obj, weights, grad_hash_id, *args, *(kwargs.values()))
1467
+ return self._executor.check_run(grad, obj, weights, grad_hash_id, *args)
1481
1468
 
1482
- def grad(self, obj, grad, weights, grad_position, *args, **kwargs):
1469
+ def grad(self, obj, grad, weights, grad_position, *args):
1483
1470
  """
1484
1471
  Get grad graph.
1485
1472
 
@@ -1490,12 +1477,11 @@ class _PyNativeExecutor:
1490
1477
  grad_position (Union(int, tuple[int])): If int, get the gradient with respect to single input.
1491
1478
  If tuple, get the gradients with respect to selected inputs. 'grad_position' begins with 0. Default: 0.
1492
1479
  args (tuple): Function or cell input arguments.
1493
- kwargs (dict): keyword arguments.
1494
1480
 
1495
1481
  Return:
1496
1482
  None.
1497
1483
  """
1498
- return self._executor.grad(grad, obj, weights, grad_position, *args, *(kwargs.values()))
1484
+ return self._executor.grad(grad, obj, weights, grad_position, *args)
1499
1485
 
1500
1486
  def clear_res(self):
1501
1487
  """
@@ -1528,9 +1514,23 @@ class _PyNativeExecutor:
1528
1514
  """
1529
1515
  return self._executor.grad_jit(output, *args)
1530
1516
 
1517
+ def call_custom_bprop(self, obj, output, *args, **kwargs):
1518
+ """
1519
+ Call custom bprop to build variable for cell bprop.
1520
+ Args:
1521
+ obj (Cell): The function or cell instance.
1522
+ output (Tensor/tuple/list): Function or cell output object.
1523
+ args (tuple): Function or cell input arguments.
1524
+ kwargs (dict): keyword arguments.
1525
+
1526
+ Return:
1527
+ None.
1528
+ """
1529
+ return self._executor.call_custom_bprop(obj, output, *args, *(kwargs.values()))
1530
+
1531
1531
  def grad_flag(self):
1532
1532
  """
1533
- The flag of building grad graph.
1533
+ The flag of whether the net building grad graph.
1534
1534
 
1535
1535
  Return:
1536
1536
  bool, whether building grad graph.
@@ -1563,7 +1563,7 @@ class _PyNativeExecutor:
1563
1563
 
1564
1564
  def enable_grad(self):
1565
1565
  """
1566
- The global flag whether needing to calculate gradient.
1566
+ The global flag that whether need to calculate gradient use in no_grad.
1567
1567
 
1568
1568
  Return:
1569
1569
  bool, whether needing to calculate gradient.
@@ -1582,6 +1582,18 @@ class _PyNativeExecutor:
1582
1582
  """
1583
1583
  self._executor.set_enable_grad(flag)
1584
1584
 
1585
+ def requires_grad(self):
1586
+ """
1587
+ When both enable_grad is true and grad_flag is true, that the flag requires_grad will be true.
1588
+
1589
+ Args:
1590
+ flag (bool): Specifying whether calculating gradient.
1591
+
1592
+ Return:
1593
+ None.
1594
+ """
1595
+ return self._executor.requires_grad()
1596
+
1585
1597
  def set_jit_compile_status(self, status, phase):
1586
1598
  """
1587
1599
  Set jit is compiling
@@ -1605,6 +1617,18 @@ class _PyNativeExecutor:
1605
1617
  """
1606
1618
  self._executor.set_is_run_recompute(status)
1607
1619
 
1620
+ def set_cell_use_dynamic_shape_process(self, flag):
1621
+ """
1622
+ Set the dynamic shape flag of eval process.
1623
+
1624
+ Args:
1625
+ flag (bool): Specifying whether using a dynamic process.
1626
+
1627
+ Return:
1628
+ None.
1629
+ """
1630
+ self._executor.set_cell_use_dynamic_shape_process(flag)
1631
+
1608
1632
  def set_dynamic_input(self, obj, *args):
1609
1633
  """
1610
1634
  Set dynamic shape tensor of input arguments.
@@ -1630,27 +1654,19 @@ class _PyNativeExecutor:
1630
1654
  """
1631
1655
  return self._executor.get_dynamic_input(*actual_args)
1632
1656
 
1633
- def is_first_cell(self):
1634
- """
1635
- The flag of first cell instance.
1636
-
1637
- Return:
1638
- bool, specifies whether is the first cell.
1657
+ def set_mixed_precision_type(self, mixed_precision_type, is_push=True):
1639
1658
  """
1640
-
1641
- return self._executor.is_first_cell()
1642
-
1643
- def set_hook_changed(self, cell):
1644
- """
1645
- The flag of registering or removing a hook function on Cell instance.
1659
+ The value of mixed precision type.
1646
1660
 
1647
1661
  Args:
1648
- cell (Cell): The cell instance.
1662
+ type(MixedPrecisionType): Mix precision type.
1663
+ is_push(bool): If called by __enter__, is push will be True
1649
1664
 
1650
1665
  Return:
1651
1666
  None.
1652
1667
  """
1653
- self._executor.set_hook_changed(cell)
1668
+
1669
+ return self._executor.set_mixed_precision_type(mixed_precision_type, is_push)
1654
1670
 
1655
1671
  def constant_folding(self, *args):
1656
1672
  """
@@ -1687,6 +1703,7 @@ class _CellGraphExecutor:
1687
1703
  self._graph_executor = GraphExecutor_.get_instance()
1688
1704
  self._graph_executor.set_py_exe_path(sys.executable)
1689
1705
  self._graph_executor.set_kernel_build_server_dir(os.path.split(kernel_build_server.__file__)[0] + os.sep)
1706
+ self._pid = os.getpid()
1690
1707
 
1691
1708
  def init_dataset(self, queue_name, dataset_size, batch_size, dataset_types, dataset_shapes,
1692
1709
  input_indexs, phase='dataset', need_run=True):
@@ -1789,6 +1806,10 @@ class _CellGraphExecutor:
1789
1806
  self._graph_executor.set_enable_tuple_broaden(self.enable_tuple_broaden)
1790
1807
  key = self._graph_executor.generate_arguments_key(obj, args, kwargs, self.enable_tuple_broaden)
1791
1808
  obj.arguments_key = str(key)
1809
+ # When exist parameter in the top graph inputs, need check if the parameter object has changed.
1810
+ parameter_ids = _get_parameter_ids(args, kwargs)
1811
+ if parameter_ids != "":
1812
+ obj.arguments_key = obj.arguments_key + '.' + parameter_ids
1792
1813
  raw_phase = phase
1793
1814
  phase = phase + '.' + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key
1794
1815
  obj.phase_cache[raw_phase] = phase
@@ -1825,7 +1846,7 @@ class _CellGraphExecutor:
1825
1846
  if graph is None:
1826
1847
  raise RuntimeError("Compile graph failed for phase {}.".format(phase))
1827
1848
 
1828
- auto_parallel_mode = _is_in_auto_parallel_mode()
1849
+ auto_parallel_mode = _is_in_auto_parallel_mode() or _is_parallel_mode()
1829
1850
  if not auto_parallel_mode:
1830
1851
  replace = obj.init_parameters_data(auto_parallel_mode=auto_parallel_mode)
1831
1852
  self._update_param_node_default_input(phase, replace)
@@ -1913,15 +1934,9 @@ class _CellGraphExecutor:
1913
1934
 
1914
1935
  def del_net_res(self, obj, net_id):
1915
1936
  """Clear the memory resource of a network."""
1916
- self._graph_executor.del_net_res(obj, net_id)
1917
-
1918
- def inc_graph_cell_count(self):
1919
- """Increase the count of GraphCell instance."""
1920
- self._graph_executor.inc_graph_cell_count()
1921
-
1922
- def dec_graph_cell_count(self):
1923
- """Decrease the count of GraphCell instance."""
1924
- self._graph_executor.dec_graph_cell_count()
1937
+ # no need to del net res by gc in independent dataset process which is a subprocess forked by main process
1938
+ if self._pid == os.getpid():
1939
+ self._graph_executor.del_net_res(obj, net_id)
1925
1940
 
1926
1941
  def _get_branch_control_input(self):
1927
1942
  if ('obf_ratio' not in self.obfuscate_config.keys()) or (
mindspore/common/dump.py CHANGED
@@ -27,18 +27,17 @@ def set_dump(target, enabled=True):
27
27
  `target` should be an instance of :class:`mindspore.nn.Cell` or :class:`mindspore.ops.Primitive` .
28
28
  Please note that this API takes effect only when Synchronous Dump is enabled and the `dump_mode`
29
29
  field in dump config file is ``"2"`` . See the `dump document
30
- <https://www.mindspore.cn/tutorials/experts/en/master/debug/dump.html>`_ for details.
30
+ <https://www.mindspore.cn/docs/en/master/model_train/debug/dump.html>`_ for details.
31
31
  The default enabled status for
32
32
  a :class:`mindspore.nn.Cell` or :class:`mindspore.ops.Primitive` is False.
33
33
 
34
34
  Note:
35
- 1. This API is only effective for GRAPH_MODE whose graph compilation level is O0/O1 with Ascend backend.
36
- 2. This API only supports being called before training starts.
35
+ 1. This API only supports being called before training starts.
37
36
  If you call this API during training, it may not be effective.
38
- 3. After using `set_dump(Cell, True)` , operators in forward and backward
37
+ 2. After using `set_dump(Cell, True)` , operators in forward and backward
39
38
  computation (computation generated by the grad operations) of the
40
39
  cell will be dumped.
41
- 4. For :class:`mindspore.nn.SoftmaxCrossEntropyWithLogits` layer, the forward
40
+ 3. For :class:`mindspore.nn.SoftmaxCrossEntropyWithLogits` layer, the forward
42
41
  computation and backward computation use the same set of
43
42
  operators. So you can only see dump data from backward computation.
44
43
  Please note that :class:`mindspore.nn.SoftmaxCrossEntropyWithLogits` layer will also use
@@ -58,7 +57,7 @@ def set_dump(target, enabled=True):
58
57
  .. note::
59
58
  Please set environment variable `MINDSPORE_DUMP_CONFIG` to the dump config file and set `dump_mode` field
60
59
  in dump config file to 2 before running this example.
61
- See `dump document <https://www.mindspore.cn/tutorials/experts/en/master/debug/dump.html>`_ for details.
60
+ See `dump document <https://www.mindspore.cn/docs/en/master/model_train/debug/dump.html>`_ for details.
62
61
 
63
62
  >>> import numpy as np
64
63
  >>> import mindspore as ms
@@ -56,12 +56,6 @@ class Generator:
56
56
  A generator that manages the state of random numbers and provides seed and offset for random functions.
57
57
  When the seed and offset are fixed, the random function generates the same random sequence.
58
58
 
59
- Inputs:
60
- - **step** (int) - Set the step size for offset update.
61
-
62
- Outputs:
63
- Tuple consisting of the seed and offset of generator.
64
-
65
59
  Supported Platforms:
66
60
  ``Ascend`` ``GPU`` ``CPU``
67
61
 
@@ -199,7 +193,7 @@ def manual_seed(seed): # pylint: disable=redefined-outer-name
199
193
  >>> print(initial_seed())
200
194
  13
201
195
  """
202
- default_generator.manual_seed(seed)
196
+ return default_generator.manual_seed(seed)
203
197
 
204
198
 
205
199
  def initial_seed():
@@ -77,27 +77,19 @@ class HookHandle:
77
77
  It is only supported in pynative mode and works when registering or removing hook function for Cell object.
78
78
 
79
79
  Args:
80
- hook_cell (Cell): The Cell object with hook function registered on. Default value: None.
81
- hook_key (int): The key of cell hook function in dict. It is generated during cell hook function registration.
82
- Default value: -1.
83
- hook_type (str): The type of cell hook function: '_forward_pre_hook', '_forward_hook' or '_cell_backward_hook'.
84
- Default value: "".
80
+ hook_dict (Dict): The hook object with hook function registered on. Default value: None.
85
81
 
86
82
  Supported Platforms:
87
83
  ``Ascend`` ``GPU`` ``CPU``
88
84
  """
89
- def __init__(self, hook_cell=None, hook_key=-1, hook_type=""):
90
- if hook_cell is not None:
91
- self._hook_cell = weakref.ref(hook_cell)
92
- else:
93
- self._hook_cell = hook_cell
94
- self._hook_key = hook_key
95
- self._hook_type = hook_type
96
-
97
- def __del__(self):
98
- self._hook_cell = None
99
- self._hook_key = None
100
- self._hook_type = None
85
+ unique_id = 0
86
+
87
+ def __init__(self, hook_dict=None):
88
+ self.hook_dict_ref = None
89
+ if hook_dict is not None:
90
+ self.hook_dict_ref = weakref.ref(hook_dict)
91
+ self.handle_id = HookHandle.unique_id
92
+ HookHandle.unique_id += 1
101
93
 
102
94
  def remove(self):
103
95
  """
@@ -121,7 +113,7 @@ class HookHandle:
121
113
  >>> from mindspore import Tensor
122
114
  >>> from mindspore.ops import GradOperation
123
115
  >>> ms.set_context(mode=ms.PYNATIVE_MODE)
124
- >>> def forward_pre_hook_fn(cell_id, inputs):
116
+ >>> def forward_pre_hook_fn(cell, inputs):
125
117
  ... print("forward inputs: ", inputs)
126
118
  ...
127
119
  >>> class Net(nn.Cell):
@@ -145,11 +137,7 @@ class HookHandle:
145
137
  (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]), Tensor(shape=[1], dtype=Float32,
146
138
  value= [ 2.00000000e+00]))
147
139
  """
148
- if self._hook_cell is not None:
149
- hook_cell = self._hook_cell()
150
- if self._hook_type == "_forward_pre_hook" and self._hook_key in hook_cell._forward_pre_hook:
151
- del hook_cell._forward_pre_hook[self._hook_key]
152
- elif self._hook_type == "_forward_hook" and self._hook_key in hook_cell._forward_hook:
153
- del hook_cell._forward_hook[self._hook_key]
154
- elif self._hook_type == "_cell_backward_hook":
155
- hook_cell._cell_backward_hook.remove_backward_hook(self._hook_key)
140
+ if self.hook_dict_ref is not None:
141
+ hook_dict = self.hook_dict_ref()
142
+ if hook_dict is not None and self.handle_id in hook_dict:
143
+ del hook_dict[self.handle_id]
@@ -90,9 +90,9 @@ def save_mindir(model, file_name):
90
90
  if not file_name.endswith('.mindir'):
91
91
  file_name += ".mindir"
92
92
 
93
- current_path = os.path.abspath(file_name)
93
+ current_path = os.path.realpath(file_name)
94
94
  dirname = os.path.dirname(current_path)
95
- os.makedirs(dirname, exist_ok=True)
95
+ os.makedirs(dirname, mode=0o700, exist_ok=True)
96
96
  if os.path.exists(file_name):
97
97
  os.chmod(file_name, stat.S_IWUSR)
98
98
 
@@ -41,6 +41,8 @@ from mindspore.parallel._ps_context import _is_role_worker, _is_role_pserver, _i
41
41
  _is_ps_mode
42
42
  from mindspore.parallel._ps_context import _reinsert_hash_table_size, _insert_accumu_init_info, _cache_enable
43
43
  from mindspore.common._decorator import deprecated
44
+ from mindspore.communication._comm_helper import _is_initialized
45
+ from mindspore.communication import get_group_size
44
46
  import mindspore.common._monad as monad
45
47
 
46
48
  __all__ = ['Parameter', 'ParameterTuple']
@@ -52,11 +54,22 @@ PARAMETER_NAME_PREFIX_MAX_LEN = 1024
52
54
  _GLOBAL_PARAMETER_KEY = -1
53
55
 
54
56
 
55
- def _is_in_parallel_mode():
57
+ def _is_in_auto_parallel_mode():
56
58
  """Get parallel mode."""
57
59
  return auto_parallel_context().get_parallel_mode() in ["semi_auto_parallel", "auto_parallel"]
58
60
 
59
61
 
62
+ def _is_parallel_mode():
63
+ """ Whether is parallel mode """
64
+ if not _is_initialized() or context.get_context('mode') == context.PYNATIVE_MODE:
65
+ return False
66
+ if os.getenv("RUN_MODE") != "predict":
67
+ return False
68
+ if get_group_size() > 1 and _get_parallel_mode() == "stand_alone":
69
+ return True
70
+ return False
71
+
72
+
60
73
  def init_to_value(init):
61
74
  """
62
75
  Get value of initializer.
@@ -91,6 +104,15 @@ def _get_unique_parameter_key():
91
104
  return _GLOBAL_PARAMETER_KEY
92
105
 
93
106
 
107
+ def _gen_offload_file_path(offload_dir):
108
+ offload_dir = os.path.relpath(offload_dir)
109
+ if not os.path.exists(offload_dir):
110
+ os.makedirs(offload_dir, mode=0o700, exist_ok=True)
111
+ offload_file_path = offload_dir + "/" + str(_get_global_rank()) + "_" + str(
112
+ _get_unique_parameter_key()) + "_" + str(time.time()) + ".data"
113
+ return offload_file_path
114
+
115
+
94
116
  def _offload_if_config(data):
95
117
  """
96
118
  Offload parameter(data size > 512) to file when enable memory offload and offload parameter to disk.
@@ -111,11 +133,7 @@ def _offload_if_config(data):
111
133
  offload_file_path = data.offload_file_path()
112
134
  if offload_file_path is None or offload_file_path == "":
113
135
  offload_dir = offload_context.get("offload_path", "./offload")
114
- offload_dir = os.path.relpath(offload_dir)
115
- if not os.path.exists(offload_dir):
116
- os.makedirs(offload_dir)
117
- offload_file_path = offload_dir + "/" + str(_get_global_rank()) + "_" + str(
118
- _get_unique_parameter_key()) + "_" + str(time.time()) + ".data"
136
+ offload_file_path = _gen_offload_file_path(offload_dir)
119
137
  data.offload(offload_file_path)
120
138
 
121
139
 
@@ -191,6 +209,12 @@ class Parameter(Tensor_):
191
209
  storage_format (str): Only Ascend device target is supported. It is used to specify the format of the weight
192
210
  loaded to the device. By default, the format is not changed. The optional values are ``"FRACTAL_NZ"`` ,
193
211
  ``"NC1HWC0"`` , ``"FRACTAL_Z"`` , etc. Default: ``""`` .
212
+ device(str): Only Ascend device target is supported. It is used to specify the device which the parameter is
213
+ stored. By default, the parameter will be stored on NPU while computing. When the device is specified as
214
+ ``"CPU"``, the parameter will be loaded into the device when it needs to be used, and unloaded to the CPU
215
+ after use. It takes effext only when `memory_offload` is ``"ON"``, `jit_level` is not ``"O2"`` and
216
+ `memory_optimize_level` is ``O0`` in `mindspore.set_context()`. Less device memory is needed when device is
217
+ specified as ``"CPU"``.
194
218
 
195
219
  Examples:
196
220
  >>> import numpy as np
@@ -244,7 +268,7 @@ class Parameter(Tensor_):
244
268
  Parameter, (data, self.name, self.requires_grad, self.layerwise_parallel))
245
269
 
246
270
  def __init__(self, default_input, name=None, requires_grad=True, layerwise_parallel=False, parallel_optimizer=True,
247
- storage_format=""):
271
+ storage_format="", device=None):
248
272
  self.param_info = ParamInfo()
249
273
  self.init_in_server = False
250
274
  self.name = name
@@ -263,7 +287,7 @@ class Parameter(Tensor_):
263
287
  self.requires_aggr = True
264
288
  self._cast_type = None
265
289
  self._unique = False
266
- self.is_in_parallel = _is_in_parallel_mode()
290
+ self.is_in_parallel = _is_in_auto_parallel_mode()
267
291
  self.is_in_shard = False
268
292
  self._pipeline_stage_list = []
269
293
  self.slice_num = 1
@@ -296,6 +320,10 @@ class Parameter(Tensor_):
296
320
  f" 'numpy.ndarray', 'list']. But got type {type(default_input)}.")
297
321
  self.param_info.parameter_shape = self.shape
298
322
  self.param_info.storage_format = storage_format
323
+ if device is not None:
324
+ if device != "CPU":
325
+ raise ValueError(f"Only 'CPU' is supported for device, but got ${device}.")
326
+ self._set_user_data("parameter_device", device)
299
327
 
300
328
  import mindspore.ops.operations.other_ops as other_ops
301
329
  self.load = other_ops.Load()
@@ -342,7 +370,8 @@ class Parameter(Tensor_):
342
370
  return (Tensor, data.asnumpy(), mstype.qint4x2)
343
371
  return (Tensor, data.asnumpy())
344
372
 
345
- not_init_data = _is_role_sched() or (_is_role_pserver() and _cache_enable()) or _is_in_parallel_mode()
373
+ not_init_data = _is_role_sched() or (_is_role_pserver() and _cache_enable()
374
+ ) or _is_in_auto_parallel_mode() or _is_parallel_mode()
346
375
  if not_init_data:
347
376
  # do not init data while in auto parallel.
348
377
  return (Tensor, None, data.dtype, get_slice_shape(data.dtype, data.shape), data.init)
@@ -368,7 +397,7 @@ class Parameter(Tensor_):
368
397
 
369
398
  Tutorial Examples:
370
399
  - `Parameter Server Mode
371
- <https://www.mindspore.cn/tutorials/experts/en/master/parallel/parameter_server_training.html>`_
400
+ <https://www.mindspore.cn/docs/en/master/model_train/parallel/parameter_server_training.html>`_
372
401
  """
373
402
  if not _is_ps_mode() or not (_is_role_worker() or _is_role_pserver() or _is_role_sched()):
374
403
  raise RuntimeError("Must complete following two steps before calling set_param_ps: \n"
@@ -616,6 +645,9 @@ class Parameter(Tensor_):
616
645
  shape = self.shape if self.slice_num == 1 else self.param_info.origin_shape
617
646
  dtype = self.dtype
618
647
  x.set_data(initializer(init, shape=shape, dtype=dtype))
648
+ device = self._get_user_data("parameter_device")
649
+ if device is not None:
650
+ x._set_user_data("parameter_device", device)
619
651
  return x
620
652
 
621
653
  @property
@@ -942,7 +974,7 @@ class Parameter(Tensor_):
942
974
  >>> x = Parameter(Tensor(np.array([[1, 2], [3, 4]], dtype=np.float32)), name="param")
943
975
  >>> x.init_data()
944
976
  """
945
- if self.is_default_input_init and self.is_in_parallel != _is_in_parallel_mode():
977
+ if self.is_default_input_init and self.is_in_parallel != _is_in_auto_parallel_mode():
946
978
  raise RuntimeError("Must set or change parallel mode before any initializer Tensor created.")
947
979
  if self.init_mode is None:
948
980
  return self
@@ -1026,8 +1058,9 @@ class ParameterTuple(tuple):
1026
1058
  Tuple, the new Parameter tuple.
1027
1059
 
1028
1060
  Tutorial Examples:
1029
- - `Cell and Parameter - Parameter Tuple
1030
- <https://mindspore.cn/tutorials/en/master/advanced/modules/layer.html#parameter-tuple>`_
1061
+ - `Tensor and Parameter - Parameter Tuple
1062
+ <https://mindspore.cn/docs/en/master/model_train/model_building/tensor_and_parameter.html
1063
+ #parameter-tuple>`_
1031
1064
  """
1032
1065
  Validator.check_str_by_regular(prefix)
1033
1066
  new = []