mindspore 2.3.0__cp39-cp39-win_amd64.whl → 2.4.0__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +3 -1
- mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +50 -9
- mindspore/_extends/parse/compile_config.py +41 -0
- mindspore/_extends/parse/parser.py +9 -7
- mindspore/_extends/parse/standard_method.py +52 -14
- mindspore/_extends/pijit/pijit_func_white_list.py +350 -24
- mindspore/amp.py +24 -10
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/common/__init__.py +6 -4
- mindspore/common/_pijit_context.py +190 -0
- mindspore/common/_register_for_tensor.py +2 -1
- mindspore/common/_tensor_overload.py +139 -0
- mindspore/common/api.py +102 -87
- mindspore/common/dump.py +5 -6
- mindspore/common/generator.py +1 -7
- mindspore/common/hook_handle.py +14 -26
- mindspore/common/mindir_util.py +2 -2
- mindspore/common/parameter.py +46 -13
- mindspore/common/recompute.py +39 -9
- mindspore/common/sparse_tensor.py +7 -3
- mindspore/common/tensor.py +209 -29
- mindspore/communication/__init__.py +1 -1
- mindspore/communication/_comm_helper.py +38 -3
- mindspore/communication/comm_func.py +310 -55
- mindspore/communication/management.py +14 -14
- mindspore/context.py +123 -22
- mindspore/dataset/__init__.py +1 -1
- mindspore/dataset/audio/__init__.py +1 -1
- mindspore/dataset/core/config.py +7 -0
- mindspore/dataset/core/validator_helpers.py +7 -0
- mindspore/dataset/engine/cache_client.py +1 -1
- mindspore/dataset/engine/datasets.py +72 -44
- mindspore/dataset/engine/datasets_audio.py +7 -7
- mindspore/dataset/engine/datasets_standard_format.py +53 -3
- mindspore/dataset/engine/datasets_text.py +20 -20
- mindspore/dataset/engine/datasets_user_defined.py +174 -104
- mindspore/dataset/engine/datasets_vision.py +33 -33
- mindspore/dataset/engine/iterators.py +29 -0
- mindspore/dataset/engine/obs/util.py +7 -0
- mindspore/dataset/engine/queue.py +114 -60
- mindspore/dataset/engine/serializer_deserializer.py +2 -2
- mindspore/dataset/engine/validators.py +34 -14
- mindspore/dataset/text/__init__.py +1 -4
- mindspore/dataset/transforms/__init__.py +0 -3
- mindspore/dataset/utils/line_reader.py +2 -0
- mindspore/dataset/vision/__init__.py +1 -4
- mindspore/dataset/vision/utils.py +1 -1
- mindspore/dataset/vision/validators.py +2 -1
- mindspore/dnnl.dll +0 -0
- mindspore/{nn/extend → experimental/es}/__init__.py +4 -11
- mindspore/experimental/es/embedding_service.py +883 -0
- mindspore/{nn/layer → experimental/es}/embedding_service_layer.py +218 -30
- mindspore/experimental/llm_boost/__init__.py +21 -0
- mindspore/{nn/extend/layer → experimental/llm_boost/atb}/__init__.py +4 -8
- mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
- mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
- mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
- mindspore/experimental/llm_boost/register.py +129 -0
- mindspore/experimental/llm_boost/utils.py +31 -0
- mindspore/experimental/optim/adamw.py +85 -0
- mindspore/experimental/optim/optimizer.py +3 -0
- mindspore/hal/__init__.py +3 -3
- mindspore/hal/contiguous_tensors_handle.py +175 -0
- mindspore/hal/stream.py +18 -0
- mindspore/include/api/model_group.h +13 -1
- mindspore/include/api/types.h +10 -10
- mindspore/include/dataset/config.h +2 -2
- mindspore/include/dataset/constants.h +2 -2
- mindspore/include/dataset/execute.h +2 -2
- mindspore/include/dataset/vision.h +4 -0
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +1 -1
- mindspore/mindrecord/filewriter.py +68 -51
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/mint/__init__.py +495 -46
- mindspore/mint/distributed/__init__.py +31 -0
- mindspore/mint/distributed/distributed.py +254 -0
- mindspore/mint/nn/__init__.py +266 -21
- mindspore/mint/nn/functional.py +125 -19
- mindspore/mint/nn/layer/__init__.py +39 -0
- mindspore/mint/nn/layer/activation.py +133 -0
- mindspore/mint/nn/layer/normalization.py +477 -0
- mindspore/mint/nn/layer/pooling.py +110 -0
- mindspore/mint/optim/adamw.py +28 -7
- mindspore/mint/special/__init__.py +63 -0
- mindspore/multiprocessing/__init__.py +2 -1
- mindspore/nn/__init__.py +0 -1
- mindspore/nn/cell.py +275 -93
- mindspore/nn/layer/activation.py +211 -44
- mindspore/nn/layer/basic.py +113 -3
- mindspore/nn/layer/embedding.py +120 -2
- mindspore/nn/layer/normalization.py +101 -5
- mindspore/nn/layer/padding.py +34 -48
- mindspore/nn/layer/pooling.py +161 -7
- mindspore/nn/layer/transformer.py +3 -3
- mindspore/nn/loss/__init__.py +2 -2
- mindspore/nn/loss/loss.py +84 -6
- mindspore/nn/optim/__init__.py +2 -1
- mindspore/nn/optim/adadelta.py +1 -1
- mindspore/nn/optim/adam.py +1 -1
- mindspore/nn/optim/lamb.py +1 -1
- mindspore/nn/optim/tft_wrapper.py +127 -0
- mindspore/nn/wrap/cell_wrapper.py +12 -23
- mindspore/nn/wrap/grad_reducer.py +5 -5
- mindspore/nn/wrap/loss_scale.py +17 -3
- mindspore/numpy/__init__.py +1 -1
- mindspore/numpy/array_creations.py +65 -68
- mindspore/numpy/array_ops.py +64 -60
- mindspore/numpy/fft.py +610 -75
- mindspore/numpy/logic_ops.py +11 -10
- mindspore/numpy/math_ops.py +85 -84
- mindspore/numpy/utils_const.py +4 -4
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +6 -4
- mindspore/ops/_grad_experimental/grad_comm_ops.py +47 -3
- mindspore/ops/_grad_experimental/grad_math_ops.py +0 -22
- mindspore/ops/_vmap/vmap_array_ops.py +2 -4
- mindspore/ops/_vmap/vmap_math_ops.py +17 -1
- mindspore/ops/_vmap/vmap_nn_ops.py +43 -2
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +85 -7
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +2 -0
- mindspore/ops/auto_generate/gen_extend_func.py +734 -13
- mindspore/ops/auto_generate/gen_ops_def.py +2420 -381
- mindspore/ops/auto_generate/gen_ops_prim.py +5196 -1659
- mindspore/ops/auto_generate/pyboost_inner_prim.py +176 -56
- mindspore/ops/composite/base.py +85 -48
- mindspore/ops/composite/multitype_ops/_compile_utils.py +1 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -2
- mindspore/ops/function/__init__.py +22 -0
- mindspore/ops/function/array_func.py +490 -153
- mindspore/ops/function/debug_func.py +113 -1
- mindspore/ops/function/fft_func.py +15 -2
- mindspore/ops/function/grad/grad_func.py +3 -2
- mindspore/ops/function/math_func.py +558 -207
- mindspore/ops/function/nn_func.py +817 -383
- mindspore/ops/function/other_func.py +3 -2
- mindspore/ops/function/random_func.py +184 -8
- mindspore/ops/function/reshard_func.py +13 -11
- mindspore/ops/function/sparse_unary_func.py +1 -1
- mindspore/ops/function/vmap_func.py +3 -2
- mindspore/ops/functional.py +24 -14
- mindspore/ops/op_info_register.py +3 -3
- mindspore/ops/operations/__init__.py +6 -1
- mindspore/ops/operations/_grad_ops.py +2 -76
- mindspore/ops/operations/_infer_ops.py +1 -1
- mindspore/ops/operations/_inner_ops.py +71 -94
- mindspore/ops/operations/array_ops.py +12 -146
- mindspore/ops/operations/comm_ops.py +42 -53
- mindspore/ops/operations/custom_ops.py +83 -19
- mindspore/ops/operations/debug_ops.py +42 -10
- mindspore/ops/operations/manually_defined/_inner.py +12 -0
- mindspore/ops/operations/manually_defined/ops_def.py +265 -10
- mindspore/ops/operations/math_ops.py +12 -223
- mindspore/ops/operations/nn_ops.py +20 -114
- mindspore/ops/operations/other_ops.py +7 -4
- mindspore/ops/operations/random_ops.py +46 -1
- mindspore/ops/primitive.py +18 -6
- mindspore/ops_generate/arg_dtype_cast.py +2 -0
- mindspore/ops_generate/gen_aclnn_implement.py +11 -11
- mindspore/ops_generate/gen_constants.py +36 -0
- mindspore/ops_generate/gen_ops.py +67 -52
- mindspore/ops_generate/gen_ops_inner_prim.py +1 -1
- mindspore/ops_generate/gen_pyboost_func.py +131 -47
- mindspore/ops_generate/op_proto.py +10 -3
- mindspore/ops_generate/pyboost_utils.py +14 -1
- mindspore/ops_generate/template.py +43 -21
- mindspore/parallel/__init__.py +3 -1
- mindspore/parallel/_auto_parallel_context.py +28 -8
- mindspore/parallel/_cell_wrapper.py +83 -0
- mindspore/parallel/_parallel_serialization.py +47 -19
- mindspore/parallel/_tensor.py +81 -11
- mindspore/parallel/_utils.py +13 -1
- mindspore/parallel/algo_parameter_config.py +5 -5
- mindspore/parallel/checkpoint_transform.py +46 -39
- mindspore/parallel/cluster/process_entity/__init__.py +1 -1
- mindspore/parallel/cluster/process_entity/_api.py +31 -23
- mindspore/parallel/cluster/process_entity/_utils.py +2 -27
- mindspore/parallel/parameter_broadcast.py +3 -4
- mindspore/parallel/shard.py +162 -31
- mindspore/parallel/transform_safetensors.py +993 -0
- mindspore/profiler/__init__.py +2 -1
- mindspore/profiler/common/constant.py +29 -0
- mindspore/profiler/common/registry.py +47 -0
- mindspore/profiler/common/util.py +28 -0
- mindspore/profiler/dynamic_profiler.py +694 -0
- mindspore/profiler/envprofiling.py +17 -19
- mindspore/profiler/parser/ascend_analysis/constant.py +18 -0
- mindspore/profiler/parser/ascend_analysis/file_manager.py +25 -4
- mindspore/profiler/parser/ascend_analysis/function_event.py +43 -19
- mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +31 -26
- mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +56 -10
- mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +55 -8
- mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
- mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +27 -20
- mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +9 -2
- mindspore/profiler/parser/ascend_msprof_exporter.py +5 -4
- mindspore/profiler/parser/ascend_timeline_generator.py +27 -25
- mindspore/profiler/parser/base_timeline_generator.py +19 -25
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +25 -12
- mindspore/profiler/parser/framework_parser.py +1 -391
- mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
- mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
- mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
- mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
- mindspore/profiler/parser/memory_usage_parser.py +0 -154
- mindspore/profiler/parser/profiler_info.py +78 -6
- mindspore/profiler/profiler.py +153 -0
- mindspore/profiler/profiling.py +280 -412
- mindspore/rewrite/__init__.py +1 -2
- mindspore/rewrite/common/namespace.py +4 -4
- mindspore/rewrite/symbol_tree/symbol_tree.py +3 -3
- mindspore/run_check/_check_version.py +36 -103
- mindspore/safeguard/rewrite_obfuscation.py +591 -247
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +4 -3
- mindspore/train/_utils.py +28 -2
- mindspore/train/amp.py +171 -53
- mindspore/train/callback/__init__.py +2 -2
- mindspore/train/callback/_callback.py +4 -4
- mindspore/train/callback/_checkpoint.py +85 -22
- mindspore/train/callback/_cluster_monitor.py +1 -1
- mindspore/train/callback/_flops_collector.py +1 -0
- mindspore/train/callback/_loss_monitor.py +3 -3
- mindspore/train/callback/_on_request_exit.py +134 -31
- mindspore/train/callback/_summary_collector.py +5 -5
- mindspore/train/callback/_tft_register.py +352 -0
- mindspore/train/dataset_helper.py +7 -3
- mindspore/train/metrics/metric.py +3 -3
- mindspore/train/metrics/roc.py +4 -4
- mindspore/train/mind_ir_pb2.py +44 -39
- mindspore/train/model.py +134 -58
- mindspore/train/serialization.py +336 -112
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +21 -0
- mindspore/utils/utils.py +60 -0
- mindspore/version.py +1 -1
- {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/METADATA +6 -2
- {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/RECORD +258 -252
- mindspore/include/c_api/ms/abstract.h +0 -67
- mindspore/include/c_api/ms/attribute.h +0 -197
- mindspore/include/c_api/ms/base/handle_types.h +0 -43
- mindspore/include/c_api/ms/base/macros.h +0 -32
- mindspore/include/c_api/ms/base/status.h +0 -33
- mindspore/include/c_api/ms/base/types.h +0 -283
- mindspore/include/c_api/ms/context.h +0 -102
- mindspore/include/c_api/ms/graph.h +0 -160
- mindspore/include/c_api/ms/node.h +0 -606
- mindspore/include/c_api/ms/tensor.h +0 -161
- mindspore/include/c_api/ms/value.h +0 -84
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/nn/extend/basic.py +0 -140
- mindspore/nn/extend/embedding.py +0 -143
- mindspore/nn/extend/layer/normalization.py +0 -109
- mindspore/nn/extend/pooling.py +0 -117
- mindspore/nn/layer/embedding_service.py +0 -531
- mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +0 -93
- mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +0 -66
- mindspore/ops/extend/__init__.py +0 -53
- mindspore/ops/extend/array_func.py +0 -218
- mindspore/ops/extend/math_func.py +0 -76
- mindspore/ops/extend/nn_func.py +0 -308
- mindspore/ops/silent_check.py +0 -162
- mindspore/profiler/parser/msadvisor_analyzer.py +0 -82
- mindspore/profiler/parser/msadvisor_parser.py +0 -240
- mindspore/train/callback/_mindio_ttp.py +0 -443
- {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/WHEEL +0 -0
- {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/entry_points.txt +0 -0
- {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/top_level.txt +0 -0
|
@@ -26,6 +26,8 @@ from mindspore._c_expression import BroadcastToPrim_
|
|
|
26
26
|
from mindspore._c_expression import ConcatPrim_
|
|
27
27
|
from mindspore._c_expression import ConvolutionGradPrim_
|
|
28
28
|
from mindspore._c_expression import ConvolutionPrim_
|
|
29
|
+
from mindspore._c_expression import CrossPrim_
|
|
30
|
+
from mindspore._c_expression import CummaxPrim_
|
|
29
31
|
from mindspore._c_expression import EluExtPrim_
|
|
30
32
|
from mindspore._c_expression import FFNExtPrim_
|
|
31
33
|
from mindspore._c_expression import FlashAttentionScoreGradPrim_
|
|
@@ -34,20 +36,30 @@ from mindspore._c_expression import GridSampler2DGradPrim_
|
|
|
34
36
|
from mindspore._c_expression import GridSampler2DPrim_
|
|
35
37
|
from mindspore._c_expression import GridSampler3DGradPrim_
|
|
36
38
|
from mindspore._c_expression import GridSampler3DPrim_
|
|
39
|
+
from mindspore._c_expression import HShrinkGradPrim_
|
|
40
|
+
from mindspore._c_expression import HShrinkPrim_
|
|
41
|
+
from mindspore._c_expression import IncreFlashAttentionPrim_
|
|
37
42
|
from mindspore._c_expression import IsClosePrim_
|
|
43
|
+
from mindspore._c_expression import LogSoftmaxGradPrim_
|
|
44
|
+
from mindspore._c_expression import LogSoftmaxPrim_
|
|
38
45
|
from mindspore._c_expression import MatMulPrim_
|
|
39
46
|
from mindspore._c_expression import MaxPoolGradWithIndicesPrim_
|
|
40
47
|
from mindspore._c_expression import MaxPoolGradWithMaskPrim_
|
|
41
48
|
from mindspore._c_expression import MaxPoolWithIndicesPrim_
|
|
42
49
|
from mindspore._c_expression import MaxPoolWithMaskPrim_
|
|
50
|
+
from mindspore._c_expression import NanToNumPrim_
|
|
43
51
|
from mindspore._c_expression import OneHotExtPrim_
|
|
44
52
|
from mindspore._c_expression import ReduceAllPrim_
|
|
45
53
|
from mindspore._c_expression import ReduceAnyPrim_
|
|
46
54
|
from mindspore._c_expression import ReverseV2Prim_
|
|
47
55
|
from mindspore._c_expression import RmsNormPrim_
|
|
56
|
+
from mindspore._c_expression import RollPrim_
|
|
48
57
|
from mindspore._c_expression import SearchSortedPrim_
|
|
49
58
|
from mindspore._c_expression import SoftmaxPrim_
|
|
59
|
+
from mindspore._c_expression import SoftShrinkGradPrim_
|
|
60
|
+
from mindspore._c_expression import SoftShrinkPrim_
|
|
50
61
|
from mindspore._c_expression import StackExtPrim_
|
|
62
|
+
from mindspore._c_expression import TrilExtPrim_
|
|
51
63
|
from mindspore._c_expression import TriuPrim_
|
|
52
64
|
from mindspore._c_expression import UpsampleTrilinear3DGradPrim_
|
|
53
65
|
from mindspore._c_expression import UpsampleTrilinear3DPrim_
|
|
@@ -94,8 +106,8 @@ batch_norm_grad_ext_impl = _PyboostBatchNormGradExtPrim()
|
|
|
94
106
|
|
|
95
107
|
class _PyboostBinaryCrossEntropyGradPrim(BinaryCrossEntropyGradPrim_):
|
|
96
108
|
def __call__(self, input, target, grad_output, weight, reduction):
|
|
97
|
-
converted_reduction = str_to_enum(reduction)
|
|
98
|
-
return _convert_stub(super().__call__(input, target, grad_output, weight,
|
|
109
|
+
converted_reduction = str_to_enum('binary_cross_entropy_grad', 'reduction', reduction)
|
|
110
|
+
return _convert_stub(super().__call__(input, target, grad_output, weight, converted_reduction))
|
|
99
111
|
|
|
100
112
|
|
|
101
113
|
binary_cross_entropy_grad_impl = _PyboostBinaryCrossEntropyGradPrim()
|
|
@@ -103,8 +115,8 @@ binary_cross_entropy_grad_impl = _PyboostBinaryCrossEntropyGradPrim()
|
|
|
103
115
|
|
|
104
116
|
class _PyboostBinaryCrossEntropyPrim(BinaryCrossEntropyPrim_):
|
|
105
117
|
def __call__(self, input, target, weight, reduction):
|
|
106
|
-
converted_reduction = str_to_enum(reduction)
|
|
107
|
-
return _convert_stub(super().__call__(input, target, weight,
|
|
118
|
+
converted_reduction = str_to_enum('binary_cross_entropy', 'reduction', reduction)
|
|
119
|
+
return _convert_stub(super().__call__(input, target, weight, converted_reduction))
|
|
108
120
|
|
|
109
121
|
|
|
110
122
|
binary_cross_entropy_impl = _PyboostBinaryCrossEntropyPrim()
|
|
@@ -112,8 +124,8 @@ binary_cross_entropy_impl = _PyboostBinaryCrossEntropyPrim()
|
|
|
112
124
|
|
|
113
125
|
class _PyboostBCEWithLogitsLossPrim(BCEWithLogitsLossPrim_):
|
|
114
126
|
def __call__(self, input, target, weight, posWeight, reduction):
|
|
115
|
-
converted_reduction = str_to_enum(reduction)
|
|
116
|
-
return _convert_stub(super().__call__(input, target, weight, posWeight,
|
|
127
|
+
converted_reduction = str_to_enum('binary_cross_entropy_with_logits', 'reduction', reduction)
|
|
128
|
+
return _convert_stub(super().__call__(input, target, weight, posWeight, converted_reduction))
|
|
117
129
|
|
|
118
130
|
|
|
119
131
|
binary_cross_entropy_with_logits_impl = _PyboostBCEWithLogitsLossPrim()
|
|
@@ -139,11 +151,11 @@ concat_impl = _PyboostConcatPrim()
|
|
|
139
151
|
|
|
140
152
|
class _PyboostConvolutionGradPrim(ConvolutionGradPrim_):
|
|
141
153
|
def __call__(self, dout, input, weight, bias, stride, padding, dilation, transposed, output_padding, groups, output_mask):
|
|
142
|
-
converted_stride = to_strides(stride)
|
|
143
|
-
converted_padding = to_2d_paddings(padding)
|
|
144
|
-
converted_dilation = to_dilations(dilation)
|
|
145
|
-
converted_output_padding = to_output_padding(output_padding)
|
|
146
|
-
return _convert_stub(super().__call__(dout, input, weight, bias,
|
|
154
|
+
converted_stride = to_strides('convolution_grad', 'stride', stride)
|
|
155
|
+
converted_padding = to_2d_paddings('convolution_grad', 'padding', padding)
|
|
156
|
+
converted_dilation = to_dilations('convolution_grad', 'dilation', dilation)
|
|
157
|
+
converted_output_padding = to_output_padding('convolution_grad', 'output_padding', output_padding)
|
|
158
|
+
return _convert_stub(super().__call__(dout, input, weight, bias, converted_stride, converted_padding, converted_dilation, transposed, converted_output_padding, groups, output_mask))
|
|
147
159
|
|
|
148
160
|
|
|
149
161
|
convolution_grad_impl = _PyboostConvolutionGradPrim()
|
|
@@ -151,16 +163,34 @@ convolution_grad_impl = _PyboostConvolutionGradPrim()
|
|
|
151
163
|
|
|
152
164
|
class _PyboostConvolutionPrim(ConvolutionPrim_):
|
|
153
165
|
def __call__(self, input, weight, bias, stride, padding, dilation, transposed, output_padding, groups):
|
|
154
|
-
converted_stride = to_strides(stride)
|
|
155
|
-
converted_padding = to_2d_paddings(padding)
|
|
156
|
-
converted_dilation = to_dilations(dilation)
|
|
157
|
-
converted_output_padding = to_output_padding(output_padding)
|
|
158
|
-
return _convert_stub(super().__call__(input, weight, bias,
|
|
166
|
+
converted_stride = to_strides('convolution', 'stride', stride)
|
|
167
|
+
converted_padding = to_2d_paddings('convolution', 'padding', padding)
|
|
168
|
+
converted_dilation = to_dilations('convolution', 'dilation', dilation)
|
|
169
|
+
converted_output_padding = to_output_padding('convolution', 'output_padding', output_padding)
|
|
170
|
+
return _convert_stub(super().__call__(input, weight, bias, converted_stride, converted_padding, converted_dilation, transposed, converted_output_padding, groups))
|
|
159
171
|
|
|
160
172
|
|
|
161
173
|
convolution_impl = _PyboostConvolutionPrim()
|
|
162
174
|
|
|
163
175
|
|
|
176
|
+
class _PyboostCrossPrim(CrossPrim_):
|
|
177
|
+
def __call__(self, input, other, dim):
|
|
178
|
+
|
|
179
|
+
return _convert_stub(super().__call__(input, other, dim))
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
cross_impl = _PyboostCrossPrim()
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
class _PyboostCummaxPrim(CummaxPrim_):
|
|
186
|
+
def __call__(self, input, axis):
|
|
187
|
+
|
|
188
|
+
return _convert_stub(super().__call__(input, axis))
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
cummax_impl = _PyboostCummaxPrim()
|
|
192
|
+
|
|
193
|
+
|
|
164
194
|
class _PyboostEluExtPrim(EluExtPrim_):
|
|
165
195
|
def __call__(self, input, alpha):
|
|
166
196
|
|
|
@@ -172,8 +202,8 @@ elu_ext_impl = _PyboostEluExtPrim()
|
|
|
172
202
|
|
|
173
203
|
class _PyboostFFNExtPrim(FFNExtPrim_):
|
|
174
204
|
def __call__(self, x, weight1, weight2, expertTokens, bias1, bias2, scale, offset, deqScale1, deqScale2, antiquant_scale1, antiquant_scale2, antiquant_offset1, antiquant_offset2, activation, inner_precise):
|
|
175
|
-
converted_activation = str_to_enum(activation)
|
|
176
|
-
return _convert_stub(super().__call__(x, weight1, weight2, expertTokens, bias1, bias2, scale, offset, deqScale1, deqScale2, antiquant_scale1, antiquant_scale2, antiquant_offset1, antiquant_offset2,
|
|
205
|
+
converted_activation = str_to_enum('ffn_ext', 'activation', activation)
|
|
206
|
+
return _convert_stub(super().__call__(x, weight1, weight2, expertTokens, bias1, bias2, scale, offset, deqScale1, deqScale2, antiquant_scale1, antiquant_scale2, antiquant_offset1, antiquant_offset2, converted_activation, inner_precise))
|
|
177
207
|
|
|
178
208
|
|
|
179
209
|
ffn_ext_impl = _PyboostFFNExtPrim()
|
|
@@ -181,8 +211,8 @@ ffn_ext_impl = _PyboostFFNExtPrim()
|
|
|
181
211
|
|
|
182
212
|
class _PyboostFlashAttentionScoreGradPrim(FlashAttentionScoreGradPrim_):
|
|
183
213
|
def __call__(self, query, key, value, dy, pse_shift, drop_mask, padding_mask, atten_mask, softmax_max, softmax_sum, softmax_in, attention_in, prefix, actual_seq_qlen, actual_seq_kvlen, head_num, keep_prob, scale_value, pre_tokens, next_tokens, inner_precise, input_layout, sparse_mode):
|
|
184
|
-
converted_input_layout = str_to_enum(input_layout)
|
|
185
|
-
return _convert_stub(super().__call__(query, key, value, dy, pse_shift, drop_mask, padding_mask, atten_mask, softmax_max, softmax_sum, softmax_in, attention_in, prefix, actual_seq_qlen, actual_seq_kvlen, head_num, keep_prob, scale_value, pre_tokens, next_tokens, inner_precise,
|
|
214
|
+
converted_input_layout = str_to_enum('flash_attention_score_grad', 'input_layout', input_layout)
|
|
215
|
+
return _convert_stub(super().__call__(query, key, value, dy, pse_shift, drop_mask, padding_mask, atten_mask, softmax_max, softmax_sum, softmax_in, attention_in, prefix, actual_seq_qlen, actual_seq_kvlen, head_num, keep_prob, scale_value, pre_tokens, next_tokens, inner_precise, converted_input_layout, sparse_mode))
|
|
186
216
|
|
|
187
217
|
|
|
188
218
|
flash_attention_score_grad_impl = _PyboostFlashAttentionScoreGradPrim()
|
|
@@ -190,8 +220,8 @@ flash_attention_score_grad_impl = _PyboostFlashAttentionScoreGradPrim()
|
|
|
190
220
|
|
|
191
221
|
class _PyboostFlashAttentionScorePrim(FlashAttentionScorePrim_):
|
|
192
222
|
def __call__(self, query, key, value, real_shift, drop_mask, padding_mask, attn_mask, prefix, actual_seq_qlen, actual_seq_kvlen, head_num, keep_prob, scale_value, pre_tokens, next_tokens, inner_precise, input_layout, sparse_mode):
|
|
193
|
-
converted_input_layout = str_to_enum(input_layout)
|
|
194
|
-
return _convert_stub(super().__call__(query, key, value, real_shift, drop_mask, padding_mask, attn_mask, prefix, actual_seq_qlen, actual_seq_kvlen, head_num, keep_prob, scale_value, pre_tokens, next_tokens, inner_precise,
|
|
223
|
+
converted_input_layout = str_to_enum('flash_attention_score', 'input_layout', input_layout)
|
|
224
|
+
return _convert_stub(super().__call__(query, key, value, real_shift, drop_mask, padding_mask, attn_mask, prefix, actual_seq_qlen, actual_seq_kvlen, head_num, keep_prob, scale_value, pre_tokens, next_tokens, inner_precise, converted_input_layout, sparse_mode))
|
|
195
225
|
|
|
196
226
|
|
|
197
227
|
flash_attention_score_impl = _PyboostFlashAttentionScorePrim()
|
|
@@ -199,9 +229,9 @@ flash_attention_score_impl = _PyboostFlashAttentionScorePrim()
|
|
|
199
229
|
|
|
200
230
|
class _PyboostGridSampler2DGradPrim(GridSampler2DGradPrim_):
|
|
201
231
|
def __call__(self, grad, input_x, grid, interpolation_mode, padding_mode, align_corners):
|
|
202
|
-
converted_interpolation_mode = str_to_enum(interpolation_mode)
|
|
203
|
-
converted_padding_mode = str_to_enum(padding_mode)
|
|
204
|
-
return _convert_stub(super().__call__(grad, input_x, grid,
|
|
232
|
+
converted_interpolation_mode = str_to_enum('grid_sampler_2d_grad', 'interpolation_mode', interpolation_mode)
|
|
233
|
+
converted_padding_mode = str_to_enum('grid_sampler_2d_grad', 'padding_mode', padding_mode)
|
|
234
|
+
return _convert_stub(super().__call__(grad, input_x, grid, converted_interpolation_mode, converted_padding_mode, align_corners))
|
|
205
235
|
|
|
206
236
|
|
|
207
237
|
grid_sampler_2d_grad_impl = _PyboostGridSampler2DGradPrim()
|
|
@@ -209,9 +239,9 @@ grid_sampler_2d_grad_impl = _PyboostGridSampler2DGradPrim()
|
|
|
209
239
|
|
|
210
240
|
class _PyboostGridSampler2DPrim(GridSampler2DPrim_):
|
|
211
241
|
def __call__(self, input_x, grid, interpolation_mode, padding_mode, align_corners):
|
|
212
|
-
converted_interpolation_mode = str_to_enum(interpolation_mode)
|
|
213
|
-
converted_padding_mode = str_to_enum(padding_mode)
|
|
214
|
-
return _convert_stub(super().__call__(input_x, grid,
|
|
242
|
+
converted_interpolation_mode = str_to_enum('grid_sampler_2d', 'interpolation_mode', interpolation_mode)
|
|
243
|
+
converted_padding_mode = str_to_enum('grid_sampler_2d', 'padding_mode', padding_mode)
|
|
244
|
+
return _convert_stub(super().__call__(input_x, grid, converted_interpolation_mode, converted_padding_mode, align_corners))
|
|
215
245
|
|
|
216
246
|
|
|
217
247
|
grid_sampler_2d_impl = _PyboostGridSampler2DPrim()
|
|
@@ -219,9 +249,9 @@ grid_sampler_2d_impl = _PyboostGridSampler2DPrim()
|
|
|
219
249
|
|
|
220
250
|
class _PyboostGridSampler3DGradPrim(GridSampler3DGradPrim_):
|
|
221
251
|
def __call__(self, grad, input_x, grid, interpolation_mode, padding_mode, align_corners):
|
|
222
|
-
converted_interpolation_mode = str_to_enum(interpolation_mode)
|
|
223
|
-
converted_padding_mode = str_to_enum(padding_mode)
|
|
224
|
-
return _convert_stub(super().__call__(grad, input_x, grid,
|
|
252
|
+
converted_interpolation_mode = str_to_enum('grid_sampler_3d_grad', 'interpolation_mode', interpolation_mode)
|
|
253
|
+
converted_padding_mode = str_to_enum('grid_sampler_3d_grad', 'padding_mode', padding_mode)
|
|
254
|
+
return _convert_stub(super().__call__(grad, input_x, grid, converted_interpolation_mode, converted_padding_mode, align_corners))
|
|
225
255
|
|
|
226
256
|
|
|
227
257
|
grid_sampler_3d_grad_impl = _PyboostGridSampler3DGradPrim()
|
|
@@ -229,14 +259,41 @@ grid_sampler_3d_grad_impl = _PyboostGridSampler3DGradPrim()
|
|
|
229
259
|
|
|
230
260
|
class _PyboostGridSampler3DPrim(GridSampler3DPrim_):
|
|
231
261
|
def __call__(self, input_x, grid, interpolation_mode, padding_mode, align_corners):
|
|
232
|
-
converted_interpolation_mode = str_to_enum(interpolation_mode)
|
|
233
|
-
converted_padding_mode = str_to_enum(padding_mode)
|
|
234
|
-
return _convert_stub(super().__call__(input_x, grid,
|
|
262
|
+
converted_interpolation_mode = str_to_enum('grid_sampler_3d', 'interpolation_mode', interpolation_mode)
|
|
263
|
+
converted_padding_mode = str_to_enum('grid_sampler_3d', 'padding_mode', padding_mode)
|
|
264
|
+
return _convert_stub(super().__call__(input_x, grid, converted_interpolation_mode, converted_padding_mode, align_corners))
|
|
235
265
|
|
|
236
266
|
|
|
237
267
|
grid_sampler_3d_impl = _PyboostGridSampler3DPrim()
|
|
238
268
|
|
|
239
269
|
|
|
270
|
+
class _PyboostHShrinkGradPrim(HShrinkGradPrim_):
|
|
271
|
+
def __call__(self, gradients, features, lambd):
|
|
272
|
+
|
|
273
|
+
return _convert_stub(super().__call__(gradients, features, lambd))
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
hshrink_grad_impl = _PyboostHShrinkGradPrim()
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
class _PyboostHShrinkPrim(HShrinkPrim_):
|
|
280
|
+
def __call__(self, input, lambd):
|
|
281
|
+
|
|
282
|
+
return _convert_stub(super().__call__(input, lambd))
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
hshrink_impl = _PyboostHShrinkPrim()
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
class _PyboostIncreFlashAttentionPrim(IncreFlashAttentionPrim_):
|
|
289
|
+
def __call__(self, query, key, value, attn_mask, actual_seq_lengths, pse_shift, dequant_scale1, quant_scale1, dequant_scale2, quant_scale2, quant_offset2, antiquant_scale, antiquant_offset, block_table, kv_padding_size, num_heads, input_layout, scale_value, num_key_value_heads, block_size, inner_precise):
|
|
290
|
+
converted_input_layout = str_to_enum('incre_flash_attention', 'input_layout', input_layout)
|
|
291
|
+
return _convert_stub(super().__call__(query, key, value, attn_mask, actual_seq_lengths, pse_shift, dequant_scale1, quant_scale1, dequant_scale2, quant_scale2, quant_offset2, antiquant_scale, antiquant_offset, block_table, kv_padding_size, num_heads, converted_input_layout, scale_value, num_key_value_heads, block_size, inner_precise))
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
incre_flash_attention_impl = _PyboostIncreFlashAttentionPrim()
|
|
295
|
+
|
|
296
|
+
|
|
240
297
|
class _PyboostIsClosePrim(IsClosePrim_):
|
|
241
298
|
def __call__(self, input, other, rtol, atol, equal_nan):
|
|
242
299
|
|
|
@@ -246,6 +303,24 @@ class _PyboostIsClosePrim(IsClosePrim_):
|
|
|
246
303
|
isclose_impl = _PyboostIsClosePrim()
|
|
247
304
|
|
|
248
305
|
|
|
306
|
+
class _PyboostLogSoftmaxGradPrim(LogSoftmaxGradPrim_):
|
|
307
|
+
def __call__(self, logits, grad, axis):
|
|
308
|
+
|
|
309
|
+
return _convert_stub(super().__call__(logits, grad, axis))
|
|
310
|
+
|
|
311
|
+
|
|
312
|
+
log_softmax_grad_impl = _PyboostLogSoftmaxGradPrim()
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
class _PyboostLogSoftmaxPrim(LogSoftmaxPrim_):
|
|
316
|
+
def __call__(self, logits, axis):
|
|
317
|
+
|
|
318
|
+
return _convert_stub(super().__call__(logits, axis))
|
|
319
|
+
|
|
320
|
+
|
|
321
|
+
log_softmax_impl = _PyboostLogSoftmaxPrim()
|
|
322
|
+
|
|
323
|
+
|
|
249
324
|
class _PyboostMatMulPrim(MatMulPrim_):
|
|
250
325
|
def __call__(self, input, mat2, transpose_a, transpose_b):
|
|
251
326
|
|
|
@@ -257,11 +332,11 @@ matmul_impl = _PyboostMatMulPrim()
|
|
|
257
332
|
|
|
258
333
|
class _PyboostMaxPoolGradWithIndicesPrim(MaxPoolGradWithIndicesPrim_):
|
|
259
334
|
def __call__(self, x, grad, argmax, kernel_size, strides, pads, dilation, ceil_mode, argmax_type):
|
|
260
|
-
converted_kernel_size = to_kernel_size(kernel_size)
|
|
261
|
-
converted_strides = to_strides(strides)
|
|
262
|
-
converted_pads = to_output_padding(pads)
|
|
263
|
-
converted_dilation = to_dilations(dilation)
|
|
264
|
-
return _convert_stub(super().__call__(x, grad, argmax,
|
|
335
|
+
converted_kernel_size = to_kernel_size('max_pool_grad_with_indices', 'kernel_size', kernel_size)
|
|
336
|
+
converted_strides = to_strides('max_pool_grad_with_indices', 'strides', strides)
|
|
337
|
+
converted_pads = to_output_padding('max_pool_grad_with_indices', 'pads', pads)
|
|
338
|
+
converted_dilation = to_dilations('max_pool_grad_with_indices', 'dilation', dilation)
|
|
339
|
+
return _convert_stub(super().__call__(x, grad, argmax, converted_kernel_size, converted_strides, converted_pads, converted_dilation, ceil_mode, argmax_type))
|
|
265
340
|
|
|
266
341
|
|
|
267
342
|
max_pool_grad_with_indices_impl = _PyboostMaxPoolGradWithIndicesPrim()
|
|
@@ -269,11 +344,11 @@ max_pool_grad_with_indices_impl = _PyboostMaxPoolGradWithIndicesPrim()
|
|
|
269
344
|
|
|
270
345
|
class _PyboostMaxPoolGradWithMaskPrim(MaxPoolGradWithMaskPrim_):
|
|
271
346
|
def __call__(self, x, grad, mask, kernel_size, strides, pads, dilation, ceil_mode, argmax_type):
|
|
272
|
-
converted_kernel_size = to_kernel_size(kernel_size)
|
|
273
|
-
converted_strides = to_strides(strides)
|
|
274
|
-
converted_pads = to_output_padding(pads)
|
|
275
|
-
converted_dilation = to_dilations(dilation)
|
|
276
|
-
return _convert_stub(super().__call__(x, grad, mask,
|
|
347
|
+
converted_kernel_size = to_kernel_size('max_pool_grad_with_mask', 'kernel_size', kernel_size)
|
|
348
|
+
converted_strides = to_strides('max_pool_grad_with_mask', 'strides', strides)
|
|
349
|
+
converted_pads = to_output_padding('max_pool_grad_with_mask', 'pads', pads)
|
|
350
|
+
converted_dilation = to_dilations('max_pool_grad_with_mask', 'dilation', dilation)
|
|
351
|
+
return _convert_stub(super().__call__(x, grad, mask, converted_kernel_size, converted_strides, converted_pads, converted_dilation, ceil_mode, argmax_type))
|
|
277
352
|
|
|
278
353
|
|
|
279
354
|
max_pool_grad_with_mask_impl = _PyboostMaxPoolGradWithMaskPrim()
|
|
@@ -281,11 +356,11 @@ max_pool_grad_with_mask_impl = _PyboostMaxPoolGradWithMaskPrim()
|
|
|
281
356
|
|
|
282
357
|
class _PyboostMaxPoolWithIndicesPrim(MaxPoolWithIndicesPrim_):
|
|
283
358
|
def __call__(self, x, kernel_size, strides, pads, dilation, ceil_mode, argmax_type):
|
|
284
|
-
converted_kernel_size = to_kernel_size(kernel_size)
|
|
285
|
-
converted_strides = to_strides(strides)
|
|
286
|
-
converted_pads = to_output_padding(pads)
|
|
287
|
-
converted_dilation = to_dilations(dilation)
|
|
288
|
-
return _convert_stub(super().__call__(x,
|
|
359
|
+
converted_kernel_size = to_kernel_size('max_pool_with_indices', 'kernel_size', kernel_size)
|
|
360
|
+
converted_strides = to_strides('max_pool_with_indices', 'strides', strides)
|
|
361
|
+
converted_pads = to_output_padding('max_pool_with_indices', 'pads', pads)
|
|
362
|
+
converted_dilation = to_dilations('max_pool_with_indices', 'dilation', dilation)
|
|
363
|
+
return _convert_stub(super().__call__(x, converted_kernel_size, converted_strides, converted_pads, converted_dilation, ceil_mode, argmax_type))
|
|
289
364
|
|
|
290
365
|
|
|
291
366
|
max_pool_with_indices_impl = _PyboostMaxPoolWithIndicesPrim()
|
|
@@ -293,16 +368,25 @@ max_pool_with_indices_impl = _PyboostMaxPoolWithIndicesPrim()
|
|
|
293
368
|
|
|
294
369
|
class _PyboostMaxPoolWithMaskPrim(MaxPoolWithMaskPrim_):
|
|
295
370
|
def __call__(self, x, kernel_size, strides, pads, dilation, ceil_mode, argmax_type):
|
|
296
|
-
converted_kernel_size = to_kernel_size(kernel_size)
|
|
297
|
-
converted_strides = to_strides(strides)
|
|
298
|
-
converted_pads = to_output_padding(pads)
|
|
299
|
-
converted_dilation = to_dilations(dilation)
|
|
300
|
-
return _convert_stub(super().__call__(x,
|
|
371
|
+
converted_kernel_size = to_kernel_size('max_pool_with_mask', 'kernel_size', kernel_size)
|
|
372
|
+
converted_strides = to_strides('max_pool_with_mask', 'strides', strides)
|
|
373
|
+
converted_pads = to_output_padding('max_pool_with_mask', 'pads', pads)
|
|
374
|
+
converted_dilation = to_dilations('max_pool_with_mask', 'dilation', dilation)
|
|
375
|
+
return _convert_stub(super().__call__(x, converted_kernel_size, converted_strides, converted_pads, converted_dilation, ceil_mode, argmax_type))
|
|
301
376
|
|
|
302
377
|
|
|
303
378
|
max_pool_with_mask_impl = _PyboostMaxPoolWithMaskPrim()
|
|
304
379
|
|
|
305
380
|
|
|
381
|
+
class _PyboostNanToNumPrim(NanToNumPrim_):
|
|
382
|
+
def __call__(self, input, nan, posinf, neginf):
|
|
383
|
+
|
|
384
|
+
return _convert_stub(super().__call__(input, nan, posinf, neginf))
|
|
385
|
+
|
|
386
|
+
|
|
387
|
+
nan_to_num_impl = _PyboostNanToNumPrim()
|
|
388
|
+
|
|
389
|
+
|
|
306
390
|
class _PyboostOneHotExtPrim(OneHotExtPrim_):
|
|
307
391
|
def __call__(self, tensor, num_classes, on_value, off_value, axis):
|
|
308
392
|
|
|
@@ -348,6 +432,15 @@ class _PyboostRmsNormPrim(RmsNormPrim_):
|
|
|
348
432
|
rms_norm_impl = _PyboostRmsNormPrim()
|
|
349
433
|
|
|
350
434
|
|
|
435
|
+
class _PyboostRollPrim(RollPrim_):
|
|
436
|
+
def __call__(self, input, shift, axis):
|
|
437
|
+
|
|
438
|
+
return _convert_stub(super().__call__(input, shift, axis))
|
|
439
|
+
|
|
440
|
+
|
|
441
|
+
roll_impl = _PyboostRollPrim()
|
|
442
|
+
|
|
443
|
+
|
|
351
444
|
class _PyboostSearchSortedPrim(SearchSortedPrim_):
|
|
352
445
|
def __call__(self, sorted_sequence, values, sorter, dtype, right):
|
|
353
446
|
|
|
@@ -366,6 +459,24 @@ class _PyboostSoftmaxPrim(SoftmaxPrim_):
|
|
|
366
459
|
softmax_impl = _PyboostSoftmaxPrim()
|
|
367
460
|
|
|
368
461
|
|
|
462
|
+
class _PyboostSoftShrinkGradPrim(SoftShrinkGradPrim_):
|
|
463
|
+
def __call__(self, input_grad, input_x, lambd):
|
|
464
|
+
|
|
465
|
+
return _convert_stub(super().__call__(input_grad, input_x, lambd))
|
|
466
|
+
|
|
467
|
+
|
|
468
|
+
softshrink_grad_impl = _PyboostSoftShrinkGradPrim()
|
|
469
|
+
|
|
470
|
+
|
|
471
|
+
class _PyboostSoftShrinkPrim(SoftShrinkPrim_):
|
|
472
|
+
def __call__(self, input, lambd):
|
|
473
|
+
|
|
474
|
+
return _convert_stub(super().__call__(input, lambd))
|
|
475
|
+
|
|
476
|
+
|
|
477
|
+
softshrink_impl = _PyboostSoftShrinkPrim()
|
|
478
|
+
|
|
479
|
+
|
|
369
480
|
class _PyboostStackExtPrim(StackExtPrim_):
|
|
370
481
|
def __call__(self, tensors, dim):
|
|
371
482
|
|
|
@@ -375,6 +486,15 @@ class _PyboostStackExtPrim(StackExtPrim_):
|
|
|
375
486
|
stack_ext_impl = _PyboostStackExtPrim()
|
|
376
487
|
|
|
377
488
|
|
|
489
|
+
class _PyboostTrilExtPrim(TrilExtPrim_):
|
|
490
|
+
def __call__(self, input, diagonal):
|
|
491
|
+
|
|
492
|
+
return _convert_stub(super().__call__(input, diagonal))
|
|
493
|
+
|
|
494
|
+
|
|
495
|
+
tril_ext_impl = _PyboostTrilExtPrim()
|
|
496
|
+
|
|
497
|
+
|
|
378
498
|
class _PyboostTriuPrim(TriuPrim_):
|
|
379
499
|
def __call__(self, input, diagonal):
|
|
380
500
|
|
|
@@ -412,9 +532,9 @@ grouped_matmul_impl = _PyboostGroupedMatmulPrim()
|
|
|
412
532
|
|
|
413
533
|
|
|
414
534
|
class _PyboostQuantBatchMatmulPrim(QuantBatchMatmulPrim_):
|
|
415
|
-
def __call__(self, x1, x2, scale, offset, bias, transpose_x1, transpose_x2, dtype):
|
|
535
|
+
def __call__(self, x1, x2, scale, offset, bias, pertokenScaleOptional, transpose_x1, transpose_x2, dtype):
|
|
416
536
|
|
|
417
|
-
return _convert_stub(super().__call__(x1, x2, scale, offset, bias, transpose_x1, transpose_x2, dtype))
|
|
537
|
+
return _convert_stub(super().__call__(x1, x2, scale, offset, bias, pertokenScaleOptional, transpose_x1, transpose_x2, dtype))
|
|
418
538
|
|
|
419
539
|
|
|
420
540
|
quant_batch_matmul_impl = _PyboostQuantBatchMatmulPrim()
|
mindspore/ops/composite/base.py
CHANGED
|
@@ -30,7 +30,7 @@ from mindspore._c_expression import GradOperation_, HyperMap_, Map_, MultitypeFu
|
|
|
30
30
|
SequenceSliceGetItem_, ListSliceSetItem_, VmapOperation_, TaylorOperation_, ListPop_, \
|
|
31
31
|
ListClear_, ListReverse_, ListExtend_, DictClear_, DictHasKey_, DictUpdate_, DictFromKeys_, \
|
|
32
32
|
ZerosLike_, TensorIndexGetitem_, TensorIndexSetitem_, ListAdd_, DictSetItem_, \
|
|
33
|
-
HandleBoolTensor_, PreSetitemByTuple_, StarredGetItem_
|
|
33
|
+
HandleBoolTensor_, PreSetitemByTuple_, StarredGetItem_, \
|
|
34
34
|
StarredUnpack_, StarredUnpackMerge_, IterConverter_, HasNext_, Next_, MSContext
|
|
35
35
|
from mindspore.common import dtype as mstype
|
|
36
36
|
from mindspore.common.api import jit, _pynative_executor, _wrap_func
|
|
@@ -346,9 +346,11 @@ class GradOperation(GradOperation_):
|
|
|
346
346
|
self.grad_position = (0,)
|
|
347
347
|
|
|
348
348
|
def __call__(self, fn, weights=None):
|
|
349
|
-
weights_id =
|
|
350
|
-
if
|
|
351
|
-
|
|
349
|
+
weights_id = ''
|
|
350
|
+
if context.get_context("mode") == context.GRAPH_MODE:
|
|
351
|
+
weights_id = _get_grad_weights_id(weights)
|
|
352
|
+
if self.grad_fn is not None and self.fn == fn and self.weights_id == weights_id:
|
|
353
|
+
return self.grad_fn
|
|
352
354
|
grad_ = GradOperation(self.get_all, self.get_by_list, self.sens_param)
|
|
353
355
|
# If calling Grad in GRAPH_MODE or calling Grad in functions decorated with 'jit', do grad in GRAPH_MODE
|
|
354
356
|
# If calling Grad in pure PYNATIVE_MODE do grad in PYNATIVE_MODE
|
|
@@ -374,8 +376,8 @@ class GradOperation(GradOperation_):
|
|
|
374
376
|
|
|
375
377
|
@_wrap_func
|
|
376
378
|
def after_grad(*args, **kwargs):
|
|
377
|
-
self._pynative_forward_run(fn, grad_, weights, args, kwargs)
|
|
378
|
-
out = _pynative_executor.grad(fn, grad_, weights, self.grad_position, *
|
|
379
|
+
run_args = self._pynative_forward_run(fn, grad_, weights, *args, **kwargs)
|
|
380
|
+
out = _pynative_executor.grad(fn, grad_, weights, self.grad_position, *run_args)
|
|
379
381
|
out = _grads_divided_by_device_num_if_recomputation(out)
|
|
380
382
|
return out
|
|
381
383
|
else:
|
|
@@ -396,26 +398,39 @@ class GradOperation(GradOperation_):
|
|
|
396
398
|
self.weights_id = weights_id
|
|
397
399
|
return self.grad_fn
|
|
398
400
|
|
|
399
|
-
def _pynative_forward_run(self, fn, grad, weights, args, kwargs):
|
|
400
|
-
"""
|
|
401
|
-
|
|
401
|
+
def _pynative_forward_run(self, fn, grad, weights, *args, **kwargs):
|
|
402
|
+
""" PyNative forward run to build grad graph. """
|
|
403
|
+
sens = None
|
|
402
404
|
if self.sens_param:
|
|
403
|
-
if 'sens'
|
|
404
|
-
|
|
405
|
+
if 'sens' in kwargs.keys():
|
|
406
|
+
sens = kwargs.pop('sens')
|
|
405
407
|
else:
|
|
406
|
-
|
|
407
|
-
|
|
408
|
+
# default use args last elem as sens
|
|
409
|
+
sens = args[-1]
|
|
410
|
+
args = args[:-1]
|
|
411
|
+
run_args = args
|
|
412
|
+
if kwargs:
|
|
413
|
+
run_args = args + tuple(kwargs.values())
|
|
414
|
+
|
|
415
|
+
# check run exclude sens
|
|
408
416
|
if isinstance(fn, (FunctionType, MethodType)):
|
|
409
|
-
if not _pynative_executor.check_run(grad, fn, weights, None, *
|
|
417
|
+
if not _pynative_executor.check_run(grad, fn, weights, None, *run_args):
|
|
410
418
|
_pynative_executor.set_grad_flag(True)
|
|
411
|
-
_pynative_executor.new_graph(fn, *args, **
|
|
412
|
-
output = fn(*args, **
|
|
413
|
-
_pynative_executor.end_graph(fn, output, *args, **
|
|
419
|
+
_pynative_executor.new_graph(fn, *args, **kwargs)
|
|
420
|
+
output = fn(*args, **kwargs)
|
|
421
|
+
_pynative_executor.end_graph(fn, output, *args, **kwargs)
|
|
414
422
|
else:
|
|
415
|
-
# Check if fn
|
|
416
|
-
if not _pynative_executor.check_run(grad, fn, weights, None, *
|
|
417
|
-
|
|
418
|
-
fn
|
|
423
|
+
# Check if fn has run already
|
|
424
|
+
if not _pynative_executor.check_run(grad, fn, weights, None, *run_args):
|
|
425
|
+
requires_grad = fn.requires_grad
|
|
426
|
+
fn.requires_grad = True
|
|
427
|
+
fn(*args, **kwargs)
|
|
428
|
+
fn.requires_grad = requires_grad
|
|
429
|
+
|
|
430
|
+
# If it has sens, keep sens as the last element
|
|
431
|
+
if sens is not None:
|
|
432
|
+
run_args += (sens,) if sens is not isinstance(run_args, tuple) else sens
|
|
433
|
+
return run_args
|
|
419
434
|
|
|
420
435
|
|
|
421
436
|
class _TaylorOperation(TaylorOperation_):
|
|
@@ -552,13 +567,15 @@ class _Grad(GradOperation_):
|
|
|
552
567
|
self.weights_id = None
|
|
553
568
|
|
|
554
569
|
def __call__(self, fn, weights=None, grad_position=0):
|
|
555
|
-
weights_id =
|
|
556
|
-
if
|
|
557
|
-
|
|
558
|
-
|
|
570
|
+
weights_id = ''
|
|
571
|
+
if context.get_context("mode") == context.GRAPH_MODE:
|
|
572
|
+
weights_id = _get_grad_weights_id(weights)
|
|
573
|
+
if self.grad_fn is not None and self.fn == fn and self.grad_position == grad_position and \
|
|
574
|
+
self.weights_id == weights_id:
|
|
575
|
+
return self.grad_fn
|
|
559
576
|
|
|
560
|
-
def aux_fn(*args):
|
|
561
|
-
outputs = fn(*args)
|
|
577
|
+
def aux_fn(*args, **kwargs):
|
|
578
|
+
outputs = fn(*args, **kwargs)
|
|
562
579
|
if not isinstance(outputs, tuple) or len(outputs) < 2:
|
|
563
580
|
raise ValueError("When has_aux is True, origin fn requires more than one outputs.")
|
|
564
581
|
res = (outputs[0],)
|
|
@@ -597,8 +614,8 @@ class _Grad(GradOperation_):
|
|
|
597
614
|
|
|
598
615
|
@_wrap_func
|
|
599
616
|
def after_grad(*args, **kwargs):
|
|
600
|
-
res = self._pynative_forward_run(fn, grad_, weights, args, kwargs)
|
|
601
|
-
out = _pynative_executor.grad(fn, grad_, weights, grad_position, *
|
|
617
|
+
run_args, res = self._pynative_forward_run(fn, grad_, weights, *args, **kwargs)
|
|
618
|
+
out = _pynative_executor.grad(fn, grad_, weights, grad_position, *run_args)
|
|
602
619
|
out = _grads_divided_by_device_num_if_recomputation(out)
|
|
603
620
|
if self.return_ids and out:
|
|
604
621
|
out = _combine_with_ids(grad_position, weights, out)
|
|
@@ -633,32 +650,49 @@ class _Grad(GradOperation_):
|
|
|
633
650
|
self.weights_id = weights_id
|
|
634
651
|
return self.grad_fn
|
|
635
652
|
|
|
636
|
-
def _pynative_forward_run(self, fn, grad, weights, args, kwargs):
|
|
637
|
-
"""
|
|
638
|
-
|
|
639
|
-
outputs = ()
|
|
653
|
+
def _pynative_forward_run(self, fn, grad, weights, *args, **kwargs):
|
|
654
|
+
""" PyNative forward runs to build grad graph. """
|
|
655
|
+
sens = None
|
|
640
656
|
if self.sens_param:
|
|
641
657
|
if 'sens' in kwargs.keys():
|
|
642
|
-
|
|
643
|
-
new_kwargs.pop('sens')
|
|
658
|
+
sens = kwargs.pop('sens')
|
|
644
659
|
else:
|
|
660
|
+
# default use args last elem as sens
|
|
661
|
+
sens = args[-1]
|
|
645
662
|
args = args[:-1]
|
|
663
|
+
run_args = args
|
|
664
|
+
if kwargs:
|
|
665
|
+
run_args = args + tuple(kwargs.values())
|
|
666
|
+
|
|
667
|
+
# check run exclude sens
|
|
668
|
+
outputs = ()
|
|
669
|
+
run_forward = False
|
|
646
670
|
if isinstance(fn, (FunctionType, MethodType)):
|
|
647
|
-
if not _pynative_executor.check_run(grad, fn, weights, self.grad_position, *
|
|
671
|
+
if not _pynative_executor.check_run(grad, fn, weights, self.grad_position, *run_args):
|
|
648
672
|
_pynative_executor.set_grad_flag(True)
|
|
649
|
-
_pynative_executor.new_graph(fn, *args, **
|
|
650
|
-
outputs = fn(*args, **
|
|
651
|
-
_pynative_executor.end_graph(fn, outputs, *args, **
|
|
652
|
-
|
|
673
|
+
_pynative_executor.new_graph(fn, *args, **kwargs)
|
|
674
|
+
outputs = fn(*args, **kwargs)
|
|
675
|
+
_pynative_executor.end_graph(fn, outputs, *args, **kwargs)
|
|
676
|
+
run_forward = True
|
|
653
677
|
else:
|
|
654
678
|
# Check if fn has run already.
|
|
655
|
-
if not _pynative_executor.check_run(grad, fn, weights, self.grad_position, *
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
|
|
679
|
+
if not _pynative_executor.check_run(grad, fn, weights, self.grad_position, *run_args):
|
|
680
|
+
requires_grad = fn.requires_grad
|
|
681
|
+
fn.requires_grad = True
|
|
682
|
+
outputs = fn(*args, **kwargs)
|
|
683
|
+
fn.requires_grad = requires_grad
|
|
684
|
+
run_forward = True
|
|
685
|
+
# If it has sens, keep sens as the last element
|
|
686
|
+
if sens is not None:
|
|
687
|
+
run_args += (sens,) if sens is not isinstance(run_args, tuple) else sens
|
|
688
|
+
|
|
689
|
+
# Normal run grad
|
|
690
|
+
if run_forward:
|
|
691
|
+
return run_args, outputs
|
|
692
|
+
|
|
659
693
|
if (self.get_value or self.has_aux) and not outputs:
|
|
660
|
-
outputs = fn(*args, **
|
|
661
|
-
return outputs
|
|
694
|
+
outputs = fn(*args, **kwargs)
|
|
695
|
+
return run_args, outputs
|
|
662
696
|
|
|
663
697
|
|
|
664
698
|
class _Vmap(VmapOperation_):
|
|
@@ -806,10 +840,12 @@ class MultitypeFuncGraph(MultitypeFuncGraph_):
|
|
|
806
840
|
|
|
807
841
|
class HyperMap(HyperMap_):
|
|
808
842
|
"""
|
|
809
|
-
|
|
843
|
+
HyperMap will apply the set operation to input sequences.
|
|
810
844
|
|
|
811
845
|
Apply the operations to every element of the sequence or nested sequence. Different
|
|
812
|
-
from `mindspore.ops.Map`, the `HyperMap` supports to apply on nested structure.
|
|
846
|
+
from `mindspore.ops.Map`, the `HyperMap` supports to apply on nested structure. The
|
|
847
|
+
`HyperMap` also supports dynamic sequences as input, but it does not extend this
|
|
848
|
+
support to nested dynamic sequences.
|
|
813
849
|
|
|
814
850
|
Args:
|
|
815
851
|
ops (Union[MultitypeFuncGraph, None], optional): `ops` is the operation to apply. If `ops` is `None`,
|
|
@@ -959,6 +995,7 @@ class _ListAppend(ListAppend_):
|
|
|
959
995
|
Args:
|
|
960
996
|
name (str): The name of the metafuncgraph object.
|
|
961
997
|
"""
|
|
998
|
+
|
|
962
999
|
# `__init__` method removed entirely
|
|
963
1000
|
def __call__(self, *args):
|
|
964
1001
|
pass
|
|
@@ -483,6 +483,7 @@ def format_index_tensor(index, arg):
|
|
|
483
483
|
index[format_idx] = F.select(index_tensor < 0, index_tensor + format_dim, index_tensor)
|
|
484
484
|
return index
|
|
485
485
|
index = Tensor(index)
|
|
486
|
+
format_dims = Tensor(format_dims)
|
|
486
487
|
return F.select(index < 0, index + format_dims, index)
|
|
487
488
|
|
|
488
489
|
|
|
@@ -41,7 +41,7 @@ def _number_not_in_tuple(x, y):
|
|
|
41
41
|
Returns:
|
|
42
42
|
bool, if x not in y return true, x in y return false.
|
|
43
43
|
"""
|
|
44
|
-
if F.
|
|
44
|
+
if F.is_sequence_value_unknown(y) or not F.isconstant(x):
|
|
45
45
|
return not InSequence()(x, y)
|
|
46
46
|
return not const_utils.scalar_in_sequence(x, y)
|
|
47
47
|
|
|
@@ -58,7 +58,7 @@ def _number_not_in_list(x, y):
|
|
|
58
58
|
Returns:
|
|
59
59
|
bool, if x not in y return true, x in y return false.
|
|
60
60
|
"""
|
|
61
|
-
if F.
|
|
61
|
+
if F.is_sequence_value_unknown(y) or not F.isconstant(x):
|
|
62
62
|
return not InSequence()(x, y)
|
|
63
63
|
return not const_utils.scalar_in_sequence(x, y)
|
|
64
64
|
|