mindspore 2.3.0__cp310-cp310-win_amd64.whl → 2.4.1__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +3 -1
- mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +50 -9
- mindspore/_extends/parse/compile_config.py +41 -0
- mindspore/_extends/parse/parser.py +9 -7
- mindspore/_extends/parse/standard_method.py +52 -14
- mindspore/_extends/pijit/pijit_func_white_list.py +350 -24
- mindspore/amp.py +24 -10
- mindspore/common/__init__.py +6 -4
- mindspore/common/_pijit_context.py +190 -0
- mindspore/common/_register_for_tensor.py +2 -1
- mindspore/common/_tensor_overload.py +139 -0
- mindspore/common/api.py +102 -87
- mindspore/common/dump.py +5 -6
- mindspore/common/generator.py +1 -7
- mindspore/common/hook_handle.py +14 -26
- mindspore/common/initializer.py +51 -15
- mindspore/common/mindir_util.py +2 -2
- mindspore/common/parameter.py +62 -15
- mindspore/common/recompute.py +39 -9
- mindspore/common/sparse_tensor.py +7 -3
- mindspore/common/tensor.py +183 -37
- mindspore/communication/__init__.py +1 -1
- mindspore/communication/_comm_helper.py +38 -3
- mindspore/communication/comm_func.py +315 -60
- mindspore/communication/management.py +14 -14
- mindspore/context.py +132 -22
- mindspore/dataset/__init__.py +1 -1
- mindspore/dataset/audio/__init__.py +1 -1
- mindspore/dataset/core/config.py +7 -0
- mindspore/dataset/core/validator_helpers.py +7 -0
- mindspore/dataset/engine/cache_client.py +1 -1
- mindspore/dataset/engine/datasets.py +72 -44
- mindspore/dataset/engine/datasets_audio.py +7 -7
- mindspore/dataset/engine/datasets_standard_format.py +53 -3
- mindspore/dataset/engine/datasets_text.py +20 -20
- mindspore/dataset/engine/datasets_user_defined.py +174 -104
- mindspore/dataset/engine/datasets_vision.py +33 -33
- mindspore/dataset/engine/iterators.py +29 -0
- mindspore/dataset/engine/obs/util.py +7 -0
- mindspore/dataset/engine/queue.py +114 -60
- mindspore/dataset/engine/serializer_deserializer.py +2 -2
- mindspore/dataset/engine/validators.py +34 -14
- mindspore/dataset/text/__init__.py +1 -4
- mindspore/dataset/transforms/__init__.py +0 -3
- mindspore/dataset/utils/line_reader.py +2 -0
- mindspore/dataset/vision/__init__.py +1 -4
- mindspore/dataset/vision/utils.py +1 -1
- mindspore/dataset/vision/validators.py +2 -1
- mindspore/{nn/extend → experimental/es}/__init__.py +4 -11
- mindspore/experimental/es/embedding_service.py +883 -0
- mindspore/{nn/layer → experimental/es}/embedding_service_layer.py +218 -30
- mindspore/experimental/llm_boost/__init__.py +21 -0
- mindspore/{nn/extend/layer → experimental/llm_boost/atb}/__init__.py +4 -8
- mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
- mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
- mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
- mindspore/experimental/llm_boost/register.py +129 -0
- mindspore/experimental/llm_boost/utils.py +31 -0
- mindspore/experimental/optim/adamw.py +85 -0
- mindspore/experimental/optim/optimizer.py +3 -0
- mindspore/hal/__init__.py +3 -3
- mindspore/hal/contiguous_tensors_handle.py +175 -0
- mindspore/hal/stream.py +18 -0
- mindspore/include/api/model_group.h +13 -1
- mindspore/include/api/types.h +10 -10
- mindspore/include/dataset/config.h +2 -2
- mindspore/include/dataset/constants.h +2 -2
- mindspore/include/dataset/execute.h +2 -2
- mindspore/include/dataset/vision.h +4 -0
- mindspore/log.py +1 -1
- mindspore/mindrecord/filewriter.py +68 -51
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/mint/__init__.py +983 -46
- mindspore/mint/distributed/__init__.py +31 -0
- mindspore/mint/distributed/distributed.py +254 -0
- mindspore/mint/nn/__init__.py +268 -23
- mindspore/mint/nn/functional.py +125 -19
- mindspore/mint/nn/layer/__init__.py +39 -0
- mindspore/mint/nn/layer/activation.py +133 -0
- mindspore/mint/nn/layer/normalization.py +477 -0
- mindspore/mint/nn/layer/pooling.py +110 -0
- mindspore/mint/optim/adamw.py +26 -13
- mindspore/mint/special/__init__.py +63 -0
- mindspore/multiprocessing/__init__.py +2 -1
- mindspore/nn/__init__.py +0 -1
- mindspore/nn/cell.py +276 -96
- mindspore/nn/layer/activation.py +211 -44
- mindspore/nn/layer/basic.py +137 -10
- mindspore/nn/layer/embedding.py +137 -2
- mindspore/nn/layer/normalization.py +101 -5
- mindspore/nn/layer/padding.py +34 -48
- mindspore/nn/layer/pooling.py +161 -7
- mindspore/nn/layer/transformer.py +3 -3
- mindspore/nn/loss/__init__.py +2 -2
- mindspore/nn/loss/loss.py +84 -6
- mindspore/nn/optim/__init__.py +2 -1
- mindspore/nn/optim/adadelta.py +1 -1
- mindspore/nn/optim/adam.py +1 -1
- mindspore/nn/optim/lamb.py +1 -1
- mindspore/nn/optim/tft_wrapper.py +124 -0
- mindspore/nn/wrap/cell_wrapper.py +12 -23
- mindspore/nn/wrap/grad_reducer.py +5 -5
- mindspore/nn/wrap/loss_scale.py +17 -3
- mindspore/numpy/__init__.py +1 -1
- mindspore/numpy/array_creations.py +65 -68
- mindspore/numpy/array_ops.py +64 -60
- mindspore/numpy/fft.py +610 -75
- mindspore/numpy/logic_ops.py +11 -10
- mindspore/numpy/math_ops.py +85 -84
- mindspore/numpy/utils_const.py +4 -4
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +6 -4
- mindspore/ops/_grad_experimental/grad_array_ops.py +0 -11
- mindspore/ops/_grad_experimental/grad_comm_ops.py +67 -4
- mindspore/ops/_grad_experimental/grad_math_ops.py +0 -22
- mindspore/ops/_vmap/vmap_array_ops.py +2 -4
- mindspore/ops/_vmap/vmap_math_ops.py +17 -1
- mindspore/ops/_vmap/vmap_nn_ops.py +43 -2
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +91 -7
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +2 -0
- mindspore/ops/auto_generate/gen_extend_func.py +767 -13
- mindspore/ops/auto_generate/gen_ops_def.py +2452 -364
- mindspore/ops/auto_generate/gen_ops_prim.py +5442 -1756
- mindspore/ops/auto_generate/pyboost_inner_prim.py +176 -56
- mindspore/ops/composite/base.py +85 -48
- mindspore/ops/composite/multitype_ops/_compile_utils.py +1 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -2
- mindspore/ops/function/__init__.py +22 -0
- mindspore/ops/function/array_func.py +492 -153
- mindspore/ops/function/debug_func.py +113 -1
- mindspore/ops/function/fft_func.py +15 -2
- mindspore/ops/function/grad/grad_func.py +3 -2
- mindspore/ops/function/math_func.py +564 -207
- mindspore/ops/function/nn_func.py +817 -383
- mindspore/ops/function/other_func.py +3 -2
- mindspore/ops/function/random_func.py +402 -12
- mindspore/ops/function/reshard_func.py +13 -11
- mindspore/ops/function/sparse_unary_func.py +1 -1
- mindspore/ops/function/vmap_func.py +3 -2
- mindspore/ops/functional.py +24 -14
- mindspore/ops/op_info_register.py +3 -3
- mindspore/ops/operations/__init__.py +7 -2
- mindspore/ops/operations/_grad_ops.py +2 -76
- mindspore/ops/operations/_infer_ops.py +1 -1
- mindspore/ops/operations/_inner_ops.py +71 -94
- mindspore/ops/operations/array_ops.py +14 -146
- mindspore/ops/operations/comm_ops.py +63 -53
- mindspore/ops/operations/custom_ops.py +83 -19
- mindspore/ops/operations/debug_ops.py +42 -10
- mindspore/ops/operations/manually_defined/_inner.py +12 -0
- mindspore/ops/operations/manually_defined/ops_def.py +273 -20
- mindspore/ops/operations/math_ops.py +12 -223
- mindspore/ops/operations/nn_ops.py +20 -114
- mindspore/ops/operations/other_ops.py +7 -4
- mindspore/ops/operations/random_ops.py +46 -1
- mindspore/ops/primitive.py +18 -6
- mindspore/ops_generate/arg_dtype_cast.py +2 -0
- mindspore/ops_generate/gen_aclnn_implement.py +11 -11
- mindspore/ops_generate/gen_constants.py +36 -0
- mindspore/ops_generate/gen_ops.py +67 -52
- mindspore/ops_generate/gen_ops_inner_prim.py +1 -1
- mindspore/ops_generate/gen_pyboost_func.py +131 -47
- mindspore/ops_generate/op_proto.py +10 -3
- mindspore/ops_generate/pyboost_utils.py +14 -1
- mindspore/ops_generate/template.py +43 -21
- mindspore/parallel/__init__.py +3 -1
- mindspore/parallel/_auto_parallel_context.py +31 -9
- mindspore/parallel/_cell_wrapper.py +85 -0
- mindspore/parallel/_parallel_serialization.py +47 -19
- mindspore/parallel/_tensor.py +127 -13
- mindspore/parallel/_utils.py +53 -22
- mindspore/parallel/algo_parameter_config.py +5 -5
- mindspore/parallel/checkpoint_transform.py +46 -39
- mindspore/parallel/cluster/process_entity/__init__.py +1 -1
- mindspore/parallel/cluster/process_entity/_api.py +31 -23
- mindspore/parallel/cluster/process_entity/_utils.py +2 -27
- mindspore/parallel/parameter_broadcast.py +3 -4
- mindspore/parallel/shard.py +162 -31
- mindspore/parallel/transform_safetensors.py +1146 -0
- mindspore/profiler/__init__.py +2 -1
- mindspore/profiler/common/constant.py +29 -0
- mindspore/profiler/common/registry.py +47 -0
- mindspore/profiler/common/util.py +28 -0
- mindspore/profiler/dynamic_profiler.py +694 -0
- mindspore/profiler/envprofiling.py +17 -19
- mindspore/profiler/parser/ascend_analysis/constant.py +18 -0
- mindspore/profiler/parser/ascend_analysis/file_manager.py +25 -4
- mindspore/profiler/parser/ascend_analysis/function_event.py +43 -19
- mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +31 -26
- mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +56 -10
- mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +55 -8
- mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
- mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +27 -20
- mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +9 -2
- mindspore/profiler/parser/ascend_msprof_exporter.py +5 -4
- mindspore/profiler/parser/ascend_timeline_generator.py +27 -25
- mindspore/profiler/parser/base_timeline_generator.py +19 -25
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +25 -12
- mindspore/profiler/parser/framework_parser.py +1 -391
- mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
- mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
- mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
- mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
- mindspore/profiler/parser/memory_usage_parser.py +0 -154
- mindspore/profiler/parser/profiler_info.py +78 -6
- mindspore/profiler/profiler.py +153 -0
- mindspore/profiler/profiling.py +285 -413
- mindspore/rewrite/__init__.py +1 -2
- mindspore/rewrite/common/namespace.py +4 -4
- mindspore/rewrite/symbol_tree/symbol_tree.py +3 -3
- mindspore/run_check/_check_version.py +39 -104
- mindspore/safeguard/rewrite_obfuscation.py +591 -247
- mindspore/train/__init__.py +4 -3
- mindspore/train/_utils.py +105 -19
- mindspore/train/amp.py +171 -53
- mindspore/train/callback/__init__.py +2 -2
- mindspore/train/callback/_callback.py +4 -4
- mindspore/train/callback/_checkpoint.py +97 -31
- mindspore/train/callback/_cluster_monitor.py +1 -1
- mindspore/train/callback/_flops_collector.py +1 -0
- mindspore/train/callback/_loss_monitor.py +3 -3
- mindspore/train/callback/_on_request_exit.py +145 -31
- mindspore/train/callback/_summary_collector.py +5 -5
- mindspore/train/callback/_tft_register.py +375 -0
- mindspore/train/dataset_helper.py +15 -3
- mindspore/train/metrics/metric.py +3 -3
- mindspore/train/metrics/roc.py +4 -4
- mindspore/train/mind_ir_pb2.py +44 -39
- mindspore/train/model.py +154 -58
- mindspore/train/serialization.py +342 -128
- mindspore/utils/__init__.py +21 -0
- mindspore/utils/utils.py +60 -0
- mindspore/version.py +1 -1
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/METADATA +13 -7
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/RECORD +248 -242
- mindspore/include/c_api/ms/abstract.h +0 -67
- mindspore/include/c_api/ms/attribute.h +0 -197
- mindspore/include/c_api/ms/base/handle_types.h +0 -43
- mindspore/include/c_api/ms/base/macros.h +0 -32
- mindspore/include/c_api/ms/base/status.h +0 -33
- mindspore/include/c_api/ms/base/types.h +0 -283
- mindspore/include/c_api/ms/context.h +0 -102
- mindspore/include/c_api/ms/graph.h +0 -160
- mindspore/include/c_api/ms/node.h +0 -606
- mindspore/include/c_api/ms/tensor.h +0 -161
- mindspore/include/c_api/ms/value.h +0 -84
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/nn/extend/basic.py +0 -140
- mindspore/nn/extend/embedding.py +0 -143
- mindspore/nn/extend/layer/normalization.py +0 -109
- mindspore/nn/extend/pooling.py +0 -117
- mindspore/nn/layer/embedding_service.py +0 -531
- mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +0 -93
- mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +0 -66
- mindspore/ops/extend/__init__.py +0 -53
- mindspore/ops/extend/array_func.py +0 -218
- mindspore/ops/extend/math_func.py +0 -76
- mindspore/ops/extend/nn_func.py +0 -308
- mindspore/ops/silent_check.py +0 -162
- mindspore/profiler/parser/msadvisor_analyzer.py +0 -82
- mindspore/profiler/parser/msadvisor_parser.py +0 -240
- mindspore/train/callback/_mindio_ttp.py +0 -443
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/WHEEL +0 -0
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/entry_points.txt +0 -0
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/top_level.txt +0 -0
|
@@ -28,6 +28,8 @@ import subprocess
|
|
|
28
28
|
import numpy as np
|
|
29
29
|
import mindspore as ms
|
|
30
30
|
from mindspore._c_expression import Oplib, typing
|
|
31
|
+
from mindspore._c_expression import pyboost_custom_ext
|
|
32
|
+
from mindspore.common._stub_tensor import _convert_stub
|
|
31
33
|
from mindspore import context
|
|
32
34
|
from mindspore.common import Tensor
|
|
33
35
|
from mindspore.common import dtype as mstype
|
|
@@ -156,6 +158,55 @@ def _compile_aot(file):
|
|
|
156
158
|
return func_path
|
|
157
159
|
|
|
158
160
|
|
|
161
|
+
class _CustomExt(ops.PrimitiveWithInfer):
|
|
162
|
+
"""
|
|
163
|
+
`Custom` primitive is used for PyBoost.
|
|
164
|
+
"""
|
|
165
|
+
|
|
166
|
+
def __init__(self, func, out_shape=None, out_dtype=None, bprop=None):
|
|
167
|
+
super().__init__("CustomExt")
|
|
168
|
+
self.func = func
|
|
169
|
+
self.out_shape = out_shape
|
|
170
|
+
self.out_dtype = out_dtype
|
|
171
|
+
self.bprop = bprop
|
|
172
|
+
|
|
173
|
+
def __infer__(self, *args):
|
|
174
|
+
if callable(self.out_shape):
|
|
175
|
+
infer_shape = self.out_shape(*(x["shape"] for x in args))
|
|
176
|
+
else:
|
|
177
|
+
infer_shape = self.out_shape
|
|
178
|
+
|
|
179
|
+
if callable(self.out_dtype):
|
|
180
|
+
infer_dtype = self.out_dtype(*(x["dtype"] for x in args))
|
|
181
|
+
else:
|
|
182
|
+
infer_dtype = self.out_dtype
|
|
183
|
+
|
|
184
|
+
infer_value = None
|
|
185
|
+
if infer_shape is None:
|
|
186
|
+
logger.warning("'out_shape' is None. Add a placeholder instead. "
|
|
187
|
+
"A CPP version of infer shape function is required "
|
|
188
|
+
"in this case.")
|
|
189
|
+
infer_shape = (1,)
|
|
190
|
+
# after all automatic infer information fulfillment, throw error if infer_shape/infer_dtype is still None
|
|
191
|
+
if not isinstance(infer_shape, (tuple, list)):
|
|
192
|
+
raise TypeError("'out_shape' must be one of [tuple, list, function], but got {}".format(type(infer_shape)))
|
|
193
|
+
|
|
194
|
+
if not isinstance(infer_dtype, (typing.Type, tuple, list)):
|
|
195
|
+
raise TypeError("'out_dtype' must be one of [mindspore.dtype, tuple, list, function], but got {}"
|
|
196
|
+
.format(type(infer_dtype)))
|
|
197
|
+
|
|
198
|
+
out = {
|
|
199
|
+
"shape": infer_shape,
|
|
200
|
+
"dtype": infer_dtype,
|
|
201
|
+
"value": infer_value,
|
|
202
|
+
}
|
|
203
|
+
return out
|
|
204
|
+
|
|
205
|
+
def get_bprop(self):
|
|
206
|
+
"""return back propagation function"""
|
|
207
|
+
return self.bprop
|
|
208
|
+
|
|
209
|
+
|
|
159
210
|
class Custom(ops.PrimitiveWithInfer):
|
|
160
211
|
r"""
|
|
161
212
|
`Custom` primitive is used for user defined operators and is to enhance the expressive ability of built-in
|
|
@@ -164,7 +215,7 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
164
215
|
function if needed. Then these `Custom` objects can be directly used in neural networks.
|
|
165
216
|
Detailed description and introduction of user-defined operators, including correct writing of parameters,
|
|
166
217
|
please refer to `Custom Operators Tutorial
|
|
167
|
-
<https://www.mindspore.cn/
|
|
218
|
+
<https://www.mindspore.cn/docs/en/master/model_train/custom_program/op_custom.html>`_ .
|
|
168
219
|
|
|
169
220
|
.. warning::
|
|
170
221
|
- This is an experimental API that is subject to change.
|
|
@@ -174,7 +225,7 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
174
225
|
|
|
175
226
|
- "hybrid": supports ["GPU", "CPU"].
|
|
176
227
|
- "akg": supports ["GPU", "CPU"].
|
|
177
|
-
- "aot": supports ["GPU", "CPU", "
|
|
228
|
+
- "aot": supports ["GPU", "CPU", "Ascend"].
|
|
178
229
|
- "pyfunc": supports ["CPU"].
|
|
179
230
|
- "julia": supports ["CPU"].
|
|
180
231
|
|
|
@@ -249,20 +300,18 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
249
300
|
(ex. Custom(func="./reorganize.so:CustomReorganize", out_shape=[1], out_dtype=mstype.float32,
|
|
250
301
|
"aot"))
|
|
251
302
|
|
|
252
|
-
b)
|
|
253
|
-
Before using Custom operators on the
|
|
254
|
-
based on Ascend C and compile them.
|
|
255
|
-
`
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
- Inferring the shape of the operator through C++ derivation: func="infer_shape.cc:aclnnAddCustom",
|
|
265
|
-
where infer_shape.cc is the shape derivation implemented in C++.
|
|
303
|
+
b) Ascend platform.
|
|
304
|
+
Before using Custom operators on the Ascend platform, users must first develop custom operators
|
|
305
|
+
based on Ascend C and compile them. The complete development and usage process can refer to the
|
|
306
|
+
tutorial `AOT-Type Custom Operators(Ascend) <https://www.mindspore.cn/docs/en/master/model_train/custom_program/operation/op_custom_ascendc.html>`_.
|
|
307
|
+
By passing the name of the operator through the input parameter `func`, there are two usage methods
|
|
308
|
+
based on the implementation of the infer shape function:
|
|
309
|
+
|
|
310
|
+
- Python infer: If the operator's infer shape is implemented in Python, that is, the infer shape
|
|
311
|
+
function is passed through the `out_shape` parameter, specify `func="CustomName"` .
|
|
312
|
+
- C++ infer: If the operator's infer shape is implemented through C++, then pass the path of the
|
|
313
|
+
infer shape implementation file in `func` and separate the operator name with `:`,
|
|
314
|
+
for example: `func="add_custom_infer.cc:AddCustom"` .
|
|
266
315
|
|
|
267
316
|
2. for "julia":
|
|
268
317
|
|
|
@@ -338,7 +387,7 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
338
387
|
or the attributes of `func` differs in different targets.
|
|
339
388
|
|
|
340
389
|
Supported Platforms:
|
|
341
|
-
``
|
|
390
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
342
391
|
|
|
343
392
|
Examples:
|
|
344
393
|
>>> import numpy as np
|
|
@@ -457,6 +506,12 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
457
506
|
|
|
458
507
|
self.add_prim_attr("func_type", self.func_type)
|
|
459
508
|
self._update_attr()
|
|
509
|
+
self.enable_pyboost = False
|
|
510
|
+
self.custom_pyboost = _CustomExt(self.func, self.out_shape, self.out_dtype, self.bprop)
|
|
511
|
+
if context.get_context("device_target") == "Ascend" and self.func_type == "aot":
|
|
512
|
+
self.enable_pyboost = True
|
|
513
|
+
for key, value in super().get_attr_dict().items():
|
|
514
|
+
self.custom_pyboost.add_prim_attr(key, value)
|
|
460
515
|
|
|
461
516
|
def __infer__(self, *args):
|
|
462
517
|
if callable(self.out_shape):
|
|
@@ -559,7 +614,7 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
559
614
|
raise TypeError(
|
|
560
615
|
"{}, 'func' should be like 'file_name:func_name', but got {}".format(
|
|
561
616
|
self.log_prefix, self.func))
|
|
562
|
-
file_path = os.path.
|
|
617
|
+
file_path = os.path.realpath(file_name_list[0])
|
|
563
618
|
if os.environ.get('MS_CUSTOM_AOT_WHITE_LIST') is None:
|
|
564
619
|
if Custom.custom_aot_warning:
|
|
565
620
|
logger.info("{}, no white list is set and it might cause problems. "
|
|
@@ -567,7 +622,7 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
567
622
|
.format(self.log_prefix))
|
|
568
623
|
Custom.custom_aot_warning = False
|
|
569
624
|
else:
|
|
570
|
-
legal_path = os.path.
|
|
625
|
+
legal_path = os.path.realpath(os.environ.get('MS_CUSTOM_AOT_WHITE_LIST'))
|
|
571
626
|
if legal_path not in file_path:
|
|
572
627
|
raise TypeError(
|
|
573
628
|
"{}, the legal path for the file is {}, but the file is {}".format(
|
|
@@ -1063,3 +1118,12 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
1063
1118
|
infer_value = Tensor(fake_output) if enable_infer_value else None
|
|
1064
1119
|
|
|
1065
1120
|
return infer_shape, infer_dtype, infer_value
|
|
1121
|
+
|
|
1122
|
+
def __call__(self, *args):
|
|
1123
|
+
if self.enable_pyboost:
|
|
1124
|
+
return _convert_stub(pyboost_custom_ext(self.custom_pyboost, [args]))
|
|
1125
|
+
should_elim, output = self.check_elim(*args)
|
|
1126
|
+
if should_elim:
|
|
1127
|
+
return output
|
|
1128
|
+
# pylint: disable=protected-access
|
|
1129
|
+
return ops.primitive._run_op(self, self.name, args)
|
|
@@ -15,17 +15,17 @@
|
|
|
15
15
|
"""debug_ops"""
|
|
16
16
|
import os
|
|
17
17
|
import stat
|
|
18
|
-
from types import FunctionType, MethodType
|
|
19
18
|
|
|
20
19
|
import numpy as np
|
|
21
20
|
from mindspore import log as logger
|
|
22
|
-
from mindspore._c_expression import security
|
|
21
|
+
from mindspore._c_expression import security, HookType
|
|
23
22
|
from mindspore._c_expression import Tensor as Tensor_
|
|
24
23
|
from mindspore import _checkparam as validator
|
|
25
24
|
from mindspore.common import dtype as mstype
|
|
26
25
|
from mindspore.common.parameter import Parameter
|
|
27
26
|
from mindspore.common.tensor import Tensor
|
|
28
27
|
from mindspore.ops.primitive import prim_attr_register, Primitive, PrimitiveWithInfer
|
|
28
|
+
from mindspore._checkparam import check_hook_fn
|
|
29
29
|
|
|
30
30
|
|
|
31
31
|
SUMMARY_TENSOR_CACHE = []
|
|
@@ -64,6 +64,8 @@ class ScalarSummary(Primitive):
|
|
|
64
64
|
which specify the directory of the summary file. The summary file can
|
|
65
65
|
be loaded and shown by MindInsight, see `MindInsight documents <https://www.mindspore.cn/
|
|
66
66
|
mindinsight/docs/en/master/index.html>`_ for details.
|
|
67
|
+
In Ascend platform with graph mode, can set environment variables `MS_DUMP_SLICE_SIZE` and `MS_DUMP_WAIT_TIME`
|
|
68
|
+
to solve operator execution failure when calling this operator intensively.
|
|
67
69
|
|
|
68
70
|
Inputs:
|
|
69
71
|
- **name** (str) - The name of the input variable, it must not be an empty string.
|
|
@@ -122,6 +124,8 @@ class ImageSummary(Primitive):
|
|
|
122
124
|
SummaryRecord or SummaryCollector, which specify the directory of the summary file. The summary file can
|
|
123
125
|
be loaded and shown by MindInsight, see `MindInsight documents <https://www.mindspore.cn/
|
|
124
126
|
mindinsight/docs/en/master/index.html>`_ for details.
|
|
127
|
+
In Ascend platform with graph mode, can set environment variables `MS_DUMP_SLICE_SIZE` and `MS_DUMP_WAIT_TIME`
|
|
128
|
+
to solve operator execution failure when calling this operator intensively.
|
|
125
129
|
|
|
126
130
|
Inputs:
|
|
127
131
|
- **name** (str) - The name of the input variable, it must not be an empty string.
|
|
@@ -173,6 +177,8 @@ class TensorSummary(Primitive):
|
|
|
173
177
|
or SummaryCollector, which specify the directory of the summary file. The summary file can
|
|
174
178
|
be loaded and shown by MindInsight, see `MindInsight documents <https://www.mindspore.cn/
|
|
175
179
|
mindinsight/docs/en/master/index.html>`_ for details.
|
|
180
|
+
In Ascend platform with graph mode, can set environment variables `MS_DUMP_SLICE_SIZE` and `MS_DUMP_WAIT_TIME`
|
|
181
|
+
to solve operator execution failure when calling this operator intensively.
|
|
176
182
|
|
|
177
183
|
Inputs:
|
|
178
184
|
- **name** (str) - The name of the input variable.
|
|
@@ -228,9 +234,6 @@ class TensorDump(Primitive):
|
|
|
228
234
|
"""
|
|
229
235
|
Save the Tensor as an npy file in numpy format.
|
|
230
236
|
|
|
231
|
-
The file name will automatically have a prefix added based on the execution order. For example, if `file` is `a`,
|
|
232
|
-
the first saved file will be named `0_a.npy`, and the second one will be named `1_a.npy`, and so on.
|
|
233
|
-
|
|
234
237
|
.. warning::
|
|
235
238
|
- If a large amount of data is stored within a short period, it may lead to memory overflow on the device side.
|
|
236
239
|
Consider slicing the data to reduce the data scale.
|
|
@@ -238,6 +241,34 @@ class TensorDump(Primitive):
|
|
|
238
241
|
too quickly, data loss may occur. You need to actively control the destruction time of the main process,
|
|
239
242
|
such as using sleep.
|
|
240
243
|
|
|
244
|
+
Args:
|
|
245
|
+
input_output (str, optional): Used to control Tensordump behavior.
|
|
246
|
+
Available value is one of ['in', 'out', 'all']. Default value is ``out``.
|
|
247
|
+
|
|
248
|
+
In case of OpA --> RedistributionOps --> OpB,
|
|
249
|
+
The dump data of OpA's output is not equal to OpB's input (Due to the redistribution operators).
|
|
250
|
+
So the parameter input_output is to handle this situation.
|
|
251
|
+
|
|
252
|
+
Assuming OpA's output is used as both Tensordump's input parameter and OpB's input parameter.
|
|
253
|
+
Different requirements of saving dump data can be achieved by configuring parameter input_output:
|
|
254
|
+
|
|
255
|
+
- If the input_output is 'out', the dump data contains only OpA's output slice.
|
|
256
|
+
- If the input_output is 'all', the dump data contains both OpA's output slice and OpB's input slice.
|
|
257
|
+
- If the input_output is 'in', the dump data contains only OpB's input slice.
|
|
258
|
+
|
|
259
|
+
For input_output is 'all' or 'in', the input slice npy file format is:
|
|
260
|
+
id_fileName_cNodeID_dumpMode_rankID.npy.
|
|
261
|
+
|
|
262
|
+
For input_output is 'out' or 'all' the output slice npy file format is:
|
|
263
|
+
id_fileName.npy.
|
|
264
|
+
|
|
265
|
+
- id: An auto increment ID.
|
|
266
|
+
- fileName: Value of the parameter file
|
|
267
|
+
(if parameter file_name is a user-specified path, the value of fileName is the last level of the path).
|
|
268
|
+
- cNodeID: The node ID of the Tensordump node in the step_parallel_end.ir file.
|
|
269
|
+
- dumpMode: Value of the parameter input_output.
|
|
270
|
+
- rankID: Logical device id.
|
|
271
|
+
|
|
241
272
|
Inputs:
|
|
242
273
|
- **file** (str) - The path of the file to be saved.
|
|
243
274
|
- **input_x** (Tensor) - Input Tensor of any dimension.
|
|
@@ -280,7 +311,7 @@ class TensorDump(Primitive):
|
|
|
280
311
|
[6. 7. 8. 9.]]
|
|
281
312
|
"""
|
|
282
313
|
@prim_attr_register
|
|
283
|
-
def __init__(self):
|
|
314
|
+
def __init__(self, input_output='out'):
|
|
284
315
|
"""Initialize TensorDump."""
|
|
285
316
|
if security.enable_security():
|
|
286
317
|
raise ValueError('The TensorDump is not supported, please without `-s on` and recompile source.')
|
|
@@ -314,6 +345,8 @@ class HistogramSummary(Primitive):
|
|
|
314
345
|
It must be used with SummaryRecord or SummaryCollector, which specify the directory of the summary file.
|
|
315
346
|
The summary file can be loaded and shown by MindInsight, see `MindInsight documents <https://www.mindspore.cn/
|
|
316
347
|
mindinsight/docs/en/master/index.html>`_ for details.
|
|
348
|
+
In Ascend platform with graph mode, can set environment variables `MS_DUMP_SLICE_SIZE` and `MS_DUMP_WAIT_TIME`
|
|
349
|
+
to solve operator execution failure when calling this operator intensively.
|
|
317
350
|
|
|
318
351
|
Inputs:
|
|
319
352
|
- **name** (str) - The name of the input variable.
|
|
@@ -499,16 +532,15 @@ class HookBackward(PrimitiveWithInfer):
|
|
|
499
532
|
def __init__(self, hook_fn, cell_id=""):
|
|
500
533
|
"""Initialize HookBackward."""
|
|
501
534
|
super(HookBackward, self).__init__(self.__class__.__name__)
|
|
502
|
-
if not
|
|
503
|
-
|
|
504
|
-
f"but got {type(hook_fn)}.")
|
|
535
|
+
if not check_hook_fn("HookBackward", hook_fn):
|
|
536
|
+
return
|
|
505
537
|
if cell_id != "":
|
|
506
538
|
logger.warning(f"The args 'cell_id' of HookBackward will be removed in a future version. If the value of "
|
|
507
539
|
f"'cell_id' is set, the hook function will not work.")
|
|
508
540
|
self.add_prim_attr("cell_id", cell_id)
|
|
509
541
|
self.init_attrs["cell_id"] = cell_id
|
|
510
542
|
self.cell_id = cell_id
|
|
511
|
-
self.
|
|
543
|
+
self.set_hook_fn(hook_fn, HookType.HookBackward)
|
|
512
544
|
|
|
513
545
|
def infer_shape(self, *inputs_shape):
|
|
514
546
|
if len(inputs_shape) == 1:
|
|
@@ -23,6 +23,7 @@ from mindspore.common import dtype as mstype
|
|
|
23
23
|
from mindspore import _checkparam as validator
|
|
24
24
|
from mindspore.common._decorator import deprecated
|
|
25
25
|
from mindspore.ops.primitive import prim_attr_register, Primitive
|
|
26
|
+
from mindspore import log as logger
|
|
26
27
|
|
|
27
28
|
|
|
28
29
|
class ScalarCast(Primitive):
|
|
@@ -59,3 +60,14 @@ class ScalarCast(Primitive):
|
|
|
59
60
|
value = np.cast[np_dtype.lower()](input_x)
|
|
60
61
|
value = value.item()
|
|
61
62
|
return value
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class TensorReport(Primitive):
|
|
66
|
+
@prim_attr_register
|
|
67
|
+
def __init__(self):
|
|
68
|
+
"""Initialize TensorReport"""
|
|
69
|
+
self.add_prim_attr("side_effect_io", True)
|
|
70
|
+
self.add_prim_attr("channel_name", "ms_tensor_report")
|
|
71
|
+
|
|
72
|
+
def __call__(self, file, input_x):
|
|
73
|
+
logger.warning("TensorReport doesn't support pynative mode.")
|
|
@@ -18,6 +18,7 @@ from __future__ import division
|
|
|
18
18
|
|
|
19
19
|
import numbers
|
|
20
20
|
import math
|
|
21
|
+
import types
|
|
21
22
|
import numpy as np
|
|
22
23
|
from mindspore.ops import signature as sig
|
|
23
24
|
from mindspore.ops.primitive import Primitive, prim_attr_register, prim_arg_register, PrimitiveWithInfer
|
|
@@ -937,6 +938,10 @@ class Tile(Primitive):
|
|
|
937
938
|
|
|
938
939
|
Refer to :func:`mindspore.ops.tile` for more details.
|
|
939
940
|
|
|
941
|
+
Note:
|
|
942
|
+
On Ascend, the number of `dims` should not exceed 8, and currently does not support scenarios
|
|
943
|
+
where more than 4 dimensions are repeated simultaneously.
|
|
944
|
+
|
|
940
945
|
Inputs:
|
|
941
946
|
- **input** (Tensor) - The tensor whose elements need to be repeated. Set the shape of input tensor as
|
|
942
947
|
:math:`(x_1, x_2, ..., x_S)` .
|
|
@@ -1025,6 +1030,10 @@ def tile(input, dims):
|
|
|
1025
1030
|
output tensor has `input.shape[i] * dims[i]` elements, and the values of `input`
|
|
1026
1031
|
are replicated `dims[i]` times along the i'th dimension.
|
|
1027
1032
|
|
|
1033
|
+
Note:
|
|
1034
|
+
On Ascend, the number of `dims` should not exceed 8, and currently does not support scenarios
|
|
1035
|
+
where more than 4 dimensions are repeated simultaneously.
|
|
1036
|
+
|
|
1028
1037
|
Args:
|
|
1029
1038
|
input (Tensor): The tensor whose elements need to be repeated. Set the shape of input tensor as
|
|
1030
1039
|
:math:`(x_1, x_2, ..., x_S)` .
|
|
@@ -1127,16 +1136,16 @@ class Cast(Primitive):
|
|
|
1127
1136
|
taken into account. As long as the real part is non-zero, it returns True; otherwise, it returns False.
|
|
1128
1137
|
|
|
1129
1138
|
Inputs:
|
|
1130
|
-
- **
|
|
1139
|
+
- **input** (Union[Tensor, Number]) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
|
|
1131
1140
|
The tensor to be cast.
|
|
1132
|
-
- **
|
|
1141
|
+
- **dtype** (dtype.Number) - The valid data type of the output tensor. Only constant value is allowed.
|
|
1133
1142
|
|
|
1134
1143
|
Outputs:
|
|
1135
|
-
Tensor, the shape of tensor is the same as `
|
|
1144
|
+
Tensor, the shape of tensor is the same as `input`, :math:`(x_1, x_2, ..., x_R)`.
|
|
1136
1145
|
|
|
1137
1146
|
Raises:
|
|
1138
|
-
TypeError: If `
|
|
1139
|
-
TypeError: If `
|
|
1147
|
+
TypeError: If `input` is neither Tensor nor Number.
|
|
1148
|
+
TypeError: If `dtype` is not a Number.
|
|
1140
1149
|
|
|
1141
1150
|
Supported Platforms:
|
|
1142
1151
|
``Ascend`` ``GPU`` ``CPU``
|
|
@@ -1146,10 +1155,10 @@ class Cast(Primitive):
|
|
|
1146
1155
|
>>> import numpy as np
|
|
1147
1156
|
>>> from mindspore import Tensor, ops
|
|
1148
1157
|
>>> input_np = np.random.randn(2, 3, 4, 5).astype(np.float32)
|
|
1149
|
-
>>>
|
|
1150
|
-
>>>
|
|
1158
|
+
>>> input = Tensor(input_np)
|
|
1159
|
+
>>> dtype = mindspore.int32
|
|
1151
1160
|
>>> cast = ops.Cast()
|
|
1152
|
-
>>> output = cast(
|
|
1161
|
+
>>> output = cast(input, dtype)
|
|
1153
1162
|
>>> print(output.dtype)
|
|
1154
1163
|
Int32
|
|
1155
1164
|
>>> print(output.shape)
|
|
@@ -1162,17 +1171,15 @@ class Cast(Primitive):
|
|
|
1162
1171
|
self.init_prim_io_names(inputs=['x', 'dst_type'], outputs=['output'])
|
|
1163
1172
|
|
|
1164
1173
|
def check_elim(self, x, dtype):
|
|
1165
|
-
if isinstance(x,
|
|
1166
|
-
|
|
1167
|
-
|
|
1168
|
-
if data.dtype == dtype:
|
|
1169
|
-
return (True, x)
|
|
1170
|
-
if isinstance(x, Tensor) and x.dtype == dtype:
|
|
1171
|
-
x = Tensor(x)
|
|
1172
|
-
x.set_cast_dtype()
|
|
1174
|
+
if isinstance(x, Parameter):
|
|
1175
|
+
data = x.data
|
|
1176
|
+
if data.dtype == dtype:
|
|
1173
1177
|
return (True, x)
|
|
1174
|
-
|
|
1175
|
-
|
|
1178
|
+
if isinstance(x, Tensor) and x.dtype == dtype:
|
|
1179
|
+
x.set_cast_dtype()
|
|
1180
|
+
return (True, x)
|
|
1181
|
+
if isinstance(x, numbers.Number):
|
|
1182
|
+
return (True, Tensor(x, dtype=dtype))
|
|
1176
1183
|
return (False, None)
|
|
1177
1184
|
|
|
1178
1185
|
def __call__(self, input_x, dtype):
|
|
@@ -1187,7 +1194,7 @@ def to_sequence(val):
|
|
|
1187
1194
|
to_sequence
|
|
1188
1195
|
"""
|
|
1189
1196
|
if isinstance(val, (tuple, list)):
|
|
1190
|
-
return val
|
|
1197
|
+
return tuple(val)
|
|
1191
1198
|
return (val,)
|
|
1192
1199
|
|
|
1193
1200
|
|
|
@@ -1891,7 +1898,7 @@ def flash_attention_score(query, key, value, head_num, real_shift=None, drop_mas
|
|
|
1891
1898
|
H2 -- Hidden size of key and value, which equals to N2 * D.
|
|
1892
1899
|
|
|
1893
1900
|
.. warning::
|
|
1894
|
-
This is an experimental API that is subject to change or deletion. Only support on Atlas training series.
|
|
1901
|
+
This is an experimental API that is subject to change or deletion. Only support on Atlas A2 training series.
|
|
1895
1902
|
|
|
1896
1903
|
Args:
|
|
1897
1904
|
query (Tensor[float16, bfloat16]): The query tensor. Input tensor of shape :math:`(B, S1, H1)`,
|
|
@@ -2014,3 +2021,249 @@ def flash_attention_score(query, key, value, head_num, real_shift=None, drop_mas
|
|
|
2014
2021
|
inner_precise, input_layout, sparse_mode)
|
|
2015
2022
|
return rank_op(query, key, value, real_shift, drop_mask, padding_mask, attn_mask, prefix, actual_seq_qlen,
|
|
2016
2023
|
actual_seq_kvlen)[3]
|
|
2024
|
+
|
|
2025
|
+
|
|
2026
|
+
class WhileLoop(Primitive):
|
|
2027
|
+
"""
|
|
2028
|
+
Provide a useful op for reducing compilation times of while loop.
|
|
2029
|
+
The execution logic of the WhileLoop operator can be roughly represented by the following code:
|
|
2030
|
+
|
|
2031
|
+
.. code-block:: python
|
|
2032
|
+
|
|
2033
|
+
def WhileLoop(cond_func, loop_func, init_val):
|
|
2034
|
+
while(cond_func(init_val)):
|
|
2035
|
+
init_val = loop_func(init_val)
|
|
2036
|
+
return init_val
|
|
2037
|
+
|
|
2038
|
+
The current WhileLoop operator has the following syntactic limitations:
|
|
2039
|
+
|
|
2040
|
+
- Using a side-effect function as `loop_func` is currently not support,
|
|
2041
|
+
such as operations that modify parameters, global variables, etc.
|
|
2042
|
+
- The return value of `loop_func` being of a different type or shape
|
|
2043
|
+
from the `init_val` is currently not support.
|
|
2044
|
+
|
|
2045
|
+
.. warning::
|
|
2046
|
+
This is an experimental API that is subject to change or deletion.
|
|
2047
|
+
|
|
2048
|
+
Inputs:
|
|
2049
|
+
- **cond_func** (Function) - The condition function.
|
|
2050
|
+
- **loop_func** (Function) - The loop function, take one argument and
|
|
2051
|
+
return value has the same type with input argument.
|
|
2052
|
+
- **init_val** (Union[Tensor, number, str, bool, list, tuple, dict]) - The initial value.
|
|
2053
|
+
|
|
2054
|
+
Outputs:
|
|
2055
|
+
Union[Tensor, number, str, bool, list, tuple, dict], the final result of the while loop,
|
|
2056
|
+
has same type and shape with input `init_val` .
|
|
2057
|
+
|
|
2058
|
+
Raises:
|
|
2059
|
+
TypeError: If `cond_func` is not a function.
|
|
2060
|
+
TypeError: If `loop_func` is not a function.
|
|
2061
|
+
ValueError: If `loop_func` cannot take `init_val` as input or has different
|
|
2062
|
+
output type or shape with `init_val` .
|
|
2063
|
+
|
|
2064
|
+
Supported Platforms:
|
|
2065
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
2066
|
+
|
|
2067
|
+
Examples:
|
|
2068
|
+
>>> from mindspore import ops
|
|
2069
|
+
>>> def loop_while_fun(init_val):
|
|
2070
|
+
... val = init_val
|
|
2071
|
+
... val = val + 1
|
|
2072
|
+
... return val
|
|
2073
|
+
...
|
|
2074
|
+
>>> init_state = 10
|
|
2075
|
+
>>> while_loop = ops.WhileLoop()
|
|
2076
|
+
>>> result = while_loop(lambda x : x < 100, loop_while_fun, init_state)
|
|
2077
|
+
>>> print(result)
|
|
2078
|
+
100
|
|
2079
|
+
"""
|
|
2080
|
+
|
|
2081
|
+
@prim_attr_register
|
|
2082
|
+
def __init__(self):
|
|
2083
|
+
"""Initialize WhileLoop."""
|
|
2084
|
+
|
|
2085
|
+
def __call__(self, cond_func, loop_func, init_val):
|
|
2086
|
+
validator.check_value_type("cond_func", cond_func,
|
|
2087
|
+
[types.FunctionType, types.MethodType], "WhileLoop")
|
|
2088
|
+
validator.check_value_type("loop_func", loop_func,
|
|
2089
|
+
[types.FunctionType, types.MethodType], "WhileLoop")
|
|
2090
|
+
val = init_val
|
|
2091
|
+
try:
|
|
2092
|
+
while cond_func(val):
|
|
2093
|
+
val = loop_func(val)
|
|
2094
|
+
except Exception as e:
|
|
2095
|
+
raise ValueError("Invalid loop_func, please check input arguments and \
|
|
2096
|
+
return value, error info: {}".format(e))
|
|
2097
|
+
return val
|
|
2098
|
+
|
|
2099
|
+
|
|
2100
|
+
class Scan(Primitive):
|
|
2101
|
+
"""
|
|
2102
|
+
Scan a function over an array while the processing of the current element
|
|
2103
|
+
depends on the execution result of the previous element.
|
|
2104
|
+
The execution logic of the Scan operator can be roughly represented by the following code:
|
|
2105
|
+
|
|
2106
|
+
.. code-block:: python
|
|
2107
|
+
|
|
2108
|
+
def Scan(loop_func, init, xs, length=None):
|
|
2109
|
+
if xs is None:
|
|
2110
|
+
xs = [None] * length
|
|
2111
|
+
carry = init
|
|
2112
|
+
ys = []
|
|
2113
|
+
for x in xs:
|
|
2114
|
+
carry, y = loop_func(carry, x)
|
|
2115
|
+
ys.append(y)
|
|
2116
|
+
return carry, ys
|
|
2117
|
+
|
|
2118
|
+
The current Scan operator has the following syntactic limitations:
|
|
2119
|
+
|
|
2120
|
+
- Using a side-effect function as `loop_func` is currently not support,
|
|
2121
|
+
such as operations that modify parameters, global variables, etc.
|
|
2122
|
+
- The first element of the return value of `loop_func` being of a different
|
|
2123
|
+
type or shape from the `init_val` is currently not support.
|
|
2124
|
+
|
|
2125
|
+
.. warning::
|
|
2126
|
+
This is an experimental API that is subject to change or deletion.
|
|
2127
|
+
|
|
2128
|
+
Inputs:
|
|
2129
|
+
- **loop_func** (Function) - The loop function.
|
|
2130
|
+
- **init** (Union[Tensor, number, str, bool, list, tuple, dict]) - An initial loop carry value.
|
|
2131
|
+
- **xs** (Union[tuple, list, None]) - The value over which to scan.
|
|
2132
|
+
- **length** (Union[int, None], optional) - The size of xs. Default: ``None`` .
|
|
2133
|
+
- **unroll** (bool, optional) - The flag for whether to perform loop unrolling in compile process.
|
|
2134
|
+
Default: ``True`` .
|
|
2135
|
+
|
|
2136
|
+
Outputs:
|
|
2137
|
+
Tuple(Union[Tensor, number, str, bool, list, tuple, dict], list). Output of scan loop,
|
|
2138
|
+
a tuple with two elements, the first element has same type and shape with init argument,
|
|
2139
|
+
and the second is a list containing the results of each loop.
|
|
2140
|
+
|
|
2141
|
+
Raises:
|
|
2142
|
+
TypeError: If `loop_func` is not a function.
|
|
2143
|
+
TypeError: If `xs` is not a tuple, a list or None.
|
|
2144
|
+
TypeError: If `length` is not an int or None.
|
|
2145
|
+
TypeError: If `unroll` is not a bool.
|
|
2146
|
+
ValueError: If `loop_func` cannot take `init` and element of `xs` as inputs.
|
|
2147
|
+
ValueError: If the return value of `loop_func` is not a tuple with two elements,
|
|
2148
|
+
or the first element of the tuple has different type or shape from `init` .
|
|
2149
|
+
|
|
2150
|
+
Supported Platforms:
|
|
2151
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
2152
|
+
|
|
2153
|
+
Examples:
|
|
2154
|
+
>>> from mindspore import ops
|
|
2155
|
+
>>> def cumsum(res, el):
|
|
2156
|
+
... res = res + el
|
|
2157
|
+
... return res, res
|
|
2158
|
+
...
|
|
2159
|
+
>>> a = [1, 2, 3, 4]
|
|
2160
|
+
>>> result_init = 0
|
|
2161
|
+
>>> scan_op = ops.Scan()
|
|
2162
|
+
>>> result = scan_op(cumsum, result_init, a)
|
|
2163
|
+
>>> print(result == (10, [1, 3, 6, 10]))
|
|
2164
|
+
True
|
|
2165
|
+
"""
|
|
2166
|
+
|
|
2167
|
+
@prim_attr_register
|
|
2168
|
+
def __init__(self):
|
|
2169
|
+
"""Initialize Scan."""
|
|
2170
|
+
|
|
2171
|
+
def __call__(self, loop_func, init, xs, length=None, unroll=True):
|
|
2172
|
+
validator.check_value_type("loop_func", loop_func,
|
|
2173
|
+
[types.FunctionType, types.MethodType], "Scan")
|
|
2174
|
+
validator.check_value_type("xs", xs, [list, tuple, None], "Scan")
|
|
2175
|
+
if xs is None:
|
|
2176
|
+
validator.check_value_type("length", length, [int], "Scan")
|
|
2177
|
+
xs = [None] * length
|
|
2178
|
+
carry = init
|
|
2179
|
+
length = len(xs)
|
|
2180
|
+
if not length:
|
|
2181
|
+
return init, []
|
|
2182
|
+
try:
|
|
2183
|
+
carry, y = loop_func(carry, xs[0])
|
|
2184
|
+
ys = [y]
|
|
2185
|
+
i = 1
|
|
2186
|
+
while i < length:
|
|
2187
|
+
carry, y = loop_func(carry, xs[i])
|
|
2188
|
+
ys.append(y)
|
|
2189
|
+
i = i + 1
|
|
2190
|
+
except Exception as e:
|
|
2191
|
+
raise ValueError("Invalid loop_func, please check input arguments and \
|
|
2192
|
+
return value, error info: {}".format(e))
|
|
2193
|
+
return carry, ys
|
|
2194
|
+
|
|
2195
|
+
|
|
2196
|
+
class ForiLoop(Primitive):
|
|
2197
|
+
"""
|
|
2198
|
+
Provide a useful op for loop from lower to upper.
|
|
2199
|
+
The execution logic of the ForiLoop operator can be roughly represented by the following code:
|
|
2200
|
+
|
|
2201
|
+
.. code-block:: python
|
|
2202
|
+
|
|
2203
|
+
def ForiLoop(lower, upper, loop_func, init_val):
|
|
2204
|
+
for i in range(lower, upper):
|
|
2205
|
+
init_val = loop_func(i, init_val)
|
|
2206
|
+
return init_val
|
|
2207
|
+
|
|
2208
|
+
The current ForiLoop operator has the following syntactic limitations:
|
|
2209
|
+
|
|
2210
|
+
- Using a side-effect function as `loop_func` is currently not support,
|
|
2211
|
+
such as operations that modify parameters, global variables, etc.
|
|
2212
|
+
- The return value of `loop_func` being of a different type or shape
|
|
2213
|
+
from the `init_val` is currently not support.
|
|
2214
|
+
- Negative numbers or custom increments is currently not support.
|
|
2215
|
+
|
|
2216
|
+
.. warning::
|
|
2217
|
+
This is an experimental API that is subject to change or deletion.
|
|
2218
|
+
|
|
2219
|
+
Inputs:
|
|
2220
|
+
- **lower** (Union[int, Tensor]) - The start index of loop.
|
|
2221
|
+
- **upper** (Union[int, Tensor]) - The end index of loop.
|
|
2222
|
+
- **loop_func** (Function) - The loop function, takes two arguments.
|
|
2223
|
+
- **init_val** (Union[Tensor, number, str, bool, list, tuple, dict]) - The init value.
|
|
2224
|
+
- **unroll** (bool, optional) - The flag for whether unroll in compile process,
|
|
2225
|
+
only valid when the number of loop iterations is determined. Default: ``True`` .
|
|
2226
|
+
|
|
2227
|
+
Outputs:
|
|
2228
|
+
Union[Tensor, number, str, bool, list, tuple, dict], the final result of the loop,
|
|
2229
|
+
has same type and shape with input `init_val` .
|
|
2230
|
+
|
|
2231
|
+
Raises:
|
|
2232
|
+
TypeError: If `lower` is not an int or a Tensor.
|
|
2233
|
+
TypeError: If `upper` is not an int or a Tensor.
|
|
2234
|
+
TypeError: If `loop_func` is not a function.
|
|
2235
|
+
ValueError: If `loop_func` cannot take index and `init_val` as arguments or if the type
|
|
2236
|
+
of output it produces is different from the type or shape of `init_val` .
|
|
2237
|
+
|
|
2238
|
+
Supported Platforms:
|
|
2239
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
2240
|
+
|
|
2241
|
+
Examples:
|
|
2242
|
+
>>> from mindspore import ops
|
|
2243
|
+
>>> def cumsum(index, res):
|
|
2244
|
+
... return index + res
|
|
2245
|
+
...
|
|
2246
|
+
>>> result_init = 0
|
|
2247
|
+
>>> fori_loop = ops.ForiLoop()
|
|
2248
|
+
>>> result = fori_loop(0, 4, cumsum, result_init)
|
|
2249
|
+
>>> print(result)
|
|
2250
|
+
6
|
|
2251
|
+
"""
|
|
2252
|
+
|
|
2253
|
+
@prim_attr_register
|
|
2254
|
+
def __init__(self):
|
|
2255
|
+
"""Initialize ForiLoop."""
|
|
2256
|
+
|
|
2257
|
+
def __call__(self, lower, upper, loop_func, init_val, unroll=True):
|
|
2258
|
+
validator.check_value_type("lower", lower, [int, Tensor], "ForiLoop")
|
|
2259
|
+
validator.check_value_type("upper", upper, [int, Tensor], "ForiLoop")
|
|
2260
|
+
validator.check_value_type("loop_func", loop_func,
|
|
2261
|
+
[types.FunctionType, types.MethodType], "ForiLoop")
|
|
2262
|
+
val = init_val
|
|
2263
|
+
try:
|
|
2264
|
+
for i in range(lower, upper):
|
|
2265
|
+
val = loop_func(i, val)
|
|
2266
|
+
except Exception as e:
|
|
2267
|
+
raise ValueError("Invalid loop_func, please check input arguments and \
|
|
2268
|
+
return value, error info: {}".format(e))
|
|
2269
|
+
return val
|