mindspore 2.3.0__cp39-cp39-win_amd64.whl → 2.4.0__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +3 -1
- mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +50 -9
- mindspore/_extends/parse/compile_config.py +41 -0
- mindspore/_extends/parse/parser.py +9 -7
- mindspore/_extends/parse/standard_method.py +52 -14
- mindspore/_extends/pijit/pijit_func_white_list.py +350 -24
- mindspore/amp.py +24 -10
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/common/__init__.py +6 -4
- mindspore/common/_pijit_context.py +190 -0
- mindspore/common/_register_for_tensor.py +2 -1
- mindspore/common/_tensor_overload.py +139 -0
- mindspore/common/api.py +102 -87
- mindspore/common/dump.py +5 -6
- mindspore/common/generator.py +1 -7
- mindspore/common/hook_handle.py +14 -26
- mindspore/common/mindir_util.py +2 -2
- mindspore/common/parameter.py +46 -13
- mindspore/common/recompute.py +39 -9
- mindspore/common/sparse_tensor.py +7 -3
- mindspore/common/tensor.py +209 -29
- mindspore/communication/__init__.py +1 -1
- mindspore/communication/_comm_helper.py +38 -3
- mindspore/communication/comm_func.py +310 -55
- mindspore/communication/management.py +14 -14
- mindspore/context.py +123 -22
- mindspore/dataset/__init__.py +1 -1
- mindspore/dataset/audio/__init__.py +1 -1
- mindspore/dataset/core/config.py +7 -0
- mindspore/dataset/core/validator_helpers.py +7 -0
- mindspore/dataset/engine/cache_client.py +1 -1
- mindspore/dataset/engine/datasets.py +72 -44
- mindspore/dataset/engine/datasets_audio.py +7 -7
- mindspore/dataset/engine/datasets_standard_format.py +53 -3
- mindspore/dataset/engine/datasets_text.py +20 -20
- mindspore/dataset/engine/datasets_user_defined.py +174 -104
- mindspore/dataset/engine/datasets_vision.py +33 -33
- mindspore/dataset/engine/iterators.py +29 -0
- mindspore/dataset/engine/obs/util.py +7 -0
- mindspore/dataset/engine/queue.py +114 -60
- mindspore/dataset/engine/serializer_deserializer.py +2 -2
- mindspore/dataset/engine/validators.py +34 -14
- mindspore/dataset/text/__init__.py +1 -4
- mindspore/dataset/transforms/__init__.py +0 -3
- mindspore/dataset/utils/line_reader.py +2 -0
- mindspore/dataset/vision/__init__.py +1 -4
- mindspore/dataset/vision/utils.py +1 -1
- mindspore/dataset/vision/validators.py +2 -1
- mindspore/dnnl.dll +0 -0
- mindspore/{nn/extend → experimental/es}/__init__.py +4 -11
- mindspore/experimental/es/embedding_service.py +883 -0
- mindspore/{nn/layer → experimental/es}/embedding_service_layer.py +218 -30
- mindspore/experimental/llm_boost/__init__.py +21 -0
- mindspore/{nn/extend/layer → experimental/llm_boost/atb}/__init__.py +4 -8
- mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
- mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
- mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
- mindspore/experimental/llm_boost/register.py +129 -0
- mindspore/experimental/llm_boost/utils.py +31 -0
- mindspore/experimental/optim/adamw.py +85 -0
- mindspore/experimental/optim/optimizer.py +3 -0
- mindspore/hal/__init__.py +3 -3
- mindspore/hal/contiguous_tensors_handle.py +175 -0
- mindspore/hal/stream.py +18 -0
- mindspore/include/api/model_group.h +13 -1
- mindspore/include/api/types.h +10 -10
- mindspore/include/dataset/config.h +2 -2
- mindspore/include/dataset/constants.h +2 -2
- mindspore/include/dataset/execute.h +2 -2
- mindspore/include/dataset/vision.h +4 -0
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +1 -1
- mindspore/mindrecord/filewriter.py +68 -51
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/mint/__init__.py +495 -46
- mindspore/mint/distributed/__init__.py +31 -0
- mindspore/mint/distributed/distributed.py +254 -0
- mindspore/mint/nn/__init__.py +266 -21
- mindspore/mint/nn/functional.py +125 -19
- mindspore/mint/nn/layer/__init__.py +39 -0
- mindspore/mint/nn/layer/activation.py +133 -0
- mindspore/mint/nn/layer/normalization.py +477 -0
- mindspore/mint/nn/layer/pooling.py +110 -0
- mindspore/mint/optim/adamw.py +28 -7
- mindspore/mint/special/__init__.py +63 -0
- mindspore/multiprocessing/__init__.py +2 -1
- mindspore/nn/__init__.py +0 -1
- mindspore/nn/cell.py +275 -93
- mindspore/nn/layer/activation.py +211 -44
- mindspore/nn/layer/basic.py +113 -3
- mindspore/nn/layer/embedding.py +120 -2
- mindspore/nn/layer/normalization.py +101 -5
- mindspore/nn/layer/padding.py +34 -48
- mindspore/nn/layer/pooling.py +161 -7
- mindspore/nn/layer/transformer.py +3 -3
- mindspore/nn/loss/__init__.py +2 -2
- mindspore/nn/loss/loss.py +84 -6
- mindspore/nn/optim/__init__.py +2 -1
- mindspore/nn/optim/adadelta.py +1 -1
- mindspore/nn/optim/adam.py +1 -1
- mindspore/nn/optim/lamb.py +1 -1
- mindspore/nn/optim/tft_wrapper.py +127 -0
- mindspore/nn/wrap/cell_wrapper.py +12 -23
- mindspore/nn/wrap/grad_reducer.py +5 -5
- mindspore/nn/wrap/loss_scale.py +17 -3
- mindspore/numpy/__init__.py +1 -1
- mindspore/numpy/array_creations.py +65 -68
- mindspore/numpy/array_ops.py +64 -60
- mindspore/numpy/fft.py +610 -75
- mindspore/numpy/logic_ops.py +11 -10
- mindspore/numpy/math_ops.py +85 -84
- mindspore/numpy/utils_const.py +4 -4
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +6 -4
- mindspore/ops/_grad_experimental/grad_comm_ops.py +47 -3
- mindspore/ops/_grad_experimental/grad_math_ops.py +0 -22
- mindspore/ops/_vmap/vmap_array_ops.py +2 -4
- mindspore/ops/_vmap/vmap_math_ops.py +17 -1
- mindspore/ops/_vmap/vmap_nn_ops.py +43 -2
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +85 -7
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +2 -0
- mindspore/ops/auto_generate/gen_extend_func.py +734 -13
- mindspore/ops/auto_generate/gen_ops_def.py +2420 -381
- mindspore/ops/auto_generate/gen_ops_prim.py +5196 -1659
- mindspore/ops/auto_generate/pyboost_inner_prim.py +176 -56
- mindspore/ops/composite/base.py +85 -48
- mindspore/ops/composite/multitype_ops/_compile_utils.py +1 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -2
- mindspore/ops/function/__init__.py +22 -0
- mindspore/ops/function/array_func.py +490 -153
- mindspore/ops/function/debug_func.py +113 -1
- mindspore/ops/function/fft_func.py +15 -2
- mindspore/ops/function/grad/grad_func.py +3 -2
- mindspore/ops/function/math_func.py +558 -207
- mindspore/ops/function/nn_func.py +817 -383
- mindspore/ops/function/other_func.py +3 -2
- mindspore/ops/function/random_func.py +184 -8
- mindspore/ops/function/reshard_func.py +13 -11
- mindspore/ops/function/sparse_unary_func.py +1 -1
- mindspore/ops/function/vmap_func.py +3 -2
- mindspore/ops/functional.py +24 -14
- mindspore/ops/op_info_register.py +3 -3
- mindspore/ops/operations/__init__.py +6 -1
- mindspore/ops/operations/_grad_ops.py +2 -76
- mindspore/ops/operations/_infer_ops.py +1 -1
- mindspore/ops/operations/_inner_ops.py +71 -94
- mindspore/ops/operations/array_ops.py +12 -146
- mindspore/ops/operations/comm_ops.py +42 -53
- mindspore/ops/operations/custom_ops.py +83 -19
- mindspore/ops/operations/debug_ops.py +42 -10
- mindspore/ops/operations/manually_defined/_inner.py +12 -0
- mindspore/ops/operations/manually_defined/ops_def.py +265 -10
- mindspore/ops/operations/math_ops.py +12 -223
- mindspore/ops/operations/nn_ops.py +20 -114
- mindspore/ops/operations/other_ops.py +7 -4
- mindspore/ops/operations/random_ops.py +46 -1
- mindspore/ops/primitive.py +18 -6
- mindspore/ops_generate/arg_dtype_cast.py +2 -0
- mindspore/ops_generate/gen_aclnn_implement.py +11 -11
- mindspore/ops_generate/gen_constants.py +36 -0
- mindspore/ops_generate/gen_ops.py +67 -52
- mindspore/ops_generate/gen_ops_inner_prim.py +1 -1
- mindspore/ops_generate/gen_pyboost_func.py +131 -47
- mindspore/ops_generate/op_proto.py +10 -3
- mindspore/ops_generate/pyboost_utils.py +14 -1
- mindspore/ops_generate/template.py +43 -21
- mindspore/parallel/__init__.py +3 -1
- mindspore/parallel/_auto_parallel_context.py +28 -8
- mindspore/parallel/_cell_wrapper.py +83 -0
- mindspore/parallel/_parallel_serialization.py +47 -19
- mindspore/parallel/_tensor.py +81 -11
- mindspore/parallel/_utils.py +13 -1
- mindspore/parallel/algo_parameter_config.py +5 -5
- mindspore/parallel/checkpoint_transform.py +46 -39
- mindspore/parallel/cluster/process_entity/__init__.py +1 -1
- mindspore/parallel/cluster/process_entity/_api.py +31 -23
- mindspore/parallel/cluster/process_entity/_utils.py +2 -27
- mindspore/parallel/parameter_broadcast.py +3 -4
- mindspore/parallel/shard.py +162 -31
- mindspore/parallel/transform_safetensors.py +993 -0
- mindspore/profiler/__init__.py +2 -1
- mindspore/profiler/common/constant.py +29 -0
- mindspore/profiler/common/registry.py +47 -0
- mindspore/profiler/common/util.py +28 -0
- mindspore/profiler/dynamic_profiler.py +694 -0
- mindspore/profiler/envprofiling.py +17 -19
- mindspore/profiler/parser/ascend_analysis/constant.py +18 -0
- mindspore/profiler/parser/ascend_analysis/file_manager.py +25 -4
- mindspore/profiler/parser/ascend_analysis/function_event.py +43 -19
- mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +31 -26
- mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +56 -10
- mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +55 -8
- mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
- mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +27 -20
- mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +9 -2
- mindspore/profiler/parser/ascend_msprof_exporter.py +5 -4
- mindspore/profiler/parser/ascend_timeline_generator.py +27 -25
- mindspore/profiler/parser/base_timeline_generator.py +19 -25
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +25 -12
- mindspore/profiler/parser/framework_parser.py +1 -391
- mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
- mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
- mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
- mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
- mindspore/profiler/parser/memory_usage_parser.py +0 -154
- mindspore/profiler/parser/profiler_info.py +78 -6
- mindspore/profiler/profiler.py +153 -0
- mindspore/profiler/profiling.py +280 -412
- mindspore/rewrite/__init__.py +1 -2
- mindspore/rewrite/common/namespace.py +4 -4
- mindspore/rewrite/symbol_tree/symbol_tree.py +3 -3
- mindspore/run_check/_check_version.py +36 -103
- mindspore/safeguard/rewrite_obfuscation.py +591 -247
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +4 -3
- mindspore/train/_utils.py +28 -2
- mindspore/train/amp.py +171 -53
- mindspore/train/callback/__init__.py +2 -2
- mindspore/train/callback/_callback.py +4 -4
- mindspore/train/callback/_checkpoint.py +85 -22
- mindspore/train/callback/_cluster_monitor.py +1 -1
- mindspore/train/callback/_flops_collector.py +1 -0
- mindspore/train/callback/_loss_monitor.py +3 -3
- mindspore/train/callback/_on_request_exit.py +134 -31
- mindspore/train/callback/_summary_collector.py +5 -5
- mindspore/train/callback/_tft_register.py +352 -0
- mindspore/train/dataset_helper.py +7 -3
- mindspore/train/metrics/metric.py +3 -3
- mindspore/train/metrics/roc.py +4 -4
- mindspore/train/mind_ir_pb2.py +44 -39
- mindspore/train/model.py +134 -58
- mindspore/train/serialization.py +336 -112
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +21 -0
- mindspore/utils/utils.py +60 -0
- mindspore/version.py +1 -1
- {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/METADATA +6 -2
- {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/RECORD +258 -252
- mindspore/include/c_api/ms/abstract.h +0 -67
- mindspore/include/c_api/ms/attribute.h +0 -197
- mindspore/include/c_api/ms/base/handle_types.h +0 -43
- mindspore/include/c_api/ms/base/macros.h +0 -32
- mindspore/include/c_api/ms/base/status.h +0 -33
- mindspore/include/c_api/ms/base/types.h +0 -283
- mindspore/include/c_api/ms/context.h +0 -102
- mindspore/include/c_api/ms/graph.h +0 -160
- mindspore/include/c_api/ms/node.h +0 -606
- mindspore/include/c_api/ms/tensor.h +0 -161
- mindspore/include/c_api/ms/value.h +0 -84
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/nn/extend/basic.py +0 -140
- mindspore/nn/extend/embedding.py +0 -143
- mindspore/nn/extend/layer/normalization.py +0 -109
- mindspore/nn/extend/pooling.py +0 -117
- mindspore/nn/layer/embedding_service.py +0 -531
- mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +0 -93
- mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +0 -66
- mindspore/ops/extend/__init__.py +0 -53
- mindspore/ops/extend/array_func.py +0 -218
- mindspore/ops/extend/math_func.py +0 -76
- mindspore/ops/extend/nn_func.py +0 -308
- mindspore/ops/silent_check.py +0 -162
- mindspore/profiler/parser/msadvisor_analyzer.py +0 -82
- mindspore/profiler/parser/msadvisor_parser.py +0 -240
- mindspore/train/callback/_mindio_ttp.py +0 -443
- {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/WHEEL +0 -0
- {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/entry_points.txt +0 -0
- {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/top_level.txt +0 -0
|
@@ -35,8 +35,8 @@ from ..auto_generate import (AbsGrad, ACosGrad, LogitGrad, AcoshGrad, AsinGrad,
|
|
|
35
35
|
SigmoidGrad, HSwishGrad, NLLLossGrad, AtanGrad, GridSampler3DGrad, GridSampler2DGrad,
|
|
36
36
|
ResizeBicubicGrad, HSigmoidGrad, CholeskyGrad, ResizeNearestNeighborGrad, LayerNormGrad,
|
|
37
37
|
HShrinkGrad, LayerNormGradGrad, SiLUGrad, MaximumGrad, MaximumGradGrad, RmsNormGrad,
|
|
38
|
-
FlashAttentionScoreGrad, UpsampleTrilinear3DGrad, UpsampleNearest3DGrad,
|
|
39
|
-
BinaryCrossEntropyGrad)
|
|
38
|
+
FlashAttentionScoreGrad, UpsampleTrilinear3DGrad, UpsampleNearest3DGrad, MaskedSelectGrad,
|
|
39
|
+
BinaryCrossEntropyGrad, SoftShrinkGrad, SeluGrad)
|
|
40
40
|
|
|
41
41
|
|
|
42
42
|
class SparseFillEmptyRowsGrad(Primitive):
|
|
@@ -1658,35 +1658,6 @@ class SoftMarginLossGrad(Primitive):
|
|
|
1658
1658
|
self.reduction = validator.check_string(reduction, ['none', 'sum', 'mean'], 'reduction', self.name)
|
|
1659
1659
|
|
|
1660
1660
|
|
|
1661
|
-
class StridedSliceV2Grad(Primitive):
|
|
1662
|
-
"""
|
|
1663
|
-
Performs grad of StridedSliceV2 operation.
|
|
1664
|
-
|
|
1665
|
-
Inputs:
|
|
1666
|
-
- **shapex** (Tensor) - StridedSliceV2 shape of input
|
|
1667
|
-
- **begin** (tuple[int]) - A tuple which represents the location where to start. Only
|
|
1668
|
-
constant value is allowed.
|
|
1669
|
-
- **end** (tuple[int]) - A tuple or which represents the maximum location where to end.
|
|
1670
|
-
Only constant value is allowed.
|
|
1671
|
-
- **strides** (tuple[int]) - A tuple which represents the stride is continuously added
|
|
1672
|
-
before reaching the maximum location. Only constant value is allowed.
|
|
1673
|
-
- **dy** (Tensor) - The output of StridedSliceV2
|
|
1674
|
-
|
|
1675
|
-
Outputs:
|
|
1676
|
-
Tensor, the shape same as the input of StridedSliceV2
|
|
1677
|
-
"""
|
|
1678
|
-
|
|
1679
|
-
@prim_attr_register
|
|
1680
|
-
def __init__(self,
|
|
1681
|
-
begin_mask=0,
|
|
1682
|
-
end_mask=0,
|
|
1683
|
-
ellipsis_mask=0,
|
|
1684
|
-
new_axis_mask=0,
|
|
1685
|
-
shrink_axis_mask=0):
|
|
1686
|
-
"""Initialize StridedSliceV2Grad"""
|
|
1687
|
-
self.init_prim_io_names(inputs=['shapex', 'begin', 'end', 'strides', 'dy'], outputs=['output'])
|
|
1688
|
-
|
|
1689
|
-
|
|
1690
1661
|
class StridedSliceGrad(Primitive):
|
|
1691
1662
|
"""
|
|
1692
1663
|
Performs grad of StridedSlice operation.
|
|
@@ -1991,51 +1962,6 @@ class MvlgammaGrad(Primitive):
|
|
|
1991
1962
|
self.p = validator.check_value_type('p', p, [int], self.name)
|
|
1992
1963
|
|
|
1993
1964
|
|
|
1994
|
-
class MaskedSelectGrad(PrimitiveWithInfer):
|
|
1995
|
-
"""Computes gradient for MaskedSelect."""
|
|
1996
|
-
|
|
1997
|
-
@prim_attr_register
|
|
1998
|
-
def __init__(self):
|
|
1999
|
-
pass
|
|
2000
|
-
|
|
2001
|
-
def infer_shape(self, x, mask, grad):
|
|
2002
|
-
return x
|
|
2003
|
-
|
|
2004
|
-
def infer_dtype(self, x, mask, grad):
|
|
2005
|
-
return x
|
|
2006
|
-
|
|
2007
|
-
|
|
2008
|
-
class SoftShrinkGrad(Primitive):
|
|
2009
|
-
r"""
|
|
2010
|
-
Gradients for SoftShrink operation.
|
|
2011
|
-
|
|
2012
|
-
Args:
|
|
2013
|
-
lambd – The \lambdaλ (must be no less than zero) value for the Softshrink formulation. Default: 0.5.
|
|
2014
|
-
|
|
2015
|
-
Inputs:
|
|
2016
|
-
- **input_grad** (Tensor) - The input gradient.
|
|
2017
|
-
- **input_x** (Tensor) - The input of SoftShrink with data type of float16 or float32.
|
|
2018
|
-
Any number of additional dimensions.
|
|
2019
|
-
|
|
2020
|
-
Outputs:
|
|
2021
|
-
output - Tensor, has the same shape and data type as input_x.
|
|
2022
|
-
|
|
2023
|
-
Raises:
|
|
2024
|
-
TypeError: If lambd is not a float.
|
|
2025
|
-
TypeError: If dtype of input_x is neither float16 nor float32.
|
|
2026
|
-
ValueError: If lambd is less than to 0.
|
|
2027
|
-
|
|
2028
|
-
Supported Platforms:
|
|
2029
|
-
``Ascend``
|
|
2030
|
-
"""
|
|
2031
|
-
|
|
2032
|
-
@prim_attr_register
|
|
2033
|
-
def __init__(self, lambd=0.5):
|
|
2034
|
-
self.init_prim_io_names(inputs=['input_grad', 'input_x'], outputs=['output'])
|
|
2035
|
-
validator.check_value_type("lambd", lambd, [float], self.name)
|
|
2036
|
-
validator.check_number("lambd", lambd, 0, validator.GE, self.name)
|
|
2037
|
-
|
|
2038
|
-
|
|
2039
1965
|
class CdistGrad(Primitive):
|
|
2040
1966
|
"""Computes gradient for Cdist."""
|
|
2041
1967
|
|
|
@@ -16,4 +16,4 @@
|
|
|
16
16
|
"""Operator of infer net"""
|
|
17
17
|
# pylint: disable=unused-import
|
|
18
18
|
from ..auto_generate import (QuantV2, DynamicQuantExt, QuantBatchMatmul, WeightQuantBatchMatmul, KVCacheScatterUpdate,
|
|
19
|
-
FusedInferAttentionScore, GroupedMatmul, MoeFinalizeRouting)
|
|
19
|
+
FusedInferAttentionScore, GroupedMatmul, MoeFinalizeRouting, QuantLinearSparse)
|
|
@@ -17,6 +17,7 @@
|
|
|
17
17
|
from types import FunctionType, MethodType
|
|
18
18
|
from collections.abc import Iterable
|
|
19
19
|
import os
|
|
20
|
+
import weakref
|
|
20
21
|
import numpy as np
|
|
21
22
|
|
|
22
23
|
from mindspore.common import Tensor
|
|
@@ -29,7 +30,7 @@ from mindspore.ops.operations.math_ops import _infer_shape_reduce
|
|
|
29
30
|
from mindspore.ops.primitive import PrimitiveWithCheck, PrimitiveWithInfer, prim_attr_register, Primitive, \
|
|
30
31
|
_run_op, _check_contains_variable
|
|
31
32
|
from mindspore._c_expression import Tensor as Tensor_
|
|
32
|
-
from mindspore._c_expression import typing
|
|
33
|
+
from mindspore._c_expression import typing, HookType
|
|
33
34
|
from mindspore import _checkparam as validator
|
|
34
35
|
from mindspore.common import dtype as mstype
|
|
35
36
|
from mindspore.common.parameter import Parameter
|
|
@@ -1535,7 +1536,7 @@ class CellBackwardHook(PrimitiveWithInfer):
|
|
|
1535
1536
|
... print(grad)
|
|
1536
1537
|
...
|
|
1537
1538
|
>>> hook = inner.CellBackwardHook()
|
|
1538
|
-
>>> hook_fn_key = hook.register_backward_hook(
|
|
1539
|
+
>>> hook_fn_key = hook.register_backward_hook()
|
|
1539
1540
|
>>> def hook_test(x, y):
|
|
1540
1541
|
... z = x * y
|
|
1541
1542
|
... z = hook(z)
|
|
@@ -1556,16 +1557,19 @@ class CellBackwardHook(PrimitiveWithInfer):
|
|
|
1556
1557
|
(Tensor(shape=[], dtype=Float32, value= 4), Tensor(shape=[], dtype=Float32, value= 4))
|
|
1557
1558
|
"""
|
|
1558
1559
|
|
|
1559
|
-
def __init__(self, cell_id=""):
|
|
1560
|
+
def __init__(self, cell_id="", cell=None, hook_dict=None):
|
|
1560
1561
|
"""Initialize CellBackwardHook"""
|
|
1561
1562
|
super(CellBackwardHook, self).__init__(self.__class__.__name__)
|
|
1562
1563
|
self.cell_id = cell_id
|
|
1564
|
+
self.cell = cell
|
|
1565
|
+
self.hook_dict = weakref.ref(hook_dict)
|
|
1563
1566
|
self.add_prim_attr("cell_id", cell_id)
|
|
1564
|
-
self.
|
|
1567
|
+
self.grad_output = None
|
|
1565
1568
|
|
|
1566
|
-
def __call__(self, args):
|
|
1567
|
-
|
|
1568
|
-
|
|
1569
|
+
def __call__(self, *args):
|
|
1570
|
+
# If args is empty, just return.
|
|
1571
|
+
if not args:
|
|
1572
|
+
return args
|
|
1569
1573
|
return _run_op(self, self.name, args)
|
|
1570
1574
|
|
|
1571
1575
|
def infer_shape(self, *inputs_shape):
|
|
@@ -1578,51 +1582,76 @@ class CellBackwardHook(PrimitiveWithInfer):
|
|
|
1578
1582
|
return inputs_type[0]
|
|
1579
1583
|
return inputs_type
|
|
1580
1584
|
|
|
1581
|
-
def register_backward_hook(self
|
|
1582
|
-
|
|
1583
|
-
|
|
1584
|
-
mode.
|
|
1585
|
-
|
|
1586
|
-
Note:
|
|
1587
|
-
The 'hook_fn' must be defined as the following code.
|
|
1588
|
-
`cell_id` is the information of registered cell. `grad_input` is the gradient passed to the cell.
|
|
1589
|
-
`grad_output` is the gradient computed and passed to the next cell or primitive, which may be modified by
|
|
1590
|
-
returning a new output gradient.
|
|
1591
|
-
The 'hook_fn' should have the following signature:
|
|
1592
|
-
hook_fn(cell_id, grad_input, grad_output) -> New output gradient or none.
|
|
1593
|
-
The 'hook_fn' is executed in the python environment.
|
|
1585
|
+
def register_backward_hook(self):
|
|
1586
|
+
"""
|
|
1587
|
+
Register the backward hook function.
|
|
1594
1588
|
|
|
1595
1589
|
Args:
|
|
1596
|
-
|
|
1590
|
+
None
|
|
1597
1591
|
|
|
1598
1592
|
Returns:
|
|
1599
|
-
|
|
1593
|
+
None
|
|
1600
1594
|
|
|
1601
|
-
|
|
1602
|
-
|
|
1595
|
+
Supported Platforms:
|
|
1596
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
1603
1597
|
"""
|
|
1604
|
-
if not isinstance(hook_fn, (FunctionType, MethodType)):
|
|
1605
|
-
raise TypeError(f"When using 'register_backward_hook(hook_fn)', the type of 'hook_fn' must be python "
|
|
1606
|
-
f"function, but got {type(hook_fn)}.")
|
|
1607
|
-
key = self.add_backward_hook_fn(hook_fn)
|
|
1608
|
-
return key
|
|
1609
1598
|
|
|
1610
|
-
|
|
1611
|
-
|
|
1612
|
-
|
|
1613
|
-
|
|
1614
|
-
|
|
1615
|
-
|
|
1616
|
-
|
|
1617
|
-
|
|
1599
|
+
def hook_backward_grad(grad):
|
|
1600
|
+
if self.grad_output is None:
|
|
1601
|
+
self.grad_output = grad
|
|
1602
|
+
# Indicates the first time of call backward hook, and need to wait for the second time call
|
|
1603
|
+
return self.cell_id
|
|
1604
|
+
backward_hook_grad_input = grad
|
|
1605
|
+
if self.hook_dict():
|
|
1606
|
+
backward_hooks = self.hook_dict().values()
|
|
1607
|
+
for hook in backward_hooks:
|
|
1608
|
+
res = hook(self.cell, backward_hook_grad_input, self.grad_output)
|
|
1609
|
+
if res is None:
|
|
1610
|
+
continue
|
|
1611
|
+
if not isinstance(res, tuple):
|
|
1612
|
+
res = (res,)
|
|
1613
|
+
if len(res) != len(grad):
|
|
1614
|
+
raise TypeError(
|
|
1615
|
+
"The backward hook return value size is {} not equal to expect grad input size {}".format(
|
|
1616
|
+
len(res), len(grad)))
|
|
1617
|
+
backward_hook_grad_input = res
|
|
1618
|
+
self.grad_output = None
|
|
1619
|
+
return backward_hook_grad_input
|
|
1620
|
+
|
|
1621
|
+
self.set_hook_fn(hook_backward_grad, HookType.BackwardHook)
|
|
1622
|
+
|
|
1623
|
+
def register_backward_pre_hook(self):
|
|
1624
|
+
"""
|
|
1625
|
+
Register the backward pre hook function.
|
|
1618
1626
|
|
|
1619
1627
|
Args:
|
|
1620
|
-
|
|
1628
|
+
None
|
|
1621
1629
|
|
|
1622
1630
|
Returns:
|
|
1623
|
-
None
|
|
1631
|
+
None
|
|
1632
|
+
|
|
1633
|
+
Supported Platforms:
|
|
1634
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
1624
1635
|
"""
|
|
1625
|
-
|
|
1636
|
+
|
|
1637
|
+
def hook_backward_pre_grad(grad):
|
|
1638
|
+
backward_pre_hook_grad = grad
|
|
1639
|
+
if self.hook_dict():
|
|
1640
|
+
backward_pre_hooks = self.hook_dict().values()
|
|
1641
|
+
for hook in backward_pre_hooks:
|
|
1642
|
+
res = hook(self.cell, backward_pre_hook_grad)
|
|
1643
|
+
if res is None:
|
|
1644
|
+
continue
|
|
1645
|
+
if not isinstance(res, tuple):
|
|
1646
|
+
res = (res,)
|
|
1647
|
+
if len(res) != len(grad):
|
|
1648
|
+
raise TypeError(
|
|
1649
|
+
"The backward pre hook return value size is {} not equal to expect output size {}".format(
|
|
1650
|
+
len(res), len(grad)))
|
|
1651
|
+
backward_pre_hook_grad = res
|
|
1652
|
+
return backward_pre_hook_grad
|
|
1653
|
+
|
|
1654
|
+
self.set_hook_fn(hook_backward_pre_grad, HookType.BackwardPreHook)
|
|
1626
1655
|
|
|
1627
1656
|
|
|
1628
1657
|
class Format(PrimitiveWithInfer):
|
|
@@ -2478,60 +2507,6 @@ class FFN(Primitive):
|
|
|
2478
2507
|
validator.check_value_type("inner_precise", inner_precise, [int], cls_name)
|
|
2479
2508
|
|
|
2480
2509
|
|
|
2481
|
-
class _MirrorSilentCheck(PrimitiveWithInfer):
|
|
2482
|
-
"""
|
|
2483
|
-
The operator _MirrorSilentCheck implements accuracy-sensitive detection on the tensor input in backpropagator.
|
|
2484
|
-
Call _MirrorSilentCheck in method __call__ of derived class to implement accuracy-sensitive detection.
|
|
2485
|
-
|
|
2486
|
-
Inputs:
|
|
2487
|
-
- **input** (Tensor) : The tensor used for detection.
|
|
2488
|
-
Its data type must be mindspore.float16, mindspore.float32 or mindspore.bfloat16.
|
|
2489
|
-
- **pre_val** (Parameter(Tensor)) : Support parameter in accuracy-sensitive detection.
|
|
2490
|
-
Please only generated by method generate_params() of ASDBase.
|
|
2491
|
-
- **min_val** (Parameter(Tensor)) : Support parameter in accuracy-sensitive detection.
|
|
2492
|
-
Please only generated by method generate_params() of ASDBase.
|
|
2493
|
-
- **max_val** (Parameter(Tensor)) : Support parameter in accuracy-sensitive detection.
|
|
2494
|
-
Please only generated by method generate_params() of ASDBase.
|
|
2495
|
-
- **cnt** (Parameter(Tensor)) : Support parameter in accuracy-sensitive detection.
|
|
2496
|
-
Please only generated by method generate_params() of ASDBase.
|
|
2497
|
-
After each invocation of _MirrorSilentCheck, increment the value of cnt by one.
|
|
2498
|
-
|
|
2499
|
-
Outputs:
|
|
2500
|
-
- **output** (Tensor) - Same shape, type and value as `input`.
|
|
2501
|
-
"""
|
|
2502
|
-
@prim_attr_register
|
|
2503
|
-
def __init__(self, min_steps=8):
|
|
2504
|
-
upper_thresh, sigma_thresh = self.get_thresh()
|
|
2505
|
-
self.min_steps = min_steps
|
|
2506
|
-
self.thresh_l1 = upper_thresh[0]
|
|
2507
|
-
self.coeff_l1 = sigma_thresh[0]
|
|
2508
|
-
self.thresh_l2 = upper_thresh[1]
|
|
2509
|
-
self.coeff_l2 = sigma_thresh[1]
|
|
2510
|
-
self.add_prim_attr('side_effect_mem', True)
|
|
2511
|
-
|
|
2512
|
-
def parse_thresh(self, env_var_name, default_value, min_value):
|
|
2513
|
-
env_var = os.environ.get(env_var_name, default=default_value)
|
|
2514
|
-
thresh = [value.strip() for value in env_var.split(",")]
|
|
2515
|
-
if len(thresh) != 2 or not all(value.isdigit() for value in thresh):
|
|
2516
|
-
thresh = default_value.split(",")
|
|
2517
|
-
thresh = [float(max(int(value), min_value)) for value in thresh]
|
|
2518
|
-
if thresh[0] <= thresh[1]:
|
|
2519
|
-
thresh = [float(value) for value in default_value.split(",")]
|
|
2520
|
-
|
|
2521
|
-
return thresh
|
|
2522
|
-
|
|
2523
|
-
def get_thresh(self):
|
|
2524
|
-
upper_thresh = self.parse_thresh("NPU_ASD_UPPER_THRESH", "1000000,10000", 3)
|
|
2525
|
-
sigma_thresh = self.parse_thresh("NPU_ASD_SIGMA_THRESH", "100000,5000", 3)
|
|
2526
|
-
return upper_thresh, sigma_thresh
|
|
2527
|
-
|
|
2528
|
-
def infer_shape(self, x_shape, pre_shape, min_shape, max_shape, n_step, loss_scale_shape):
|
|
2529
|
-
return x_shape
|
|
2530
|
-
|
|
2531
|
-
def infer_dtype(self, x_dtype, pre_dtype, min_dtype, max_dtype, n_dtype, loss_scale_dtype):
|
|
2532
|
-
return x_dtype
|
|
2533
|
-
|
|
2534
|
-
|
|
2535
2510
|
class _VirtualConverterEnd(PrimitiveWithInfer):
|
|
2536
2511
|
"""
|
|
2537
2512
|
Auto parallel virtual operator.
|
|
@@ -2560,6 +2535,8 @@ class _VirtualConverterBegin(PrimitiveWithInfer):
|
|
|
2560
2535
|
self.output_nums = output_nums
|
|
2561
2536
|
|
|
2562
2537
|
def infer_shape(self, arg):
|
|
2538
|
+
if self.output_nums == 0:
|
|
2539
|
+
return ValueError("output_nums can\'t be zero.")
|
|
2563
2540
|
new_arg = (arg[0] / self.output_nums,) + tuple(arg[1:])
|
|
2564
2541
|
return (new_arg,) * self.output_nums
|
|
2565
2542
|
|
|
@@ -39,11 +39,12 @@ from ..auto_generate import (ExpandDims, Reshape, TensorShape, Transpose, Gather
|
|
|
39
39
|
OnesLike, ZerosLike, Argmax, ArgMaxExt,
|
|
40
40
|
ReverseV2, Diag, Eye, ScatterNd, ResizeNearestNeighborV2,
|
|
41
41
|
GatherNd, GatherD, Range, MaskedFill, RightShift, NonZero,
|
|
42
|
-
ResizeNearestNeighbor, Identity, Split, CumSum, CumProd,
|
|
42
|
+
ResizeNearestNeighbor, Identity, Split, CumSum, CumProd, MaskedSelect,
|
|
43
43
|
Cummax, Cummin, Argmin, Concat, UnsortedSegmentSum, ScalarToTensor,
|
|
44
44
|
Triu, BroadcastTo, StridedSlice, Select, TopkExt, SearchSorted)
|
|
45
45
|
from .manually_defined import Rank, Shape, Tile, Cast, Ones, Zeros
|
|
46
46
|
from ..auto_generate import ArgMaxWithValue, ArgMinWithValue
|
|
47
|
+
from ..auto_generate import TensorScatterElements as TensorScatterElementsExt
|
|
47
48
|
|
|
48
49
|
class _ScatterOp(PrimitiveWithInfer):
|
|
49
50
|
"""
|
|
@@ -769,41 +770,13 @@ class Padding(Primitive):
|
|
|
769
770
|
|
|
770
771
|
class UniqueWithPad(Primitive):
|
|
771
772
|
"""
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
The basic function is the same as the Unique operator, but the UniqueWithPad operator adds a Pad function.
|
|
775
|
-
The returned tuple(`y`, `idx`) after the input Tensor `x` is processed by the unique operator,
|
|
776
|
-
in which the shapes of `y` and `idx` are mostly not equal. Therefore, in order to solve the above situation,
|
|
777
|
-
the UniqueWithPad operator will fill the `y` Tensor with the `pad_num` specified by the user
|
|
778
|
-
to make it have the same shape as the Tensor `idx`.
|
|
779
|
-
|
|
780
|
-
Refer to :func:`mindspore.ops.unique_with_pad` for more details.
|
|
781
|
-
|
|
782
|
-
Inputs:
|
|
783
|
-
- **x** (Tensor) - The tensor need to be unique. Must be 1-D vector with types: int32, int64.
|
|
784
|
-
- **pad_num** (int) - Pad num. The data type is an int.
|
|
785
|
-
|
|
786
|
-
Outputs:
|
|
787
|
-
tuple(Tensor), tuple of 2 tensors, `y` and `idx`.
|
|
788
|
-
|
|
789
|
-
- y (Tensor) - The unique elements filled with pad_num, the shape and data type same as `x`.
|
|
790
|
-
- idx (Tensor) - The index of each value of `x` in the unique output `y`, the shape and data type same as `x`.
|
|
773
|
+
'ops.UniqueWithPad' is deprecated from version 2.4 and will be removed in a future version.
|
|
791
774
|
|
|
792
775
|
Supported Platforms:
|
|
793
|
-
|
|
794
|
-
|
|
795
|
-
Examples:
|
|
796
|
-
>>> import mindspore
|
|
797
|
-
>>> import numpy as np
|
|
798
|
-
>>> from mindspore import Tensor, ops
|
|
799
|
-
>>> x = Tensor(np.array([1, 1, 2, 2, 3, 3, 4, 5]), mindspore.int32)
|
|
800
|
-
>>> pad_num = 8
|
|
801
|
-
>>> output = ops.UniqueWithPad()(x, pad_num)
|
|
802
|
-
>>> print(output)
|
|
803
|
-
(Tensor(shape=[8], dtype=Int32, value= [1, 2, 3, 4, 5, 8, 8, 8]),
|
|
804
|
-
Tensor(shape=[8], dtype=Int32, value= [0, 0, 1, 1, 2, 2, 3, 4]))
|
|
776
|
+
Deprecated
|
|
805
777
|
"""
|
|
806
778
|
|
|
779
|
+
@deprecated("2.4", "ops.Unique and ops.PadV3", False)
|
|
807
780
|
@prim_attr_register
|
|
808
781
|
def __init__(self):
|
|
809
782
|
"""init UniqueWithPad"""
|
|
@@ -819,7 +792,7 @@ class Size(Primitive):
|
|
|
819
792
|
|
|
820
793
|
Inputs:
|
|
821
794
|
- **input_x** (Tensor) - Input parameters, the shape of tensor is :math:`(x_1, x_2, ..., x_R)`. The data type is
|
|
822
|
-
`number <https://www.mindspore.cn/docs/en/master/api_python/mindspore
|
|
795
|
+
`number <https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.dtype.html>`_.
|
|
823
796
|
|
|
824
797
|
Outputs:
|
|
825
798
|
int. A scalar representing the elements' size of `input_x`, tensor is the number of elements
|
|
@@ -2112,60 +2085,6 @@ class Rint(Primitive):
|
|
|
2112
2085
|
self.init_prim_io_names(inputs=['x'], outputs=['output'])
|
|
2113
2086
|
|
|
2114
2087
|
|
|
2115
|
-
class StridedSliceV2(Primitive):
|
|
2116
|
-
r"""
|
|
2117
|
-
StridedSliceV2 will be deprecated by StridedSlice in the future.
|
|
2118
|
-
Extracts a strided slice of a tensor.
|
|
2119
|
-
Refer to class StridedSlice for more details.
|
|
2120
|
-
|
|
2121
|
-
Args:
|
|
2122
|
-
begin_mask (int): Starting index of the slice. Default: ``0`` .
|
|
2123
|
-
end_mask (int): Ending index of the slice. Default: ``0`` .
|
|
2124
|
-
ellipsis_mask (int): An int mask. Default: ``0`` .
|
|
2125
|
-
new_axis_mask (int): An int mask. Default: ``0`` .
|
|
2126
|
-
shrink_axis_mask (int): An int mask. Default: ``0`` .
|
|
2127
|
-
|
|
2128
|
-
Inputs:
|
|
2129
|
-
- **input_x** (Tensor) - The input Tensor.
|
|
2130
|
-
- **begin** (tuple[int]) - A tuple which represents the location where to start. Only
|
|
2131
|
-
constant value is allowed.
|
|
2132
|
-
- **end** (tuple[int]) - A tuple or which represents the maximum location where to end.
|
|
2133
|
-
Only constant value is allowed.
|
|
2134
|
-
- **strides** (tuple[int]) - A tuple which represents the stride is continuously added
|
|
2135
|
-
before reaching the maximum location. Only constant value is allowed.
|
|
2136
|
-
|
|
2137
|
-
Outputs:
|
|
2138
|
-
Tensor, The output is explained by following example.
|
|
2139
|
-
|
|
2140
|
-
Raises:
|
|
2141
|
-
TypeError: If `begin_mask`, `end_mask`, `ellipsis_mask`, `new_axis_mask` or `shrink_axis_mask` is not an int.
|
|
2142
|
-
TypeError: If `begin`, `end` or `strides` is not a tuple.
|
|
2143
|
-
ValueError: If `begin_mask`, `end_mask`, `ellipsis_mask`, `new_axis_mask` or `shrink_axis_mask` is less than 0.
|
|
2144
|
-
|
|
2145
|
-
Supported Platforms:
|
|
2146
|
-
``Ascend`` ``CPU``
|
|
2147
|
-
|
|
2148
|
-
Examples:
|
|
2149
|
-
>>> input_x = Tensor([[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]],
|
|
2150
|
-
... [[5, 5, 5], [6, 6, 6]]], mindspore.float32)
|
|
2151
|
-
>>> strided_slice_v2 = ops.StridedSliceV2()
|
|
2152
|
-
>>> output = strided_slice_v2(input_x, (1, 0, 2), (3, 1, 3), (1, 1, 1))
|
|
2153
|
-
>>> print(output)
|
|
2154
|
-
[[[3.]]
|
|
2155
|
-
[[5.]]]
|
|
2156
|
-
"""
|
|
2157
|
-
|
|
2158
|
-
@prim_attr_register
|
|
2159
|
-
def __init__(self,
|
|
2160
|
-
begin_mask=0,
|
|
2161
|
-
end_mask=0,
|
|
2162
|
-
ellipsis_mask=0,
|
|
2163
|
-
new_axis_mask=0,
|
|
2164
|
-
shrink_axis_mask=0):
|
|
2165
|
-
"""Initialize StridedSliceV2"""
|
|
2166
|
-
self.init_prim_io_names(inputs=['x', 'begin', 'end', 'strides'], outputs=['output'])
|
|
2167
|
-
|
|
2168
|
-
|
|
2169
2088
|
class DiagPart(PrimitiveWithCheck):
|
|
2170
2089
|
r"""
|
|
2171
2090
|
|
|
@@ -4356,53 +4275,6 @@ class MaskedScatter(Primitive):
|
|
|
4356
4275
|
self.init_prim_io_names(inputs=['x', 'mask', 'updates'], outputs=['y'])
|
|
4357
4276
|
|
|
4358
4277
|
|
|
4359
|
-
class MaskedSelect(PrimitiveWithCheck):
|
|
4360
|
-
"""
|
|
4361
|
-
Returns a new 1-D Tensor which indexes the `x` tensor according to the boolean `mask`.
|
|
4362
|
-
The shapes of the `mask` tensor and the `x` tensor don't need to match, but they must be broadcastable.
|
|
4363
|
-
|
|
4364
|
-
Inputs:
|
|
4365
|
-
- **x** (Tensor) - Input Tensor of any dimension.
|
|
4366
|
-
- **mask** (Tensor[bool]) - Boolean mask Tensor, has the same shape as `x`.
|
|
4367
|
-
|
|
4368
|
-
Outputs:
|
|
4369
|
-
A 1-D Tensor, with the same type as x.
|
|
4370
|
-
|
|
4371
|
-
Raises:
|
|
4372
|
-
TypeError: If `x` or `mask` is not a Tensor.
|
|
4373
|
-
TypeError: If dtype of `mask` is not bool.
|
|
4374
|
-
|
|
4375
|
-
Supported Platforms:
|
|
4376
|
-
``Ascend`` ``GPU`` ``CPU``
|
|
4377
|
-
|
|
4378
|
-
Examples:
|
|
4379
|
-
>>> import mindspore
|
|
4380
|
-
>>> import numpy as np
|
|
4381
|
-
>>> from mindspore import Tensor, ops
|
|
4382
|
-
>>> x = Tensor(np.array([1, 2, 3, 4]), mindspore.int32)
|
|
4383
|
-
>>> mask = Tensor(np.array([1, 0, 1, 0]), mindspore.bool_)
|
|
4384
|
-
>>> output = ops.MaskedSelect()(x, mask)
|
|
4385
|
-
>>> print(output)
|
|
4386
|
-
[1 3]
|
|
4387
|
-
>>> x = Tensor(2.1, mindspore.float32)
|
|
4388
|
-
>>> mask = Tensor(True, mindspore.bool_)
|
|
4389
|
-
>>> output = ops.MaskedSelect()(x, mask)
|
|
4390
|
-
>>> print(output)
|
|
4391
|
-
[2.1]
|
|
4392
|
-
"""
|
|
4393
|
-
|
|
4394
|
-
@prim_attr_register
|
|
4395
|
-
def __init__(self):
|
|
4396
|
-
self.init_prim_io_names(inputs=['x', 'mask'], outputs=['output'])
|
|
4397
|
-
|
|
4398
|
-
def check_shape(self, x_shape, mask_shape):
|
|
4399
|
-
get_broadcast_shape(x_shape, mask_shape, self.name, arg_name1="x", arg_name2="mask")
|
|
4400
|
-
|
|
4401
|
-
def check_dtype(self, x_dtype, mask_dtype):
|
|
4402
|
-
validator.check_tensor_dtype_valid('mask', mask_dtype, [mstype.bool_], self.name)
|
|
4403
|
-
validator.check_tensor_dtype_valid('x', x_dtype, (mstype.bool_,) + mstype.number_type, self.name)
|
|
4404
|
-
|
|
4405
|
-
|
|
4406
4278
|
class _TensorScatterOp(PrimitiveWithInfer):
|
|
4407
4279
|
"""
|
|
4408
4280
|
Defines TensorScatter Base Operators
|
|
@@ -4962,7 +4834,7 @@ class SplitV(Primitive):
|
|
|
4962
4834
|
self.init_prim_io_names(inputs=['input_x'], outputs=['output'])
|
|
4963
4835
|
|
|
4964
4836
|
|
|
4965
|
-
class TensorScatterElements(
|
|
4837
|
+
class TensorScatterElements(TensorScatterElementsExt):
|
|
4966
4838
|
"""
|
|
4967
4839
|
Write all elements in `updates` to the index specified by `indices` in `input_x` according to the reduction
|
|
4968
4840
|
operation specified by `reduction`.
|
|
@@ -4977,6 +4849,9 @@ class TensorScatterElements(Primitive):
|
|
|
4977
4849
|
.. warning::
|
|
4978
4850
|
This is an experimental API that is subject to change or deletion.
|
|
4979
4851
|
|
|
4852
|
+
Note:
|
|
4853
|
+
The backward is supported only for the case `updates.shape == indices.shape`.
|
|
4854
|
+
|
|
4980
4855
|
Args:
|
|
4981
4856
|
axis (int, optional): Specify which axis to do scatter operation. Default: ``0`` .
|
|
4982
4857
|
reduction (str, optional): Which reduction operation to scatter, default is ``"none"`` . Other option: "add".
|
|
@@ -4986,7 +4861,7 @@ class TensorScatterElements(Primitive):
|
|
|
4986
4861
|
- **indices** (Tensor) - The index of `input_x` to do scatter operation whose data type must be int32 or
|
|
4987
4862
|
int64. It has the same rank as `data`. And accepted range is [-s, s) where s is the size along axis.
|
|
4988
4863
|
- **updates** (Tensor) - The tensor doing the scatter operation with `data`,
|
|
4989
|
-
it has the same type as `data
|
|
4864
|
+
it has the same type as `data`.
|
|
4990
4865
|
|
|
4991
4866
|
Outputs:
|
|
4992
4867
|
Tensor, has the same shape and type as `data`.
|
|
@@ -5021,16 +4896,7 @@ class TensorScatterElements(Primitive):
|
|
|
5021
4896
|
|
|
5022
4897
|
@prim_attr_register
|
|
5023
4898
|
def __init__(self, axis=0, reduction="none"):
|
|
5024
|
-
|
|
5025
|
-
validator.check_value_type("axis", axis, [int], self.name)
|
|
5026
|
-
validator.check_value_type("reduction", reduction, [str], self.name)
|
|
5027
|
-
validator.check_string(reduction, ["none", "add"], "reduction", self.name)
|
|
5028
|
-
self.init_prim_io_names(inputs=['data', 'indices', 'updates'], outputs=['y'])
|
|
5029
|
-
target = context.get_context("device_target")
|
|
5030
|
-
if reduction != 'none' and target.lower() == "ascend":
|
|
5031
|
-
raise ValueError(f"For '{self.name}', "
|
|
5032
|
-
f"Currently Ascend device_target only support `reduction`='none', "
|
|
5033
|
-
f"but got {reduction}")
|
|
4899
|
+
super().__init__(axis, reduce=reduction)
|
|
5034
4900
|
|
|
5035
4901
|
|
|
5036
4902
|
class ExtractVolumePatches(Primitive):
|