mindspore 2.3.0__cp39-cp39-win_amd64.whl → 2.4.1__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (287) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +3 -1
  3. mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
  4. mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
  5. mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
  6. mindspore/_checkparam.py +50 -9
  7. mindspore/_extends/parse/compile_config.py +41 -0
  8. mindspore/_extends/parse/parser.py +9 -7
  9. mindspore/_extends/parse/standard_method.py +52 -14
  10. mindspore/_extends/pijit/pijit_func_white_list.py +350 -24
  11. mindspore/amp.py +24 -10
  12. mindspore/avcodec-59.dll +0 -0
  13. mindspore/avdevice-59.dll +0 -0
  14. mindspore/avfilter-8.dll +0 -0
  15. mindspore/avformat-59.dll +0 -0
  16. mindspore/avutil-57.dll +0 -0
  17. mindspore/common/__init__.py +6 -4
  18. mindspore/common/_pijit_context.py +190 -0
  19. mindspore/common/_register_for_tensor.py +2 -1
  20. mindspore/common/_tensor_overload.py +139 -0
  21. mindspore/common/api.py +102 -87
  22. mindspore/common/dump.py +5 -6
  23. mindspore/common/generator.py +1 -7
  24. mindspore/common/hook_handle.py +14 -26
  25. mindspore/common/initializer.py +51 -15
  26. mindspore/common/mindir_util.py +2 -2
  27. mindspore/common/parameter.py +62 -15
  28. mindspore/common/recompute.py +39 -9
  29. mindspore/common/sparse_tensor.py +7 -3
  30. mindspore/common/tensor.py +183 -37
  31. mindspore/communication/__init__.py +1 -1
  32. mindspore/communication/_comm_helper.py +38 -3
  33. mindspore/communication/comm_func.py +315 -60
  34. mindspore/communication/management.py +14 -14
  35. mindspore/context.py +132 -22
  36. mindspore/dataset/__init__.py +1 -1
  37. mindspore/dataset/audio/__init__.py +1 -1
  38. mindspore/dataset/core/config.py +7 -0
  39. mindspore/dataset/core/validator_helpers.py +7 -0
  40. mindspore/dataset/engine/cache_client.py +1 -1
  41. mindspore/dataset/engine/datasets.py +72 -44
  42. mindspore/dataset/engine/datasets_audio.py +7 -7
  43. mindspore/dataset/engine/datasets_standard_format.py +53 -3
  44. mindspore/dataset/engine/datasets_text.py +20 -20
  45. mindspore/dataset/engine/datasets_user_defined.py +174 -104
  46. mindspore/dataset/engine/datasets_vision.py +33 -33
  47. mindspore/dataset/engine/iterators.py +29 -0
  48. mindspore/dataset/engine/obs/util.py +7 -0
  49. mindspore/dataset/engine/queue.py +114 -60
  50. mindspore/dataset/engine/serializer_deserializer.py +2 -2
  51. mindspore/dataset/engine/validators.py +34 -14
  52. mindspore/dataset/text/__init__.py +1 -4
  53. mindspore/dataset/transforms/__init__.py +0 -3
  54. mindspore/dataset/utils/line_reader.py +2 -0
  55. mindspore/dataset/vision/__init__.py +1 -4
  56. mindspore/dataset/vision/utils.py +1 -1
  57. mindspore/dataset/vision/validators.py +2 -1
  58. mindspore/dnnl.dll +0 -0
  59. mindspore/{nn/extend → experimental/es}/__init__.py +4 -11
  60. mindspore/experimental/es/embedding_service.py +883 -0
  61. mindspore/{nn/layer → experimental/es}/embedding_service_layer.py +218 -30
  62. mindspore/experimental/llm_boost/__init__.py +21 -0
  63. mindspore/{nn/extend/layer → experimental/llm_boost/atb}/__init__.py +4 -8
  64. mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
  65. mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
  66. mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
  67. mindspore/experimental/llm_boost/register.py +129 -0
  68. mindspore/experimental/llm_boost/utils.py +31 -0
  69. mindspore/experimental/optim/adamw.py +85 -0
  70. mindspore/experimental/optim/optimizer.py +3 -0
  71. mindspore/hal/__init__.py +3 -3
  72. mindspore/hal/contiguous_tensors_handle.py +175 -0
  73. mindspore/hal/stream.py +18 -0
  74. mindspore/include/api/model_group.h +13 -1
  75. mindspore/include/api/types.h +10 -10
  76. mindspore/include/dataset/config.h +2 -2
  77. mindspore/include/dataset/constants.h +2 -2
  78. mindspore/include/dataset/execute.h +2 -2
  79. mindspore/include/dataset/vision.h +4 -0
  80. mindspore/jpeg62.dll +0 -0
  81. mindspore/log.py +1 -1
  82. mindspore/mindrecord/filewriter.py +68 -51
  83. mindspore/mindspore_backend.dll +0 -0
  84. mindspore/mindspore_common.dll +0 -0
  85. mindspore/mindspore_core.dll +0 -0
  86. mindspore/mindspore_glog.dll +0 -0
  87. mindspore/mindspore_np_dtype.dll +0 -0
  88. mindspore/mindspore_ops.dll +0 -0
  89. mindspore/mint/__init__.py +983 -46
  90. mindspore/mint/distributed/__init__.py +31 -0
  91. mindspore/mint/distributed/distributed.py +254 -0
  92. mindspore/mint/nn/__init__.py +268 -23
  93. mindspore/mint/nn/functional.py +125 -19
  94. mindspore/mint/nn/layer/__init__.py +39 -0
  95. mindspore/mint/nn/layer/activation.py +133 -0
  96. mindspore/mint/nn/layer/normalization.py +477 -0
  97. mindspore/mint/nn/layer/pooling.py +110 -0
  98. mindspore/mint/optim/adamw.py +26 -13
  99. mindspore/mint/special/__init__.py +63 -0
  100. mindspore/multiprocessing/__init__.py +2 -1
  101. mindspore/nn/__init__.py +0 -1
  102. mindspore/nn/cell.py +276 -96
  103. mindspore/nn/layer/activation.py +211 -44
  104. mindspore/nn/layer/basic.py +137 -10
  105. mindspore/nn/layer/embedding.py +137 -2
  106. mindspore/nn/layer/normalization.py +101 -5
  107. mindspore/nn/layer/padding.py +34 -48
  108. mindspore/nn/layer/pooling.py +161 -7
  109. mindspore/nn/layer/transformer.py +3 -3
  110. mindspore/nn/loss/__init__.py +2 -2
  111. mindspore/nn/loss/loss.py +84 -6
  112. mindspore/nn/optim/__init__.py +2 -1
  113. mindspore/nn/optim/adadelta.py +1 -1
  114. mindspore/nn/optim/adam.py +1 -1
  115. mindspore/nn/optim/lamb.py +1 -1
  116. mindspore/nn/optim/tft_wrapper.py +124 -0
  117. mindspore/nn/wrap/cell_wrapper.py +12 -23
  118. mindspore/nn/wrap/grad_reducer.py +5 -5
  119. mindspore/nn/wrap/loss_scale.py +17 -3
  120. mindspore/numpy/__init__.py +1 -1
  121. mindspore/numpy/array_creations.py +65 -68
  122. mindspore/numpy/array_ops.py +64 -60
  123. mindspore/numpy/fft.py +610 -75
  124. mindspore/numpy/logic_ops.py +11 -10
  125. mindspore/numpy/math_ops.py +85 -84
  126. mindspore/numpy/utils_const.py +4 -4
  127. mindspore/opencv_core452.dll +0 -0
  128. mindspore/opencv_imgcodecs452.dll +0 -0
  129. mindspore/opencv_imgproc452.dll +0 -0
  130. mindspore/ops/__init__.py +6 -4
  131. mindspore/ops/_grad_experimental/grad_array_ops.py +0 -11
  132. mindspore/ops/_grad_experimental/grad_comm_ops.py +67 -4
  133. mindspore/ops/_grad_experimental/grad_math_ops.py +0 -22
  134. mindspore/ops/_vmap/vmap_array_ops.py +2 -4
  135. mindspore/ops/_vmap/vmap_math_ops.py +17 -1
  136. mindspore/ops/_vmap/vmap_nn_ops.py +43 -2
  137. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +91 -7
  138. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +2 -0
  139. mindspore/ops/auto_generate/gen_extend_func.py +767 -13
  140. mindspore/ops/auto_generate/gen_ops_def.py +2452 -364
  141. mindspore/ops/auto_generate/gen_ops_prim.py +5442 -1756
  142. mindspore/ops/auto_generate/pyboost_inner_prim.py +176 -56
  143. mindspore/ops/composite/base.py +85 -48
  144. mindspore/ops/composite/multitype_ops/_compile_utils.py +1 -0
  145. mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -2
  146. mindspore/ops/function/__init__.py +22 -0
  147. mindspore/ops/function/array_func.py +492 -153
  148. mindspore/ops/function/debug_func.py +113 -1
  149. mindspore/ops/function/fft_func.py +15 -2
  150. mindspore/ops/function/grad/grad_func.py +3 -2
  151. mindspore/ops/function/math_func.py +564 -207
  152. mindspore/ops/function/nn_func.py +817 -383
  153. mindspore/ops/function/other_func.py +3 -2
  154. mindspore/ops/function/random_func.py +402 -12
  155. mindspore/ops/function/reshard_func.py +13 -11
  156. mindspore/ops/function/sparse_unary_func.py +1 -1
  157. mindspore/ops/function/vmap_func.py +3 -2
  158. mindspore/ops/functional.py +24 -14
  159. mindspore/ops/op_info_register.py +3 -3
  160. mindspore/ops/operations/__init__.py +7 -2
  161. mindspore/ops/operations/_grad_ops.py +2 -76
  162. mindspore/ops/operations/_infer_ops.py +1 -1
  163. mindspore/ops/operations/_inner_ops.py +71 -94
  164. mindspore/ops/operations/array_ops.py +14 -146
  165. mindspore/ops/operations/comm_ops.py +63 -53
  166. mindspore/ops/operations/custom_ops.py +83 -19
  167. mindspore/ops/operations/debug_ops.py +42 -10
  168. mindspore/ops/operations/manually_defined/_inner.py +12 -0
  169. mindspore/ops/operations/manually_defined/ops_def.py +273 -20
  170. mindspore/ops/operations/math_ops.py +12 -223
  171. mindspore/ops/operations/nn_ops.py +20 -114
  172. mindspore/ops/operations/other_ops.py +7 -4
  173. mindspore/ops/operations/random_ops.py +46 -1
  174. mindspore/ops/primitive.py +18 -6
  175. mindspore/ops_generate/arg_dtype_cast.py +2 -0
  176. mindspore/ops_generate/gen_aclnn_implement.py +11 -11
  177. mindspore/ops_generate/gen_constants.py +36 -0
  178. mindspore/ops_generate/gen_ops.py +67 -52
  179. mindspore/ops_generate/gen_ops_inner_prim.py +1 -1
  180. mindspore/ops_generate/gen_pyboost_func.py +131 -47
  181. mindspore/ops_generate/op_proto.py +10 -3
  182. mindspore/ops_generate/pyboost_utils.py +14 -1
  183. mindspore/ops_generate/template.py +43 -21
  184. mindspore/parallel/__init__.py +3 -1
  185. mindspore/parallel/_auto_parallel_context.py +31 -9
  186. mindspore/parallel/_cell_wrapper.py +85 -0
  187. mindspore/parallel/_parallel_serialization.py +47 -19
  188. mindspore/parallel/_tensor.py +127 -13
  189. mindspore/parallel/_utils.py +53 -22
  190. mindspore/parallel/algo_parameter_config.py +5 -5
  191. mindspore/parallel/checkpoint_transform.py +46 -39
  192. mindspore/parallel/cluster/process_entity/__init__.py +1 -1
  193. mindspore/parallel/cluster/process_entity/_api.py +31 -23
  194. mindspore/parallel/cluster/process_entity/_utils.py +2 -27
  195. mindspore/parallel/parameter_broadcast.py +3 -4
  196. mindspore/parallel/shard.py +162 -31
  197. mindspore/parallel/transform_safetensors.py +1146 -0
  198. mindspore/profiler/__init__.py +2 -1
  199. mindspore/profiler/common/constant.py +29 -0
  200. mindspore/profiler/common/registry.py +47 -0
  201. mindspore/profiler/common/util.py +28 -0
  202. mindspore/profiler/dynamic_profiler.py +694 -0
  203. mindspore/profiler/envprofiling.py +17 -19
  204. mindspore/profiler/parser/ascend_analysis/constant.py +18 -0
  205. mindspore/profiler/parser/ascend_analysis/file_manager.py +25 -4
  206. mindspore/profiler/parser/ascend_analysis/function_event.py +43 -19
  207. mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +31 -26
  208. mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +56 -10
  209. mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +55 -8
  210. mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
  211. mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +27 -20
  212. mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +9 -2
  213. mindspore/profiler/parser/ascend_msprof_exporter.py +5 -4
  214. mindspore/profiler/parser/ascend_timeline_generator.py +27 -25
  215. mindspore/profiler/parser/base_timeline_generator.py +19 -25
  216. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +25 -12
  217. mindspore/profiler/parser/framework_parser.py +1 -391
  218. mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
  219. mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
  220. mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
  221. mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
  222. mindspore/profiler/parser/memory_usage_parser.py +0 -154
  223. mindspore/profiler/parser/profiler_info.py +78 -6
  224. mindspore/profiler/profiler.py +153 -0
  225. mindspore/profiler/profiling.py +285 -413
  226. mindspore/rewrite/__init__.py +1 -2
  227. mindspore/rewrite/common/namespace.py +4 -4
  228. mindspore/rewrite/symbol_tree/symbol_tree.py +3 -3
  229. mindspore/run_check/_check_version.py +39 -104
  230. mindspore/safeguard/rewrite_obfuscation.py +591 -247
  231. mindspore/swresample-4.dll +0 -0
  232. mindspore/swscale-6.dll +0 -0
  233. mindspore/tinyxml2.dll +0 -0
  234. mindspore/train/__init__.py +4 -3
  235. mindspore/train/_utils.py +105 -19
  236. mindspore/train/amp.py +171 -53
  237. mindspore/train/callback/__init__.py +2 -2
  238. mindspore/train/callback/_callback.py +4 -4
  239. mindspore/train/callback/_checkpoint.py +97 -31
  240. mindspore/train/callback/_cluster_monitor.py +1 -1
  241. mindspore/train/callback/_flops_collector.py +1 -0
  242. mindspore/train/callback/_loss_monitor.py +3 -3
  243. mindspore/train/callback/_on_request_exit.py +145 -31
  244. mindspore/train/callback/_summary_collector.py +5 -5
  245. mindspore/train/callback/_tft_register.py +375 -0
  246. mindspore/train/dataset_helper.py +15 -3
  247. mindspore/train/metrics/metric.py +3 -3
  248. mindspore/train/metrics/roc.py +4 -4
  249. mindspore/train/mind_ir_pb2.py +44 -39
  250. mindspore/train/model.py +154 -58
  251. mindspore/train/serialization.py +342 -128
  252. mindspore/turbojpeg.dll +0 -0
  253. mindspore/utils/__init__.py +21 -0
  254. mindspore/utils/utils.py +60 -0
  255. mindspore/version.py +1 -1
  256. {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/METADATA +13 -7
  257. {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/RECORD +260 -254
  258. {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/WHEEL +1 -1
  259. mindspore/include/c_api/ms/abstract.h +0 -67
  260. mindspore/include/c_api/ms/attribute.h +0 -197
  261. mindspore/include/c_api/ms/base/handle_types.h +0 -43
  262. mindspore/include/c_api/ms/base/macros.h +0 -32
  263. mindspore/include/c_api/ms/base/status.h +0 -33
  264. mindspore/include/c_api/ms/base/types.h +0 -283
  265. mindspore/include/c_api/ms/context.h +0 -102
  266. mindspore/include/c_api/ms/graph.h +0 -160
  267. mindspore/include/c_api/ms/node.h +0 -606
  268. mindspore/include/c_api/ms/tensor.h +0 -161
  269. mindspore/include/c_api/ms/value.h +0 -84
  270. mindspore/mindspore_shared_lib.dll +0 -0
  271. mindspore/nn/extend/basic.py +0 -140
  272. mindspore/nn/extend/embedding.py +0 -143
  273. mindspore/nn/extend/layer/normalization.py +0 -109
  274. mindspore/nn/extend/pooling.py +0 -117
  275. mindspore/nn/layer/embedding_service.py +0 -531
  276. mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +0 -93
  277. mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +0 -66
  278. mindspore/ops/extend/__init__.py +0 -53
  279. mindspore/ops/extend/array_func.py +0 -218
  280. mindspore/ops/extend/math_func.py +0 -76
  281. mindspore/ops/extend/nn_func.py +0 -308
  282. mindspore/ops/silent_check.py +0 -162
  283. mindspore/profiler/parser/msadvisor_analyzer.py +0 -82
  284. mindspore/profiler/parser/msadvisor_parser.py +0 -240
  285. mindspore/train/callback/_mindio_ttp.py +0 -443
  286. {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/entry_points.txt +0 -0
  287. {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/top_level.txt +0 -0
@@ -28,6 +28,8 @@ import subprocess
28
28
  import numpy as np
29
29
  import mindspore as ms
30
30
  from mindspore._c_expression import Oplib, typing
31
+ from mindspore._c_expression import pyboost_custom_ext
32
+ from mindspore.common._stub_tensor import _convert_stub
31
33
  from mindspore import context
32
34
  from mindspore.common import Tensor
33
35
  from mindspore.common import dtype as mstype
@@ -156,6 +158,55 @@ def _compile_aot(file):
156
158
  return func_path
157
159
 
158
160
 
161
+ class _CustomExt(ops.PrimitiveWithInfer):
162
+ """
163
+ `Custom` primitive is used for PyBoost.
164
+ """
165
+
166
+ def __init__(self, func, out_shape=None, out_dtype=None, bprop=None):
167
+ super().__init__("CustomExt")
168
+ self.func = func
169
+ self.out_shape = out_shape
170
+ self.out_dtype = out_dtype
171
+ self.bprop = bprop
172
+
173
+ def __infer__(self, *args):
174
+ if callable(self.out_shape):
175
+ infer_shape = self.out_shape(*(x["shape"] for x in args))
176
+ else:
177
+ infer_shape = self.out_shape
178
+
179
+ if callable(self.out_dtype):
180
+ infer_dtype = self.out_dtype(*(x["dtype"] for x in args))
181
+ else:
182
+ infer_dtype = self.out_dtype
183
+
184
+ infer_value = None
185
+ if infer_shape is None:
186
+ logger.warning("'out_shape' is None. Add a placeholder instead. "
187
+ "A CPP version of infer shape function is required "
188
+ "in this case.")
189
+ infer_shape = (1,)
190
+ # after all automatic infer information fulfillment, throw error if infer_shape/infer_dtype is still None
191
+ if not isinstance(infer_shape, (tuple, list)):
192
+ raise TypeError("'out_shape' must be one of [tuple, list, function], but got {}".format(type(infer_shape)))
193
+
194
+ if not isinstance(infer_dtype, (typing.Type, tuple, list)):
195
+ raise TypeError("'out_dtype' must be one of [mindspore.dtype, tuple, list, function], but got {}"
196
+ .format(type(infer_dtype)))
197
+
198
+ out = {
199
+ "shape": infer_shape,
200
+ "dtype": infer_dtype,
201
+ "value": infer_value,
202
+ }
203
+ return out
204
+
205
+ def get_bprop(self):
206
+ """return back propagation function"""
207
+ return self.bprop
208
+
209
+
159
210
  class Custom(ops.PrimitiveWithInfer):
160
211
  r"""
161
212
  `Custom` primitive is used for user defined operators and is to enhance the expressive ability of built-in
@@ -164,7 +215,7 @@ class Custom(ops.PrimitiveWithInfer):
164
215
  function if needed. Then these `Custom` objects can be directly used in neural networks.
165
216
  Detailed description and introduction of user-defined operators, including correct writing of parameters,
166
217
  please refer to `Custom Operators Tutorial
167
- <https://www.mindspore.cn/tutorials/experts/en/master/operation/op_custom.html>`_ .
218
+ <https://www.mindspore.cn/docs/en/master/model_train/custom_program/op_custom.html>`_ .
168
219
 
169
220
  .. warning::
170
221
  - This is an experimental API that is subject to change.
@@ -174,7 +225,7 @@ class Custom(ops.PrimitiveWithInfer):
174
225
 
175
226
  - "hybrid": supports ["GPU", "CPU"].
176
227
  - "akg": supports ["GPU", "CPU"].
177
- - "aot": supports ["GPU", "CPU", "ASCEDN"].
228
+ - "aot": supports ["GPU", "CPU", "Ascend"].
178
229
  - "pyfunc": supports ["CPU"].
179
230
  - "julia": supports ["CPU"].
180
231
 
@@ -249,20 +300,18 @@ class Custom(ops.PrimitiveWithInfer):
249
300
  (ex. Custom(func="./reorganize.so:CustomReorganize", out_shape=[1], out_dtype=mstype.float32,
250
301
  "aot"))
251
302
 
252
- b) ASCEND platform
253
- Before using Custom operators on the ASCEND platform, users must first develop custom operators
254
- based on Ascend C and compile them. For operator development, you can refer to the tutorial on
255
- `Quick Start for End-to-End Operator Development
256
- <https://www.hiascend.com/document/detail/zh/canncommercial/70RC1/operatordev/Ascendcopdevg/atlas_ascendc_10_0022.html>`_,
257
- and for compiling custom operators, you can use the `Offline Compilation of Ascend C Custom Operators
258
- <https://www.mindspore.cn/tutorials/experts/en/master/operation/op_custom_ascendc.html>` tool.
259
- When passing the operator's name into the func parameter, taking AddCustom as an example for the
260
- name given in the custom operator implementation, there are several ways to use it:
261
-
262
- - Usin TBE: func="AddCustom"
263
- - Using AclNN: func="aclnnAddCustom"
264
- - Inferring the shape of the operator through C++ derivation: func="infer_shape.cc:aclnnAddCustom",
265
- where infer_shape.cc is the shape derivation implemented in C++.
303
+ b) Ascend platform.
304
+ Before using Custom operators on the Ascend platform, users must first develop custom operators
305
+ based on Ascend C and compile them. The complete development and usage process can refer to the
306
+ tutorial `AOT-Type Custom Operators(Ascend) <https://www.mindspore.cn/docs/en/master/model_train/custom_program/operation/op_custom_ascendc.html>`_.
307
+ By passing the name of the operator through the input parameter `func`, there are two usage methods
308
+ based on the implementation of the infer shape function:
309
+
310
+ - Python infer: If the operator's infer shape is implemented in Python, that is, the infer shape
311
+ function is passed through the `out_shape` parameter, specify `func="CustomName"` .
312
+ - C++ infer: If the operator's infer shape is implemented through C++, then pass the path of the
313
+ infer shape implementation file in `func` and separate the operator name with `:`,
314
+ for example: `func="add_custom_infer.cc:AddCustom"` .
266
315
 
267
316
  2. for "julia":
268
317
 
@@ -338,7 +387,7 @@ class Custom(ops.PrimitiveWithInfer):
338
387
  or the attributes of `func` differs in different targets.
339
388
 
340
389
  Supported Platforms:
341
- ``GPU`` ``CPU`` ``ASCEND``
390
+ ``Ascend`` ``GPU`` ``CPU``
342
391
 
343
392
  Examples:
344
393
  >>> import numpy as np
@@ -457,6 +506,12 @@ class Custom(ops.PrimitiveWithInfer):
457
506
 
458
507
  self.add_prim_attr("func_type", self.func_type)
459
508
  self._update_attr()
509
+ self.enable_pyboost = False
510
+ self.custom_pyboost = _CustomExt(self.func, self.out_shape, self.out_dtype, self.bprop)
511
+ if context.get_context("device_target") == "Ascend" and self.func_type == "aot":
512
+ self.enable_pyboost = True
513
+ for key, value in super().get_attr_dict().items():
514
+ self.custom_pyboost.add_prim_attr(key, value)
460
515
 
461
516
  def __infer__(self, *args):
462
517
  if callable(self.out_shape):
@@ -559,7 +614,7 @@ class Custom(ops.PrimitiveWithInfer):
559
614
  raise TypeError(
560
615
  "{}, 'func' should be like 'file_name:func_name', but got {}".format(
561
616
  self.log_prefix, self.func))
562
- file_path = os.path.abspath(file_name_list[0])
617
+ file_path = os.path.realpath(file_name_list[0])
563
618
  if os.environ.get('MS_CUSTOM_AOT_WHITE_LIST') is None:
564
619
  if Custom.custom_aot_warning:
565
620
  logger.info("{}, no white list is set and it might cause problems. "
@@ -567,7 +622,7 @@ class Custom(ops.PrimitiveWithInfer):
567
622
  .format(self.log_prefix))
568
623
  Custom.custom_aot_warning = False
569
624
  else:
570
- legal_path = os.path.abspath(os.environ.get('MS_CUSTOM_AOT_WHITE_LIST'))
625
+ legal_path = os.path.realpath(os.environ.get('MS_CUSTOM_AOT_WHITE_LIST'))
571
626
  if legal_path not in file_path:
572
627
  raise TypeError(
573
628
  "{}, the legal path for the file is {}, but the file is {}".format(
@@ -1063,3 +1118,12 @@ class Custom(ops.PrimitiveWithInfer):
1063
1118
  infer_value = Tensor(fake_output) if enable_infer_value else None
1064
1119
 
1065
1120
  return infer_shape, infer_dtype, infer_value
1121
+
1122
+ def __call__(self, *args):
1123
+ if self.enable_pyboost:
1124
+ return _convert_stub(pyboost_custom_ext(self.custom_pyboost, [args]))
1125
+ should_elim, output = self.check_elim(*args)
1126
+ if should_elim:
1127
+ return output
1128
+ # pylint: disable=protected-access
1129
+ return ops.primitive._run_op(self, self.name, args)
@@ -15,17 +15,17 @@
15
15
  """debug_ops"""
16
16
  import os
17
17
  import stat
18
- from types import FunctionType, MethodType
19
18
 
20
19
  import numpy as np
21
20
  from mindspore import log as logger
22
- from mindspore._c_expression import security
21
+ from mindspore._c_expression import security, HookType
23
22
  from mindspore._c_expression import Tensor as Tensor_
24
23
  from mindspore import _checkparam as validator
25
24
  from mindspore.common import dtype as mstype
26
25
  from mindspore.common.parameter import Parameter
27
26
  from mindspore.common.tensor import Tensor
28
27
  from mindspore.ops.primitive import prim_attr_register, Primitive, PrimitiveWithInfer
28
+ from mindspore._checkparam import check_hook_fn
29
29
 
30
30
 
31
31
  SUMMARY_TENSOR_CACHE = []
@@ -64,6 +64,8 @@ class ScalarSummary(Primitive):
64
64
  which specify the directory of the summary file. The summary file can
65
65
  be loaded and shown by MindInsight, see `MindInsight documents <https://www.mindspore.cn/
66
66
  mindinsight/docs/en/master/index.html>`_ for details.
67
+ In Ascend platform with graph mode, can set environment variables `MS_DUMP_SLICE_SIZE` and `MS_DUMP_WAIT_TIME`
68
+ to solve operator execution failure when calling this operator intensively.
67
69
 
68
70
  Inputs:
69
71
  - **name** (str) - The name of the input variable, it must not be an empty string.
@@ -122,6 +124,8 @@ class ImageSummary(Primitive):
122
124
  SummaryRecord or SummaryCollector, which specify the directory of the summary file. The summary file can
123
125
  be loaded and shown by MindInsight, see `MindInsight documents <https://www.mindspore.cn/
124
126
  mindinsight/docs/en/master/index.html>`_ for details.
127
+ In Ascend platform with graph mode, can set environment variables `MS_DUMP_SLICE_SIZE` and `MS_DUMP_WAIT_TIME`
128
+ to solve operator execution failure when calling this operator intensively.
125
129
 
126
130
  Inputs:
127
131
  - **name** (str) - The name of the input variable, it must not be an empty string.
@@ -173,6 +177,8 @@ class TensorSummary(Primitive):
173
177
  or SummaryCollector, which specify the directory of the summary file. The summary file can
174
178
  be loaded and shown by MindInsight, see `MindInsight documents <https://www.mindspore.cn/
175
179
  mindinsight/docs/en/master/index.html>`_ for details.
180
+ In Ascend platform with graph mode, can set environment variables `MS_DUMP_SLICE_SIZE` and `MS_DUMP_WAIT_TIME`
181
+ to solve operator execution failure when calling this operator intensively.
176
182
 
177
183
  Inputs:
178
184
  - **name** (str) - The name of the input variable.
@@ -228,9 +234,6 @@ class TensorDump(Primitive):
228
234
  """
229
235
  Save the Tensor as an npy file in numpy format.
230
236
 
231
- The file name will automatically have a prefix added based on the execution order. For example, if `file` is `a`,
232
- the first saved file will be named `0_a.npy`, and the second one will be named `1_a.npy`, and so on.
233
-
234
237
  .. warning::
235
238
  - If a large amount of data is stored within a short period, it may lead to memory overflow on the device side.
236
239
  Consider slicing the data to reduce the data scale.
@@ -238,6 +241,34 @@ class TensorDump(Primitive):
238
241
  too quickly, data loss may occur. You need to actively control the destruction time of the main process,
239
242
  such as using sleep.
240
243
 
244
+ Args:
245
+ input_output (str, optional): Used to control Tensordump behavior.
246
+ Available value is one of ['in', 'out', 'all']. Default value is ``out``.
247
+
248
+ In case of OpA --> RedistributionOps --> OpB,
249
+ The dump data of OpA's output is not equal to OpB's input (Due to the redistribution operators).
250
+ So the parameter input_output is to handle this situation.
251
+
252
+ Assuming OpA's output is used as both Tensordump's input parameter and OpB's input parameter.
253
+ Different requirements of saving dump data can be achieved by configuring parameter input_output:
254
+
255
+ - If the input_output is 'out', the dump data contains only OpA's output slice.
256
+ - If the input_output is 'all', the dump data contains both OpA's output slice and OpB's input slice.
257
+ - If the input_output is 'in', the dump data contains only OpB's input slice.
258
+
259
+ For input_output is 'all' or 'in', the input slice npy file format is:
260
+ id_fileName_cNodeID_dumpMode_rankID.npy.
261
+
262
+ For input_output is 'out' or 'all' the output slice npy file format is:
263
+ id_fileName.npy.
264
+
265
+ - id: An auto increment ID.
266
+ - fileName: Value of the parameter file
267
+ (if parameter file_name is a user-specified path, the value of fileName is the last level of the path).
268
+ - cNodeID: The node ID of the Tensordump node in the step_parallel_end.ir file.
269
+ - dumpMode: Value of the parameter input_output.
270
+ - rankID: Logical device id.
271
+
241
272
  Inputs:
242
273
  - **file** (str) - The path of the file to be saved.
243
274
  - **input_x** (Tensor) - Input Tensor of any dimension.
@@ -280,7 +311,7 @@ class TensorDump(Primitive):
280
311
  [6. 7. 8. 9.]]
281
312
  """
282
313
  @prim_attr_register
283
- def __init__(self):
314
+ def __init__(self, input_output='out'):
284
315
  """Initialize TensorDump."""
285
316
  if security.enable_security():
286
317
  raise ValueError('The TensorDump is not supported, please without `-s on` and recompile source.')
@@ -314,6 +345,8 @@ class HistogramSummary(Primitive):
314
345
  It must be used with SummaryRecord or SummaryCollector, which specify the directory of the summary file.
315
346
  The summary file can be loaded and shown by MindInsight, see `MindInsight documents <https://www.mindspore.cn/
316
347
  mindinsight/docs/en/master/index.html>`_ for details.
348
+ In Ascend platform with graph mode, can set environment variables `MS_DUMP_SLICE_SIZE` and `MS_DUMP_WAIT_TIME`
349
+ to solve operator execution failure when calling this operator intensively.
317
350
 
318
351
  Inputs:
319
352
  - **name** (str) - The name of the input variable.
@@ -499,16 +532,15 @@ class HookBackward(PrimitiveWithInfer):
499
532
  def __init__(self, hook_fn, cell_id=""):
500
533
  """Initialize HookBackward."""
501
534
  super(HookBackward, self).__init__(self.__class__.__name__)
502
- if not isinstance(hook_fn, (FunctionType, MethodType)):
503
- raise TypeError(f"For '{self.name}', the type of 'hook_fn' must be python function, "
504
- f"but got {type(hook_fn)}.")
535
+ if not check_hook_fn("HookBackward", hook_fn):
536
+ return
505
537
  if cell_id != "":
506
538
  logger.warning(f"The args 'cell_id' of HookBackward will be removed in a future version. If the value of "
507
539
  f"'cell_id' is set, the hook function will not work.")
508
540
  self.add_prim_attr("cell_id", cell_id)
509
541
  self.init_attrs["cell_id"] = cell_id
510
542
  self.cell_id = cell_id
511
- self.add_backward_hook_fn(hook_fn)
543
+ self.set_hook_fn(hook_fn, HookType.HookBackward)
512
544
 
513
545
  def infer_shape(self, *inputs_shape):
514
546
  if len(inputs_shape) == 1:
@@ -23,6 +23,7 @@ from mindspore.common import dtype as mstype
23
23
  from mindspore import _checkparam as validator
24
24
  from mindspore.common._decorator import deprecated
25
25
  from mindspore.ops.primitive import prim_attr_register, Primitive
26
+ from mindspore import log as logger
26
27
 
27
28
 
28
29
  class ScalarCast(Primitive):
@@ -59,3 +60,14 @@ class ScalarCast(Primitive):
59
60
  value = np.cast[np_dtype.lower()](input_x)
60
61
  value = value.item()
61
62
  return value
63
+
64
+
65
+ class TensorReport(Primitive):
66
+ @prim_attr_register
67
+ def __init__(self):
68
+ """Initialize TensorReport"""
69
+ self.add_prim_attr("side_effect_io", True)
70
+ self.add_prim_attr("channel_name", "ms_tensor_report")
71
+
72
+ def __call__(self, file, input_x):
73
+ logger.warning("TensorReport doesn't support pynative mode.")
@@ -18,6 +18,7 @@ from __future__ import division
18
18
 
19
19
  import numbers
20
20
  import math
21
+ import types
21
22
  import numpy as np
22
23
  from mindspore.ops import signature as sig
23
24
  from mindspore.ops.primitive import Primitive, prim_attr_register, prim_arg_register, PrimitiveWithInfer
@@ -937,6 +938,10 @@ class Tile(Primitive):
937
938
 
938
939
  Refer to :func:`mindspore.ops.tile` for more details.
939
940
 
941
+ Note:
942
+ On Ascend, the number of `dims` should not exceed 8, and currently does not support scenarios
943
+ where more than 4 dimensions are repeated simultaneously.
944
+
940
945
  Inputs:
941
946
  - **input** (Tensor) - The tensor whose elements need to be repeated. Set the shape of input tensor as
942
947
  :math:`(x_1, x_2, ..., x_S)` .
@@ -1025,6 +1030,10 @@ def tile(input, dims):
1025
1030
  output tensor has `input.shape[i] * dims[i]` elements, and the values of `input`
1026
1031
  are replicated `dims[i]` times along the i'th dimension.
1027
1032
 
1033
+ Note:
1034
+ On Ascend, the number of `dims` should not exceed 8, and currently does not support scenarios
1035
+ where more than 4 dimensions are repeated simultaneously.
1036
+
1028
1037
  Args:
1029
1038
  input (Tensor): The tensor whose elements need to be repeated. Set the shape of input tensor as
1030
1039
  :math:`(x_1, x_2, ..., x_S)` .
@@ -1127,16 +1136,16 @@ class Cast(Primitive):
1127
1136
  taken into account. As long as the real part is non-zero, it returns True; otherwise, it returns False.
1128
1137
 
1129
1138
  Inputs:
1130
- - **input_x** (Union[Tensor, Number]) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
1139
+ - **input** (Union[Tensor, Number]) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
1131
1140
  The tensor to be cast.
1132
- - **type** (dtype.Number) - The valid data type of the output tensor. Only constant value is allowed.
1141
+ - **dtype** (dtype.Number) - The valid data type of the output tensor. Only constant value is allowed.
1133
1142
 
1134
1143
  Outputs:
1135
- Tensor, the shape of tensor is the same as `input_x`, :math:`(x_1, x_2, ..., x_R)`.
1144
+ Tensor, the shape of tensor is the same as `input`, :math:`(x_1, x_2, ..., x_R)`.
1136
1145
 
1137
1146
  Raises:
1138
- TypeError: If `input_x` is neither Tensor nor Number.
1139
- TypeError: If `type` is not a Number.
1147
+ TypeError: If `input` is neither Tensor nor Number.
1148
+ TypeError: If `dtype` is not a Number.
1140
1149
 
1141
1150
  Supported Platforms:
1142
1151
  ``Ascend`` ``GPU`` ``CPU``
@@ -1146,10 +1155,10 @@ class Cast(Primitive):
1146
1155
  >>> import numpy as np
1147
1156
  >>> from mindspore import Tensor, ops
1148
1157
  >>> input_np = np.random.randn(2, 3, 4, 5).astype(np.float32)
1149
- >>> input_x = Tensor(input_np)
1150
- >>> type_dst = mindspore.int32
1158
+ >>> input = Tensor(input_np)
1159
+ >>> dtype = mindspore.int32
1151
1160
  >>> cast = ops.Cast()
1152
- >>> output = cast(input_x, type_dst)
1161
+ >>> output = cast(input, dtype)
1153
1162
  >>> print(output.dtype)
1154
1163
  Int32
1155
1164
  >>> print(output.shape)
@@ -1162,17 +1171,15 @@ class Cast(Primitive):
1162
1171
  self.init_prim_io_names(inputs=['x', 'dst_type'], outputs=['output'])
1163
1172
 
1164
1173
  def check_elim(self, x, dtype):
1165
- if isinstance(x, (Tensor, numbers.Number, Parameter)):
1166
- if isinstance(x, Parameter):
1167
- data = x.data
1168
- if data.dtype == dtype:
1169
- return (True, x)
1170
- if isinstance(x, Tensor) and x.dtype == dtype:
1171
- x = Tensor(x)
1172
- x.set_cast_dtype()
1174
+ if isinstance(x, Parameter):
1175
+ data = x.data
1176
+ if data.dtype == dtype:
1173
1177
  return (True, x)
1174
- if isinstance(x, numbers.Number):
1175
- return (True, Tensor(x, dtype=dtype))
1178
+ if isinstance(x, Tensor) and x.dtype == dtype:
1179
+ x.set_cast_dtype()
1180
+ return (True, x)
1181
+ if isinstance(x, numbers.Number):
1182
+ return (True, Tensor(x, dtype=dtype))
1176
1183
  return (False, None)
1177
1184
 
1178
1185
  def __call__(self, input_x, dtype):
@@ -1187,7 +1194,7 @@ def to_sequence(val):
1187
1194
  to_sequence
1188
1195
  """
1189
1196
  if isinstance(val, (tuple, list)):
1190
- return val
1197
+ return tuple(val)
1191
1198
  return (val,)
1192
1199
 
1193
1200
 
@@ -1891,7 +1898,7 @@ def flash_attention_score(query, key, value, head_num, real_shift=None, drop_mas
1891
1898
  H2 -- Hidden size of key and value, which equals to N2 * D.
1892
1899
 
1893
1900
  .. warning::
1894
- This is an experimental API that is subject to change or deletion. Only support on Atlas training series.
1901
+ This is an experimental API that is subject to change or deletion. Only support on Atlas A2 training series.
1895
1902
 
1896
1903
  Args:
1897
1904
  query (Tensor[float16, bfloat16]): The query tensor. Input tensor of shape :math:`(B, S1, H1)`,
@@ -2014,3 +2021,249 @@ def flash_attention_score(query, key, value, head_num, real_shift=None, drop_mas
2014
2021
  inner_precise, input_layout, sparse_mode)
2015
2022
  return rank_op(query, key, value, real_shift, drop_mask, padding_mask, attn_mask, prefix, actual_seq_qlen,
2016
2023
  actual_seq_kvlen)[3]
2024
+
2025
+
2026
+ class WhileLoop(Primitive):
2027
+ """
2028
+ Provide a useful op for reducing compilation times of while loop.
2029
+ The execution logic of the WhileLoop operator can be roughly represented by the following code:
2030
+
2031
+ .. code-block:: python
2032
+
2033
+ def WhileLoop(cond_func, loop_func, init_val):
2034
+ while(cond_func(init_val)):
2035
+ init_val = loop_func(init_val)
2036
+ return init_val
2037
+
2038
+ The current WhileLoop operator has the following syntactic limitations:
2039
+
2040
+ - Using a side-effect function as `loop_func` is currently not support,
2041
+ such as operations that modify parameters, global variables, etc.
2042
+ - The return value of `loop_func` being of a different type or shape
2043
+ from the `init_val` is currently not support.
2044
+
2045
+ .. warning::
2046
+ This is an experimental API that is subject to change or deletion.
2047
+
2048
+ Inputs:
2049
+ - **cond_func** (Function) - The condition function.
2050
+ - **loop_func** (Function) - The loop function, take one argument and
2051
+ return value has the same type with input argument.
2052
+ - **init_val** (Union[Tensor, number, str, bool, list, tuple, dict]) - The initial value.
2053
+
2054
+ Outputs:
2055
+ Union[Tensor, number, str, bool, list, tuple, dict], the final result of the while loop,
2056
+ has same type and shape with input `init_val` .
2057
+
2058
+ Raises:
2059
+ TypeError: If `cond_func` is not a function.
2060
+ TypeError: If `loop_func` is not a function.
2061
+ ValueError: If `loop_func` cannot take `init_val` as input or has different
2062
+ output type or shape with `init_val` .
2063
+
2064
+ Supported Platforms:
2065
+ ``Ascend`` ``GPU`` ``CPU``
2066
+
2067
+ Examples:
2068
+ >>> from mindspore import ops
2069
+ >>> def loop_while_fun(init_val):
2070
+ ... val = init_val
2071
+ ... val = val + 1
2072
+ ... return val
2073
+ ...
2074
+ >>> init_state = 10
2075
+ >>> while_loop = ops.WhileLoop()
2076
+ >>> result = while_loop(lambda x : x < 100, loop_while_fun, init_state)
2077
+ >>> print(result)
2078
+ 100
2079
+ """
2080
+
2081
+ @prim_attr_register
2082
+ def __init__(self):
2083
+ """Initialize WhileLoop."""
2084
+
2085
+ def __call__(self, cond_func, loop_func, init_val):
2086
+ validator.check_value_type("cond_func", cond_func,
2087
+ [types.FunctionType, types.MethodType], "WhileLoop")
2088
+ validator.check_value_type("loop_func", loop_func,
2089
+ [types.FunctionType, types.MethodType], "WhileLoop")
2090
+ val = init_val
2091
+ try:
2092
+ while cond_func(val):
2093
+ val = loop_func(val)
2094
+ except Exception as e:
2095
+ raise ValueError("Invalid loop_func, please check input arguments and \
2096
+ return value, error info: {}".format(e))
2097
+ return val
2098
+
2099
+
2100
+ class Scan(Primitive):
2101
+ """
2102
+ Scan a function over an array while the processing of the current element
2103
+ depends on the execution result of the previous element.
2104
+ The execution logic of the Scan operator can be roughly represented by the following code:
2105
+
2106
+ .. code-block:: python
2107
+
2108
+ def Scan(loop_func, init, xs, length=None):
2109
+ if xs is None:
2110
+ xs = [None] * length
2111
+ carry = init
2112
+ ys = []
2113
+ for x in xs:
2114
+ carry, y = loop_func(carry, x)
2115
+ ys.append(y)
2116
+ return carry, ys
2117
+
2118
+ The current Scan operator has the following syntactic limitations:
2119
+
2120
+ - Using a side-effect function as `loop_func` is currently not support,
2121
+ such as operations that modify parameters, global variables, etc.
2122
+ - The first element of the return value of `loop_func` being of a different
2123
+ type or shape from the `init_val` is currently not support.
2124
+
2125
+ .. warning::
2126
+ This is an experimental API that is subject to change or deletion.
2127
+
2128
+ Inputs:
2129
+ - **loop_func** (Function) - The loop function.
2130
+ - **init** (Union[Tensor, number, str, bool, list, tuple, dict]) - An initial loop carry value.
2131
+ - **xs** (Union[tuple, list, None]) - The value over which to scan.
2132
+ - **length** (Union[int, None], optional) - The size of xs. Default: ``None`` .
2133
+ - **unroll** (bool, optional) - The flag for whether to perform loop unrolling in compile process.
2134
+ Default: ``True`` .
2135
+
2136
+ Outputs:
2137
+ Tuple(Union[Tensor, number, str, bool, list, tuple, dict], list). Output of scan loop,
2138
+ a tuple with two elements, the first element has same type and shape with init argument,
2139
+ and the second is a list containing the results of each loop.
2140
+
2141
+ Raises:
2142
+ TypeError: If `loop_func` is not a function.
2143
+ TypeError: If `xs` is not a tuple, a list or None.
2144
+ TypeError: If `length` is not an int or None.
2145
+ TypeError: If `unroll` is not a bool.
2146
+ ValueError: If `loop_func` cannot take `init` and element of `xs` as inputs.
2147
+ ValueError: If the return value of `loop_func` is not a tuple with two elements,
2148
+ or the first element of the tuple has different type or shape from `init` .
2149
+
2150
+ Supported Platforms:
2151
+ ``Ascend`` ``GPU`` ``CPU``
2152
+
2153
+ Examples:
2154
+ >>> from mindspore import ops
2155
+ >>> def cumsum(res, el):
2156
+ ... res = res + el
2157
+ ... return res, res
2158
+ ...
2159
+ >>> a = [1, 2, 3, 4]
2160
+ >>> result_init = 0
2161
+ >>> scan_op = ops.Scan()
2162
+ >>> result = scan_op(cumsum, result_init, a)
2163
+ >>> print(result == (10, [1, 3, 6, 10]))
2164
+ True
2165
+ """
2166
+
2167
+ @prim_attr_register
2168
+ def __init__(self):
2169
+ """Initialize Scan."""
2170
+
2171
+ def __call__(self, loop_func, init, xs, length=None, unroll=True):
2172
+ validator.check_value_type("loop_func", loop_func,
2173
+ [types.FunctionType, types.MethodType], "Scan")
2174
+ validator.check_value_type("xs", xs, [list, tuple, None], "Scan")
2175
+ if xs is None:
2176
+ validator.check_value_type("length", length, [int], "Scan")
2177
+ xs = [None] * length
2178
+ carry = init
2179
+ length = len(xs)
2180
+ if not length:
2181
+ return init, []
2182
+ try:
2183
+ carry, y = loop_func(carry, xs[0])
2184
+ ys = [y]
2185
+ i = 1
2186
+ while i < length:
2187
+ carry, y = loop_func(carry, xs[i])
2188
+ ys.append(y)
2189
+ i = i + 1
2190
+ except Exception as e:
2191
+ raise ValueError("Invalid loop_func, please check input arguments and \
2192
+ return value, error info: {}".format(e))
2193
+ return carry, ys
2194
+
2195
+
2196
+ class ForiLoop(Primitive):
2197
+ """
2198
+ Provide a useful op for loop from lower to upper.
2199
+ The execution logic of the ForiLoop operator can be roughly represented by the following code:
2200
+
2201
+ .. code-block:: python
2202
+
2203
+ def ForiLoop(lower, upper, loop_func, init_val):
2204
+ for i in range(lower, upper):
2205
+ init_val = loop_func(i, init_val)
2206
+ return init_val
2207
+
2208
+ The current ForiLoop operator has the following syntactic limitations:
2209
+
2210
+ - Using a side-effect function as `loop_func` is currently not support,
2211
+ such as operations that modify parameters, global variables, etc.
2212
+ - The return value of `loop_func` being of a different type or shape
2213
+ from the `init_val` is currently not support.
2214
+ - Negative numbers or custom increments is currently not support.
2215
+
2216
+ .. warning::
2217
+ This is an experimental API that is subject to change or deletion.
2218
+
2219
+ Inputs:
2220
+ - **lower** (Union[int, Tensor]) - The start index of loop.
2221
+ - **upper** (Union[int, Tensor]) - The end index of loop.
2222
+ - **loop_func** (Function) - The loop function, takes two arguments.
2223
+ - **init_val** (Union[Tensor, number, str, bool, list, tuple, dict]) - The init value.
2224
+ - **unroll** (bool, optional) - The flag for whether unroll in compile process,
2225
+ only valid when the number of loop iterations is determined. Default: ``True`` .
2226
+
2227
+ Outputs:
2228
+ Union[Tensor, number, str, bool, list, tuple, dict], the final result of the loop,
2229
+ has same type and shape with input `init_val` .
2230
+
2231
+ Raises:
2232
+ TypeError: If `lower` is not an int or a Tensor.
2233
+ TypeError: If `upper` is not an int or a Tensor.
2234
+ TypeError: If `loop_func` is not a function.
2235
+ ValueError: If `loop_func` cannot take index and `init_val` as arguments or if the type
2236
+ of output it produces is different from the type or shape of `init_val` .
2237
+
2238
+ Supported Platforms:
2239
+ ``Ascend`` ``GPU`` ``CPU``
2240
+
2241
+ Examples:
2242
+ >>> from mindspore import ops
2243
+ >>> def cumsum(index, res):
2244
+ ... return index + res
2245
+ ...
2246
+ >>> result_init = 0
2247
+ >>> fori_loop = ops.ForiLoop()
2248
+ >>> result = fori_loop(0, 4, cumsum, result_init)
2249
+ >>> print(result)
2250
+ 6
2251
+ """
2252
+
2253
+ @prim_attr_register
2254
+ def __init__(self):
2255
+ """Initialize ForiLoop."""
2256
+
2257
+ def __call__(self, lower, upper, loop_func, init_val, unroll=True):
2258
+ validator.check_value_type("lower", lower, [int, Tensor], "ForiLoop")
2259
+ validator.check_value_type("upper", upper, [int, Tensor], "ForiLoop")
2260
+ validator.check_value_type("loop_func", loop_func,
2261
+ [types.FunctionType, types.MethodType], "ForiLoop")
2262
+ val = init_val
2263
+ try:
2264
+ for i in range(lower, upper):
2265
+ val = loop_func(i, val)
2266
+ except Exception as e:
2267
+ raise ValueError("Invalid loop_func, please check input arguments and \
2268
+ return value, error info: {}".format(e))
2269
+ return val