mindspore 2.3.0__cp39-cp39-win_amd64.whl → 2.4.0__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (285) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +3 -1
  3. mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
  4. mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
  5. mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
  6. mindspore/_checkparam.py +50 -9
  7. mindspore/_extends/parse/compile_config.py +41 -0
  8. mindspore/_extends/parse/parser.py +9 -7
  9. mindspore/_extends/parse/standard_method.py +52 -14
  10. mindspore/_extends/pijit/pijit_func_white_list.py +350 -24
  11. mindspore/amp.py +24 -10
  12. mindspore/avcodec-59.dll +0 -0
  13. mindspore/avdevice-59.dll +0 -0
  14. mindspore/avfilter-8.dll +0 -0
  15. mindspore/avformat-59.dll +0 -0
  16. mindspore/avutil-57.dll +0 -0
  17. mindspore/common/__init__.py +6 -4
  18. mindspore/common/_pijit_context.py +190 -0
  19. mindspore/common/_register_for_tensor.py +2 -1
  20. mindspore/common/_tensor_overload.py +139 -0
  21. mindspore/common/api.py +102 -87
  22. mindspore/common/dump.py +5 -6
  23. mindspore/common/generator.py +1 -7
  24. mindspore/common/hook_handle.py +14 -26
  25. mindspore/common/mindir_util.py +2 -2
  26. mindspore/common/parameter.py +46 -13
  27. mindspore/common/recompute.py +39 -9
  28. mindspore/common/sparse_tensor.py +7 -3
  29. mindspore/common/tensor.py +209 -29
  30. mindspore/communication/__init__.py +1 -1
  31. mindspore/communication/_comm_helper.py +38 -3
  32. mindspore/communication/comm_func.py +310 -55
  33. mindspore/communication/management.py +14 -14
  34. mindspore/context.py +123 -22
  35. mindspore/dataset/__init__.py +1 -1
  36. mindspore/dataset/audio/__init__.py +1 -1
  37. mindspore/dataset/core/config.py +7 -0
  38. mindspore/dataset/core/validator_helpers.py +7 -0
  39. mindspore/dataset/engine/cache_client.py +1 -1
  40. mindspore/dataset/engine/datasets.py +72 -44
  41. mindspore/dataset/engine/datasets_audio.py +7 -7
  42. mindspore/dataset/engine/datasets_standard_format.py +53 -3
  43. mindspore/dataset/engine/datasets_text.py +20 -20
  44. mindspore/dataset/engine/datasets_user_defined.py +174 -104
  45. mindspore/dataset/engine/datasets_vision.py +33 -33
  46. mindspore/dataset/engine/iterators.py +29 -0
  47. mindspore/dataset/engine/obs/util.py +7 -0
  48. mindspore/dataset/engine/queue.py +114 -60
  49. mindspore/dataset/engine/serializer_deserializer.py +2 -2
  50. mindspore/dataset/engine/validators.py +34 -14
  51. mindspore/dataset/text/__init__.py +1 -4
  52. mindspore/dataset/transforms/__init__.py +0 -3
  53. mindspore/dataset/utils/line_reader.py +2 -0
  54. mindspore/dataset/vision/__init__.py +1 -4
  55. mindspore/dataset/vision/utils.py +1 -1
  56. mindspore/dataset/vision/validators.py +2 -1
  57. mindspore/dnnl.dll +0 -0
  58. mindspore/{nn/extend → experimental/es}/__init__.py +4 -11
  59. mindspore/experimental/es/embedding_service.py +883 -0
  60. mindspore/{nn/layer → experimental/es}/embedding_service_layer.py +218 -30
  61. mindspore/experimental/llm_boost/__init__.py +21 -0
  62. mindspore/{nn/extend/layer → experimental/llm_boost/atb}/__init__.py +4 -8
  63. mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
  64. mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
  65. mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
  66. mindspore/experimental/llm_boost/register.py +129 -0
  67. mindspore/experimental/llm_boost/utils.py +31 -0
  68. mindspore/experimental/optim/adamw.py +85 -0
  69. mindspore/experimental/optim/optimizer.py +3 -0
  70. mindspore/hal/__init__.py +3 -3
  71. mindspore/hal/contiguous_tensors_handle.py +175 -0
  72. mindspore/hal/stream.py +18 -0
  73. mindspore/include/api/model_group.h +13 -1
  74. mindspore/include/api/types.h +10 -10
  75. mindspore/include/dataset/config.h +2 -2
  76. mindspore/include/dataset/constants.h +2 -2
  77. mindspore/include/dataset/execute.h +2 -2
  78. mindspore/include/dataset/vision.h +4 -0
  79. mindspore/jpeg62.dll +0 -0
  80. mindspore/log.py +1 -1
  81. mindspore/mindrecord/filewriter.py +68 -51
  82. mindspore/mindspore_backend.dll +0 -0
  83. mindspore/mindspore_common.dll +0 -0
  84. mindspore/mindspore_core.dll +0 -0
  85. mindspore/mindspore_glog.dll +0 -0
  86. mindspore/mindspore_np_dtype.dll +0 -0
  87. mindspore/mindspore_ops.dll +0 -0
  88. mindspore/mint/__init__.py +495 -46
  89. mindspore/mint/distributed/__init__.py +31 -0
  90. mindspore/mint/distributed/distributed.py +254 -0
  91. mindspore/mint/nn/__init__.py +266 -21
  92. mindspore/mint/nn/functional.py +125 -19
  93. mindspore/mint/nn/layer/__init__.py +39 -0
  94. mindspore/mint/nn/layer/activation.py +133 -0
  95. mindspore/mint/nn/layer/normalization.py +477 -0
  96. mindspore/mint/nn/layer/pooling.py +110 -0
  97. mindspore/mint/optim/adamw.py +28 -7
  98. mindspore/mint/special/__init__.py +63 -0
  99. mindspore/multiprocessing/__init__.py +2 -1
  100. mindspore/nn/__init__.py +0 -1
  101. mindspore/nn/cell.py +275 -93
  102. mindspore/nn/layer/activation.py +211 -44
  103. mindspore/nn/layer/basic.py +113 -3
  104. mindspore/nn/layer/embedding.py +120 -2
  105. mindspore/nn/layer/normalization.py +101 -5
  106. mindspore/nn/layer/padding.py +34 -48
  107. mindspore/nn/layer/pooling.py +161 -7
  108. mindspore/nn/layer/transformer.py +3 -3
  109. mindspore/nn/loss/__init__.py +2 -2
  110. mindspore/nn/loss/loss.py +84 -6
  111. mindspore/nn/optim/__init__.py +2 -1
  112. mindspore/nn/optim/adadelta.py +1 -1
  113. mindspore/nn/optim/adam.py +1 -1
  114. mindspore/nn/optim/lamb.py +1 -1
  115. mindspore/nn/optim/tft_wrapper.py +127 -0
  116. mindspore/nn/wrap/cell_wrapper.py +12 -23
  117. mindspore/nn/wrap/grad_reducer.py +5 -5
  118. mindspore/nn/wrap/loss_scale.py +17 -3
  119. mindspore/numpy/__init__.py +1 -1
  120. mindspore/numpy/array_creations.py +65 -68
  121. mindspore/numpy/array_ops.py +64 -60
  122. mindspore/numpy/fft.py +610 -75
  123. mindspore/numpy/logic_ops.py +11 -10
  124. mindspore/numpy/math_ops.py +85 -84
  125. mindspore/numpy/utils_const.py +4 -4
  126. mindspore/opencv_core452.dll +0 -0
  127. mindspore/opencv_imgcodecs452.dll +0 -0
  128. mindspore/opencv_imgproc452.dll +0 -0
  129. mindspore/ops/__init__.py +6 -4
  130. mindspore/ops/_grad_experimental/grad_comm_ops.py +47 -3
  131. mindspore/ops/_grad_experimental/grad_math_ops.py +0 -22
  132. mindspore/ops/_vmap/vmap_array_ops.py +2 -4
  133. mindspore/ops/_vmap/vmap_math_ops.py +17 -1
  134. mindspore/ops/_vmap/vmap_nn_ops.py +43 -2
  135. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +85 -7
  136. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +2 -0
  137. mindspore/ops/auto_generate/gen_extend_func.py +734 -13
  138. mindspore/ops/auto_generate/gen_ops_def.py +2420 -381
  139. mindspore/ops/auto_generate/gen_ops_prim.py +5196 -1659
  140. mindspore/ops/auto_generate/pyboost_inner_prim.py +176 -56
  141. mindspore/ops/composite/base.py +85 -48
  142. mindspore/ops/composite/multitype_ops/_compile_utils.py +1 -0
  143. mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -2
  144. mindspore/ops/function/__init__.py +22 -0
  145. mindspore/ops/function/array_func.py +490 -153
  146. mindspore/ops/function/debug_func.py +113 -1
  147. mindspore/ops/function/fft_func.py +15 -2
  148. mindspore/ops/function/grad/grad_func.py +3 -2
  149. mindspore/ops/function/math_func.py +558 -207
  150. mindspore/ops/function/nn_func.py +817 -383
  151. mindspore/ops/function/other_func.py +3 -2
  152. mindspore/ops/function/random_func.py +184 -8
  153. mindspore/ops/function/reshard_func.py +13 -11
  154. mindspore/ops/function/sparse_unary_func.py +1 -1
  155. mindspore/ops/function/vmap_func.py +3 -2
  156. mindspore/ops/functional.py +24 -14
  157. mindspore/ops/op_info_register.py +3 -3
  158. mindspore/ops/operations/__init__.py +6 -1
  159. mindspore/ops/operations/_grad_ops.py +2 -76
  160. mindspore/ops/operations/_infer_ops.py +1 -1
  161. mindspore/ops/operations/_inner_ops.py +71 -94
  162. mindspore/ops/operations/array_ops.py +12 -146
  163. mindspore/ops/operations/comm_ops.py +42 -53
  164. mindspore/ops/operations/custom_ops.py +83 -19
  165. mindspore/ops/operations/debug_ops.py +42 -10
  166. mindspore/ops/operations/manually_defined/_inner.py +12 -0
  167. mindspore/ops/operations/manually_defined/ops_def.py +265 -10
  168. mindspore/ops/operations/math_ops.py +12 -223
  169. mindspore/ops/operations/nn_ops.py +20 -114
  170. mindspore/ops/operations/other_ops.py +7 -4
  171. mindspore/ops/operations/random_ops.py +46 -1
  172. mindspore/ops/primitive.py +18 -6
  173. mindspore/ops_generate/arg_dtype_cast.py +2 -0
  174. mindspore/ops_generate/gen_aclnn_implement.py +11 -11
  175. mindspore/ops_generate/gen_constants.py +36 -0
  176. mindspore/ops_generate/gen_ops.py +67 -52
  177. mindspore/ops_generate/gen_ops_inner_prim.py +1 -1
  178. mindspore/ops_generate/gen_pyboost_func.py +131 -47
  179. mindspore/ops_generate/op_proto.py +10 -3
  180. mindspore/ops_generate/pyboost_utils.py +14 -1
  181. mindspore/ops_generate/template.py +43 -21
  182. mindspore/parallel/__init__.py +3 -1
  183. mindspore/parallel/_auto_parallel_context.py +28 -8
  184. mindspore/parallel/_cell_wrapper.py +83 -0
  185. mindspore/parallel/_parallel_serialization.py +47 -19
  186. mindspore/parallel/_tensor.py +81 -11
  187. mindspore/parallel/_utils.py +13 -1
  188. mindspore/parallel/algo_parameter_config.py +5 -5
  189. mindspore/parallel/checkpoint_transform.py +46 -39
  190. mindspore/parallel/cluster/process_entity/__init__.py +1 -1
  191. mindspore/parallel/cluster/process_entity/_api.py +31 -23
  192. mindspore/parallel/cluster/process_entity/_utils.py +2 -27
  193. mindspore/parallel/parameter_broadcast.py +3 -4
  194. mindspore/parallel/shard.py +162 -31
  195. mindspore/parallel/transform_safetensors.py +993 -0
  196. mindspore/profiler/__init__.py +2 -1
  197. mindspore/profiler/common/constant.py +29 -0
  198. mindspore/profiler/common/registry.py +47 -0
  199. mindspore/profiler/common/util.py +28 -0
  200. mindspore/profiler/dynamic_profiler.py +694 -0
  201. mindspore/profiler/envprofiling.py +17 -19
  202. mindspore/profiler/parser/ascend_analysis/constant.py +18 -0
  203. mindspore/profiler/parser/ascend_analysis/file_manager.py +25 -4
  204. mindspore/profiler/parser/ascend_analysis/function_event.py +43 -19
  205. mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +31 -26
  206. mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +56 -10
  207. mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +55 -8
  208. mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
  209. mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +27 -20
  210. mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +9 -2
  211. mindspore/profiler/parser/ascend_msprof_exporter.py +5 -4
  212. mindspore/profiler/parser/ascend_timeline_generator.py +27 -25
  213. mindspore/profiler/parser/base_timeline_generator.py +19 -25
  214. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +25 -12
  215. mindspore/profiler/parser/framework_parser.py +1 -391
  216. mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
  217. mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
  218. mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
  219. mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
  220. mindspore/profiler/parser/memory_usage_parser.py +0 -154
  221. mindspore/profiler/parser/profiler_info.py +78 -6
  222. mindspore/profiler/profiler.py +153 -0
  223. mindspore/profiler/profiling.py +280 -412
  224. mindspore/rewrite/__init__.py +1 -2
  225. mindspore/rewrite/common/namespace.py +4 -4
  226. mindspore/rewrite/symbol_tree/symbol_tree.py +3 -3
  227. mindspore/run_check/_check_version.py +36 -103
  228. mindspore/safeguard/rewrite_obfuscation.py +591 -247
  229. mindspore/swresample-4.dll +0 -0
  230. mindspore/swscale-6.dll +0 -0
  231. mindspore/tinyxml2.dll +0 -0
  232. mindspore/train/__init__.py +4 -3
  233. mindspore/train/_utils.py +28 -2
  234. mindspore/train/amp.py +171 -53
  235. mindspore/train/callback/__init__.py +2 -2
  236. mindspore/train/callback/_callback.py +4 -4
  237. mindspore/train/callback/_checkpoint.py +85 -22
  238. mindspore/train/callback/_cluster_monitor.py +1 -1
  239. mindspore/train/callback/_flops_collector.py +1 -0
  240. mindspore/train/callback/_loss_monitor.py +3 -3
  241. mindspore/train/callback/_on_request_exit.py +134 -31
  242. mindspore/train/callback/_summary_collector.py +5 -5
  243. mindspore/train/callback/_tft_register.py +352 -0
  244. mindspore/train/dataset_helper.py +7 -3
  245. mindspore/train/metrics/metric.py +3 -3
  246. mindspore/train/metrics/roc.py +4 -4
  247. mindspore/train/mind_ir_pb2.py +44 -39
  248. mindspore/train/model.py +134 -58
  249. mindspore/train/serialization.py +336 -112
  250. mindspore/turbojpeg.dll +0 -0
  251. mindspore/utils/__init__.py +21 -0
  252. mindspore/utils/utils.py +60 -0
  253. mindspore/version.py +1 -1
  254. {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/METADATA +6 -2
  255. {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/RECORD +258 -252
  256. mindspore/include/c_api/ms/abstract.h +0 -67
  257. mindspore/include/c_api/ms/attribute.h +0 -197
  258. mindspore/include/c_api/ms/base/handle_types.h +0 -43
  259. mindspore/include/c_api/ms/base/macros.h +0 -32
  260. mindspore/include/c_api/ms/base/status.h +0 -33
  261. mindspore/include/c_api/ms/base/types.h +0 -283
  262. mindspore/include/c_api/ms/context.h +0 -102
  263. mindspore/include/c_api/ms/graph.h +0 -160
  264. mindspore/include/c_api/ms/node.h +0 -606
  265. mindspore/include/c_api/ms/tensor.h +0 -161
  266. mindspore/include/c_api/ms/value.h +0 -84
  267. mindspore/mindspore_shared_lib.dll +0 -0
  268. mindspore/nn/extend/basic.py +0 -140
  269. mindspore/nn/extend/embedding.py +0 -143
  270. mindspore/nn/extend/layer/normalization.py +0 -109
  271. mindspore/nn/extend/pooling.py +0 -117
  272. mindspore/nn/layer/embedding_service.py +0 -531
  273. mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +0 -93
  274. mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +0 -66
  275. mindspore/ops/extend/__init__.py +0 -53
  276. mindspore/ops/extend/array_func.py +0 -218
  277. mindspore/ops/extend/math_func.py +0 -76
  278. mindspore/ops/extend/nn_func.py +0 -308
  279. mindspore/ops/silent_check.py +0 -162
  280. mindspore/profiler/parser/msadvisor_analyzer.py +0 -82
  281. mindspore/profiler/parser/msadvisor_parser.py +0 -240
  282. mindspore/train/callback/_mindio_ttp.py +0 -443
  283. {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/WHEEL +0 -0
  284. {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/entry_points.txt +0 -0
  285. {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/top_level.txt +0 -0
@@ -35,8 +35,8 @@ from ..auto_generate import (AbsGrad, ACosGrad, LogitGrad, AcoshGrad, AsinGrad,
35
35
  SigmoidGrad, HSwishGrad, NLLLossGrad, AtanGrad, GridSampler3DGrad, GridSampler2DGrad,
36
36
  ResizeBicubicGrad, HSigmoidGrad, CholeskyGrad, ResizeNearestNeighborGrad, LayerNormGrad,
37
37
  HShrinkGrad, LayerNormGradGrad, SiLUGrad, MaximumGrad, MaximumGradGrad, RmsNormGrad,
38
- FlashAttentionScoreGrad, UpsampleTrilinear3DGrad, UpsampleNearest3DGrad,
39
- BinaryCrossEntropyGrad)
38
+ FlashAttentionScoreGrad, UpsampleTrilinear3DGrad, UpsampleNearest3DGrad, MaskedSelectGrad,
39
+ BinaryCrossEntropyGrad, SoftShrinkGrad, SeluGrad)
40
40
 
41
41
 
42
42
  class SparseFillEmptyRowsGrad(Primitive):
@@ -1658,35 +1658,6 @@ class SoftMarginLossGrad(Primitive):
1658
1658
  self.reduction = validator.check_string(reduction, ['none', 'sum', 'mean'], 'reduction', self.name)
1659
1659
 
1660
1660
 
1661
- class StridedSliceV2Grad(Primitive):
1662
- """
1663
- Performs grad of StridedSliceV2 operation.
1664
-
1665
- Inputs:
1666
- - **shapex** (Tensor) - StridedSliceV2 shape of input
1667
- - **begin** (tuple[int]) - A tuple which represents the location where to start. Only
1668
- constant value is allowed.
1669
- - **end** (tuple[int]) - A tuple or which represents the maximum location where to end.
1670
- Only constant value is allowed.
1671
- - **strides** (tuple[int]) - A tuple which represents the stride is continuously added
1672
- before reaching the maximum location. Only constant value is allowed.
1673
- - **dy** (Tensor) - The output of StridedSliceV2
1674
-
1675
- Outputs:
1676
- Tensor, the shape same as the input of StridedSliceV2
1677
- """
1678
-
1679
- @prim_attr_register
1680
- def __init__(self,
1681
- begin_mask=0,
1682
- end_mask=0,
1683
- ellipsis_mask=0,
1684
- new_axis_mask=0,
1685
- shrink_axis_mask=0):
1686
- """Initialize StridedSliceV2Grad"""
1687
- self.init_prim_io_names(inputs=['shapex', 'begin', 'end', 'strides', 'dy'], outputs=['output'])
1688
-
1689
-
1690
1661
  class StridedSliceGrad(Primitive):
1691
1662
  """
1692
1663
  Performs grad of StridedSlice operation.
@@ -1991,51 +1962,6 @@ class MvlgammaGrad(Primitive):
1991
1962
  self.p = validator.check_value_type('p', p, [int], self.name)
1992
1963
 
1993
1964
 
1994
- class MaskedSelectGrad(PrimitiveWithInfer):
1995
- """Computes gradient for MaskedSelect."""
1996
-
1997
- @prim_attr_register
1998
- def __init__(self):
1999
- pass
2000
-
2001
- def infer_shape(self, x, mask, grad):
2002
- return x
2003
-
2004
- def infer_dtype(self, x, mask, grad):
2005
- return x
2006
-
2007
-
2008
- class SoftShrinkGrad(Primitive):
2009
- r"""
2010
- Gradients for SoftShrink operation.
2011
-
2012
- Args:
2013
- lambd – The \lambdaλ (must be no less than zero) value for the Softshrink formulation. Default: 0.5.
2014
-
2015
- Inputs:
2016
- - **input_grad** (Tensor) - The input gradient.
2017
- - **input_x** (Tensor) - The input of SoftShrink with data type of float16 or float32.
2018
- Any number of additional dimensions.
2019
-
2020
- Outputs:
2021
- output - Tensor, has the same shape and data type as input_x.
2022
-
2023
- Raises:
2024
- TypeError: If lambd is not a float.
2025
- TypeError: If dtype of input_x is neither float16 nor float32.
2026
- ValueError: If lambd is less than to 0.
2027
-
2028
- Supported Platforms:
2029
- ``Ascend``
2030
- """
2031
-
2032
- @prim_attr_register
2033
- def __init__(self, lambd=0.5):
2034
- self.init_prim_io_names(inputs=['input_grad', 'input_x'], outputs=['output'])
2035
- validator.check_value_type("lambd", lambd, [float], self.name)
2036
- validator.check_number("lambd", lambd, 0, validator.GE, self.name)
2037
-
2038
-
2039
1965
  class CdistGrad(Primitive):
2040
1966
  """Computes gradient for Cdist."""
2041
1967
 
@@ -16,4 +16,4 @@
16
16
  """Operator of infer net"""
17
17
  # pylint: disable=unused-import
18
18
  from ..auto_generate import (QuantV2, DynamicQuantExt, QuantBatchMatmul, WeightQuantBatchMatmul, KVCacheScatterUpdate,
19
- FusedInferAttentionScore, GroupedMatmul, MoeFinalizeRouting)
19
+ FusedInferAttentionScore, GroupedMatmul, MoeFinalizeRouting, QuantLinearSparse)
@@ -17,6 +17,7 @@
17
17
  from types import FunctionType, MethodType
18
18
  from collections.abc import Iterable
19
19
  import os
20
+ import weakref
20
21
  import numpy as np
21
22
 
22
23
  from mindspore.common import Tensor
@@ -29,7 +30,7 @@ from mindspore.ops.operations.math_ops import _infer_shape_reduce
29
30
  from mindspore.ops.primitive import PrimitiveWithCheck, PrimitiveWithInfer, prim_attr_register, Primitive, \
30
31
  _run_op, _check_contains_variable
31
32
  from mindspore._c_expression import Tensor as Tensor_
32
- from mindspore._c_expression import typing
33
+ from mindspore._c_expression import typing, HookType
33
34
  from mindspore import _checkparam as validator
34
35
  from mindspore.common import dtype as mstype
35
36
  from mindspore.common.parameter import Parameter
@@ -1535,7 +1536,7 @@ class CellBackwardHook(PrimitiveWithInfer):
1535
1536
  ... print(grad)
1536
1537
  ...
1537
1538
  >>> hook = inner.CellBackwardHook()
1538
- >>> hook_fn_key = hook.register_backward_hook(hook_fn)
1539
+ >>> hook_fn_key = hook.register_backward_hook()
1539
1540
  >>> def hook_test(x, y):
1540
1541
  ... z = x * y
1541
1542
  ... z = hook(z)
@@ -1556,16 +1557,19 @@ class CellBackwardHook(PrimitiveWithInfer):
1556
1557
  (Tensor(shape=[], dtype=Float32, value= 4), Tensor(shape=[], dtype=Float32, value= 4))
1557
1558
  """
1558
1559
 
1559
- def __init__(self, cell_id=""):
1560
+ def __init__(self, cell_id="", cell=None, hook_dict=None):
1560
1561
  """Initialize CellBackwardHook"""
1561
1562
  super(CellBackwardHook, self).__init__(self.__class__.__name__)
1562
1563
  self.cell_id = cell_id
1564
+ self.cell = cell
1565
+ self.hook_dict = weakref.ref(hook_dict)
1563
1566
  self.add_prim_attr("cell_id", cell_id)
1564
- self.init_attrs["cell_id"] = cell_id
1567
+ self.grad_output = None
1565
1568
 
1566
- def __call__(self, args):
1567
- if not isinstance(args, tuple):
1568
- args = (args,)
1569
+ def __call__(self, *args):
1570
+ # If args is empty, just return.
1571
+ if not args:
1572
+ return args
1569
1573
  return _run_op(self, self.name, args)
1570
1574
 
1571
1575
  def infer_shape(self, *inputs_shape):
@@ -1578,51 +1582,76 @@ class CellBackwardHook(PrimitiveWithInfer):
1578
1582
  return inputs_type[0]
1579
1583
  return inputs_type
1580
1584
 
1581
- def register_backward_hook(self, hook_fn):
1582
- r"""
1583
- This function is used to register backward hook function. Note that this function is only supported in pynative
1584
- mode.
1585
-
1586
- Note:
1587
- The 'hook_fn' must be defined as the following code.
1588
- `cell_id` is the information of registered cell. `grad_input` is the gradient passed to the cell.
1589
- `grad_output` is the gradient computed and passed to the next cell or primitive, which may be modified by
1590
- returning a new output gradient.
1591
- The 'hook_fn' should have the following signature:
1592
- hook_fn(cell_id, grad_input, grad_output) -> New output gradient or none.
1593
- The 'hook_fn' is executed in the python environment.
1585
+ def register_backward_hook(self):
1586
+ """
1587
+ Register the backward hook function.
1594
1588
 
1595
1589
  Args:
1596
- hook_fn (Function): Python function. Backward hook function.
1590
+ None
1597
1591
 
1598
1592
  Returns:
1599
- - **key** (int) - The key of 'hook_fn'.
1593
+ None
1600
1594
 
1601
- Raises:
1602
- TypeError: If the `hook_fn` is not a function of python.
1595
+ Supported Platforms:
1596
+ ``Ascend`` ``GPU`` ``CPU``
1603
1597
  """
1604
- if not isinstance(hook_fn, (FunctionType, MethodType)):
1605
- raise TypeError(f"When using 'register_backward_hook(hook_fn)', the type of 'hook_fn' must be python "
1606
- f"function, but got {type(hook_fn)}.")
1607
- key = self.add_backward_hook_fn(hook_fn)
1608
- return key
1609
1598
 
1610
- def remove_backward_hook(self, key):
1611
- r"""
1612
- This function is used to remove backward hook function. Note that this operation is only supported in pynative
1613
- mode.
1614
-
1615
- Note:
1616
- The 'key' is the object returned by 'register_backward_hook' function of the same CellBackwardHook
1617
- operator.
1599
+ def hook_backward_grad(grad):
1600
+ if self.grad_output is None:
1601
+ self.grad_output = grad
1602
+ # Indicates the first time of call backward hook, and need to wait for the second time call
1603
+ return self.cell_id
1604
+ backward_hook_grad_input = grad
1605
+ if self.hook_dict():
1606
+ backward_hooks = self.hook_dict().values()
1607
+ for hook in backward_hooks:
1608
+ res = hook(self.cell, backward_hook_grad_input, self.grad_output)
1609
+ if res is None:
1610
+ continue
1611
+ if not isinstance(res, tuple):
1612
+ res = (res,)
1613
+ if len(res) != len(grad):
1614
+ raise TypeError(
1615
+ "The backward hook return value size is {} not equal to expect grad input size {}".format(
1616
+ len(res), len(grad)))
1617
+ backward_hook_grad_input = res
1618
+ self.grad_output = None
1619
+ return backward_hook_grad_input
1620
+
1621
+ self.set_hook_fn(hook_backward_grad, HookType.BackwardHook)
1622
+
1623
+ def register_backward_pre_hook(self):
1624
+ """
1625
+ Register the backward pre hook function.
1618
1626
 
1619
1627
  Args:
1620
- key (int): The key corresponding to the 'hook_fn'.
1628
+ None
1621
1629
 
1622
1630
  Returns:
1623
- None.
1631
+ None
1632
+
1633
+ Supported Platforms:
1634
+ ``Ascend`` ``GPU`` ``CPU``
1624
1635
  """
1625
- self.remove_backward_hook_fn(key)
1636
+
1637
+ def hook_backward_pre_grad(grad):
1638
+ backward_pre_hook_grad = grad
1639
+ if self.hook_dict():
1640
+ backward_pre_hooks = self.hook_dict().values()
1641
+ for hook in backward_pre_hooks:
1642
+ res = hook(self.cell, backward_pre_hook_grad)
1643
+ if res is None:
1644
+ continue
1645
+ if not isinstance(res, tuple):
1646
+ res = (res,)
1647
+ if len(res) != len(grad):
1648
+ raise TypeError(
1649
+ "The backward pre hook return value size is {} not equal to expect output size {}".format(
1650
+ len(res), len(grad)))
1651
+ backward_pre_hook_grad = res
1652
+ return backward_pre_hook_grad
1653
+
1654
+ self.set_hook_fn(hook_backward_pre_grad, HookType.BackwardPreHook)
1626
1655
 
1627
1656
 
1628
1657
  class Format(PrimitiveWithInfer):
@@ -2478,60 +2507,6 @@ class FFN(Primitive):
2478
2507
  validator.check_value_type("inner_precise", inner_precise, [int], cls_name)
2479
2508
 
2480
2509
 
2481
- class _MirrorSilentCheck(PrimitiveWithInfer):
2482
- """
2483
- The operator _MirrorSilentCheck implements accuracy-sensitive detection on the tensor input in backpropagator.
2484
- Call _MirrorSilentCheck in method __call__ of derived class to implement accuracy-sensitive detection.
2485
-
2486
- Inputs:
2487
- - **input** (Tensor) : The tensor used for detection.
2488
- Its data type must be mindspore.float16, mindspore.float32 or mindspore.bfloat16.
2489
- - **pre_val** (Parameter(Tensor)) : Support parameter in accuracy-sensitive detection.
2490
- Please only generated by method generate_params() of ASDBase.
2491
- - **min_val** (Parameter(Tensor)) : Support parameter in accuracy-sensitive detection.
2492
- Please only generated by method generate_params() of ASDBase.
2493
- - **max_val** (Parameter(Tensor)) : Support parameter in accuracy-sensitive detection.
2494
- Please only generated by method generate_params() of ASDBase.
2495
- - **cnt** (Parameter(Tensor)) : Support parameter in accuracy-sensitive detection.
2496
- Please only generated by method generate_params() of ASDBase.
2497
- After each invocation of _MirrorSilentCheck, increment the value of cnt by one.
2498
-
2499
- Outputs:
2500
- - **output** (Tensor) - Same shape, type and value as `input`.
2501
- """
2502
- @prim_attr_register
2503
- def __init__(self, min_steps=8):
2504
- upper_thresh, sigma_thresh = self.get_thresh()
2505
- self.min_steps = min_steps
2506
- self.thresh_l1 = upper_thresh[0]
2507
- self.coeff_l1 = sigma_thresh[0]
2508
- self.thresh_l2 = upper_thresh[1]
2509
- self.coeff_l2 = sigma_thresh[1]
2510
- self.add_prim_attr('side_effect_mem', True)
2511
-
2512
- def parse_thresh(self, env_var_name, default_value, min_value):
2513
- env_var = os.environ.get(env_var_name, default=default_value)
2514
- thresh = [value.strip() for value in env_var.split(",")]
2515
- if len(thresh) != 2 or not all(value.isdigit() for value in thresh):
2516
- thresh = default_value.split(",")
2517
- thresh = [float(max(int(value), min_value)) for value in thresh]
2518
- if thresh[0] <= thresh[1]:
2519
- thresh = [float(value) for value in default_value.split(",")]
2520
-
2521
- return thresh
2522
-
2523
- def get_thresh(self):
2524
- upper_thresh = self.parse_thresh("NPU_ASD_UPPER_THRESH", "1000000,10000", 3)
2525
- sigma_thresh = self.parse_thresh("NPU_ASD_SIGMA_THRESH", "100000,5000", 3)
2526
- return upper_thresh, sigma_thresh
2527
-
2528
- def infer_shape(self, x_shape, pre_shape, min_shape, max_shape, n_step, loss_scale_shape):
2529
- return x_shape
2530
-
2531
- def infer_dtype(self, x_dtype, pre_dtype, min_dtype, max_dtype, n_dtype, loss_scale_dtype):
2532
- return x_dtype
2533
-
2534
-
2535
2510
  class _VirtualConverterEnd(PrimitiveWithInfer):
2536
2511
  """
2537
2512
  Auto parallel virtual operator.
@@ -2560,6 +2535,8 @@ class _VirtualConverterBegin(PrimitiveWithInfer):
2560
2535
  self.output_nums = output_nums
2561
2536
 
2562
2537
  def infer_shape(self, arg):
2538
+ if self.output_nums == 0:
2539
+ return ValueError("output_nums can\'t be zero.")
2563
2540
  new_arg = (arg[0] / self.output_nums,) + tuple(arg[1:])
2564
2541
  return (new_arg,) * self.output_nums
2565
2542
 
@@ -39,11 +39,12 @@ from ..auto_generate import (ExpandDims, Reshape, TensorShape, Transpose, Gather
39
39
  OnesLike, ZerosLike, Argmax, ArgMaxExt,
40
40
  ReverseV2, Diag, Eye, ScatterNd, ResizeNearestNeighborV2,
41
41
  GatherNd, GatherD, Range, MaskedFill, RightShift, NonZero,
42
- ResizeNearestNeighbor, Identity, Split, CumSum, CumProd,
42
+ ResizeNearestNeighbor, Identity, Split, CumSum, CumProd, MaskedSelect,
43
43
  Cummax, Cummin, Argmin, Concat, UnsortedSegmentSum, ScalarToTensor,
44
44
  Triu, BroadcastTo, StridedSlice, Select, TopkExt, SearchSorted)
45
45
  from .manually_defined import Rank, Shape, Tile, Cast, Ones, Zeros
46
46
  from ..auto_generate import ArgMaxWithValue, ArgMinWithValue
47
+ from ..auto_generate import TensorScatterElements as TensorScatterElementsExt
47
48
 
48
49
  class _ScatterOp(PrimitiveWithInfer):
49
50
  """
@@ -769,41 +770,13 @@ class Padding(Primitive):
769
770
 
770
771
  class UniqueWithPad(Primitive):
771
772
  """
772
- Returns unique elements and relative indexes in 1-D tensor, filled with padding num.
773
-
774
- The basic function is the same as the Unique operator, but the UniqueWithPad operator adds a Pad function.
775
- The returned tuple(`y`, `idx`) after the input Tensor `x` is processed by the unique operator,
776
- in which the shapes of `y` and `idx` are mostly not equal. Therefore, in order to solve the above situation,
777
- the UniqueWithPad operator will fill the `y` Tensor with the `pad_num` specified by the user
778
- to make it have the same shape as the Tensor `idx`.
779
-
780
- Refer to :func:`mindspore.ops.unique_with_pad` for more details.
781
-
782
- Inputs:
783
- - **x** (Tensor) - The tensor need to be unique. Must be 1-D vector with types: int32, int64.
784
- - **pad_num** (int) - Pad num. The data type is an int.
785
-
786
- Outputs:
787
- tuple(Tensor), tuple of 2 tensors, `y` and `idx`.
788
-
789
- - y (Tensor) - The unique elements filled with pad_num, the shape and data type same as `x`.
790
- - idx (Tensor) - The index of each value of `x` in the unique output `y`, the shape and data type same as `x`.
773
+ 'ops.UniqueWithPad' is deprecated from version 2.4 and will be removed in a future version.
791
774
 
792
775
  Supported Platforms:
793
- ``Ascend`` ``GPU`` ``CPU``
794
-
795
- Examples:
796
- >>> import mindspore
797
- >>> import numpy as np
798
- >>> from mindspore import Tensor, ops
799
- >>> x = Tensor(np.array([1, 1, 2, 2, 3, 3, 4, 5]), mindspore.int32)
800
- >>> pad_num = 8
801
- >>> output = ops.UniqueWithPad()(x, pad_num)
802
- >>> print(output)
803
- (Tensor(shape=[8], dtype=Int32, value= [1, 2, 3, 4, 5, 8, 8, 8]),
804
- Tensor(shape=[8], dtype=Int32, value= [0, 0, 1, 1, 2, 2, 3, 4]))
776
+ Deprecated
805
777
  """
806
778
 
779
+ @deprecated("2.4", "ops.Unique and ops.PadV3", False)
807
780
  @prim_attr_register
808
781
  def __init__(self):
809
782
  """init UniqueWithPad"""
@@ -819,7 +792,7 @@ class Size(Primitive):
819
792
 
820
793
  Inputs:
821
794
  - **input_x** (Tensor) - Input parameters, the shape of tensor is :math:`(x_1, x_2, ..., x_R)`. The data type is
822
- `number <https://www.mindspore.cn/docs/en/master/api_python/mindspore.html#mindspore.dtype>`_.
795
+ `number <https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.dtype.html>`_.
823
796
 
824
797
  Outputs:
825
798
  int. A scalar representing the elements' size of `input_x`, tensor is the number of elements
@@ -2112,60 +2085,6 @@ class Rint(Primitive):
2112
2085
  self.init_prim_io_names(inputs=['x'], outputs=['output'])
2113
2086
 
2114
2087
 
2115
- class StridedSliceV2(Primitive):
2116
- r"""
2117
- StridedSliceV2 will be deprecated by StridedSlice in the future.
2118
- Extracts a strided slice of a tensor.
2119
- Refer to class StridedSlice for more details.
2120
-
2121
- Args:
2122
- begin_mask (int): Starting index of the slice. Default: ``0`` .
2123
- end_mask (int): Ending index of the slice. Default: ``0`` .
2124
- ellipsis_mask (int): An int mask. Default: ``0`` .
2125
- new_axis_mask (int): An int mask. Default: ``0`` .
2126
- shrink_axis_mask (int): An int mask. Default: ``0`` .
2127
-
2128
- Inputs:
2129
- - **input_x** (Tensor) - The input Tensor.
2130
- - **begin** (tuple[int]) - A tuple which represents the location where to start. Only
2131
- constant value is allowed.
2132
- - **end** (tuple[int]) - A tuple or which represents the maximum location where to end.
2133
- Only constant value is allowed.
2134
- - **strides** (tuple[int]) - A tuple which represents the stride is continuously added
2135
- before reaching the maximum location. Only constant value is allowed.
2136
-
2137
- Outputs:
2138
- Tensor, The output is explained by following example.
2139
-
2140
- Raises:
2141
- TypeError: If `begin_mask`, `end_mask`, `ellipsis_mask`, `new_axis_mask` or `shrink_axis_mask` is not an int.
2142
- TypeError: If `begin`, `end` or `strides` is not a tuple.
2143
- ValueError: If `begin_mask`, `end_mask`, `ellipsis_mask`, `new_axis_mask` or `shrink_axis_mask` is less than 0.
2144
-
2145
- Supported Platforms:
2146
- ``Ascend`` ``CPU``
2147
-
2148
- Examples:
2149
- >>> input_x = Tensor([[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]],
2150
- ... [[5, 5, 5], [6, 6, 6]]], mindspore.float32)
2151
- >>> strided_slice_v2 = ops.StridedSliceV2()
2152
- >>> output = strided_slice_v2(input_x, (1, 0, 2), (3, 1, 3), (1, 1, 1))
2153
- >>> print(output)
2154
- [[[3.]]
2155
- [[5.]]]
2156
- """
2157
-
2158
- @prim_attr_register
2159
- def __init__(self,
2160
- begin_mask=0,
2161
- end_mask=0,
2162
- ellipsis_mask=0,
2163
- new_axis_mask=0,
2164
- shrink_axis_mask=0):
2165
- """Initialize StridedSliceV2"""
2166
- self.init_prim_io_names(inputs=['x', 'begin', 'end', 'strides'], outputs=['output'])
2167
-
2168
-
2169
2088
  class DiagPart(PrimitiveWithCheck):
2170
2089
  r"""
2171
2090
 
@@ -4356,53 +4275,6 @@ class MaskedScatter(Primitive):
4356
4275
  self.init_prim_io_names(inputs=['x', 'mask', 'updates'], outputs=['y'])
4357
4276
 
4358
4277
 
4359
- class MaskedSelect(PrimitiveWithCheck):
4360
- """
4361
- Returns a new 1-D Tensor which indexes the `x` tensor according to the boolean `mask`.
4362
- The shapes of the `mask` tensor and the `x` tensor don't need to match, but they must be broadcastable.
4363
-
4364
- Inputs:
4365
- - **x** (Tensor) - Input Tensor of any dimension.
4366
- - **mask** (Tensor[bool]) - Boolean mask Tensor, has the same shape as `x`.
4367
-
4368
- Outputs:
4369
- A 1-D Tensor, with the same type as x.
4370
-
4371
- Raises:
4372
- TypeError: If `x` or `mask` is not a Tensor.
4373
- TypeError: If dtype of `mask` is not bool.
4374
-
4375
- Supported Platforms:
4376
- ``Ascend`` ``GPU`` ``CPU``
4377
-
4378
- Examples:
4379
- >>> import mindspore
4380
- >>> import numpy as np
4381
- >>> from mindspore import Tensor, ops
4382
- >>> x = Tensor(np.array([1, 2, 3, 4]), mindspore.int32)
4383
- >>> mask = Tensor(np.array([1, 0, 1, 0]), mindspore.bool_)
4384
- >>> output = ops.MaskedSelect()(x, mask)
4385
- >>> print(output)
4386
- [1 3]
4387
- >>> x = Tensor(2.1, mindspore.float32)
4388
- >>> mask = Tensor(True, mindspore.bool_)
4389
- >>> output = ops.MaskedSelect()(x, mask)
4390
- >>> print(output)
4391
- [2.1]
4392
- """
4393
-
4394
- @prim_attr_register
4395
- def __init__(self):
4396
- self.init_prim_io_names(inputs=['x', 'mask'], outputs=['output'])
4397
-
4398
- def check_shape(self, x_shape, mask_shape):
4399
- get_broadcast_shape(x_shape, mask_shape, self.name, arg_name1="x", arg_name2="mask")
4400
-
4401
- def check_dtype(self, x_dtype, mask_dtype):
4402
- validator.check_tensor_dtype_valid('mask', mask_dtype, [mstype.bool_], self.name)
4403
- validator.check_tensor_dtype_valid('x', x_dtype, (mstype.bool_,) + mstype.number_type, self.name)
4404
-
4405
-
4406
4278
  class _TensorScatterOp(PrimitiveWithInfer):
4407
4279
  """
4408
4280
  Defines TensorScatter Base Operators
@@ -4962,7 +4834,7 @@ class SplitV(Primitive):
4962
4834
  self.init_prim_io_names(inputs=['input_x'], outputs=['output'])
4963
4835
 
4964
4836
 
4965
- class TensorScatterElements(Primitive):
4837
+ class TensorScatterElements(TensorScatterElementsExt):
4966
4838
  """
4967
4839
  Write all elements in `updates` to the index specified by `indices` in `input_x` according to the reduction
4968
4840
  operation specified by `reduction`.
@@ -4977,6 +4849,9 @@ class TensorScatterElements(Primitive):
4977
4849
  .. warning::
4978
4850
  This is an experimental API that is subject to change or deletion.
4979
4851
 
4852
+ Note:
4853
+ The backward is supported only for the case `updates.shape == indices.shape`.
4854
+
4980
4855
  Args:
4981
4856
  axis (int, optional): Specify which axis to do scatter operation. Default: ``0`` .
4982
4857
  reduction (str, optional): Which reduction operation to scatter, default is ``"none"`` . Other option: "add".
@@ -4986,7 +4861,7 @@ class TensorScatterElements(Primitive):
4986
4861
  - **indices** (Tensor) - The index of `input_x` to do scatter operation whose data type must be int32 or
4987
4862
  int64. It has the same rank as `data`. And accepted range is [-s, s) where s is the size along axis.
4988
4863
  - **updates** (Tensor) - The tensor doing the scatter operation with `data`,
4989
- it has the same type as `data` and the same shape as `indices`.
4864
+ it has the same type as `data`.
4990
4865
 
4991
4866
  Outputs:
4992
4867
  Tensor, has the same shape and type as `data`.
@@ -5021,16 +4896,7 @@ class TensorScatterElements(Primitive):
5021
4896
 
5022
4897
  @prim_attr_register
5023
4898
  def __init__(self, axis=0, reduction="none"):
5024
- """Initialize TensorScatterElements"""
5025
- validator.check_value_type("axis", axis, [int], self.name)
5026
- validator.check_value_type("reduction", reduction, [str], self.name)
5027
- validator.check_string(reduction, ["none", "add"], "reduction", self.name)
5028
- self.init_prim_io_names(inputs=['data', 'indices', 'updates'], outputs=['y'])
5029
- target = context.get_context("device_target")
5030
- if reduction != 'none' and target.lower() == "ascend":
5031
- raise ValueError(f"For '{self.name}', "
5032
- f"Currently Ascend device_target only support `reduction`='none', "
5033
- f"but got {reduction}")
4899
+ super().__init__(axis, reduce=reduction)
5034
4900
 
5035
4901
 
5036
4902
  class ExtractVolumePatches(Primitive):