mindspore 2.3.0__cp310-cp310-win_amd64.whl → 2.4.0__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (308) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  3. mindspore/Newtonsoft.Json.dll +0 -0
  4. mindspore/__init__.py +3 -1
  5. mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
  6. mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
  7. mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
  8. mindspore/_checkparam.py +50 -9
  9. mindspore/_extends/parse/compile_config.py +41 -0
  10. mindspore/_extends/parse/parser.py +9 -7
  11. mindspore/_extends/parse/standard_method.py +52 -14
  12. mindspore/_extends/pijit/pijit_func_white_list.py +350 -24
  13. mindspore/amp.py +24 -10
  14. mindspore/atlprov.dll +0 -0
  15. mindspore/avcodec-59.dll +0 -0
  16. mindspore/avdevice-59.dll +0 -0
  17. mindspore/avfilter-8.dll +0 -0
  18. mindspore/avformat-59.dll +0 -0
  19. mindspore/avutil-57.dll +0 -0
  20. mindspore/c1.dll +0 -0
  21. mindspore/c1xx.dll +0 -0
  22. mindspore/c2.dll +0 -0
  23. mindspore/common/__init__.py +6 -4
  24. mindspore/common/_pijit_context.py +190 -0
  25. mindspore/common/_register_for_tensor.py +2 -1
  26. mindspore/common/_tensor_overload.py +139 -0
  27. mindspore/common/api.py +102 -87
  28. mindspore/common/dump.py +5 -6
  29. mindspore/common/generator.py +1 -7
  30. mindspore/common/hook_handle.py +14 -26
  31. mindspore/common/mindir_util.py +2 -2
  32. mindspore/common/parameter.py +46 -13
  33. mindspore/common/recompute.py +39 -9
  34. mindspore/common/sparse_tensor.py +7 -3
  35. mindspore/common/tensor.py +209 -29
  36. mindspore/communication/__init__.py +1 -1
  37. mindspore/communication/_comm_helper.py +38 -3
  38. mindspore/communication/comm_func.py +310 -55
  39. mindspore/communication/management.py +14 -14
  40. mindspore/context.py +123 -22
  41. mindspore/dataset/__init__.py +1 -1
  42. mindspore/dataset/audio/__init__.py +1 -1
  43. mindspore/dataset/core/config.py +7 -0
  44. mindspore/dataset/core/validator_helpers.py +7 -0
  45. mindspore/dataset/engine/cache_client.py +1 -1
  46. mindspore/dataset/engine/datasets.py +72 -44
  47. mindspore/dataset/engine/datasets_audio.py +7 -7
  48. mindspore/dataset/engine/datasets_standard_format.py +53 -3
  49. mindspore/dataset/engine/datasets_text.py +20 -20
  50. mindspore/dataset/engine/datasets_user_defined.py +174 -104
  51. mindspore/dataset/engine/datasets_vision.py +33 -33
  52. mindspore/dataset/engine/iterators.py +29 -0
  53. mindspore/dataset/engine/obs/util.py +7 -0
  54. mindspore/dataset/engine/queue.py +114 -60
  55. mindspore/dataset/engine/serializer_deserializer.py +2 -2
  56. mindspore/dataset/engine/validators.py +34 -14
  57. mindspore/dataset/text/__init__.py +1 -4
  58. mindspore/dataset/transforms/__init__.py +0 -3
  59. mindspore/dataset/utils/line_reader.py +2 -0
  60. mindspore/dataset/vision/__init__.py +1 -4
  61. mindspore/dataset/vision/utils.py +1 -1
  62. mindspore/dataset/vision/validators.py +2 -1
  63. mindspore/dnnl.dll +0 -0
  64. mindspore/dpcmi.dll +0 -0
  65. mindspore/{nn/extend → experimental/es}/__init__.py +4 -11
  66. mindspore/experimental/es/embedding_service.py +883 -0
  67. mindspore/{nn/layer → experimental/es}/embedding_service_layer.py +218 -30
  68. mindspore/experimental/llm_boost/__init__.py +21 -0
  69. mindspore/{nn/extend/layer → experimental/llm_boost/atb}/__init__.py +4 -8
  70. mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
  71. mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
  72. mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
  73. mindspore/experimental/llm_boost/register.py +129 -0
  74. mindspore/experimental/llm_boost/utils.py +31 -0
  75. mindspore/experimental/optim/adamw.py +85 -0
  76. mindspore/experimental/optim/optimizer.py +3 -0
  77. mindspore/hal/__init__.py +3 -3
  78. mindspore/hal/contiguous_tensors_handle.py +175 -0
  79. mindspore/hal/stream.py +18 -0
  80. mindspore/include/api/model_group.h +13 -1
  81. mindspore/include/api/types.h +10 -10
  82. mindspore/include/dataset/config.h +2 -2
  83. mindspore/include/dataset/constants.h +2 -2
  84. mindspore/include/dataset/execute.h +2 -2
  85. mindspore/include/dataset/vision.h +4 -0
  86. mindspore/jpeg62.dll +0 -0
  87. mindspore/log.py +1 -1
  88. mindspore/mindrecord/filewriter.py +68 -51
  89. mindspore/mindspore_backend.dll +0 -0
  90. mindspore/mindspore_common.dll +0 -0
  91. mindspore/mindspore_core.dll +0 -0
  92. mindspore/mindspore_glog.dll +0 -0
  93. mindspore/mindspore_np_dtype.dll +0 -0
  94. mindspore/mindspore_ops.dll +0 -0
  95. mindspore/mint/__init__.py +495 -46
  96. mindspore/mint/distributed/__init__.py +31 -0
  97. mindspore/mint/distributed/distributed.py +254 -0
  98. mindspore/mint/nn/__init__.py +266 -21
  99. mindspore/mint/nn/functional.py +125 -19
  100. mindspore/mint/nn/layer/__init__.py +39 -0
  101. mindspore/mint/nn/layer/activation.py +133 -0
  102. mindspore/mint/nn/layer/normalization.py +477 -0
  103. mindspore/mint/nn/layer/pooling.py +110 -0
  104. mindspore/mint/optim/adamw.py +28 -7
  105. mindspore/mint/special/__init__.py +63 -0
  106. mindspore/msobj140.dll +0 -0
  107. mindspore/mspdb140.dll +0 -0
  108. mindspore/mspdbcore.dll +0 -0
  109. mindspore/mspdbst.dll +0 -0
  110. mindspore/mspft140.dll +0 -0
  111. mindspore/msvcdis140.dll +0 -0
  112. mindspore/msvcp140_1.dll +0 -0
  113. mindspore/msvcp140_2.dll +0 -0
  114. mindspore/msvcp140_atomic_wait.dll +0 -0
  115. mindspore/msvcp140_codecvt_ids.dll +0 -0
  116. mindspore/multiprocessing/__init__.py +2 -1
  117. mindspore/nn/__init__.py +0 -1
  118. mindspore/nn/cell.py +275 -93
  119. mindspore/nn/layer/activation.py +211 -44
  120. mindspore/nn/layer/basic.py +113 -3
  121. mindspore/nn/layer/embedding.py +120 -2
  122. mindspore/nn/layer/normalization.py +101 -5
  123. mindspore/nn/layer/padding.py +34 -48
  124. mindspore/nn/layer/pooling.py +161 -7
  125. mindspore/nn/layer/transformer.py +3 -3
  126. mindspore/nn/loss/__init__.py +2 -2
  127. mindspore/nn/loss/loss.py +84 -6
  128. mindspore/nn/optim/__init__.py +2 -1
  129. mindspore/nn/optim/adadelta.py +1 -1
  130. mindspore/nn/optim/adam.py +1 -1
  131. mindspore/nn/optim/lamb.py +1 -1
  132. mindspore/nn/optim/tft_wrapper.py +127 -0
  133. mindspore/nn/wrap/cell_wrapper.py +12 -23
  134. mindspore/nn/wrap/grad_reducer.py +5 -5
  135. mindspore/nn/wrap/loss_scale.py +17 -3
  136. mindspore/numpy/__init__.py +1 -1
  137. mindspore/numpy/array_creations.py +65 -68
  138. mindspore/numpy/array_ops.py +64 -60
  139. mindspore/numpy/fft.py +610 -75
  140. mindspore/numpy/logic_ops.py +11 -10
  141. mindspore/numpy/math_ops.py +85 -84
  142. mindspore/numpy/utils_const.py +4 -4
  143. mindspore/opencv_core452.dll +0 -0
  144. mindspore/opencv_imgcodecs452.dll +0 -0
  145. mindspore/opencv_imgproc452.dll +0 -0
  146. mindspore/ops/__init__.py +6 -4
  147. mindspore/ops/_grad_experimental/grad_comm_ops.py +47 -3
  148. mindspore/ops/_grad_experimental/grad_math_ops.py +0 -22
  149. mindspore/ops/_vmap/vmap_array_ops.py +2 -4
  150. mindspore/ops/_vmap/vmap_math_ops.py +17 -1
  151. mindspore/ops/_vmap/vmap_nn_ops.py +43 -2
  152. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +85 -7
  153. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +2 -0
  154. mindspore/ops/auto_generate/gen_extend_func.py +734 -13
  155. mindspore/ops/auto_generate/gen_ops_def.py +2420 -381
  156. mindspore/ops/auto_generate/gen_ops_prim.py +5196 -1659
  157. mindspore/ops/auto_generate/pyboost_inner_prim.py +176 -56
  158. mindspore/ops/composite/base.py +85 -48
  159. mindspore/ops/composite/multitype_ops/_compile_utils.py +1 -0
  160. mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -2
  161. mindspore/ops/function/__init__.py +22 -0
  162. mindspore/ops/function/array_func.py +490 -153
  163. mindspore/ops/function/debug_func.py +113 -1
  164. mindspore/ops/function/fft_func.py +15 -2
  165. mindspore/ops/function/grad/grad_func.py +3 -2
  166. mindspore/ops/function/math_func.py +558 -207
  167. mindspore/ops/function/nn_func.py +817 -383
  168. mindspore/ops/function/other_func.py +3 -2
  169. mindspore/ops/function/random_func.py +184 -8
  170. mindspore/ops/function/reshard_func.py +13 -11
  171. mindspore/ops/function/sparse_unary_func.py +1 -1
  172. mindspore/ops/function/vmap_func.py +3 -2
  173. mindspore/ops/functional.py +24 -14
  174. mindspore/ops/op_info_register.py +3 -3
  175. mindspore/ops/operations/__init__.py +6 -1
  176. mindspore/ops/operations/_grad_ops.py +2 -76
  177. mindspore/ops/operations/_infer_ops.py +1 -1
  178. mindspore/ops/operations/_inner_ops.py +71 -94
  179. mindspore/ops/operations/array_ops.py +12 -146
  180. mindspore/ops/operations/comm_ops.py +42 -53
  181. mindspore/ops/operations/custom_ops.py +83 -19
  182. mindspore/ops/operations/debug_ops.py +42 -10
  183. mindspore/ops/operations/manually_defined/_inner.py +12 -0
  184. mindspore/ops/operations/manually_defined/ops_def.py +265 -10
  185. mindspore/ops/operations/math_ops.py +12 -223
  186. mindspore/ops/operations/nn_ops.py +20 -114
  187. mindspore/ops/operations/other_ops.py +7 -4
  188. mindspore/ops/operations/random_ops.py +46 -1
  189. mindspore/ops/primitive.py +18 -6
  190. mindspore/ops_generate/arg_dtype_cast.py +2 -0
  191. mindspore/ops_generate/gen_aclnn_implement.py +11 -11
  192. mindspore/ops_generate/gen_constants.py +36 -0
  193. mindspore/ops_generate/gen_ops.py +67 -52
  194. mindspore/ops_generate/gen_ops_inner_prim.py +1 -1
  195. mindspore/ops_generate/gen_pyboost_func.py +131 -47
  196. mindspore/ops_generate/op_proto.py +10 -3
  197. mindspore/ops_generate/pyboost_utils.py +14 -1
  198. mindspore/ops_generate/template.py +43 -21
  199. mindspore/parallel/__init__.py +3 -1
  200. mindspore/parallel/_auto_parallel_context.py +28 -8
  201. mindspore/parallel/_cell_wrapper.py +83 -0
  202. mindspore/parallel/_parallel_serialization.py +47 -19
  203. mindspore/parallel/_tensor.py +81 -11
  204. mindspore/parallel/_utils.py +13 -1
  205. mindspore/parallel/algo_parameter_config.py +5 -5
  206. mindspore/parallel/checkpoint_transform.py +46 -39
  207. mindspore/parallel/cluster/process_entity/__init__.py +1 -1
  208. mindspore/parallel/cluster/process_entity/_api.py +31 -23
  209. mindspore/parallel/cluster/process_entity/_utils.py +2 -27
  210. mindspore/parallel/parameter_broadcast.py +3 -4
  211. mindspore/parallel/shard.py +162 -31
  212. mindspore/parallel/transform_safetensors.py +993 -0
  213. mindspore/pgodb140.dll +0 -0
  214. mindspore/pgort140.dll +0 -0
  215. mindspore/profiler/__init__.py +2 -1
  216. mindspore/profiler/common/constant.py +29 -0
  217. mindspore/profiler/common/registry.py +47 -0
  218. mindspore/profiler/common/util.py +28 -0
  219. mindspore/profiler/dynamic_profiler.py +694 -0
  220. mindspore/profiler/envprofiling.py +17 -19
  221. mindspore/profiler/parser/ascend_analysis/constant.py +18 -0
  222. mindspore/profiler/parser/ascend_analysis/file_manager.py +25 -4
  223. mindspore/profiler/parser/ascend_analysis/function_event.py +43 -19
  224. mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +31 -26
  225. mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +56 -10
  226. mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +55 -8
  227. mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
  228. mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +27 -20
  229. mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +9 -2
  230. mindspore/profiler/parser/ascend_msprof_exporter.py +5 -4
  231. mindspore/profiler/parser/ascend_timeline_generator.py +27 -25
  232. mindspore/profiler/parser/base_timeline_generator.py +19 -25
  233. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +25 -12
  234. mindspore/profiler/parser/framework_parser.py +1 -391
  235. mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
  236. mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
  237. mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
  238. mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
  239. mindspore/profiler/parser/memory_usage_parser.py +0 -154
  240. mindspore/profiler/parser/profiler_info.py +78 -6
  241. mindspore/profiler/profiler.py +153 -0
  242. mindspore/profiler/profiling.py +280 -412
  243. mindspore/rewrite/__init__.py +1 -2
  244. mindspore/rewrite/common/namespace.py +4 -4
  245. mindspore/rewrite/symbol_tree/symbol_tree.py +3 -3
  246. mindspore/run_check/_check_version.py +36 -103
  247. mindspore/safeguard/rewrite_obfuscation.py +591 -247
  248. mindspore/swresample-4.dll +0 -0
  249. mindspore/swscale-6.dll +0 -0
  250. mindspore/tbbmalloc.dll +0 -0
  251. mindspore/tinyxml2.dll +0 -0
  252. mindspore/train/__init__.py +4 -3
  253. mindspore/train/_utils.py +28 -2
  254. mindspore/train/amp.py +171 -53
  255. mindspore/train/callback/__init__.py +2 -2
  256. mindspore/train/callback/_callback.py +4 -4
  257. mindspore/train/callback/_checkpoint.py +85 -22
  258. mindspore/train/callback/_cluster_monitor.py +1 -1
  259. mindspore/train/callback/_flops_collector.py +1 -0
  260. mindspore/train/callback/_loss_monitor.py +3 -3
  261. mindspore/train/callback/_on_request_exit.py +134 -31
  262. mindspore/train/callback/_summary_collector.py +5 -5
  263. mindspore/train/callback/_tft_register.py +352 -0
  264. mindspore/train/dataset_helper.py +7 -3
  265. mindspore/train/metrics/metric.py +3 -3
  266. mindspore/train/metrics/roc.py +4 -4
  267. mindspore/train/mind_ir_pb2.py +44 -39
  268. mindspore/train/model.py +134 -58
  269. mindspore/train/serialization.py +336 -112
  270. mindspore/turbojpeg.dll +0 -0
  271. mindspore/utils/__init__.py +21 -0
  272. mindspore/utils/utils.py +60 -0
  273. mindspore/vcmeta.dll +0 -0
  274. mindspore/vcruntime140.dll +0 -0
  275. mindspore/vcruntime140_1.dll +0 -0
  276. mindspore/version.py +1 -1
  277. {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/METADATA +6 -2
  278. {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/RECORD +281 -275
  279. mindspore/include/c_api/ms/abstract.h +0 -67
  280. mindspore/include/c_api/ms/attribute.h +0 -197
  281. mindspore/include/c_api/ms/base/handle_types.h +0 -43
  282. mindspore/include/c_api/ms/base/macros.h +0 -32
  283. mindspore/include/c_api/ms/base/status.h +0 -33
  284. mindspore/include/c_api/ms/base/types.h +0 -283
  285. mindspore/include/c_api/ms/context.h +0 -102
  286. mindspore/include/c_api/ms/graph.h +0 -160
  287. mindspore/include/c_api/ms/node.h +0 -606
  288. mindspore/include/c_api/ms/tensor.h +0 -161
  289. mindspore/include/c_api/ms/value.h +0 -84
  290. mindspore/mindspore_shared_lib.dll +0 -0
  291. mindspore/nn/extend/basic.py +0 -140
  292. mindspore/nn/extend/embedding.py +0 -143
  293. mindspore/nn/extend/layer/normalization.py +0 -109
  294. mindspore/nn/extend/pooling.py +0 -117
  295. mindspore/nn/layer/embedding_service.py +0 -531
  296. mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +0 -93
  297. mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +0 -66
  298. mindspore/ops/extend/__init__.py +0 -53
  299. mindspore/ops/extend/array_func.py +0 -218
  300. mindspore/ops/extend/math_func.py +0 -76
  301. mindspore/ops/extend/nn_func.py +0 -308
  302. mindspore/ops/silent_check.py +0 -162
  303. mindspore/profiler/parser/msadvisor_analyzer.py +0 -82
  304. mindspore/profiler/parser/msadvisor_parser.py +0 -240
  305. mindspore/train/callback/_mindio_ttp.py +0 -443
  306. {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/WHEEL +0 -0
  307. {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/entry_points.txt +0 -0
  308. {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/top_level.txt +0 -0
@@ -23,8 +23,10 @@ from mindspore.common.tensor import Tensor
23
23
  from mindspore import ops
24
24
  from mindspore.ops.composite import GradOperation
25
25
  from mindspore.common._register_for_recompute import recompute_registry
26
- from mindspore.common.api import _pynative_executor
26
+ from mindspore.common.api import _pynative_executor, _no_grad
27
27
  from mindspore.common.generator import get_rng_state, set_rng_state
28
+ from mindspore.train.amp import amp_decorator
29
+ from mindspore._c_expression.amp import get_curr_amp_strategy
28
30
 
29
31
 
30
32
  class _WrapCell(Cell):
@@ -34,7 +36,7 @@ class _WrapCell(Cell):
34
36
  """
35
37
 
36
38
  def __init__(self, function):
37
- super(_WrapCell, self).__init__()
39
+ super(_WrapCell, self).__init__(auto_prefix=False)
38
40
  self.function = function
39
41
 
40
42
  def construct(self, *args, **kwargs):
@@ -56,6 +58,7 @@ class _RecomputeCell(Cell):
56
58
  self.args = []
57
59
  self.kwargs = []
58
60
  self.wrap_cell = _WrapCell(block)
61
+ self.wrap_cell.set_inputs()
59
62
 
60
63
  self.net = block
61
64
  self.internal_params = []
@@ -64,15 +67,18 @@ class _RecomputeCell(Cell):
64
67
  self._add_attr("is_cell_recompute", "True")
65
68
  self.grad = GradOperation(get_all=True, get_by_list=True, sens_param=True)
66
69
  self.init_mixed_precision_type(block)
70
+ self.amp_strategy = None
67
71
 
68
72
  def construct(self, *args, **kwargs):
69
- _check_input_args_validate(self.net, args)
73
+ _check_input_args_validate(self.net, args, kwargs)
70
74
  self.args.append(args)
71
75
  self.kwargs.append(kwargs)
72
76
  self.save_rng_state = kwargs.pop("save_rng_state", True)
73
77
  if self.save_rng_state:
74
78
  self.cpu_rng_state = get_rng_state()
75
- return self.net(*args, **kwargs)
79
+ self.amp_strategy = get_curr_amp_strategy()
80
+ with _no_grad():
81
+ return self.net(*args, **kwargs)
76
82
 
77
83
  def bprop(self, *args):
78
84
  """
@@ -86,14 +92,23 @@ class _RecomputeCell(Cell):
86
92
  self.args.pop()
87
93
  self.kwargs.pop()
88
94
  if kwargs:
89
- input_args = list(input_args) + list(kwargs.values())
95
+ input_args_for_check = list(input_args) + list(kwargs.values())
96
+ else:
97
+ input_args_for_check = list(input_args)
90
98
  # To detach inputs to avoid erasing auto grad meta info of origin inputs.
91
99
  input_args = _detach_input(input_args)
100
+ kwargs = _detach_input(kwargs)
101
+ kwargs['sens'] = grad_input
92
102
  try:
93
103
  pre_rng_state = get_rng_state()
94
104
  set_rng_state(self.cpu_rng_state)
95
105
  _pynative_executor.set_is_run_recompute(True)
96
- grads = self.grad(self.net, self.internal_params)(*input_args, grad_input)
106
+ if self.amp_strategy:
107
+ with amp_decorator(self.amp_strategy.get_amp_level(), self.amp_strategy.get_amp_dtype(),
108
+ self.amp_strategy.get_white_list(), self.amp_strategy.get_black_list()):
109
+ grads = self.grad(self.net, self.internal_params)(*input_args, **kwargs)
110
+ else:
111
+ grads = self.grad(self.net, self.internal_params)(*input_args, **kwargs)
97
112
  _pynative_executor.set_is_run_recompute(False)
98
113
  set_rng_state(pre_rng_state)
99
114
  except Exception as err:
@@ -101,7 +116,7 @@ class _RecomputeCell(Cell):
101
116
  raise err
102
117
  weights = OrderedDict()
103
118
  input_grads = list(grads[0])
104
- _padding_input_grads(input_args, input_grads)
119
+ _padding_input_grads(input_args_for_check, input_grads)
105
120
  for i, param in enumerate(self.internal_params):
106
121
  weights[param] = grads[1][i]
107
122
  return tuple(input_grads), weights
@@ -121,6 +136,7 @@ class _RecomputeCell(Cell):
121
136
  # To avoid sub cell same name
122
137
  block.__self__.check_names_and_refresh_name()
123
138
  self.internal_params = block.__self__.trainable_params()
139
+ self.wrap_cell.mixed_precision_type = block.__self__.get_mixed_precision_type()
124
140
  self.wrap_cell.set_mixed_precision_type(block.__self__.get_mixed_precision_type())
125
141
  self.net = self.wrap_cell
126
142
  else:
@@ -128,13 +144,14 @@ class _RecomputeCell(Cell):
128
144
  "only support Cell object or MethodType function!")
129
145
 
130
146
 
131
- def _check_input_args_validate(block, args):
147
+ def _check_input_args_validate(block, args, kwargs):
132
148
  """
133
149
  Check recompute input args validate
134
150
  :param args:
135
151
  :return:
136
152
  """
137
- if not any([isinstance(arg, Tensor) for arg in args]):
153
+ if not (any([isinstance(arg, Tensor) for arg in args]) or \
154
+ any([isinstance(arg, Tensor) for arg in kwargs.values()])):
138
155
  logger.warning("None of the inputs of function are tensors, which not need use recompute!")
139
156
  for arg in args:
140
157
  if isinstance(arg, (tuple, list)):
@@ -168,6 +185,11 @@ def _padding_input_grads(args, input_grads):
168
185
 
169
186
 
170
187
  def _detach_input(input_arg):
188
+ """
189
+ Detach input
190
+ :param input_arg:
191
+ :return: detach output
192
+ """
171
193
  if isinstance(input_arg, Tensor):
172
194
  return ops.stop_gradient(input_arg)
173
195
  if isinstance(input_arg, (list, tuple)):
@@ -175,6 +197,14 @@ def _detach_input(input_arg):
175
197
  for arg in input_arg:
176
198
  detach_inputs.append(_detach_input(arg))
177
199
  return detach_inputs if isinstance(input_arg, list) else tuple(detach_inputs)
200
+ if isinstance(input_arg, dict):
201
+ detach_inputs = {}
202
+ for key, val in input_arg.items():
203
+ if isinstance(val, Tensor):
204
+ detach_inputs[key] = ops.stop_gradient(val)
205
+ else:
206
+ detach_inputs[key] = val
207
+ return detach_inputs
178
208
  return input_arg
179
209
 
180
210
 
@@ -97,7 +97,8 @@ class RowTensor(RowTensorInner):
97
97
  [0, 0]]
98
98
 
99
99
  .. warning::
100
- This is an experimental API that is subjected to change or deletion.
100
+ - This is an experimental API that is subjected to change or deletion.
101
+ - If use PyNative mode, set "export MS_PYNATIVE_CONFIG_STATIC_SHAPE=1".
101
102
 
102
103
  Args:
103
104
  indices (Tensor): A 1-D integer Tensor of shape :math:`(d_0)` . Default: ``None``.
@@ -226,10 +227,11 @@ class COOTensor(COOTensor_):
226
227
 
227
228
  Common arithmetic operations include: addition (+), subtraction (-), multiplication (*),
228
229
  and division (/). For details about operations supported by `COOTensor`, see
229
- `operators <https://www.mindspore.cn/docs/en/master/note/static_graph_syntax_support.html#operators>`_.
230
+ `operators <https://www.mindspore.cn/docs/en/master/model_train/program_form/static_graph.html#operators>`_.
230
231
 
231
232
  .. warning::
232
233
  - This is an experimental API that is subject to change or deletion.
234
+ - If use PyNative mode, set "export MS_PYNATIVE_CONFIG_STATIC_SHAPE=1".
233
235
  - Currently, duplicate coordinates in the indices will not be coalesced.
234
236
  If the indices contain out-of-bound values, the result will be undefined.
235
237
 
@@ -646,6 +648,7 @@ class CSRTensor(CSRTensor_):
646
648
  [1., 2., 3., 4., 5., 6.], shape is (3, 5), then the dense representation of the sparse tensor will be:
647
649
 
648
650
  .. code-block::
651
+
649
652
  [[1., 0., 0., 2., 0.],
650
653
  [0., 3., 4., 0., 5.],
651
654
  [0., 0., 6., 0., 0.]]
@@ -668,10 +671,11 @@ class CSRTensor(CSRTensor_):
668
671
 
669
672
  Common arithmetic operations include: addition (+), subtraction (-), multiplication (*),
670
673
  and division (/). For details about operations supported by `CSRTensor`, see
671
- `operators <https://www.mindspore.cn/docs/en/master/note/static_graph_syntax_support.html#operators>`_.
674
+ `operators <https://www.mindspore.cn/docs/en/master/model_train/program_form/static_graph.html#operators>`_.
672
675
 
673
676
  .. warning::
674
677
  - This is an experimental API that is subjected to change.
678
+ - If use PyNative mode, set "export MS_PYNATIVE_CONFIG_STATIC_SHAPE=1".
675
679
  - If the values given by `indptr` or `indices` are invalid, the results may be undefined. Invalid values include
676
680
  when the length of `values` or `indices` exceeds the range indicated by `indptr`, and when the columns
677
681
  indicated by `indices` are repeated on the same row.
@@ -1,4 +1,4 @@
1
- # Copyright 2020-2022 Huawei Technologies Co., Ltd
1
+ # Copyright 2020-2024 Huawei Technologies Co., Ltd
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -31,6 +31,8 @@ from mindspore.common.hook_handle import _TensorHookHandle
31
31
 
32
32
  from mindspore.common._utils import get_slice_num
33
33
  from mindspore.common._register_for_tensor import tensor_operator_registry
34
+ from mindspore.common._tensor_overload import (repeat_interleave_mint, add_mint, item_mint, isnan_mint, flatten_mint,
35
+ max_mint, mean_mint, min_mint, split_mint, sub_mint)
34
36
  from mindspore._c_expression import Tensor as Tensor_
35
37
  from mindspore import _checkparam as validator
36
38
  from mindspore._checkparam import check_is_number, is_stub_tensor, check_hook_fn
@@ -51,7 +53,7 @@ def _check_input_data_type(input_data):
51
53
  np.float16, np.float32, np.float64, np.bool_, np.str_, np.complex64, np.complex128)
52
54
  if isinstance(input_data, np.ndarray) and input_data.dtype not in valid_dtypes and \
53
55
  input_data.dtype.kind != 'U' and input_data.dtype.kind != 'S' and \
54
- input_data.dtype.kind != 'T': # Support dtype np.str_ and npy_bfloat16
56
+ not (input_data.dtype.kind == 'V' and input_data.dtype.char == 'E'): # Support np.str_ and np.bfloat16
55
57
  new_line = '\n'
56
58
  for index, x in np.ndenumerate(input_data):
57
59
  if np.array(x).dtype not in valid_dtypes:
@@ -85,11 +87,11 @@ def tensor(input_data=None, dtype=None, shape=None, init=None, internal=False, c
85
87
  based on the `dtype` argument.
86
88
 
87
89
  Please refer to `Creating and Using Tensor
88
- <https://www.mindspore.cn/docs/en/master/note/static_graph_syntax_support.html#mindspore-user-defined-data-types>`_ .
90
+ <https://www.mindspore.cn/docs/en/master/model_train/program_form/static_graph.html#mindspore-user-defined-data-types>`_ .
89
91
 
90
92
  The difference between it and the Tensor class is that it adds
91
93
  `Annotation
92
- <https://www.mindspore.cn/docs/en/master/design/dynamic_graph_and_static_graph.html?#annotation-type>`_
94
+ <https://www.mindspore.cn/docs/en/master/model_train/program_form/static_graph.html#annotation-type>`_
93
95
  which can prevent the generation of AnyType compared to the Tensor class.
94
96
 
95
97
  The arguments and return values are the same as the Tensor class. Also see: :class:`mindspore.Tensor`.
@@ -143,6 +145,8 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
143
145
  Default: ``False`` .
144
146
  const_arg (bool): Whether the tensor is a constant when it is used for the argument of a network.
145
147
  Default: ``False`` .
148
+ device(str): This parameter is reserved and does not need to be configured.
149
+ Default: ``None`` .
146
150
 
147
151
  Outputs:
148
152
  Tensor.
@@ -205,7 +209,8 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
205
209
  """
206
210
  delta_seed = 0
207
211
 
208
- def __init__(self, input_data=None, dtype=None, shape=None, init=None, internal=False, const_arg=False):
212
+ def __init__(self, input_data=None, dtype=None, shape=None, init=None, internal=False, const_arg=False,
213
+ device=None):
209
214
  self.init_finished = False
210
215
  if isinstance(input_data, (Tensor, Tensor_)) and dtype is not None:
211
216
  logger.info("It is suggested to use 'Tensor.astype()' to convert the dtype of a Tensor.")
@@ -264,6 +269,9 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
264
269
  Tensor_.__init__(self, input_data)
265
270
  validator.check_value_type('const_arg', const_arg, bool, 'Tensor')
266
271
 
272
+ if device is not None and device != "CPU":
273
+ raise ValueError(f"Only 'CPU' is supported for device, but got {device}.")
274
+
267
275
  self.const_arg = const_arg
268
276
  self.virtual_flag = False
269
277
  self.init = init
@@ -380,6 +388,7 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
380
388
  def __abs__(self):
381
389
  return tensor_operator_registry.get('abs')(self)
382
390
 
391
+ @add_mint
383
392
  def __add__(self, other):
384
393
  return tensor_operator_registry.get('__add__')(self, other)
385
394
 
@@ -404,6 +413,7 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
404
413
  def __iadd__(self, other):
405
414
  return self.__add__(other)
406
415
 
416
+ @sub_mint
407
417
  def __sub__(self, other):
408
418
  return tensor_operator_registry.get('__sub__')(self, other)
409
419
 
@@ -513,9 +523,12 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
513
523
  return state
514
524
 
515
525
  def __setstate__(self, state):
516
- value = state.pop("value")
526
+ if isinstance(state, tuple):
527
+ value = state
528
+ else:
529
+ value = state.pop("value")
530
+ self.__dict__.update(state)
517
531
  Tensor_.__setstate__(self, value)
518
- self.__dict__.update(state)
519
532
 
520
533
  @property
521
534
  def shape(self):
@@ -706,8 +719,9 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
706
719
 
707
720
  Examples:
708
721
  >>> from mindspore import Tensor
722
+ >>> from mindspore import dtype as mstype
709
723
  >>> import numpy as np
710
- >>> x = Tensor(np.array([[1, 2], [3, 4]]))
724
+ >>> x = Tensor(np.array([[1, 2], [3, 4]]), dtype=mstype.int64)
711
725
  >>> output = x.strides
712
726
  >>> print(output)
713
727
  (16, 8)
@@ -940,6 +954,7 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
940
954
  """
941
955
  return tensor_operator_registry.get('chunk')(self, chunks, axis)
942
956
 
957
+ @item_mint
943
958
  def item(self, index=None):
944
959
  """
945
960
  Get the item at the specified index of the tensor.
@@ -1054,7 +1069,7 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
1054
1069
  self.init_data()
1055
1070
  return Tensor_.asnumpy(self)
1056
1071
 
1057
- def numpy(self):
1072
+ def numpy(self, *, force=False):
1058
1073
  """
1059
1074
  Alias for :func:`mindspore.Tensor.asnumpy`.
1060
1075
  """
@@ -1295,12 +1310,48 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
1295
1310
  """
1296
1311
  return tensor_operator_registry.get('addcmul')(self, tensor1, tensor2, value)
1297
1312
 
1313
+ @add_mint
1298
1314
  def add(self, other):
1299
1315
  r"""
1300
1316
  For details, please refer to :func:`mindspore.ops.add`.
1301
1317
  """
1302
1318
  return tensor_operator_registry.get('add')(self, other)
1303
1319
 
1320
+ def add_(self, other, *, alpha=1):
1321
+ """
1322
+ inplace update self by following compute:
1323
+ self = self + other * alpha.
1324
+
1325
+ .. warning::
1326
+ This is an experimental API that is subject to change or deletion.
1327
+ The `other` tensor must be broadcastable with the `self` tensor. It may be of a different data type.
1328
+
1329
+ Args:
1330
+ other (Tensor): the source tensor Add to self Tensor.
1331
+ alpha (Number): no effect currently.
1332
+
1333
+ Returns:
1334
+ Return self Tensor.
1335
+
1336
+ Supported Platforms:
1337
+ ``Ascend``
1338
+
1339
+ Examples:
1340
+ >>> import numpy as np
1341
+ >>> from mindspore import Tensor
1342
+ >>> a = Tensor(np.ones((2,3)).astype("float32"))
1343
+ >>> b = Tensor(np.ones((2,3)).astype("float32"))
1344
+ >>> a.add_(b)
1345
+ >>> print(a)
1346
+ [[2. 2. 2.]
1347
+ [2. 2. 2.]]
1348
+ """
1349
+ if isinstance(other, (int, float)):
1350
+ ret = tensor_operator_registry.get("adds_")(self, other, alpha)
1351
+ else:
1352
+ ret = tensor_operator_registry.get("add_")(self, other, alpha)
1353
+ return ret
1354
+
1304
1355
  def subtract(self, other, *, alpha=1):
1305
1356
  r"""
1306
1357
  For details, please refer to :func:`mindspore.ops.subtract`.
@@ -1337,6 +1388,19 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
1337
1388
  """
1338
1389
  return tensor_operator_registry.get('addmm')(self, mat1, mat2, beta=beta, alpha=alpha)
1339
1390
 
1391
+ def addmm_(self, mat1, mat2, *, beta=1, alpha=1):
1392
+ r"""
1393
+ For details, please refer to :func:`mindspore.ops.addmm`.
1394
+
1395
+ .. note::
1396
+ The output results are directly updated in the Tensor.
1397
+
1398
+ .. warning::
1399
+ This is an experimental API that is subject to change or deletion.
1400
+
1401
+ """
1402
+ return tensor_operator_registry.get('addmm_')(self, mat1, mat2, beta=beta, alpha=alpha)
1403
+
1340
1404
  def addr(self, vec1, vec2, beta=1, alpha=1):
1341
1405
  r"""
1342
1406
  For details, please refer to :func:`mindspore.ops.addr`.
@@ -1579,6 +1643,7 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
1579
1643
  """
1580
1644
  return tensor_operator_registry.get('square')(self)
1581
1645
 
1646
+ @sub_mint
1582
1647
  def sub(self, y):
1583
1648
  r"""
1584
1649
  For details, please refer to :func:`mindspore.ops.sub`.
@@ -1842,6 +1907,7 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
1842
1907
  """
1843
1908
  return tensor_operator_registry.get('log2')(self)
1844
1909
 
1910
+ @mean_mint
1845
1911
  def mean(self, axis=None, keep_dims=False):
1846
1912
  """
1847
1913
  For details, please refer to :func:`mindspore.ops.mean`.
@@ -2012,11 +2078,11 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
2012
2078
  reshape_op = tensor_operator_registry.get('reshape')
2013
2079
  return reshape_op(self, (-1,))
2014
2080
 
2015
- def round(self):
2081
+ def round(self, decimals=0):
2016
2082
  """
2017
2083
  For details, please refer to :func:`mindspore.ops.round`.
2018
2084
  """
2019
- return tensor_operator_registry.get('round')(self)
2085
+ return tensor_operator_registry.get('round')(self, decimals=decimals)
2020
2086
 
2021
2087
  def roll(self, shifts, dims):
2022
2088
  """
@@ -2091,6 +2157,7 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
2091
2157
  """
2092
2158
  return tensor_operator_registry.get('remainder')(self, divisor)
2093
2159
 
2160
+ @flatten_mint
2094
2161
  def flatten(self, order='C', *, start_dim=0, end_dim=-1):
2095
2162
  r"""
2096
2163
  For details, please refer to :func:`mindspore.ops.flatten`.
@@ -2399,6 +2466,38 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
2399
2466
  x = x.astype(origin_dtype)
2400
2467
  return x
2401
2468
 
2469
+ def copy_(self, src, non_blocking=False):
2470
+ """
2471
+ Copies the elements from src into self tensor and returns self.
2472
+
2473
+ .. warning::
2474
+ This is an experimental API that is subject to change or deletion.
2475
+ The `src` tensor must be broadcastable with the `self` tensor. It may be of a different data type.
2476
+
2477
+ Args:
2478
+ src (Tensor): the source tensor to copy from.
2479
+ non_blocking (bool): no effect currently.
2480
+
2481
+ Returns:
2482
+ Return self Tensor.
2483
+
2484
+ Supported Platforms:
2485
+ ``Ascend``
2486
+
2487
+ Examples:
2488
+ >>> import numpy as np
2489
+ >>> from mindspore import Tensor
2490
+ >>> a = Tensor(np.ones((3,3)).astype("float32"))
2491
+ >>> b = Tensor(np.zeros((3,3)).astype("float32"))
2492
+ >>> a.copy_(b)
2493
+ >>> print(a)
2494
+ [[0. 0. 0.]
2495
+ [0. 0. 0.]
2496
+ [0. 0. 0.]]
2497
+ """
2498
+ return tensor_operator_registry.get("copy_")(self, src)
2499
+
2500
+ @max_mint
2402
2501
  def max(self, axis=None, keepdims=False, *, initial=None, where=True, return_indices=False):
2403
2502
  """
2404
2503
  Return the maximum of a tensor or maximum along an axis.
@@ -2467,6 +2566,7 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
2467
2566
  return values
2468
2567
  return values, indices
2469
2568
 
2569
+ @min_mint
2470
2570
  def min(self, axis=None, keepdims=False, *, initial=None, where=True, return_indices=False):
2471
2571
  """
2472
2572
  Return the minimum of a tensor or minimum along an axis.
@@ -2763,7 +2863,7 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
2763
2863
  opt_shard_group(str): Optimizer shard group which is used in auto or semi auto parallel mode
2764
2864
  to get one shard of a parameter's slice. For more information about optimizer parallel, please refer to:
2765
2865
  `Optimizer Parallel
2766
- <https://www.mindspore.cn/tutorials/experts/en/master/parallel/optimizer_parallel.html>`_.
2866
+ <https://www.mindspore.cn/docs/en/master/model_train/parallel/optimizer_parallel.html>`_.
2767
2867
  Default: ``None``.
2768
2868
 
2769
2869
  Returns:
@@ -2995,16 +3095,7 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
2995
3095
  >>> print(x.trace())
2996
3096
  3.0
2997
3097
  """
2998
- if offset == 0 and axis1 == 0 and axis2 == 1 and dtype is None:
2999
- return tensor_operator_registry.get('trace')(self)
3000
- d = self.diagonal(offset, axis1=axis1, axis2=axis2)
3001
- shape = d.shape
3002
- if dtype is None:
3003
- dtype = d.dtype
3004
- if shape[-1] == 0:
3005
- return tensor_operator_registry.get('fill')(dtype, shape[:-1], 0)
3006
- res = tensor_operator_registry.get('reduce_sum')(d.astype(mstype.float32), -1)
3007
- return res.astype(dtype)
3098
+ return tensor_operator_registry.get('tracev2')(self, offset, axis1, axis2, dtype)
3008
3099
 
3009
3100
  def take(self, indices, axis=None, mode='clip'):
3010
3101
  """
@@ -3164,6 +3255,7 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
3164
3255
  sorter (Union[int, list, tuple, Tensor]): optional tensor of
3165
3256
  integer indices that sort the tensor into ascending order on the innermost dimension
3166
3257
  and the type must be int64. They are typically the result of argsort. Default: ``None`` .
3258
+ CPU and GPU can only use default values
3167
3259
 
3168
3260
  Returns:
3169
3261
  Tensor, array of insertion points with the same shape as `v`.
@@ -3217,10 +3309,10 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
3217
3309
 
3218
3310
  def uniform(self, from_=0., to=1., generator=None):
3219
3311
  r"""
3220
- Generates random numbers in the half-open interval [from_, to).
3312
+ Generates random numbers in the half-open interval [from\_, to).
3221
3313
 
3222
3314
  Args:
3223
- from_ (number): The lower bound of the interval.
3315
+ from\_ (number): The lower bound of the interval.
3224
3316
  to (number): The upper bound of the interval.
3225
3317
  generator (Generator, optional): The random seed. Default: None.
3226
3318
 
@@ -3506,6 +3598,7 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
3506
3598
  repeated_subs.append(tensor_operator_registry.get('repeat_elements')(sub, rep, axis))
3507
3599
  return tensor_operator_registry.get('concatenate')(repeated_subs, axis)
3508
3600
 
3601
+ @repeat_interleave_mint
3509
3602
  def repeat_interleave(self, repeats, dim=None):
3510
3603
  """
3511
3604
  For details, please refer to :func:`mindspore.ops.repeat_interleave`.
@@ -3740,6 +3833,7 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
3740
3833
  """
3741
3834
  return tensor_operator_registry.get("xdivy")(self, y)
3742
3835
 
3836
+ @split_mint
3743
3837
  def split(self, split_size_or_sections, axis=0):
3744
3838
  """
3745
3839
  For details, please refer to :func:`mindspore.ops.split`.
@@ -4039,6 +4133,27 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
4039
4133
  """
4040
4134
  return tensor_operator_registry.get('int')(self, mstype.int32)
4041
4135
 
4136
+ def byte(self):
4137
+ r"""
4138
+ Converts input tensor dtype to `uint8`.
4139
+
4140
+ Returns:
4141
+ Tensor, converted to the `uint8` dtype.
4142
+
4143
+ Supported Platforms:
4144
+ ``Ascend`` ``GPU`` ``CPU``
4145
+
4146
+ Examples:
4147
+ >>> import numpy as np
4148
+ >>> import mindspore
4149
+ >>> from mindspore import Tensor
4150
+ >>> input_x = Tensor(np.ones([2,2]), mindspore.float32)
4151
+ >>> output = input_x.byte()
4152
+ >>> print(output.dtype)
4153
+ uint8
4154
+ """
4155
+ return tensor_operator_registry.get('byte')(self, mstype.uint8)
4156
+
4042
4157
  def long(self):
4043
4158
  r"""
4044
4159
  Converts input tensor dtype to `int64`. If the value in tensor is float or half, the decimal will be discarded.
@@ -4249,6 +4364,7 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
4249
4364
  """
4250
4365
  return tensor_operator_registry.get('isinf')(self)
4251
4366
 
4367
+ @isnan_mint
4252
4368
  def isnan(self):
4253
4369
  r"""
4254
4370
  For details, please refer to :func:`mindspore.ops.isnan`.
@@ -4425,7 +4541,7 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
4425
4541
  """
4426
4542
  return tensor_operator_registry.get('mul')(self, value)
4427
4543
 
4428
- def nan_to_num(self, nan=0.0, posinf=None, neginf=None):
4544
+ def nan_to_num(self, nan=None, posinf=None, neginf=None):
4429
4545
  """
4430
4546
  For details, please refer to :func:`mindspore.ops.nan_to_num`.
4431
4547
  """
@@ -4482,6 +4598,31 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
4482
4598
  """
4483
4599
  return tensor_operator_registry.get('zeros')(size, dtype)
4484
4600
 
4601
+ def zero_(self):
4602
+ r"""
4603
+ Return a tensor filled with zeros.
4604
+
4605
+ .. warning::
4606
+ This is an experimental API that is subject to change or deletion.
4607
+
4608
+ Returns:
4609
+ Return a tensor. Fill self tensor with zeros.
4610
+
4611
+ Supported Platforms:
4612
+ ``Ascend``
4613
+
4614
+ Examples:
4615
+ >>> import numpy as np
4616
+ >>> import mindspore
4617
+ >>> from mindspore import Tensor
4618
+ >>> x = Tensor(np.array([2, 2]))
4619
+ >>> output = x.zero_()
4620
+ >>> print(output)
4621
+ [[0. 0.]
4622
+ [0. 0.]]
4623
+ """
4624
+ return tensor_operator_registry.get('zero_')(self)
4625
+
4485
4626
  def new_ones(self, size, dtype=None):
4486
4627
  r"""
4487
4628
  Return a tensor of `size` filled with ones.
@@ -4758,7 +4899,7 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
4758
4899
  mode = context.get_context("mode")
4759
4900
  if mode != context.PYNATIVE_MODE:
4760
4901
  raise ValueError(f"The method of 'move_to' only supported in pynative mode, but got: {mode}.")
4761
- return Tensor(Tensor_.move_to(self, to, blocking))
4902
+ return Tensor(Tensor_.move_to(self, to, blocking), device="CPU" if to == "CPU" else None)
4762
4903
 
4763
4904
 
4764
4905
  def _offload(self):
@@ -4805,6 +4946,44 @@ def _vm_compare(*args):
4805
4946
  return Tensor(np.array(fn(y)))
4806
4947
 
4807
4948
 
4949
+ def _check_sequence_shape(input_data):
4950
+ """Check the shape of tensor input with type of sequence."""
4951
+ max_dims_reached = False
4952
+ max_ndim = 64 # corresponding to NPY_MAXDIMS
4953
+ out_shape = [0]*max_ndim
4954
+
4955
+ def check_shape_recursive(input_data, curr_ndim):
4956
+ nonlocal max_dims_reached, max_ndim, out_shape
4957
+ if curr_ndim > max_ndim:
4958
+ return False
4959
+ if not isinstance(input_data, (tuple, list)):
4960
+ if max_dims_reached and curr_ndim != max_ndim:
4961
+ max_ndim = curr_ndim
4962
+ return False
4963
+ max_dims_reached = True
4964
+ max_ndim = curr_ndim
4965
+ return True
4966
+ if not max_dims_reached:
4967
+ out_shape[curr_ndim] = len(input_data)
4968
+ else:
4969
+ if out_shape[curr_ndim] != len(input_data):
4970
+ max_ndim = curr_ndim
4971
+ return False
4972
+ if not input_data:
4973
+ # process empty list
4974
+ if not check_shape_recursive(None, curr_ndim + 1):
4975
+ return False
4976
+ for data in input_data:
4977
+ if not check_shape_recursive(data, curr_ndim + 1):
4978
+ return False
4979
+ return True
4980
+
4981
+ if not check_shape_recursive(input_data, 0):
4982
+ raise ValueError(f"When initializing a tensor with a sequence, the sequence has an inhomogeneous shape "
4983
+ f"after {max_ndim} dimensions. The detected shape was {tuple(out_shape[:max_ndim])} "
4984
+ f"+ inhomogeneous part.")
4985
+
4986
+
4808
4987
  def _check_tensor_input(input_data=None, dtype=None, shape=None, init=None):
4809
4988
  """Check the tensor input."""
4810
4989
  if input_data is not None and shape is not None:
@@ -4817,9 +4996,10 @@ def _check_tensor_input(input_data=None, dtype=None, shape=None, init=None):
4817
4996
  if input_data is not None:
4818
4997
  if isinstance(input_data, np.ndarray) and input_data.ndim >= 1 and input_data.size == 0:
4819
4998
  raise ValueError("input_data can not contain zero dimension.")
4820
- if isinstance(input_data, (tuple, list)) and np.array(input_data).ndim >= 1 \
4821
- and np.array(input_data).size == 0:
4822
- raise ValueError("input_data can not contain zero dimension.")
4999
+ if isinstance(input_data, (tuple, list)):
5000
+ _check_sequence_shape(input_data)
5001
+ if np.array(input_data).ndim >= 1 and np.array(input_data).size == 0:
5002
+ raise ValueError("input_data can not contain zero dimension.")
4823
5003
 
4824
5004
  if shape is not None and not (hasattr(init, "__enable_zero_dim__") and init.__enable_zero_dim__) and 0 in shape:
4825
5005
  raise ValueError("Shape can not contain zero value.")
@@ -20,7 +20,7 @@ Note that the APIs in the following list need to preset communication environmen
20
20
  For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
21
21
  without any third-party or configuration file dependencies.
22
22
  Please see the `msrun start up
23
- <https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/msrun_launcher.html>`_
23
+ <https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
24
24
  for more details.
25
25
  """
26
26