mindspore 2.3.0__cp310-cp310-win_amd64.whl → 2.4.1__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (275) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +3 -1
  3. mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
  4. mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
  5. mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
  6. mindspore/_checkparam.py +50 -9
  7. mindspore/_extends/parse/compile_config.py +41 -0
  8. mindspore/_extends/parse/parser.py +9 -7
  9. mindspore/_extends/parse/standard_method.py +52 -14
  10. mindspore/_extends/pijit/pijit_func_white_list.py +350 -24
  11. mindspore/amp.py +24 -10
  12. mindspore/common/__init__.py +6 -4
  13. mindspore/common/_pijit_context.py +190 -0
  14. mindspore/common/_register_for_tensor.py +2 -1
  15. mindspore/common/_tensor_overload.py +139 -0
  16. mindspore/common/api.py +102 -87
  17. mindspore/common/dump.py +5 -6
  18. mindspore/common/generator.py +1 -7
  19. mindspore/common/hook_handle.py +14 -26
  20. mindspore/common/initializer.py +51 -15
  21. mindspore/common/mindir_util.py +2 -2
  22. mindspore/common/parameter.py +62 -15
  23. mindspore/common/recompute.py +39 -9
  24. mindspore/common/sparse_tensor.py +7 -3
  25. mindspore/common/tensor.py +183 -37
  26. mindspore/communication/__init__.py +1 -1
  27. mindspore/communication/_comm_helper.py +38 -3
  28. mindspore/communication/comm_func.py +315 -60
  29. mindspore/communication/management.py +14 -14
  30. mindspore/context.py +132 -22
  31. mindspore/dataset/__init__.py +1 -1
  32. mindspore/dataset/audio/__init__.py +1 -1
  33. mindspore/dataset/core/config.py +7 -0
  34. mindspore/dataset/core/validator_helpers.py +7 -0
  35. mindspore/dataset/engine/cache_client.py +1 -1
  36. mindspore/dataset/engine/datasets.py +72 -44
  37. mindspore/dataset/engine/datasets_audio.py +7 -7
  38. mindspore/dataset/engine/datasets_standard_format.py +53 -3
  39. mindspore/dataset/engine/datasets_text.py +20 -20
  40. mindspore/dataset/engine/datasets_user_defined.py +174 -104
  41. mindspore/dataset/engine/datasets_vision.py +33 -33
  42. mindspore/dataset/engine/iterators.py +29 -0
  43. mindspore/dataset/engine/obs/util.py +7 -0
  44. mindspore/dataset/engine/queue.py +114 -60
  45. mindspore/dataset/engine/serializer_deserializer.py +2 -2
  46. mindspore/dataset/engine/validators.py +34 -14
  47. mindspore/dataset/text/__init__.py +1 -4
  48. mindspore/dataset/transforms/__init__.py +0 -3
  49. mindspore/dataset/utils/line_reader.py +2 -0
  50. mindspore/dataset/vision/__init__.py +1 -4
  51. mindspore/dataset/vision/utils.py +1 -1
  52. mindspore/dataset/vision/validators.py +2 -1
  53. mindspore/{nn/extend → experimental/es}/__init__.py +4 -11
  54. mindspore/experimental/es/embedding_service.py +883 -0
  55. mindspore/{nn/layer → experimental/es}/embedding_service_layer.py +218 -30
  56. mindspore/experimental/llm_boost/__init__.py +21 -0
  57. mindspore/{nn/extend/layer → experimental/llm_boost/atb}/__init__.py +4 -8
  58. mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
  59. mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
  60. mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
  61. mindspore/experimental/llm_boost/register.py +129 -0
  62. mindspore/experimental/llm_boost/utils.py +31 -0
  63. mindspore/experimental/optim/adamw.py +85 -0
  64. mindspore/experimental/optim/optimizer.py +3 -0
  65. mindspore/hal/__init__.py +3 -3
  66. mindspore/hal/contiguous_tensors_handle.py +175 -0
  67. mindspore/hal/stream.py +18 -0
  68. mindspore/include/api/model_group.h +13 -1
  69. mindspore/include/api/types.h +10 -10
  70. mindspore/include/dataset/config.h +2 -2
  71. mindspore/include/dataset/constants.h +2 -2
  72. mindspore/include/dataset/execute.h +2 -2
  73. mindspore/include/dataset/vision.h +4 -0
  74. mindspore/log.py +1 -1
  75. mindspore/mindrecord/filewriter.py +68 -51
  76. mindspore/mindspore_backend.dll +0 -0
  77. mindspore/mindspore_common.dll +0 -0
  78. mindspore/mindspore_core.dll +0 -0
  79. mindspore/mindspore_np_dtype.dll +0 -0
  80. mindspore/mindspore_ops.dll +0 -0
  81. mindspore/mint/__init__.py +983 -46
  82. mindspore/mint/distributed/__init__.py +31 -0
  83. mindspore/mint/distributed/distributed.py +254 -0
  84. mindspore/mint/nn/__init__.py +268 -23
  85. mindspore/mint/nn/functional.py +125 -19
  86. mindspore/mint/nn/layer/__init__.py +39 -0
  87. mindspore/mint/nn/layer/activation.py +133 -0
  88. mindspore/mint/nn/layer/normalization.py +477 -0
  89. mindspore/mint/nn/layer/pooling.py +110 -0
  90. mindspore/mint/optim/adamw.py +26 -13
  91. mindspore/mint/special/__init__.py +63 -0
  92. mindspore/multiprocessing/__init__.py +2 -1
  93. mindspore/nn/__init__.py +0 -1
  94. mindspore/nn/cell.py +276 -96
  95. mindspore/nn/layer/activation.py +211 -44
  96. mindspore/nn/layer/basic.py +137 -10
  97. mindspore/nn/layer/embedding.py +137 -2
  98. mindspore/nn/layer/normalization.py +101 -5
  99. mindspore/nn/layer/padding.py +34 -48
  100. mindspore/nn/layer/pooling.py +161 -7
  101. mindspore/nn/layer/transformer.py +3 -3
  102. mindspore/nn/loss/__init__.py +2 -2
  103. mindspore/nn/loss/loss.py +84 -6
  104. mindspore/nn/optim/__init__.py +2 -1
  105. mindspore/nn/optim/adadelta.py +1 -1
  106. mindspore/nn/optim/adam.py +1 -1
  107. mindspore/nn/optim/lamb.py +1 -1
  108. mindspore/nn/optim/tft_wrapper.py +124 -0
  109. mindspore/nn/wrap/cell_wrapper.py +12 -23
  110. mindspore/nn/wrap/grad_reducer.py +5 -5
  111. mindspore/nn/wrap/loss_scale.py +17 -3
  112. mindspore/numpy/__init__.py +1 -1
  113. mindspore/numpy/array_creations.py +65 -68
  114. mindspore/numpy/array_ops.py +64 -60
  115. mindspore/numpy/fft.py +610 -75
  116. mindspore/numpy/logic_ops.py +11 -10
  117. mindspore/numpy/math_ops.py +85 -84
  118. mindspore/numpy/utils_const.py +4 -4
  119. mindspore/opencv_core452.dll +0 -0
  120. mindspore/opencv_imgcodecs452.dll +0 -0
  121. mindspore/opencv_imgproc452.dll +0 -0
  122. mindspore/ops/__init__.py +6 -4
  123. mindspore/ops/_grad_experimental/grad_array_ops.py +0 -11
  124. mindspore/ops/_grad_experimental/grad_comm_ops.py +67 -4
  125. mindspore/ops/_grad_experimental/grad_math_ops.py +0 -22
  126. mindspore/ops/_vmap/vmap_array_ops.py +2 -4
  127. mindspore/ops/_vmap/vmap_math_ops.py +17 -1
  128. mindspore/ops/_vmap/vmap_nn_ops.py +43 -2
  129. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +91 -7
  130. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +2 -0
  131. mindspore/ops/auto_generate/gen_extend_func.py +767 -13
  132. mindspore/ops/auto_generate/gen_ops_def.py +2452 -364
  133. mindspore/ops/auto_generate/gen_ops_prim.py +5442 -1756
  134. mindspore/ops/auto_generate/pyboost_inner_prim.py +176 -56
  135. mindspore/ops/composite/base.py +85 -48
  136. mindspore/ops/composite/multitype_ops/_compile_utils.py +1 -0
  137. mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -2
  138. mindspore/ops/function/__init__.py +22 -0
  139. mindspore/ops/function/array_func.py +492 -153
  140. mindspore/ops/function/debug_func.py +113 -1
  141. mindspore/ops/function/fft_func.py +15 -2
  142. mindspore/ops/function/grad/grad_func.py +3 -2
  143. mindspore/ops/function/math_func.py +564 -207
  144. mindspore/ops/function/nn_func.py +817 -383
  145. mindspore/ops/function/other_func.py +3 -2
  146. mindspore/ops/function/random_func.py +402 -12
  147. mindspore/ops/function/reshard_func.py +13 -11
  148. mindspore/ops/function/sparse_unary_func.py +1 -1
  149. mindspore/ops/function/vmap_func.py +3 -2
  150. mindspore/ops/functional.py +24 -14
  151. mindspore/ops/op_info_register.py +3 -3
  152. mindspore/ops/operations/__init__.py +7 -2
  153. mindspore/ops/operations/_grad_ops.py +2 -76
  154. mindspore/ops/operations/_infer_ops.py +1 -1
  155. mindspore/ops/operations/_inner_ops.py +71 -94
  156. mindspore/ops/operations/array_ops.py +14 -146
  157. mindspore/ops/operations/comm_ops.py +63 -53
  158. mindspore/ops/operations/custom_ops.py +83 -19
  159. mindspore/ops/operations/debug_ops.py +42 -10
  160. mindspore/ops/operations/manually_defined/_inner.py +12 -0
  161. mindspore/ops/operations/manually_defined/ops_def.py +273 -20
  162. mindspore/ops/operations/math_ops.py +12 -223
  163. mindspore/ops/operations/nn_ops.py +20 -114
  164. mindspore/ops/operations/other_ops.py +7 -4
  165. mindspore/ops/operations/random_ops.py +46 -1
  166. mindspore/ops/primitive.py +18 -6
  167. mindspore/ops_generate/arg_dtype_cast.py +2 -0
  168. mindspore/ops_generate/gen_aclnn_implement.py +11 -11
  169. mindspore/ops_generate/gen_constants.py +36 -0
  170. mindspore/ops_generate/gen_ops.py +67 -52
  171. mindspore/ops_generate/gen_ops_inner_prim.py +1 -1
  172. mindspore/ops_generate/gen_pyboost_func.py +131 -47
  173. mindspore/ops_generate/op_proto.py +10 -3
  174. mindspore/ops_generate/pyboost_utils.py +14 -1
  175. mindspore/ops_generate/template.py +43 -21
  176. mindspore/parallel/__init__.py +3 -1
  177. mindspore/parallel/_auto_parallel_context.py +31 -9
  178. mindspore/parallel/_cell_wrapper.py +85 -0
  179. mindspore/parallel/_parallel_serialization.py +47 -19
  180. mindspore/parallel/_tensor.py +127 -13
  181. mindspore/parallel/_utils.py +53 -22
  182. mindspore/parallel/algo_parameter_config.py +5 -5
  183. mindspore/parallel/checkpoint_transform.py +46 -39
  184. mindspore/parallel/cluster/process_entity/__init__.py +1 -1
  185. mindspore/parallel/cluster/process_entity/_api.py +31 -23
  186. mindspore/parallel/cluster/process_entity/_utils.py +2 -27
  187. mindspore/parallel/parameter_broadcast.py +3 -4
  188. mindspore/parallel/shard.py +162 -31
  189. mindspore/parallel/transform_safetensors.py +1146 -0
  190. mindspore/profiler/__init__.py +2 -1
  191. mindspore/profiler/common/constant.py +29 -0
  192. mindspore/profiler/common/registry.py +47 -0
  193. mindspore/profiler/common/util.py +28 -0
  194. mindspore/profiler/dynamic_profiler.py +694 -0
  195. mindspore/profiler/envprofiling.py +17 -19
  196. mindspore/profiler/parser/ascend_analysis/constant.py +18 -0
  197. mindspore/profiler/parser/ascend_analysis/file_manager.py +25 -4
  198. mindspore/profiler/parser/ascend_analysis/function_event.py +43 -19
  199. mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +31 -26
  200. mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +56 -10
  201. mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +55 -8
  202. mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
  203. mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +27 -20
  204. mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +9 -2
  205. mindspore/profiler/parser/ascend_msprof_exporter.py +5 -4
  206. mindspore/profiler/parser/ascend_timeline_generator.py +27 -25
  207. mindspore/profiler/parser/base_timeline_generator.py +19 -25
  208. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +25 -12
  209. mindspore/profiler/parser/framework_parser.py +1 -391
  210. mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
  211. mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
  212. mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
  213. mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
  214. mindspore/profiler/parser/memory_usage_parser.py +0 -154
  215. mindspore/profiler/parser/profiler_info.py +78 -6
  216. mindspore/profiler/profiler.py +153 -0
  217. mindspore/profiler/profiling.py +285 -413
  218. mindspore/rewrite/__init__.py +1 -2
  219. mindspore/rewrite/common/namespace.py +4 -4
  220. mindspore/rewrite/symbol_tree/symbol_tree.py +3 -3
  221. mindspore/run_check/_check_version.py +39 -104
  222. mindspore/safeguard/rewrite_obfuscation.py +591 -247
  223. mindspore/train/__init__.py +4 -3
  224. mindspore/train/_utils.py +105 -19
  225. mindspore/train/amp.py +171 -53
  226. mindspore/train/callback/__init__.py +2 -2
  227. mindspore/train/callback/_callback.py +4 -4
  228. mindspore/train/callback/_checkpoint.py +97 -31
  229. mindspore/train/callback/_cluster_monitor.py +1 -1
  230. mindspore/train/callback/_flops_collector.py +1 -0
  231. mindspore/train/callback/_loss_monitor.py +3 -3
  232. mindspore/train/callback/_on_request_exit.py +145 -31
  233. mindspore/train/callback/_summary_collector.py +5 -5
  234. mindspore/train/callback/_tft_register.py +375 -0
  235. mindspore/train/dataset_helper.py +15 -3
  236. mindspore/train/metrics/metric.py +3 -3
  237. mindspore/train/metrics/roc.py +4 -4
  238. mindspore/train/mind_ir_pb2.py +44 -39
  239. mindspore/train/model.py +154 -58
  240. mindspore/train/serialization.py +342 -128
  241. mindspore/utils/__init__.py +21 -0
  242. mindspore/utils/utils.py +60 -0
  243. mindspore/version.py +1 -1
  244. {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/METADATA +13 -7
  245. {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/RECORD +248 -242
  246. mindspore/include/c_api/ms/abstract.h +0 -67
  247. mindspore/include/c_api/ms/attribute.h +0 -197
  248. mindspore/include/c_api/ms/base/handle_types.h +0 -43
  249. mindspore/include/c_api/ms/base/macros.h +0 -32
  250. mindspore/include/c_api/ms/base/status.h +0 -33
  251. mindspore/include/c_api/ms/base/types.h +0 -283
  252. mindspore/include/c_api/ms/context.h +0 -102
  253. mindspore/include/c_api/ms/graph.h +0 -160
  254. mindspore/include/c_api/ms/node.h +0 -606
  255. mindspore/include/c_api/ms/tensor.h +0 -161
  256. mindspore/include/c_api/ms/value.h +0 -84
  257. mindspore/mindspore_shared_lib.dll +0 -0
  258. mindspore/nn/extend/basic.py +0 -140
  259. mindspore/nn/extend/embedding.py +0 -143
  260. mindspore/nn/extend/layer/normalization.py +0 -109
  261. mindspore/nn/extend/pooling.py +0 -117
  262. mindspore/nn/layer/embedding_service.py +0 -531
  263. mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +0 -93
  264. mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +0 -66
  265. mindspore/ops/extend/__init__.py +0 -53
  266. mindspore/ops/extend/array_func.py +0 -218
  267. mindspore/ops/extend/math_func.py +0 -76
  268. mindspore/ops/extend/nn_func.py +0 -308
  269. mindspore/ops/silent_check.py +0 -162
  270. mindspore/profiler/parser/msadvisor_analyzer.py +0 -82
  271. mindspore/profiler/parser/msadvisor_parser.py +0 -240
  272. mindspore/train/callback/_mindio_ttp.py +0 -443
  273. {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/WHEEL +0 -0
  274. {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/entry_points.txt +0 -0
  275. {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/top_level.txt +0 -0
@@ -39,11 +39,12 @@ from ..auto_generate import (ExpandDims, Reshape, TensorShape, Transpose, Gather
39
39
  OnesLike, ZerosLike, Argmax, ArgMaxExt,
40
40
  ReverseV2, Diag, Eye, ScatterNd, ResizeNearestNeighborV2,
41
41
  GatherNd, GatherD, Range, MaskedFill, RightShift, NonZero,
42
- ResizeNearestNeighbor, Identity, Split, CumSum, CumProd,
42
+ ResizeNearestNeighbor, Identity, Split, CumSum, CumProd, MaskedSelect,
43
43
  Cummax, Cummin, Argmin, Concat, UnsortedSegmentSum, ScalarToTensor,
44
44
  Triu, BroadcastTo, StridedSlice, Select, TopkExt, SearchSorted)
45
45
  from .manually_defined import Rank, Shape, Tile, Cast, Ones, Zeros
46
46
  from ..auto_generate import ArgMaxWithValue, ArgMinWithValue
47
+ from ..auto_generate import TensorScatterElements as TensorScatterElementsExt
47
48
 
48
49
  class _ScatterOp(PrimitiveWithInfer):
49
50
  """
@@ -769,41 +770,15 @@ class Padding(Primitive):
769
770
 
770
771
  class UniqueWithPad(Primitive):
771
772
  """
772
- Returns unique elements and relative indexes in 1-D tensor, filled with padding num.
773
-
774
- The basic function is the same as the Unique operator, but the UniqueWithPad operator adds a Pad function.
775
- The returned tuple(`y`, `idx`) after the input Tensor `x` is processed by the unique operator,
776
- in which the shapes of `y` and `idx` are mostly not equal. Therefore, in order to solve the above situation,
777
- the UniqueWithPad operator will fill the `y` Tensor with the `pad_num` specified by the user
778
- to make it have the same shape as the Tensor `idx`.
779
-
780
- Refer to :func:`mindspore.ops.unique_with_pad` for more details.
781
-
782
- Inputs:
783
- - **x** (Tensor) - The tensor need to be unique. Must be 1-D vector with types: int32, int64.
784
- - **pad_num** (int) - Pad num. The data type is an int.
785
-
786
- Outputs:
787
- tuple(Tensor), tuple of 2 tensors, `y` and `idx`.
788
-
789
- - y (Tensor) - The unique elements filled with pad_num, the shape and data type same as `x`.
790
- - idx (Tensor) - The index of each value of `x` in the unique output `y`, the shape and data type same as `x`.
773
+ 'ops.UniqueWithPad' is deprecated from version 2.4 and will be removed in a future version.
774
+ Please use the :func:`mindspore.ops.unique` combined with :func:`mindspore.ops.pad` to realize
775
+ the same function.
791
776
 
792
777
  Supported Platforms:
793
- ``Ascend`` ``GPU`` ``CPU``
794
-
795
- Examples:
796
- >>> import mindspore
797
- >>> import numpy as np
798
- >>> from mindspore import Tensor, ops
799
- >>> x = Tensor(np.array([1, 1, 2, 2, 3, 3, 4, 5]), mindspore.int32)
800
- >>> pad_num = 8
801
- >>> output = ops.UniqueWithPad()(x, pad_num)
802
- >>> print(output)
803
- (Tensor(shape=[8], dtype=Int32, value= [1, 2, 3, 4, 5, 8, 8, 8]),
804
- Tensor(shape=[8], dtype=Int32, value= [0, 0, 1, 1, 2, 2, 3, 4]))
778
+ Deprecated
805
779
  """
806
780
 
781
+ @deprecated("2.4", "ops.unique and ops.pad", False)
807
782
  @prim_attr_register
808
783
  def __init__(self):
809
784
  """init UniqueWithPad"""
@@ -819,7 +794,7 @@ class Size(Primitive):
819
794
 
820
795
  Inputs:
821
796
  - **input_x** (Tensor) - Input parameters, the shape of tensor is :math:`(x_1, x_2, ..., x_R)`. The data type is
822
- `number <https://www.mindspore.cn/docs/en/master/api_python/mindspore.html#mindspore.dtype>`_.
797
+ `number <https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.dtype.html>`_.
823
798
 
824
799
  Outputs:
825
800
  int. A scalar representing the elements' size of `input_x`, tensor is the number of elements
@@ -2112,60 +2087,6 @@ class Rint(Primitive):
2112
2087
  self.init_prim_io_names(inputs=['x'], outputs=['output'])
2113
2088
 
2114
2089
 
2115
- class StridedSliceV2(Primitive):
2116
- r"""
2117
- StridedSliceV2 will be deprecated by StridedSlice in the future.
2118
- Extracts a strided slice of a tensor.
2119
- Refer to class StridedSlice for more details.
2120
-
2121
- Args:
2122
- begin_mask (int): Starting index of the slice. Default: ``0`` .
2123
- end_mask (int): Ending index of the slice. Default: ``0`` .
2124
- ellipsis_mask (int): An int mask. Default: ``0`` .
2125
- new_axis_mask (int): An int mask. Default: ``0`` .
2126
- shrink_axis_mask (int): An int mask. Default: ``0`` .
2127
-
2128
- Inputs:
2129
- - **input_x** (Tensor) - The input Tensor.
2130
- - **begin** (tuple[int]) - A tuple which represents the location where to start. Only
2131
- constant value is allowed.
2132
- - **end** (tuple[int]) - A tuple or which represents the maximum location where to end.
2133
- Only constant value is allowed.
2134
- - **strides** (tuple[int]) - A tuple which represents the stride is continuously added
2135
- before reaching the maximum location. Only constant value is allowed.
2136
-
2137
- Outputs:
2138
- Tensor, The output is explained by following example.
2139
-
2140
- Raises:
2141
- TypeError: If `begin_mask`, `end_mask`, `ellipsis_mask`, `new_axis_mask` or `shrink_axis_mask` is not an int.
2142
- TypeError: If `begin`, `end` or `strides` is not a tuple.
2143
- ValueError: If `begin_mask`, `end_mask`, `ellipsis_mask`, `new_axis_mask` or `shrink_axis_mask` is less than 0.
2144
-
2145
- Supported Platforms:
2146
- ``Ascend`` ``CPU``
2147
-
2148
- Examples:
2149
- >>> input_x = Tensor([[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]],
2150
- ... [[5, 5, 5], [6, 6, 6]]], mindspore.float32)
2151
- >>> strided_slice_v2 = ops.StridedSliceV2()
2152
- >>> output = strided_slice_v2(input_x, (1, 0, 2), (3, 1, 3), (1, 1, 1))
2153
- >>> print(output)
2154
- [[[3.]]
2155
- [[5.]]]
2156
- """
2157
-
2158
- @prim_attr_register
2159
- def __init__(self,
2160
- begin_mask=0,
2161
- end_mask=0,
2162
- ellipsis_mask=0,
2163
- new_axis_mask=0,
2164
- shrink_axis_mask=0):
2165
- """Initialize StridedSliceV2"""
2166
- self.init_prim_io_names(inputs=['x', 'begin', 'end', 'strides'], outputs=['output'])
2167
-
2168
-
2169
2090
  class DiagPart(PrimitiveWithCheck):
2170
2091
  r"""
2171
2092
 
@@ -4356,53 +4277,6 @@ class MaskedScatter(Primitive):
4356
4277
  self.init_prim_io_names(inputs=['x', 'mask', 'updates'], outputs=['y'])
4357
4278
 
4358
4279
 
4359
- class MaskedSelect(PrimitiveWithCheck):
4360
- """
4361
- Returns a new 1-D Tensor which indexes the `x` tensor according to the boolean `mask`.
4362
- The shapes of the `mask` tensor and the `x` tensor don't need to match, but they must be broadcastable.
4363
-
4364
- Inputs:
4365
- - **x** (Tensor) - Input Tensor of any dimension.
4366
- - **mask** (Tensor[bool]) - Boolean mask Tensor, has the same shape as `x`.
4367
-
4368
- Outputs:
4369
- A 1-D Tensor, with the same type as x.
4370
-
4371
- Raises:
4372
- TypeError: If `x` or `mask` is not a Tensor.
4373
- TypeError: If dtype of `mask` is not bool.
4374
-
4375
- Supported Platforms:
4376
- ``Ascend`` ``GPU`` ``CPU``
4377
-
4378
- Examples:
4379
- >>> import mindspore
4380
- >>> import numpy as np
4381
- >>> from mindspore import Tensor, ops
4382
- >>> x = Tensor(np.array([1, 2, 3, 4]), mindspore.int32)
4383
- >>> mask = Tensor(np.array([1, 0, 1, 0]), mindspore.bool_)
4384
- >>> output = ops.MaskedSelect()(x, mask)
4385
- >>> print(output)
4386
- [1 3]
4387
- >>> x = Tensor(2.1, mindspore.float32)
4388
- >>> mask = Tensor(True, mindspore.bool_)
4389
- >>> output = ops.MaskedSelect()(x, mask)
4390
- >>> print(output)
4391
- [2.1]
4392
- """
4393
-
4394
- @prim_attr_register
4395
- def __init__(self):
4396
- self.init_prim_io_names(inputs=['x', 'mask'], outputs=['output'])
4397
-
4398
- def check_shape(self, x_shape, mask_shape):
4399
- get_broadcast_shape(x_shape, mask_shape, self.name, arg_name1="x", arg_name2="mask")
4400
-
4401
- def check_dtype(self, x_dtype, mask_dtype):
4402
- validator.check_tensor_dtype_valid('mask', mask_dtype, [mstype.bool_], self.name)
4403
- validator.check_tensor_dtype_valid('x', x_dtype, (mstype.bool_,) + mstype.number_type, self.name)
4404
-
4405
-
4406
4280
  class _TensorScatterOp(PrimitiveWithInfer):
4407
4281
  """
4408
4282
  Defines TensorScatter Base Operators
@@ -4962,7 +4836,7 @@ class SplitV(Primitive):
4962
4836
  self.init_prim_io_names(inputs=['input_x'], outputs=['output'])
4963
4837
 
4964
4838
 
4965
- class TensorScatterElements(Primitive):
4839
+ class TensorScatterElements(TensorScatterElementsExt):
4966
4840
  """
4967
4841
  Write all elements in `updates` to the index specified by `indices` in `input_x` according to the reduction
4968
4842
  operation specified by `reduction`.
@@ -4977,6 +4851,9 @@ class TensorScatterElements(Primitive):
4977
4851
  .. warning::
4978
4852
  This is an experimental API that is subject to change or deletion.
4979
4853
 
4854
+ Note:
4855
+ The backward is supported only for the case `updates.shape == indices.shape`.
4856
+
4980
4857
  Args:
4981
4858
  axis (int, optional): Specify which axis to do scatter operation. Default: ``0`` .
4982
4859
  reduction (str, optional): Which reduction operation to scatter, default is ``"none"`` . Other option: "add".
@@ -4986,7 +4863,7 @@ class TensorScatterElements(Primitive):
4986
4863
  - **indices** (Tensor) - The index of `input_x` to do scatter operation whose data type must be int32 or
4987
4864
  int64. It has the same rank as `data`. And accepted range is [-s, s) where s is the size along axis.
4988
4865
  - **updates** (Tensor) - The tensor doing the scatter operation with `data`,
4989
- it has the same type as `data` and the same shape as `indices`.
4866
+ it has the same type as `data`.
4990
4867
 
4991
4868
  Outputs:
4992
4869
  Tensor, has the same shape and type as `data`.
@@ -5021,16 +4898,7 @@ class TensorScatterElements(Primitive):
5021
4898
 
5022
4899
  @prim_attr_register
5023
4900
  def __init__(self, axis=0, reduction="none"):
5024
- """Initialize TensorScatterElements"""
5025
- validator.check_value_type("axis", axis, [int], self.name)
5026
- validator.check_value_type("reduction", reduction, [str], self.name)
5027
- validator.check_string(reduction, ["none", "add"], "reduction", self.name)
5028
- self.init_prim_io_names(inputs=['data', 'indices', 'updates'], outputs=['y'])
5029
- target = context.get_context("device_target")
5030
- if reduction != 'none' and target.lower() == "ascend":
5031
- raise ValueError(f"For '{self.name}', "
5032
- f"Currently Ascend device_target only support `reduction`='none', "
5033
- f"but got {reduction}")
4901
+ super().__init__(axis, reduce=reduction)
5034
4902
 
5035
4903
 
5036
4904
  class ExtractVolumePatches(Primitive):
@@ -54,7 +54,7 @@ class ReduceOp:
54
54
  For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
55
55
  without any third-party or configuration file dependencies.
56
56
  Please see the `msrun start up
57
- <https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/msrun_launcher.html>`_
57
+ <https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
58
58
  for more details.
59
59
 
60
60
  This example should be run with multiple devices.
@@ -141,7 +141,7 @@ class AllReduce(Primitive):
141
141
  For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
142
142
  without any third-party or configuration file dependencies.
143
143
  Please see the `msrun start up
144
- <https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/msrun_launcher.html>`_
144
+ <https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
145
145
  for more details.
146
146
 
147
147
  This example should be run with 2 devices.
@@ -178,14 +178,15 @@ class AllReduce(Primitive):
178
178
  @prim_attr_register
179
179
  def __init__(self, op=ReduceOp.SUM, group=GlobalComm.WORLD_COMM_GROUP):
180
180
  """Initialize AllReduce."""
181
+ self.group = _get_group(group)
181
182
  if not isinstance(op, type(ReduceOp.SUM)):
182
183
  raise TypeError(f"For '{self.name}', the 'op' must be str, but got {type(op).__name__}.")
183
- if not isinstance(_get_group(group), str):
184
+ if not isinstance(self.group, str):
184
185
  raise TypeError(f"For '{self.name}', the 'group' must be str, "
185
- f"but got {type(_get_group(group)).__name__}.")
186
- check_hcom_group_valid(group, prim_name=self.name)
186
+ f"but got {type(self.group).__name__}.")
187
+ check_hcom_group_valid(self.group, prim_name=self.name)
187
188
  self.op = op
188
- self.add_prim_attr('group', _get_group(group))
189
+ self.add_prim_attr('group', self.group)
189
190
  self.add_prim_attr('fusion', 0)
190
191
  self.add_prim_attr('index', 0)
191
192
  self.add_prim_attr('no_eliminate', True)
@@ -230,7 +231,7 @@ class Reduce(PrimitiveWithInfer):
230
231
  For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method without any third-party
231
232
  or configuration file dependencies.
232
233
  Please see the `msrun start up
233
- <https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/msrun_launcher.html>`_
234
+ <https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
234
235
  for more details.
235
236
 
236
237
  This example should be run with 4 devices.
@@ -314,7 +315,7 @@ class AllGather(PrimitiveWithInfer):
314
315
  For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
315
316
  without any third-party or configuration file dependencies.
316
317
  Please see the `msrun start up
317
- <https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/msrun_launcher.html>`_
318
+ <https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
318
319
  for more details.
319
320
 
320
321
  This example should be run with 2 devices.
@@ -354,12 +355,13 @@ class AllGather(PrimitiveWithInfer):
354
355
  @prim_attr_register
355
356
  def __init__(self, group=GlobalComm.WORLD_COMM_GROUP):
356
357
  """Initialize AllGather."""
357
- validator.check_value_type('group', _get_group(group), (str,), self.name)
358
- self.rank = get_rank(_get_group(group))
359
- self.rank_size = get_group_size(_get_group(group))
358
+ self.group = _get_group(group)
359
+ validator.check_value_type('group', self.group, (str,), self.name)
360
+ self.rank = get_rank(self.group)
361
+ self.rank_size = get_group_size(self.group)
360
362
  validator.check('rank', self.rank, 'rank_size', self.rank_size, validator.LT, self.name)
361
363
  self.add_prim_attr('rank_size', self.rank_size)
362
- self.add_prim_attr('group', _get_group(group))
364
+ self.add_prim_attr('group', self.group)
363
365
  self.add_prim_attr('fusion', 0)
364
366
  self.add_prim_attr('mean_flag', False)
365
367
  self.add_prim_attr('no_eliminate', True)
@@ -375,25 +377,6 @@ class AllGather(PrimitiveWithInfer):
375
377
  return x_dtype
376
378
 
377
379
 
378
- class AShardIdentity(PrimitiveWithInfer):
379
- """
380
- Auto parallel virtual operator. Identity operator only for shard function.
381
- Do nothing in terms of infer_shape, infer_dtype, and the tensor.
382
-
383
- It is only for internal use of parallel modules and cannot be called by users.
384
- """
385
-
386
- @prim_attr_register
387
- def __init__(self):
388
- pass
389
-
390
- def infer_shape(self, x_shape):
391
- return x_shape
392
-
393
- def infer_dtype(self, x_dtype):
394
- return x_dtype
395
-
396
-
397
380
  class _MiniStepAllGather(PrimitiveWithInfer):
398
381
  """
399
382
  Auto parallel virtual operator. Do nothing in forward, do reducescatter in backward in mini-step. It is only for
@@ -555,7 +538,7 @@ class ReduceScatter(Primitive):
555
538
  For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
556
539
  without any third-party or configuration file dependencies.
557
540
  Please see the `msrun start up
558
- <https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/msrun_launcher.html>`_
541
+ <https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
559
542
  for more details.
560
543
 
561
544
  This example should be run with 2 devices.
@@ -597,11 +580,12 @@ class ReduceScatter(Primitive):
597
580
  def __init__(self, op=ReduceOp.SUM, group=GlobalComm.WORLD_COMM_GROUP):
598
581
  """Initialize ReduceScatter."""
599
582
  validator.check_value_type('op', op, (type(ReduceOp.SUM),), self.name)
600
- validator.check_value_type('group', _get_group(group), (str,), self.name)
583
+ self.group = _get_group(group)
584
+ validator.check_value_type('group', self.group, (str,), self.name)
601
585
  self.op = op
602
- self.rank_size = get_group_size(_get_group(group))
586
+ self.rank_size = get_group_size(self.group)
603
587
  self.add_prim_attr('rank_size', self.rank_size)
604
- self.add_prim_attr('group', _get_group(group))
588
+ self.add_prim_attr('group', self.group)
605
589
  self.add_prim_attr('fusion', 0)
606
590
  self.add_prim_attr('no_eliminate', True)
607
591
 
@@ -692,7 +676,7 @@ class Broadcast(PrimitiveWithInfer):
692
676
  For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
693
677
  without any third-party or configuration file dependencies.
694
678
  Please see the `msrun start up
695
- <https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/msrun_launcher.html>`_
679
+ <https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
696
680
  for more details.
697
681
 
698
682
  This example should be run with 2 devices.
@@ -922,7 +906,7 @@ class AlltoAll(PrimitiveWithInfer):
922
906
  For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
923
907
  without any third-party or configuration file dependencies.
924
908
  Please see the `msrun start up
925
- <https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/msrun_launcher.html>`_
909
+ <https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
926
910
  for more details.
927
911
 
928
912
  This example should be run with 8 devices.
@@ -1041,7 +1025,7 @@ class NeighborExchangeV2(Primitive):
1041
1025
  For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
1042
1026
  without any third-party or configuration file dependencies.
1043
1027
  Please see the `msrun start up
1044
- <https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/msrun_launcher.html>`_
1028
+ <https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
1045
1029
  for more details.
1046
1030
 
1047
1031
  This example should be run with 2 devices.
@@ -1158,7 +1142,7 @@ class CollectiveScatter(Primitive):
1158
1142
  For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
1159
1143
  without any third-party or configuration file dependencies.
1160
1144
  Please see the `msrun start up
1161
- <https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/msrun_launcher.html>`_
1145
+ <https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
1162
1146
  for more details.
1163
1147
 
1164
1148
  This example should be run with 2 devices.
@@ -1243,7 +1227,7 @@ class CollectiveGather(Primitive):
1243
1227
  For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
1244
1228
  without any third-party or configuration file dependencies.
1245
1229
  Please see the `msrun start up
1246
- <https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/msrun_launcher.html>`_
1230
+ <https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
1247
1231
  for more details.
1248
1232
 
1249
1233
  This example should be run with 4 devices.
@@ -1308,8 +1292,6 @@ class Barrier(PrimitiveWithInfer):
1308
1292
  Raises:
1309
1293
  TypeError: If `group` is not a str.
1310
1294
  RuntimeError: If backend is invalid, or distributed initialization fails.
1311
- ValueError: If the local rank id of the calling process in the group
1312
- is larger than the group's rank size.
1313
1295
 
1314
1296
  Supported Platforms:
1315
1297
  ``Ascend``
@@ -1321,7 +1303,7 @@ class Barrier(PrimitiveWithInfer):
1321
1303
  For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
1322
1304
  without any third-party or configuration file dependencies.
1323
1305
  Please see the `msrun start up
1324
- <https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/msrun_launcher.html>`_
1306
+ <https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
1325
1307
  for more details.
1326
1308
 
1327
1309
  This example should be run with 2 devices.
@@ -1395,7 +1377,7 @@ class Send(PrimitiveWithInfer):
1395
1377
  For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
1396
1378
  without any third-party or configuration file dependencies.
1397
1379
  Please see the `msrun start up
1398
- <https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/msrun_launcher.html>`_
1380
+ <https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
1399
1381
  for more details.
1400
1382
 
1401
1383
  This example should be run with 2 devices.
@@ -1431,7 +1413,7 @@ class Send(PrimitiveWithInfer):
1431
1413
  def __init__(self, sr_tag, dest_rank, group=GlobalComm.WORLD_COMM_GROUP, group_back=GlobalComm.WORLD_COMM_GROUP):
1432
1414
  self.rank = dest_rank
1433
1415
  self.sr_tag = sr_tag
1434
- self.group = group
1416
+ self.group = _get_group(group)
1435
1417
  self.add_prim_attr("no_eliminate", True)
1436
1418
 
1437
1419
  def infer_shape(self, x_shape):
@@ -1479,7 +1461,7 @@ class Receive(PrimitiveWithInfer):
1479
1461
  For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
1480
1462
  without any third-party or configuration file dependencies.
1481
1463
  Please see the `msrun start up
1482
- <https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/msrun_launcher.html>`_
1464
+ <https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
1483
1465
  for more details.
1484
1466
 
1485
1467
  This example should be run with 2 devices.
@@ -1517,7 +1499,7 @@ class Receive(PrimitiveWithInfer):
1517
1499
  self.tag = sr_tag
1518
1500
  self.shape = shape
1519
1501
  self.dtype = dtype
1520
- self.group = group
1502
+ self.group = _get_group(group)
1521
1503
  self.add_prim_attr("no_eliminate", True)
1522
1504
  valid_type = [mstype.float16, mstype.float32, mstype.float64, mstype.bfloat16,
1523
1505
  mstype.int8, mstype.int16, mstype.int32, mstype.int64,
@@ -1695,9 +1677,32 @@ class _VirtualAssignAdd(PrimitiveWithInfer):
1695
1677
 
1696
1678
  def infer_dtype(self, x_dtype, y_dtype):
1697
1679
  return x_dtype
1680
+
1681
+
1698
1682
  virtual_assign_add = _VirtualAssignAdd()
1699
1683
 
1700
1684
 
1685
+ class _VirtualAssignKvCache(PrimitiveWithInfer):
1686
+ """
1687
+ Auto parallel virtual operator. Do nothing in forward, do Assign kv cache in backward. It is only for
1688
+ internal use of parallel modules and cannot be called by users.
1689
+
1690
+ """
1691
+
1692
+ @prim_attr_register
1693
+ def __init__(self):
1694
+ """Initialize _VirtualAssignAdd."""
1695
+ self.add_prim_attr('order_enforce_skip', True)
1696
+ self.add_prim_attr('side_effect_backprop_mem', True)
1697
+
1698
+ def infer_shape(self, x_shape, y_shape, kv_equal_shape):
1699
+ return x_shape
1700
+
1701
+ def infer_dtype(self, x_dtype, y_dtype, kv_equal_dtype):
1702
+ return x_dtype
1703
+ virtual_assign_kv_cache = _VirtualAssignKvCache()
1704
+
1705
+
1701
1706
  class _VirtualAccuGrad(PrimitiveWithInfer):
1702
1707
  """
1703
1708
  Auto parallel virtual operator. Do nothing in forward, return y in backward. It is only for
@@ -1834,7 +1839,7 @@ class BatchISendIRecv(PrimitiveWithInfer):
1834
1839
  without any third-party or configuration file dependencies.
1835
1840
 
1836
1841
  Please see the `msrun start up
1837
- <https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/msrun_launcher.html>`_
1842
+ <https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
1838
1843
  for more details.
1839
1844
 
1840
1845
  This example should be run with 2 devices.
@@ -1924,6 +1929,7 @@ class AlltoAllV(PrimitiveWithInfer):
1924
1929
  recv_numel_list(Union[tuple[int], list[int]]): split numel to gather from different remote rank.
1925
1930
  group (str): The communication group to work on. Default: ``GlobalComm.WORLD_COMM_GROUP``, which
1926
1931
  means ``"hccl_world_group"`` in Ascend, and ``"nccl_world_group"`` in GPU.
1932
+ TODO:
1927
1933
 
1928
1934
  Inputs:
1929
1935
  - **input_x** (Tensor) - flatten tensor to scatter. The shape of tensor is :math:`(x_1)`.
@@ -1946,7 +1952,7 @@ class AlltoAllV(PrimitiveWithInfer):
1946
1952
  without any third-party or configuration file dependencies.
1947
1953
 
1948
1954
  Please see the `msrun start up
1949
- <https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/msrun_launcher.html>`_
1955
+ <https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
1950
1956
  for more details.
1951
1957
 
1952
1958
  This example should be run with 2 devices.
@@ -1986,11 +1992,15 @@ class AlltoAllV(PrimitiveWithInfer):
1986
1992
  """
1987
1993
 
1988
1994
  @prim_attr_register
1989
- def __init__(self, send_numel_list, recv_numel_list, group=None):
1995
+ def __init__(self, send_numel_list, recv_numel_list, group=None, split_sizes_empty=False):
1990
1996
  validator.check_value_type("send_numel_list", send_numel_list, [tuple, list], self.name)
1991
1997
  validator.check_value_type("recv_numel_list", recv_numel_list, [tuple, list], self.name)
1992
- if group is None:
1993
- group = GlobalComm.WORLD_COMM_GROUP
1994
- self.add_prim_attr('group', group)
1998
+ self.group = GlobalComm.WORLD_COMM_GROUP if group is None else _get_group(group)
1999
+ self.send_numel_list = send_numel_list
2000
+ self.recv_numel_list = recv_numel_list
2001
+ self.split_sizes_empty = split_sizes_empty
2002
+ self.rank_size = get_group_size(self.group)
2003
+
2004
+ self.add_prim_attr('group', self.group)
1995
2005
  self.add_prim_attr('send_numel_list', send_numel_list)
1996
2006
  self.add_prim_attr('recv_numel_list', recv_numel_list)