mindspore 2.3.0__cp39-cp39-win_amd64.whl → 2.4.1__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +3 -1
- mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +50 -9
- mindspore/_extends/parse/compile_config.py +41 -0
- mindspore/_extends/parse/parser.py +9 -7
- mindspore/_extends/parse/standard_method.py +52 -14
- mindspore/_extends/pijit/pijit_func_white_list.py +350 -24
- mindspore/amp.py +24 -10
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/common/__init__.py +6 -4
- mindspore/common/_pijit_context.py +190 -0
- mindspore/common/_register_for_tensor.py +2 -1
- mindspore/common/_tensor_overload.py +139 -0
- mindspore/common/api.py +102 -87
- mindspore/common/dump.py +5 -6
- mindspore/common/generator.py +1 -7
- mindspore/common/hook_handle.py +14 -26
- mindspore/common/initializer.py +51 -15
- mindspore/common/mindir_util.py +2 -2
- mindspore/common/parameter.py +62 -15
- mindspore/common/recompute.py +39 -9
- mindspore/common/sparse_tensor.py +7 -3
- mindspore/common/tensor.py +183 -37
- mindspore/communication/__init__.py +1 -1
- mindspore/communication/_comm_helper.py +38 -3
- mindspore/communication/comm_func.py +315 -60
- mindspore/communication/management.py +14 -14
- mindspore/context.py +132 -22
- mindspore/dataset/__init__.py +1 -1
- mindspore/dataset/audio/__init__.py +1 -1
- mindspore/dataset/core/config.py +7 -0
- mindspore/dataset/core/validator_helpers.py +7 -0
- mindspore/dataset/engine/cache_client.py +1 -1
- mindspore/dataset/engine/datasets.py +72 -44
- mindspore/dataset/engine/datasets_audio.py +7 -7
- mindspore/dataset/engine/datasets_standard_format.py +53 -3
- mindspore/dataset/engine/datasets_text.py +20 -20
- mindspore/dataset/engine/datasets_user_defined.py +174 -104
- mindspore/dataset/engine/datasets_vision.py +33 -33
- mindspore/dataset/engine/iterators.py +29 -0
- mindspore/dataset/engine/obs/util.py +7 -0
- mindspore/dataset/engine/queue.py +114 -60
- mindspore/dataset/engine/serializer_deserializer.py +2 -2
- mindspore/dataset/engine/validators.py +34 -14
- mindspore/dataset/text/__init__.py +1 -4
- mindspore/dataset/transforms/__init__.py +0 -3
- mindspore/dataset/utils/line_reader.py +2 -0
- mindspore/dataset/vision/__init__.py +1 -4
- mindspore/dataset/vision/utils.py +1 -1
- mindspore/dataset/vision/validators.py +2 -1
- mindspore/dnnl.dll +0 -0
- mindspore/{nn/extend → experimental/es}/__init__.py +4 -11
- mindspore/experimental/es/embedding_service.py +883 -0
- mindspore/{nn/layer → experimental/es}/embedding_service_layer.py +218 -30
- mindspore/experimental/llm_boost/__init__.py +21 -0
- mindspore/{nn/extend/layer → experimental/llm_boost/atb}/__init__.py +4 -8
- mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
- mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
- mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
- mindspore/experimental/llm_boost/register.py +129 -0
- mindspore/experimental/llm_boost/utils.py +31 -0
- mindspore/experimental/optim/adamw.py +85 -0
- mindspore/experimental/optim/optimizer.py +3 -0
- mindspore/hal/__init__.py +3 -3
- mindspore/hal/contiguous_tensors_handle.py +175 -0
- mindspore/hal/stream.py +18 -0
- mindspore/include/api/model_group.h +13 -1
- mindspore/include/api/types.h +10 -10
- mindspore/include/dataset/config.h +2 -2
- mindspore/include/dataset/constants.h +2 -2
- mindspore/include/dataset/execute.h +2 -2
- mindspore/include/dataset/vision.h +4 -0
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +1 -1
- mindspore/mindrecord/filewriter.py +68 -51
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/mint/__init__.py +983 -46
- mindspore/mint/distributed/__init__.py +31 -0
- mindspore/mint/distributed/distributed.py +254 -0
- mindspore/mint/nn/__init__.py +268 -23
- mindspore/mint/nn/functional.py +125 -19
- mindspore/mint/nn/layer/__init__.py +39 -0
- mindspore/mint/nn/layer/activation.py +133 -0
- mindspore/mint/nn/layer/normalization.py +477 -0
- mindspore/mint/nn/layer/pooling.py +110 -0
- mindspore/mint/optim/adamw.py +26 -13
- mindspore/mint/special/__init__.py +63 -0
- mindspore/multiprocessing/__init__.py +2 -1
- mindspore/nn/__init__.py +0 -1
- mindspore/nn/cell.py +276 -96
- mindspore/nn/layer/activation.py +211 -44
- mindspore/nn/layer/basic.py +137 -10
- mindspore/nn/layer/embedding.py +137 -2
- mindspore/nn/layer/normalization.py +101 -5
- mindspore/nn/layer/padding.py +34 -48
- mindspore/nn/layer/pooling.py +161 -7
- mindspore/nn/layer/transformer.py +3 -3
- mindspore/nn/loss/__init__.py +2 -2
- mindspore/nn/loss/loss.py +84 -6
- mindspore/nn/optim/__init__.py +2 -1
- mindspore/nn/optim/adadelta.py +1 -1
- mindspore/nn/optim/adam.py +1 -1
- mindspore/nn/optim/lamb.py +1 -1
- mindspore/nn/optim/tft_wrapper.py +124 -0
- mindspore/nn/wrap/cell_wrapper.py +12 -23
- mindspore/nn/wrap/grad_reducer.py +5 -5
- mindspore/nn/wrap/loss_scale.py +17 -3
- mindspore/numpy/__init__.py +1 -1
- mindspore/numpy/array_creations.py +65 -68
- mindspore/numpy/array_ops.py +64 -60
- mindspore/numpy/fft.py +610 -75
- mindspore/numpy/logic_ops.py +11 -10
- mindspore/numpy/math_ops.py +85 -84
- mindspore/numpy/utils_const.py +4 -4
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +6 -4
- mindspore/ops/_grad_experimental/grad_array_ops.py +0 -11
- mindspore/ops/_grad_experimental/grad_comm_ops.py +67 -4
- mindspore/ops/_grad_experimental/grad_math_ops.py +0 -22
- mindspore/ops/_vmap/vmap_array_ops.py +2 -4
- mindspore/ops/_vmap/vmap_math_ops.py +17 -1
- mindspore/ops/_vmap/vmap_nn_ops.py +43 -2
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +91 -7
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +2 -0
- mindspore/ops/auto_generate/gen_extend_func.py +767 -13
- mindspore/ops/auto_generate/gen_ops_def.py +2452 -364
- mindspore/ops/auto_generate/gen_ops_prim.py +5442 -1756
- mindspore/ops/auto_generate/pyboost_inner_prim.py +176 -56
- mindspore/ops/composite/base.py +85 -48
- mindspore/ops/composite/multitype_ops/_compile_utils.py +1 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -2
- mindspore/ops/function/__init__.py +22 -0
- mindspore/ops/function/array_func.py +492 -153
- mindspore/ops/function/debug_func.py +113 -1
- mindspore/ops/function/fft_func.py +15 -2
- mindspore/ops/function/grad/grad_func.py +3 -2
- mindspore/ops/function/math_func.py +564 -207
- mindspore/ops/function/nn_func.py +817 -383
- mindspore/ops/function/other_func.py +3 -2
- mindspore/ops/function/random_func.py +402 -12
- mindspore/ops/function/reshard_func.py +13 -11
- mindspore/ops/function/sparse_unary_func.py +1 -1
- mindspore/ops/function/vmap_func.py +3 -2
- mindspore/ops/functional.py +24 -14
- mindspore/ops/op_info_register.py +3 -3
- mindspore/ops/operations/__init__.py +7 -2
- mindspore/ops/operations/_grad_ops.py +2 -76
- mindspore/ops/operations/_infer_ops.py +1 -1
- mindspore/ops/operations/_inner_ops.py +71 -94
- mindspore/ops/operations/array_ops.py +14 -146
- mindspore/ops/operations/comm_ops.py +63 -53
- mindspore/ops/operations/custom_ops.py +83 -19
- mindspore/ops/operations/debug_ops.py +42 -10
- mindspore/ops/operations/manually_defined/_inner.py +12 -0
- mindspore/ops/operations/manually_defined/ops_def.py +273 -20
- mindspore/ops/operations/math_ops.py +12 -223
- mindspore/ops/operations/nn_ops.py +20 -114
- mindspore/ops/operations/other_ops.py +7 -4
- mindspore/ops/operations/random_ops.py +46 -1
- mindspore/ops/primitive.py +18 -6
- mindspore/ops_generate/arg_dtype_cast.py +2 -0
- mindspore/ops_generate/gen_aclnn_implement.py +11 -11
- mindspore/ops_generate/gen_constants.py +36 -0
- mindspore/ops_generate/gen_ops.py +67 -52
- mindspore/ops_generate/gen_ops_inner_prim.py +1 -1
- mindspore/ops_generate/gen_pyboost_func.py +131 -47
- mindspore/ops_generate/op_proto.py +10 -3
- mindspore/ops_generate/pyboost_utils.py +14 -1
- mindspore/ops_generate/template.py +43 -21
- mindspore/parallel/__init__.py +3 -1
- mindspore/parallel/_auto_parallel_context.py +31 -9
- mindspore/parallel/_cell_wrapper.py +85 -0
- mindspore/parallel/_parallel_serialization.py +47 -19
- mindspore/parallel/_tensor.py +127 -13
- mindspore/parallel/_utils.py +53 -22
- mindspore/parallel/algo_parameter_config.py +5 -5
- mindspore/parallel/checkpoint_transform.py +46 -39
- mindspore/parallel/cluster/process_entity/__init__.py +1 -1
- mindspore/parallel/cluster/process_entity/_api.py +31 -23
- mindspore/parallel/cluster/process_entity/_utils.py +2 -27
- mindspore/parallel/parameter_broadcast.py +3 -4
- mindspore/parallel/shard.py +162 -31
- mindspore/parallel/transform_safetensors.py +1146 -0
- mindspore/profiler/__init__.py +2 -1
- mindspore/profiler/common/constant.py +29 -0
- mindspore/profiler/common/registry.py +47 -0
- mindspore/profiler/common/util.py +28 -0
- mindspore/profiler/dynamic_profiler.py +694 -0
- mindspore/profiler/envprofiling.py +17 -19
- mindspore/profiler/parser/ascend_analysis/constant.py +18 -0
- mindspore/profiler/parser/ascend_analysis/file_manager.py +25 -4
- mindspore/profiler/parser/ascend_analysis/function_event.py +43 -19
- mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +31 -26
- mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +56 -10
- mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +55 -8
- mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
- mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +27 -20
- mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +9 -2
- mindspore/profiler/parser/ascend_msprof_exporter.py +5 -4
- mindspore/profiler/parser/ascend_timeline_generator.py +27 -25
- mindspore/profiler/parser/base_timeline_generator.py +19 -25
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +25 -12
- mindspore/profiler/parser/framework_parser.py +1 -391
- mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
- mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
- mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
- mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
- mindspore/profiler/parser/memory_usage_parser.py +0 -154
- mindspore/profiler/parser/profiler_info.py +78 -6
- mindspore/profiler/profiler.py +153 -0
- mindspore/profiler/profiling.py +285 -413
- mindspore/rewrite/__init__.py +1 -2
- mindspore/rewrite/common/namespace.py +4 -4
- mindspore/rewrite/symbol_tree/symbol_tree.py +3 -3
- mindspore/run_check/_check_version.py +39 -104
- mindspore/safeguard/rewrite_obfuscation.py +591 -247
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +4 -3
- mindspore/train/_utils.py +105 -19
- mindspore/train/amp.py +171 -53
- mindspore/train/callback/__init__.py +2 -2
- mindspore/train/callback/_callback.py +4 -4
- mindspore/train/callback/_checkpoint.py +97 -31
- mindspore/train/callback/_cluster_monitor.py +1 -1
- mindspore/train/callback/_flops_collector.py +1 -0
- mindspore/train/callback/_loss_monitor.py +3 -3
- mindspore/train/callback/_on_request_exit.py +145 -31
- mindspore/train/callback/_summary_collector.py +5 -5
- mindspore/train/callback/_tft_register.py +375 -0
- mindspore/train/dataset_helper.py +15 -3
- mindspore/train/metrics/metric.py +3 -3
- mindspore/train/metrics/roc.py +4 -4
- mindspore/train/mind_ir_pb2.py +44 -39
- mindspore/train/model.py +154 -58
- mindspore/train/serialization.py +342 -128
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +21 -0
- mindspore/utils/utils.py +60 -0
- mindspore/version.py +1 -1
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/METADATA +13 -7
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/RECORD +260 -254
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/WHEEL +1 -1
- mindspore/include/c_api/ms/abstract.h +0 -67
- mindspore/include/c_api/ms/attribute.h +0 -197
- mindspore/include/c_api/ms/base/handle_types.h +0 -43
- mindspore/include/c_api/ms/base/macros.h +0 -32
- mindspore/include/c_api/ms/base/status.h +0 -33
- mindspore/include/c_api/ms/base/types.h +0 -283
- mindspore/include/c_api/ms/context.h +0 -102
- mindspore/include/c_api/ms/graph.h +0 -160
- mindspore/include/c_api/ms/node.h +0 -606
- mindspore/include/c_api/ms/tensor.h +0 -161
- mindspore/include/c_api/ms/value.h +0 -84
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/nn/extend/basic.py +0 -140
- mindspore/nn/extend/embedding.py +0 -143
- mindspore/nn/extend/layer/normalization.py +0 -109
- mindspore/nn/extend/pooling.py +0 -117
- mindspore/nn/layer/embedding_service.py +0 -531
- mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +0 -93
- mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +0 -66
- mindspore/ops/extend/__init__.py +0 -53
- mindspore/ops/extend/array_func.py +0 -218
- mindspore/ops/extend/math_func.py +0 -76
- mindspore/ops/extend/nn_func.py +0 -308
- mindspore/ops/silent_check.py +0 -162
- mindspore/profiler/parser/msadvisor_analyzer.py +0 -82
- mindspore/profiler/parser/msadvisor_parser.py +0 -240
- mindspore/train/callback/_mindio_ttp.py +0 -443
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/entry_points.txt +0 -0
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/top_level.txt +0 -0
|
@@ -27,8 +27,9 @@ def vmap(fn, in_axes=0, out_axes=0):
|
|
|
27
27
|
Vmap is pioneered by Jax and it removes the restriction of batch dimension on the operator, and provides a
|
|
28
28
|
more convenient and unified operator expression. Moreover, it allows users to composite with other functional
|
|
29
29
|
modules such as :func:`mindspore.grad`, to improve the development efficiency, please refer to the
|
|
30
|
-
`Automatic Vectorization (Vmap) <https://www.mindspore.cn/
|
|
31
|
-
for more detail.
|
|
30
|
+
`Automatic Vectorization (Vmap) <https://www.mindspore.cn/docs/en/master/model_train/train_process/optimize/vmap.html>`_
|
|
31
|
+
tutorial for more detail.
|
|
32
|
+
In addition, the vectorizing map does not execute loops outside the function, but sinks loops
|
|
32
33
|
into the primitive operations of the function for better performance. When combined with `Graph Kernel Fusion`,
|
|
33
34
|
operational efficiency would be further improved.
|
|
34
35
|
|
mindspore/ops/functional.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
# This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
|
|
2
2
|
#
|
|
3
|
-
# Copyright 2021-
|
|
3
|
+
# Copyright 2021-2024 Huawei Technologies Co., Ltd
|
|
4
4
|
#
|
|
5
5
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
6
|
# you may not use this file except in compliance with the License.
|
|
@@ -19,7 +19,7 @@
|
|
|
19
19
|
from mindspore.common._register_for_tensor import tensor_operator_registry
|
|
20
20
|
from mindspore.ops import _constants
|
|
21
21
|
from mindspore.ops.function import *
|
|
22
|
-
from mindspore.ops.function.array_func import
|
|
22
|
+
from mindspore.ops.function.array_func import chunk_ext, zero_
|
|
23
23
|
from mindspore.ops.function.math_func import all, argmax_ext
|
|
24
24
|
from mindspore.ops.function.random_func import uniform_ext
|
|
25
25
|
from mindspore.ops import operations as P
|
|
@@ -34,14 +34,15 @@ from mindspore.ops.operations.math_ops import Roll
|
|
|
34
34
|
from mindspore.ops.composite.math_ops import mm
|
|
35
35
|
from mindspore.ops.function.math_func import dot
|
|
36
36
|
from mindspore.ops import auto_generate
|
|
37
|
+
from mindspore.ops.auto_generate import cast
|
|
37
38
|
from mindspore.ops_generate.gen_ops_inner_prim import DtypeToEnum
|
|
38
|
-
from mindspore.ops.operations.manually_defined.ops_def import scalar_div, scalar_mod, scalar_add, scalar_mul
|
|
39
|
-
scalar_sub, scalar_gt, scalar_ge, scalar_le, scalar_lt, scalar_eq, scalar_floordiv, scalar_log, scalar_pow
|
|
39
|
+
from mindspore.ops.operations.manually_defined.ops_def import scalar_div, scalar_mod, scalar_add, scalar_mul, \
|
|
40
|
+
scalar_sub, scalar_gt, scalar_ge, scalar_le, scalar_lt, scalar_eq, scalar_floordiv, scalar_log, scalar_pow, \
|
|
40
41
|
scalar_uadd, scalar_usub, flash_attention_score
|
|
41
42
|
|
|
42
43
|
typeof = Primitive('typeof')
|
|
43
44
|
hastype = Primitive('hastype')
|
|
44
|
-
|
|
45
|
+
_cast = P.Cast()
|
|
45
46
|
dtype = P.DType()
|
|
46
47
|
isconstant = _inner_ops.IsConstant()
|
|
47
48
|
isconstant.set_const_prim(True)
|
|
@@ -116,7 +117,8 @@ reduced_shape = Primitive("reduced_shape")
|
|
|
116
117
|
# shape_mul:input must be shape multiply elements in tuple(shape)
|
|
117
118
|
shape_mul = _sequence_ops.shape_mul()
|
|
118
119
|
|
|
119
|
-
setattr(tensor_operator_registry, 'tuple_to_tensor',
|
|
120
|
+
setattr(tensor_operator_registry, 'tuple_to_tensor',
|
|
121
|
+
_sequence_ops.TupleToTensor)
|
|
120
122
|
setattr(tensor_operator_registry, 'add', add)
|
|
121
123
|
setattr(tensor_operator_registry, 'softmax', softmax)
|
|
122
124
|
setattr(tensor_operator_registry, 'addr', addr)
|
|
@@ -136,6 +138,7 @@ setattr(tensor_operator_registry, 'rsqrt', rsqrt)
|
|
|
136
138
|
setattr(tensor_operator_registry, 'bincount', bincount)
|
|
137
139
|
setattr(tensor_operator_registry, 'slogdet', slogdet)
|
|
138
140
|
setattr(tensor_operator_registry, 'trace', trace)
|
|
141
|
+
setattr(tensor_operator_registry, 'tracev2', auto_generate.trace_v2_op)
|
|
139
142
|
setattr(tensor_operator_registry, 'tril', tril)
|
|
140
143
|
setattr(tensor_operator_registry, 'chunk', chunk)
|
|
141
144
|
setattr(tensor_operator_registry, 'count_nonzero', count_nonzero)
|
|
@@ -210,7 +213,8 @@ setattr(tensor_operator_registry, 'dot', dot)
|
|
|
210
213
|
setattr(tensor_operator_registry, 'outer', outer)
|
|
211
214
|
setattr(tensor_operator_registry, 'log1p', log1p)
|
|
212
215
|
setattr(tensor_operator_registry, 'logdet', logdet)
|
|
213
|
-
setattr(tensor_operator_registry,
|
|
216
|
+
setattr(tensor_operator_registry,
|
|
217
|
+
'log_matrix_determinant', log_matrix_determinant)
|
|
214
218
|
setattr(tensor_operator_registry, 'matrix_determinant', matrix_determinant)
|
|
215
219
|
setattr(tensor_operator_registry, 'ceil', ceil)
|
|
216
220
|
setattr(tensor_operator_registry, 'fillv2', P.FillV2)
|
|
@@ -223,6 +227,7 @@ setattr(tensor_operator_registry, 'vsplit', vsplit)
|
|
|
223
227
|
setattr(tensor_operator_registry, 'hsplit', hsplit)
|
|
224
228
|
setattr(tensor_operator_registry, 'dsplit', dsplit)
|
|
225
229
|
setattr(tensor_operator_registry, 'zeros_like', zeros_like)
|
|
230
|
+
setattr(tensor_operator_registry, 'zero_', zero_)
|
|
226
231
|
setattr(tensor_operator_registry, 'scalar_to_tensor', scalar_to_tensor)
|
|
227
232
|
setattr(tensor_operator_registry, 'stop_gradient', stop_gradient)
|
|
228
233
|
setattr(tensor_operator_registry, 'masked_fill', masked_fill)
|
|
@@ -264,6 +269,7 @@ setattr(tensor_operator_registry, 'tanh', tanh)
|
|
|
264
269
|
setattr(tensor_operator_registry, 'exp', exp)
|
|
265
270
|
setattr(tensor_operator_registry, 'addbmm', addbmm)
|
|
266
271
|
setattr(tensor_operator_registry, 'addmm', addmm)
|
|
272
|
+
setattr(tensor_operator_registry, 'addmm_', auto_generate.inplace_addmm_op)
|
|
267
273
|
setattr(tensor_operator_registry, 'addmv', addmv)
|
|
268
274
|
setattr(tensor_operator_registry, 'adjoint', adjoint)
|
|
269
275
|
setattr(tensor_operator_registry, 'asinh', asinh)
|
|
@@ -314,7 +320,7 @@ setattr(tensor_operator_registry, 'unsqueeze', unsqueeze)
|
|
|
314
320
|
setattr(tensor_operator_registry, 'expand_dims', expand_dims)
|
|
315
321
|
setattr(tensor_operator_registry, 'contiguous', auto_generate.contiguous)
|
|
316
322
|
# support GE backend for no compare operators
|
|
317
|
-
setattr(tensor_operator_registry, 'cast',
|
|
323
|
+
setattr(tensor_operator_registry, 'cast', _cast)
|
|
318
324
|
setattr(tensor_operator_registry, 'shape_mul', shape_mul)
|
|
319
325
|
setattr(tensor_operator_registry, 'concatenate', concat)
|
|
320
326
|
setattr(tensor_operator_registry, 'fill', fill)
|
|
@@ -392,12 +398,13 @@ setattr(tensor_operator_registry, 'argwhere', argwhere)
|
|
|
392
398
|
setattr(tensor_operator_registry, 'coo_add', coo_add)
|
|
393
399
|
setattr(tensor_operator_registry, 'topk', topk)
|
|
394
400
|
setattr(tensor_operator_registry, 'isfinite', isfinite)
|
|
395
|
-
setattr(tensor_operator_registry, 'to',
|
|
396
|
-
setattr(tensor_operator_registry, 'bool',
|
|
397
|
-
setattr(tensor_operator_registry, 'float',
|
|
398
|
-
setattr(tensor_operator_registry, 'half',
|
|
399
|
-
setattr(tensor_operator_registry, 'int',
|
|
400
|
-
setattr(tensor_operator_registry, 'long',
|
|
401
|
+
setattr(tensor_operator_registry, 'to', _cast)
|
|
402
|
+
setattr(tensor_operator_registry, 'bool', _cast)
|
|
403
|
+
setattr(tensor_operator_registry, 'float', _cast)
|
|
404
|
+
setattr(tensor_operator_registry, 'half', _cast)
|
|
405
|
+
setattr(tensor_operator_registry, 'int', _cast)
|
|
406
|
+
setattr(tensor_operator_registry, 'long', _cast)
|
|
407
|
+
setattr(tensor_operator_registry, 'byte', _cast)
|
|
401
408
|
setattr(tensor_operator_registry, 'cholesky', cholesky)
|
|
402
409
|
setattr(tensor_operator_registry, 'cholesky_inverse', cholesky_inverse)
|
|
403
410
|
setattr(tensor_operator_registry, 'cholesky_solve', cholesky_solve)
|
|
@@ -440,6 +447,9 @@ setattr(tensor_operator_registry, 'imag', imag)
|
|
|
440
447
|
setattr(tensor_operator_registry, 'repeat_interleave', repeat_interleave)
|
|
441
448
|
setattr(tensor_operator_registry, 'rad2deg', rad2deg)
|
|
442
449
|
setattr(tensor_operator_registry, 'deg2rad', deg2rad)
|
|
450
|
+
setattr(tensor_operator_registry, 'copy_', auto_generate.copy_ext)
|
|
451
|
+
setattr(tensor_operator_registry, 'add_', auto_generate.inplace_add_ext)
|
|
452
|
+
setattr(tensor_operator_registry, 'adds_', auto_generate.inplace_adds_ext)
|
|
443
453
|
setattr(tensor_operator_registry, 'copysign', copysign)
|
|
444
454
|
setattr(tensor_operator_registry, 'roll', Roll)
|
|
445
455
|
setattr(tensor_operator_registry, 'rot90', rot90)
|
|
@@ -82,7 +82,7 @@ class _CustomInstaller:
|
|
|
82
82
|
for dir_name in dir_names:
|
|
83
83
|
if not os.path.isdir(dir_name):
|
|
84
84
|
try:
|
|
85
|
-
os.makedirs(dir_name, exist_ok=True)
|
|
85
|
+
os.makedirs(dir_name, mode=0o700, exist_ok=True)
|
|
86
86
|
except OSError as err:
|
|
87
87
|
if err.errno == 17: # File exists
|
|
88
88
|
pass
|
|
@@ -121,7 +121,7 @@ class _CustomInstaller:
|
|
|
121
121
|
|
|
122
122
|
def _find_ai_cpu_so_path(self, so_file):
|
|
123
123
|
"""find the absolute path of so"""
|
|
124
|
-
current_path = os.path.dirname(os.path.
|
|
124
|
+
current_path = os.path.dirname(os.path.realpath(__file__))
|
|
125
125
|
search_paths = [current_path + "/../lib", current_path + "/../lib/plugin/ascend"]
|
|
126
126
|
for path in search_paths:
|
|
127
127
|
so_path = os.path.join(path, so_file)
|
|
@@ -235,7 +235,7 @@ class _CustomInstaller:
|
|
|
235
235
|
# generate and copy reg info file
|
|
236
236
|
op_info = self._gen_ai_core_reg_info(imply_path, self.func.__name__)
|
|
237
237
|
self._copy_file(imply_path, self.ai_core_impl_dir)
|
|
238
|
-
for arc_name in ["ascend910", "ascend910b", "
|
|
238
|
+
for arc_name in ["ascend910", "ascend910b", "ascend910_93", "ascend310p"]:
|
|
239
239
|
arc_dir = os.path.join(self.ai_core_config_dir, arc_name)
|
|
240
240
|
_CustomInstaller._create_dir(arc_dir)
|
|
241
241
|
self._save_op_info(arc_dir, "aic-{}-ops-info.json".format(arc_name), op_info)
|
|
@@ -55,7 +55,7 @@ from .comm_ops import (AllGather, AllReduce, Reduce, NeighborExchange, NeighborE
|
|
|
55
55
|
_MirrorOperator, _MirrorMiniStepOperator, _MiniStepAllGather, ReduceOp, _VirtualDataset,
|
|
56
56
|
_VirtualOutput, _VirtualDiv, _GetTensorSlice, _VirtualAdd, _VirtualAssignAdd, _VirtualAccuGrad,
|
|
57
57
|
_HostAllGather, _HostReduceScatter, _MirrorMicroStepOperator, _MicroStepAllGather,
|
|
58
|
-
_VirtualPipelineEnd, AlltoAllV, ReduceScatter)
|
|
58
|
+
_VirtualPipelineEnd, AlltoAllV, ReduceScatter, _VirtualAssignKvCache)
|
|
59
59
|
from .control_ops import GeSwitch, Merge
|
|
60
60
|
from .custom_ops import (Custom)
|
|
61
61
|
from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary,
|
|
@@ -96,7 +96,7 @@ from .nn_ops import (LSTM, SGD, Adam, AdamWeightDecay, FusedSparseAdam, FusedSpa
|
|
|
96
96
|
InstanceNorm,
|
|
97
97
|
GeLU, FastGeLU, Elu, CeLU,
|
|
98
98
|
GetNext, L2Normalize, LayerNorm, L2Loss, CTCLoss, CTCLossV2, CTCLossV2Grad, CTCGreedyDecoder,
|
|
99
|
-
LogSoftmax, MaxPool3D, AvgPool3D,
|
|
99
|
+
LogSoftmax, LogSoftmaxExt, MaxPool3D, AvgPool3D,
|
|
100
100
|
MaxPool, DataFormatDimMap,
|
|
101
101
|
AvgPool, Conv2DBackpropInput, ComputeAccidentalHits,
|
|
102
102
|
MaxPoolWithArgmaxV2, OneHot, Pad, MirrorPad, Mish, PReLU, ReLU, ReLU6,
|
|
@@ -136,6 +136,7 @@ from ..deprecated import (identity, DropoutDoMask, MaxPoolWithArgmax, DropoutGen
|
|
|
136
136
|
TensorAdd, InplaceUpdate, ScatterNonAliasingAdd,
|
|
137
137
|
BatchToSpaceND, Unpack, GatherV2, DynamicShape, ScalarToArray, Pack)
|
|
138
138
|
from .manually_defined._inner import ScalarCast
|
|
139
|
+
from .manually_defined import WhileLoop, Scan, ForiLoop
|
|
139
140
|
from .reshard_ops import (Reshard)
|
|
140
141
|
|
|
141
142
|
__all__ = [
|
|
@@ -203,6 +204,7 @@ __all__ = [
|
|
|
203
204
|
'Softmax',
|
|
204
205
|
'Softsign',
|
|
205
206
|
'LogSoftmax',
|
|
207
|
+
'LogSoftmaxExt',
|
|
206
208
|
'SoftmaxCrossEntropyWithLogits',
|
|
207
209
|
'BCEWithLogitsLoss',
|
|
208
210
|
'ROIAlign',
|
|
@@ -337,6 +339,9 @@ __all__ = [
|
|
|
337
339
|
'TupleToArray',
|
|
338
340
|
'GeSwitch',
|
|
339
341
|
'Merge',
|
|
342
|
+
'WhileLoop',
|
|
343
|
+
'Scan',
|
|
344
|
+
'ForiLoop',
|
|
340
345
|
'CheckValid',
|
|
341
346
|
'BartlettWindow',
|
|
342
347
|
'BlackmanWindow',
|
|
@@ -35,8 +35,8 @@ from ..auto_generate import (AbsGrad, ACosGrad, LogitGrad, AcoshGrad, AsinGrad,
|
|
|
35
35
|
SigmoidGrad, HSwishGrad, NLLLossGrad, AtanGrad, GridSampler3DGrad, GridSampler2DGrad,
|
|
36
36
|
ResizeBicubicGrad, HSigmoidGrad, CholeskyGrad, ResizeNearestNeighborGrad, LayerNormGrad,
|
|
37
37
|
HShrinkGrad, LayerNormGradGrad, SiLUGrad, MaximumGrad, MaximumGradGrad, RmsNormGrad,
|
|
38
|
-
FlashAttentionScoreGrad, UpsampleTrilinear3DGrad, UpsampleNearest3DGrad,
|
|
39
|
-
BinaryCrossEntropyGrad)
|
|
38
|
+
FlashAttentionScoreGrad, UpsampleTrilinear3DGrad, UpsampleNearest3DGrad, MaskedSelectGrad,
|
|
39
|
+
BinaryCrossEntropyGrad, SoftShrinkGrad, SeluGrad)
|
|
40
40
|
|
|
41
41
|
|
|
42
42
|
class SparseFillEmptyRowsGrad(Primitive):
|
|
@@ -1658,35 +1658,6 @@ class SoftMarginLossGrad(Primitive):
|
|
|
1658
1658
|
self.reduction = validator.check_string(reduction, ['none', 'sum', 'mean'], 'reduction', self.name)
|
|
1659
1659
|
|
|
1660
1660
|
|
|
1661
|
-
class StridedSliceV2Grad(Primitive):
|
|
1662
|
-
"""
|
|
1663
|
-
Performs grad of StridedSliceV2 operation.
|
|
1664
|
-
|
|
1665
|
-
Inputs:
|
|
1666
|
-
- **shapex** (Tensor) - StridedSliceV2 shape of input
|
|
1667
|
-
- **begin** (tuple[int]) - A tuple which represents the location where to start. Only
|
|
1668
|
-
constant value is allowed.
|
|
1669
|
-
- **end** (tuple[int]) - A tuple or which represents the maximum location where to end.
|
|
1670
|
-
Only constant value is allowed.
|
|
1671
|
-
- **strides** (tuple[int]) - A tuple which represents the stride is continuously added
|
|
1672
|
-
before reaching the maximum location. Only constant value is allowed.
|
|
1673
|
-
- **dy** (Tensor) - The output of StridedSliceV2
|
|
1674
|
-
|
|
1675
|
-
Outputs:
|
|
1676
|
-
Tensor, the shape same as the input of StridedSliceV2
|
|
1677
|
-
"""
|
|
1678
|
-
|
|
1679
|
-
@prim_attr_register
|
|
1680
|
-
def __init__(self,
|
|
1681
|
-
begin_mask=0,
|
|
1682
|
-
end_mask=0,
|
|
1683
|
-
ellipsis_mask=0,
|
|
1684
|
-
new_axis_mask=0,
|
|
1685
|
-
shrink_axis_mask=0):
|
|
1686
|
-
"""Initialize StridedSliceV2Grad"""
|
|
1687
|
-
self.init_prim_io_names(inputs=['shapex', 'begin', 'end', 'strides', 'dy'], outputs=['output'])
|
|
1688
|
-
|
|
1689
|
-
|
|
1690
1661
|
class StridedSliceGrad(Primitive):
|
|
1691
1662
|
"""
|
|
1692
1663
|
Performs grad of StridedSlice operation.
|
|
@@ -1991,51 +1962,6 @@ class MvlgammaGrad(Primitive):
|
|
|
1991
1962
|
self.p = validator.check_value_type('p', p, [int], self.name)
|
|
1992
1963
|
|
|
1993
1964
|
|
|
1994
|
-
class MaskedSelectGrad(PrimitiveWithInfer):
|
|
1995
|
-
"""Computes gradient for MaskedSelect."""
|
|
1996
|
-
|
|
1997
|
-
@prim_attr_register
|
|
1998
|
-
def __init__(self):
|
|
1999
|
-
pass
|
|
2000
|
-
|
|
2001
|
-
def infer_shape(self, x, mask, grad):
|
|
2002
|
-
return x
|
|
2003
|
-
|
|
2004
|
-
def infer_dtype(self, x, mask, grad):
|
|
2005
|
-
return x
|
|
2006
|
-
|
|
2007
|
-
|
|
2008
|
-
class SoftShrinkGrad(Primitive):
|
|
2009
|
-
r"""
|
|
2010
|
-
Gradients for SoftShrink operation.
|
|
2011
|
-
|
|
2012
|
-
Args:
|
|
2013
|
-
lambd – The \lambdaλ (must be no less than zero) value for the Softshrink formulation. Default: 0.5.
|
|
2014
|
-
|
|
2015
|
-
Inputs:
|
|
2016
|
-
- **input_grad** (Tensor) - The input gradient.
|
|
2017
|
-
- **input_x** (Tensor) - The input of SoftShrink with data type of float16 or float32.
|
|
2018
|
-
Any number of additional dimensions.
|
|
2019
|
-
|
|
2020
|
-
Outputs:
|
|
2021
|
-
output - Tensor, has the same shape and data type as input_x.
|
|
2022
|
-
|
|
2023
|
-
Raises:
|
|
2024
|
-
TypeError: If lambd is not a float.
|
|
2025
|
-
TypeError: If dtype of input_x is neither float16 nor float32.
|
|
2026
|
-
ValueError: If lambd is less than to 0.
|
|
2027
|
-
|
|
2028
|
-
Supported Platforms:
|
|
2029
|
-
``Ascend``
|
|
2030
|
-
"""
|
|
2031
|
-
|
|
2032
|
-
@prim_attr_register
|
|
2033
|
-
def __init__(self, lambd=0.5):
|
|
2034
|
-
self.init_prim_io_names(inputs=['input_grad', 'input_x'], outputs=['output'])
|
|
2035
|
-
validator.check_value_type("lambd", lambd, [float], self.name)
|
|
2036
|
-
validator.check_number("lambd", lambd, 0, validator.GE, self.name)
|
|
2037
|
-
|
|
2038
|
-
|
|
2039
1965
|
class CdistGrad(Primitive):
|
|
2040
1966
|
"""Computes gradient for Cdist."""
|
|
2041
1967
|
|
|
@@ -16,4 +16,4 @@
|
|
|
16
16
|
"""Operator of infer net"""
|
|
17
17
|
# pylint: disable=unused-import
|
|
18
18
|
from ..auto_generate import (QuantV2, DynamicQuantExt, QuantBatchMatmul, WeightQuantBatchMatmul, KVCacheScatterUpdate,
|
|
19
|
-
FusedInferAttentionScore, GroupedMatmul, MoeFinalizeRouting)
|
|
19
|
+
FusedInferAttentionScore, GroupedMatmul, MoeFinalizeRouting, QuantLinearSparse)
|
|
@@ -17,6 +17,7 @@
|
|
|
17
17
|
from types import FunctionType, MethodType
|
|
18
18
|
from collections.abc import Iterable
|
|
19
19
|
import os
|
|
20
|
+
import weakref
|
|
20
21
|
import numpy as np
|
|
21
22
|
|
|
22
23
|
from mindspore.common import Tensor
|
|
@@ -29,7 +30,7 @@ from mindspore.ops.operations.math_ops import _infer_shape_reduce
|
|
|
29
30
|
from mindspore.ops.primitive import PrimitiveWithCheck, PrimitiveWithInfer, prim_attr_register, Primitive, \
|
|
30
31
|
_run_op, _check_contains_variable
|
|
31
32
|
from mindspore._c_expression import Tensor as Tensor_
|
|
32
|
-
from mindspore._c_expression import typing
|
|
33
|
+
from mindspore._c_expression import typing, HookType
|
|
33
34
|
from mindspore import _checkparam as validator
|
|
34
35
|
from mindspore.common import dtype as mstype
|
|
35
36
|
from mindspore.common.parameter import Parameter
|
|
@@ -1535,7 +1536,7 @@ class CellBackwardHook(PrimitiveWithInfer):
|
|
|
1535
1536
|
... print(grad)
|
|
1536
1537
|
...
|
|
1537
1538
|
>>> hook = inner.CellBackwardHook()
|
|
1538
|
-
>>> hook_fn_key = hook.register_backward_hook(
|
|
1539
|
+
>>> hook_fn_key = hook.register_backward_hook()
|
|
1539
1540
|
>>> def hook_test(x, y):
|
|
1540
1541
|
... z = x * y
|
|
1541
1542
|
... z = hook(z)
|
|
@@ -1556,16 +1557,19 @@ class CellBackwardHook(PrimitiveWithInfer):
|
|
|
1556
1557
|
(Tensor(shape=[], dtype=Float32, value= 4), Tensor(shape=[], dtype=Float32, value= 4))
|
|
1557
1558
|
"""
|
|
1558
1559
|
|
|
1559
|
-
def __init__(self, cell_id=""):
|
|
1560
|
+
def __init__(self, cell_id="", cell=None, hook_dict=None):
|
|
1560
1561
|
"""Initialize CellBackwardHook"""
|
|
1561
1562
|
super(CellBackwardHook, self).__init__(self.__class__.__name__)
|
|
1562
1563
|
self.cell_id = cell_id
|
|
1564
|
+
self.cell = cell
|
|
1565
|
+
self.hook_dict = weakref.ref(hook_dict)
|
|
1563
1566
|
self.add_prim_attr("cell_id", cell_id)
|
|
1564
|
-
self.
|
|
1567
|
+
self.grad_output = None
|
|
1565
1568
|
|
|
1566
|
-
def __call__(self, args):
|
|
1567
|
-
|
|
1568
|
-
|
|
1569
|
+
def __call__(self, *args):
|
|
1570
|
+
# If args is empty, just return.
|
|
1571
|
+
if not args:
|
|
1572
|
+
return args
|
|
1569
1573
|
return _run_op(self, self.name, args)
|
|
1570
1574
|
|
|
1571
1575
|
def infer_shape(self, *inputs_shape):
|
|
@@ -1578,51 +1582,76 @@ class CellBackwardHook(PrimitiveWithInfer):
|
|
|
1578
1582
|
return inputs_type[0]
|
|
1579
1583
|
return inputs_type
|
|
1580
1584
|
|
|
1581
|
-
def register_backward_hook(self
|
|
1582
|
-
|
|
1583
|
-
|
|
1584
|
-
mode.
|
|
1585
|
-
|
|
1586
|
-
Note:
|
|
1587
|
-
The 'hook_fn' must be defined as the following code.
|
|
1588
|
-
`cell_id` is the information of registered cell. `grad_input` is the gradient passed to the cell.
|
|
1589
|
-
`grad_output` is the gradient computed and passed to the next cell or primitive, which may be modified by
|
|
1590
|
-
returning a new output gradient.
|
|
1591
|
-
The 'hook_fn' should have the following signature:
|
|
1592
|
-
hook_fn(cell_id, grad_input, grad_output) -> New output gradient or none.
|
|
1593
|
-
The 'hook_fn' is executed in the python environment.
|
|
1585
|
+
def register_backward_hook(self):
|
|
1586
|
+
"""
|
|
1587
|
+
Register the backward hook function.
|
|
1594
1588
|
|
|
1595
1589
|
Args:
|
|
1596
|
-
|
|
1590
|
+
None
|
|
1597
1591
|
|
|
1598
1592
|
Returns:
|
|
1599
|
-
|
|
1593
|
+
None
|
|
1600
1594
|
|
|
1601
|
-
|
|
1602
|
-
|
|
1595
|
+
Supported Platforms:
|
|
1596
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
1603
1597
|
"""
|
|
1604
|
-
if not isinstance(hook_fn, (FunctionType, MethodType)):
|
|
1605
|
-
raise TypeError(f"When using 'register_backward_hook(hook_fn)', the type of 'hook_fn' must be python "
|
|
1606
|
-
f"function, but got {type(hook_fn)}.")
|
|
1607
|
-
key = self.add_backward_hook_fn(hook_fn)
|
|
1608
|
-
return key
|
|
1609
1598
|
|
|
1610
|
-
|
|
1611
|
-
|
|
1612
|
-
|
|
1613
|
-
|
|
1614
|
-
|
|
1615
|
-
|
|
1616
|
-
|
|
1617
|
-
|
|
1599
|
+
def hook_backward_grad(grad):
|
|
1600
|
+
if self.grad_output is None:
|
|
1601
|
+
self.grad_output = grad
|
|
1602
|
+
# Indicates the first time of call backward hook, and need to wait for the second time call
|
|
1603
|
+
return self.cell_id
|
|
1604
|
+
backward_hook_grad_input = grad
|
|
1605
|
+
if self.hook_dict():
|
|
1606
|
+
backward_hooks = self.hook_dict().values()
|
|
1607
|
+
for hook in backward_hooks:
|
|
1608
|
+
res = hook(self.cell, backward_hook_grad_input, self.grad_output)
|
|
1609
|
+
if res is None:
|
|
1610
|
+
continue
|
|
1611
|
+
if not isinstance(res, tuple):
|
|
1612
|
+
res = (res,)
|
|
1613
|
+
if len(res) != len(grad):
|
|
1614
|
+
raise TypeError(
|
|
1615
|
+
"The backward hook return value size is {} not equal to expect grad input size {}".format(
|
|
1616
|
+
len(res), len(grad)))
|
|
1617
|
+
backward_hook_grad_input = res
|
|
1618
|
+
self.grad_output = None
|
|
1619
|
+
return backward_hook_grad_input
|
|
1620
|
+
|
|
1621
|
+
self.set_hook_fn(hook_backward_grad, HookType.BackwardHook)
|
|
1622
|
+
|
|
1623
|
+
def register_backward_pre_hook(self):
|
|
1624
|
+
"""
|
|
1625
|
+
Register the backward pre hook function.
|
|
1618
1626
|
|
|
1619
1627
|
Args:
|
|
1620
|
-
|
|
1628
|
+
None
|
|
1621
1629
|
|
|
1622
1630
|
Returns:
|
|
1623
|
-
None
|
|
1631
|
+
None
|
|
1632
|
+
|
|
1633
|
+
Supported Platforms:
|
|
1634
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
1624
1635
|
"""
|
|
1625
|
-
|
|
1636
|
+
|
|
1637
|
+
def hook_backward_pre_grad(grad):
|
|
1638
|
+
backward_pre_hook_grad = grad
|
|
1639
|
+
if self.hook_dict():
|
|
1640
|
+
backward_pre_hooks = self.hook_dict().values()
|
|
1641
|
+
for hook in backward_pre_hooks:
|
|
1642
|
+
res = hook(self.cell, backward_pre_hook_grad)
|
|
1643
|
+
if res is None:
|
|
1644
|
+
continue
|
|
1645
|
+
if not isinstance(res, tuple):
|
|
1646
|
+
res = (res,)
|
|
1647
|
+
if len(res) != len(grad):
|
|
1648
|
+
raise TypeError(
|
|
1649
|
+
"The backward pre hook return value size is {} not equal to expect output size {}".format(
|
|
1650
|
+
len(res), len(grad)))
|
|
1651
|
+
backward_pre_hook_grad = res
|
|
1652
|
+
return backward_pre_hook_grad
|
|
1653
|
+
|
|
1654
|
+
self.set_hook_fn(hook_backward_pre_grad, HookType.BackwardPreHook)
|
|
1626
1655
|
|
|
1627
1656
|
|
|
1628
1657
|
class Format(PrimitiveWithInfer):
|
|
@@ -2478,60 +2507,6 @@ class FFN(Primitive):
|
|
|
2478
2507
|
validator.check_value_type("inner_precise", inner_precise, [int], cls_name)
|
|
2479
2508
|
|
|
2480
2509
|
|
|
2481
|
-
class _MirrorSilentCheck(PrimitiveWithInfer):
|
|
2482
|
-
"""
|
|
2483
|
-
The operator _MirrorSilentCheck implements accuracy-sensitive detection on the tensor input in backpropagator.
|
|
2484
|
-
Call _MirrorSilentCheck in method __call__ of derived class to implement accuracy-sensitive detection.
|
|
2485
|
-
|
|
2486
|
-
Inputs:
|
|
2487
|
-
- **input** (Tensor) : The tensor used for detection.
|
|
2488
|
-
Its data type must be mindspore.float16, mindspore.float32 or mindspore.bfloat16.
|
|
2489
|
-
- **pre_val** (Parameter(Tensor)) : Support parameter in accuracy-sensitive detection.
|
|
2490
|
-
Please only generated by method generate_params() of ASDBase.
|
|
2491
|
-
- **min_val** (Parameter(Tensor)) : Support parameter in accuracy-sensitive detection.
|
|
2492
|
-
Please only generated by method generate_params() of ASDBase.
|
|
2493
|
-
- **max_val** (Parameter(Tensor)) : Support parameter in accuracy-sensitive detection.
|
|
2494
|
-
Please only generated by method generate_params() of ASDBase.
|
|
2495
|
-
- **cnt** (Parameter(Tensor)) : Support parameter in accuracy-sensitive detection.
|
|
2496
|
-
Please only generated by method generate_params() of ASDBase.
|
|
2497
|
-
After each invocation of _MirrorSilentCheck, increment the value of cnt by one.
|
|
2498
|
-
|
|
2499
|
-
Outputs:
|
|
2500
|
-
- **output** (Tensor) - Same shape, type and value as `input`.
|
|
2501
|
-
"""
|
|
2502
|
-
@prim_attr_register
|
|
2503
|
-
def __init__(self, min_steps=8):
|
|
2504
|
-
upper_thresh, sigma_thresh = self.get_thresh()
|
|
2505
|
-
self.min_steps = min_steps
|
|
2506
|
-
self.thresh_l1 = upper_thresh[0]
|
|
2507
|
-
self.coeff_l1 = sigma_thresh[0]
|
|
2508
|
-
self.thresh_l2 = upper_thresh[1]
|
|
2509
|
-
self.coeff_l2 = sigma_thresh[1]
|
|
2510
|
-
self.add_prim_attr('side_effect_mem', True)
|
|
2511
|
-
|
|
2512
|
-
def parse_thresh(self, env_var_name, default_value, min_value):
|
|
2513
|
-
env_var = os.environ.get(env_var_name, default=default_value)
|
|
2514
|
-
thresh = [value.strip() for value in env_var.split(",")]
|
|
2515
|
-
if len(thresh) != 2 or not all(value.isdigit() for value in thresh):
|
|
2516
|
-
thresh = default_value.split(",")
|
|
2517
|
-
thresh = [float(max(int(value), min_value)) for value in thresh]
|
|
2518
|
-
if thresh[0] <= thresh[1]:
|
|
2519
|
-
thresh = [float(value) for value in default_value.split(",")]
|
|
2520
|
-
|
|
2521
|
-
return thresh
|
|
2522
|
-
|
|
2523
|
-
def get_thresh(self):
|
|
2524
|
-
upper_thresh = self.parse_thresh("NPU_ASD_UPPER_THRESH", "1000000,10000", 3)
|
|
2525
|
-
sigma_thresh = self.parse_thresh("NPU_ASD_SIGMA_THRESH", "100000,5000", 3)
|
|
2526
|
-
return upper_thresh, sigma_thresh
|
|
2527
|
-
|
|
2528
|
-
def infer_shape(self, x_shape, pre_shape, min_shape, max_shape, n_step, loss_scale_shape):
|
|
2529
|
-
return x_shape
|
|
2530
|
-
|
|
2531
|
-
def infer_dtype(self, x_dtype, pre_dtype, min_dtype, max_dtype, n_dtype, loss_scale_dtype):
|
|
2532
|
-
return x_dtype
|
|
2533
|
-
|
|
2534
|
-
|
|
2535
2510
|
class _VirtualConverterEnd(PrimitiveWithInfer):
|
|
2536
2511
|
"""
|
|
2537
2512
|
Auto parallel virtual operator.
|
|
@@ -2560,6 +2535,8 @@ class _VirtualConverterBegin(PrimitiveWithInfer):
|
|
|
2560
2535
|
self.output_nums = output_nums
|
|
2561
2536
|
|
|
2562
2537
|
def infer_shape(self, arg):
|
|
2538
|
+
if self.output_nums == 0:
|
|
2539
|
+
return ValueError("output_nums can\'t be zero.")
|
|
2563
2540
|
new_arg = (arg[0] / self.output_nums,) + tuple(arg[1:])
|
|
2564
2541
|
return (new_arg,) * self.output_nums
|
|
2565
2542
|
|