mindspore 2.2.0__cp38-none-any.whl → 2.2.10__cp38-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/_akg/akg/composite/build_module.py +9 -15
- mindspore/_akg/akg/utils/ascend_profilier/__init__.py +0 -0
- mindspore/_akg/akg/utils/ascend_profilier/cann_file_parser.py +76 -0
- mindspore/_akg/akg/utils/ascend_profilier/file_manager.py +56 -0
- mindspore/_akg/akg/utils/ascend_profilier/op_summary_bean.py +23 -0
- mindspore/_akg/akg/utils/ascend_profilier/op_summary_headers.py +8 -0
- mindspore/_akg/akg/utils/ascend_profilier/op_summary_parser.py +42 -0
- mindspore/_akg/akg/utils/ascend_profilier/path_manager.py +65 -0
- mindspore/_akg/akg/utils/kernel_exec.py +41 -15
- mindspore/_akg/akg/utils/tbe_codegen_utils.py +27 -6
- mindspore/_akg/akg/utils/util.py +38 -0
- mindspore/_c_dataengine.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/_c_expression.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/_checkparam.py +3 -3
- mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
- mindspore/_extends/graph_kernel/splitter.py +3 -2
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +83 -66
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -4
- mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +2 -1
- mindspore/_extends/parse/standard_method.py +2 -9
- mindspore/_extends/remote/kernel_build_server.py +2 -1
- mindspore/_mindspore_offline_debug.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/bin/cache_admin +0 -0
- mindspore/bin/cache_server +0 -0
- mindspore/common/api.py +1 -1
- mindspore/common/auto_dynamic_shape.py +81 -85
- mindspore/common/dump.py +1 -1
- mindspore/common/tensor.py +3 -20
- mindspore/config/op_info.config +1 -1
- mindspore/context.py +11 -4
- mindspore/dataset/engine/datasets_standard_format.py +5 -0
- mindspore/dataset/vision/transforms.py +21 -21
- mindspore/experimental/optim/adam.py +1 -1
- mindspore/gen_ops.py +1 -1
- mindspore/include/api/model.h +17 -0
- mindspore/include/api/status.h +8 -3
- mindspore/lib/libdnnl.so.2 +0 -0
- mindspore/lib/libmindspore.so +0 -0
- mindspore/lib/libmindspore_backend.so +0 -0
- mindspore/lib/libmindspore_common.so +0 -0
- mindspore/lib/libmindspore_core.so +0 -0
- mindspore/lib/libmindspore_glog.so.0 +0 -0
- mindspore/lib/libmindspore_gpr.so.15 +0 -0
- mindspore/lib/libmindspore_grpc++.so.1 +0 -0
- mindspore/lib/libmindspore_grpc.so.15 +0 -0
- mindspore/lib/libmindspore_shared_lib.so +0 -0
- mindspore/lib/libnnacl.so +0 -0
- mindspore/lib/libopencv_core.so.4.5 +0 -0
- mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
- mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310/aic-ascend310-ops-info.json +123 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310p/aic-ascend310p-ops-info.json +123 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +158 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +37 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_proto/libop_proto.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +8 -80
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
- mindspore/lib/plugin/ascend/libakg.so +0 -0
- mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
- mindspore/lib/plugin/cpu/libakg.so +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
- mindspore/nn/cell.py +0 -3
- mindspore/nn/layer/activation.py +4 -5
- mindspore/nn/layer/conv.py +39 -23
- mindspore/nn/layer/flash_attention.py +90 -78
- mindspore/nn/layer/math.py +3 -7
- mindspore/nn/layer/rnn_cells.py +5 -5
- mindspore/nn/wrap/cell_wrapper.py +6 -0
- mindspore/numpy/utils_const.py +5 -5
- mindspore/ops/_grad_experimental/grad_array_ops.py +1 -1
- mindspore/ops/_grad_experimental/grad_implementations.py +2 -2
- mindspore/ops/_grad_experimental/grad_math_ops.py +19 -18
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
- mindspore/ops/_op_impl/aicpu/add.py +3 -3
- mindspore/ops/_utils/utils.py +2 -0
- mindspore/ops/composite/multitype_ops/_compile_utils.py +2 -1
- mindspore/ops/composite/multitype_ops/getitem_impl.py +2 -2
- mindspore/ops/function/array_func.py +10 -7
- mindspore/ops/function/grad/grad_func.py +0 -1
- mindspore/ops/function/nn_func.py +98 -9
- mindspore/ops/function/random_func.py +2 -1
- mindspore/ops/op_info_register.py +24 -21
- mindspore/ops/operations/__init__.py +3 -2
- mindspore/ops/operations/_grad_ops.py +24 -4
- mindspore/ops/operations/_inner_ops.py +155 -23
- mindspore/ops/operations/array_ops.py +9 -7
- mindspore/ops/operations/comm_ops.py +2 -2
- mindspore/ops/operations/custom_ops.py +85 -68
- mindspore/ops/operations/inner_ops.py +26 -3
- mindspore/ops/operations/math_ops.py +4 -3
- mindspore/ops/operations/nn_ops.py +109 -28
- mindspore/parallel/_parallel_serialization.py +10 -3
- mindspore/parallel/_tensor.py +4 -1
- mindspore/parallel/checkpoint_transform.py +13 -2
- mindspore/parallel/shard.py +17 -10
- mindspore/profiler/common/util.py +1 -0
- mindspore/profiler/parser/ascend_hccl_generator.py +232 -0
- mindspore/profiler/parser/ascend_msprof_exporter.py +86 -43
- mindspore/profiler/parser/ascend_msprof_generator.py +196 -9
- mindspore/profiler/parser/ascend_op_generator.py +1 -1
- mindspore/profiler/parser/ascend_timeline_generator.py +6 -182
- mindspore/profiler/parser/base_timeline_generator.py +1 -1
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +2 -2
- mindspore/profiler/parser/framework_parser.py +1 -1
- mindspore/profiler/parser/profiler_info.py +19 -0
- mindspore/profiler/profiling.py +46 -24
- mindspore/rewrite/api/pattern_engine.py +1 -1
- mindspore/rewrite/parsers/for_parser.py +1 -1
- mindspore/rewrite/symbol_tree.py +1 -4
- mindspore/run_check/_check_version.py +5 -3
- mindspore/safeguard/rewrite_obfuscation.py +52 -28
- mindspore/train/callback/_summary_collector.py +1 -1
- mindspore/train/dataset_helper.py +1 -0
- mindspore/train/model.py +2 -2
- mindspore/train/serialization.py +97 -11
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/summary_record.py +23 -7
- mindspore/version.py +1 -1
- {mindspore-2.2.0.dist-info → mindspore-2.2.10.dist-info}/METADATA +1 -1
- {mindspore-2.2.0.dist-info → mindspore-2.2.10.dist-info}/RECORD +138 -118
- {mindspore-2.2.0.dist-info → mindspore-2.2.10.dist-info}/WHEEL +0 -0
- {mindspore-2.2.0.dist-info → mindspore-2.2.10.dist-info}/entry_points.txt +0 -0
- {mindspore-2.2.0.dist-info → mindspore-2.2.10.dist-info}/top_level.txt +0 -0
mindspore/nn/layer/conv.py
CHANGED
|
@@ -718,9 +718,9 @@ class Conv3d(_Conv):
|
|
|
718
718
|
|
|
719
719
|
.. math::
|
|
720
720
|
\begin{array}{ll} \\
|
|
721
|
-
D_{out}
|
|
722
|
-
H_{out}
|
|
723
|
-
W_{out}
|
|
721
|
+
D_{out} = \left \lceil{\frac{D_{in}}{\text{stride[0]}}} \right \rceil \\
|
|
722
|
+
H_{out} = \left \lceil{\frac{H_{in}}{\text{stride[1]}}} \right \rceil \\
|
|
723
|
+
W_{out} = \left \lceil{\frac{W_{in}}{\text{stride[2]}}} \right \rceil \\
|
|
724
724
|
\end{array}
|
|
725
725
|
|
|
726
726
|
|
|
@@ -728,11 +728,11 @@ class Conv3d(_Conv):
|
|
|
728
728
|
|
|
729
729
|
.. math::
|
|
730
730
|
\begin{array}{ll} \\
|
|
731
|
-
D_{out}
|
|
731
|
+
D_{out} = \left \lfloor{\frac{D_{in} - \text{dilation[0]} \times (\text{kernel_size[0]} - 1) }
|
|
732
732
|
{\text{stride[0]}} + 1} \right \rfloor \\
|
|
733
|
-
H_{out}
|
|
733
|
+
H_{out} = \left \lfloor{\frac{H_{in} - \text{dilation[1]} \times (\text{kernel_size[1]} - 1) }
|
|
734
734
|
{\text{stride[1]}} + 1} \right \rfloor \\
|
|
735
|
-
W_{out}
|
|
735
|
+
W_{out} = \left \lfloor{\frac{W_{in} - \text{dilation[2]} \times (\text{kernel_size[2]} - 1) }
|
|
736
736
|
{\text{stride[2]}} + 1} \right \rfloor \\
|
|
737
737
|
\end{array}
|
|
738
738
|
|
|
@@ -740,11 +740,11 @@ class Conv3d(_Conv):
|
|
|
740
740
|
|
|
741
741
|
.. math::
|
|
742
742
|
\begin{array}{ll} \\
|
|
743
|
-
D_{out}
|
|
743
|
+
D_{out} = \left \lfloor{\frac{D_{in} + padding[0] + padding[1] - (\text{dilation[0]} - 1) \times
|
|
744
744
|
\text{kernel_size[0]} - 1 }{\text{stride[0]}} + 1} \right \rfloor \\
|
|
745
|
-
H_{out}
|
|
745
|
+
H_{out} = \left \lfloor{\frac{H_{in} + padding[2] + padding[3] - (\text{dilation[1]} - 1) \times
|
|
746
746
|
\text{kernel_size[1]} - 1 }{\text{stride[1]}} + 1} \right \rfloor \\
|
|
747
|
-
W_{out}
|
|
747
|
+
W_{out} = \left \lfloor{\frac{W_{in} + padding[4] + padding[5] - (\text{dilation[2]} - 1) \times
|
|
748
748
|
\text{kernel_size[2]} - 1 }{\text{stride[2]}} + 1} \right \rfloor \\
|
|
749
749
|
\end{array}
|
|
750
750
|
|
|
@@ -812,7 +812,7 @@ class Conv3d(_Conv):
|
|
|
812
812
|
bias_init,
|
|
813
813
|
data_format,
|
|
814
814
|
dtype=dtype)
|
|
815
|
-
out_channels = self.out_channels
|
|
815
|
+
out_channels = self.out_channels // group
|
|
816
816
|
self.conv3d = P.Conv3D(out_channel=out_channels,
|
|
817
817
|
kernel_size=self.kernel_size,
|
|
818
818
|
mode=1,
|
|
@@ -820,17 +820,33 @@ class Conv3d(_Conv):
|
|
|
820
820
|
pad=self.padding,
|
|
821
821
|
stride=self.stride,
|
|
822
822
|
dilation=self.dilation,
|
|
823
|
-
group=
|
|
823
|
+
group=1,
|
|
824
824
|
data_format=self.data_format)
|
|
825
825
|
self.bias_add = P.BiasAdd(data_format=self.data_format)
|
|
826
826
|
self.shape = P.Shape()
|
|
827
|
+
self.concat = P.Concat(1)
|
|
828
|
+
self.split_0 = P.Split(0, self.group)
|
|
829
|
+
self.split_1 = P.Split(1, self.group)
|
|
827
830
|
|
|
828
831
|
def construct(self, x):
|
|
829
832
|
x_shape = self.shape(x)
|
|
830
833
|
_check_input_5dims(x_shape, self.cls_name)
|
|
831
|
-
|
|
832
|
-
|
|
833
|
-
|
|
834
|
+
if self.group == 1:
|
|
835
|
+
out = self.conv3d(x, self.weight)
|
|
836
|
+
if self.has_bias:
|
|
837
|
+
out = self.bias_add(out, self.bias)
|
|
838
|
+
else:
|
|
839
|
+
features = self.split_1(x)
|
|
840
|
+
weights = self.split_0(self.weight)
|
|
841
|
+
outputs = ()
|
|
842
|
+
for i in range(self.group):
|
|
843
|
+
output = self.conv3d(features[i], weights[i])
|
|
844
|
+
outputs = outputs + (output,)
|
|
845
|
+
out = self.concat(outputs)
|
|
846
|
+
if self.bias is not None:
|
|
847
|
+
new_shape = [1 for _ in range(out.ndim)]
|
|
848
|
+
new_shape[1] = self.out_channels
|
|
849
|
+
out = out + self.bias.reshape(new_shape)
|
|
834
850
|
return out
|
|
835
851
|
|
|
836
852
|
|
|
@@ -921,9 +937,9 @@ class Conv3dTranspose(_Conv):
|
|
|
921
937
|
|
|
922
938
|
.. math::
|
|
923
939
|
\begin{array}{ll} \\
|
|
924
|
-
D_{out}
|
|
925
|
-
H_{out}
|
|
926
|
-
W_{out}
|
|
940
|
+
D_{out} = \left \lfloor{\frac{D_{in}}{\text{stride[0]}} + 1} \right \rfloor \\
|
|
941
|
+
H_{out} = \left \lfloor{\frac{H_{in}}{\text{stride[1]}} + 1} \right \rfloor \\
|
|
942
|
+
W_{out} = \left \lfloor{\frac{W_{in}}{\text{stride[2]}} + 1} \right \rfloor \\
|
|
927
943
|
\end{array}
|
|
928
944
|
|
|
929
945
|
|
|
@@ -931,11 +947,11 @@ class Conv3dTranspose(_Conv):
|
|
|
931
947
|
|
|
932
948
|
.. math::
|
|
933
949
|
\begin{array}{ll} \\
|
|
934
|
-
D_{out}
|
|
950
|
+
D_{out} = \left \lfloor{\frac{D_{in} - \text{dilation[0]} \times (\text{kernel_size[0]} - 1) }
|
|
935
951
|
{\text{stride[0]}} + 1} \right \rfloor \\
|
|
936
|
-
H_{out}
|
|
952
|
+
H_{out} = \left \lfloor{\frac{H_{in} - \text{dilation[1]} \times (\text{kernel_size[1]} - 1) }
|
|
937
953
|
{\text{stride[1]}} + 1} \right \rfloor \\
|
|
938
|
-
W_{out}
|
|
954
|
+
W_{out} = \left \lfloor{\frac{W_{in} - \text{dilation[2]} \times (\text{kernel_size[2]} - 1) }
|
|
939
955
|
{\text{stride[2]}} + 1} \right \rfloor \\
|
|
940
956
|
\end{array}
|
|
941
957
|
|
|
@@ -943,11 +959,11 @@ class Conv3dTranspose(_Conv):
|
|
|
943
959
|
|
|
944
960
|
.. math::
|
|
945
961
|
\begin{array}{ll} \\
|
|
946
|
-
D_{out}
|
|
962
|
+
D_{out} = \left \lfloor{\frac{D_{in} + padding[0] + padding[1] - (\text{dilation[0]} - 1) \times
|
|
947
963
|
\text{kernel_size[0]} - 1 }{\text{stride[0]}} + 1} \right \rfloor \\
|
|
948
|
-
H_{out}
|
|
964
|
+
H_{out} = \left \lfloor{\frac{H_{in} + padding[2] + padding[3] - (\text{dilation[1]} - 1) \times
|
|
949
965
|
\text{kernel_size[1]} - 1 }{\text{stride[1]}} + 1} \right \rfloor \\
|
|
950
|
-
W_{out}
|
|
966
|
+
W_{out} = \left \lfloor{\frac{W_{in} + padding[4] + padding[5] - (\text{dilation[2]} - 1) \times
|
|
951
967
|
\text{kernel_size[2]} - 1 }{\text{stride[2]}} + 1} \right \rfloor \\
|
|
952
968
|
\end{array}
|
|
953
969
|
|
|
@@ -57,14 +57,15 @@ class FlashAttention(Cell):
|
|
|
57
57
|
Default True
|
|
58
58
|
alibi(bool): This parameter indicates whether the flashattention supports the Alibi.
|
|
59
59
|
Default: False
|
|
60
|
+
use_mqa(bool): Using MHA if True, only take effect under 910B. Default: False.
|
|
60
61
|
|
|
61
62
|
|
|
62
63
|
Inputs:
|
|
63
64
|
- **query** (Tensor) - Tensor query (:class:`mstype.fp16` [batch_size, head_num, seq_length, head_dim])
|
|
64
65
|
- **key** (Tensor) - Tensor key (:class:`mstype.fp16` [batch_size, head_num, seq_length, head_dim])
|
|
65
66
|
- **value** (Tensor) - Tensor value (:class:`mstype.fp16` [batch_size, head_num, seq_length, head_dim])
|
|
66
|
-
- **attention_mask** (Tensor) - Float Tensor the mask of (:class:`mstype.fp16`
|
|
67
|
-
|
|
67
|
+
- **attention_mask** (Tensor) - Float Tensor the mask of (:class:`mstype.fp16` `mstype.uint8`
|
|
68
|
+
[batch_size, seq_length, seq_length]): A matrix to pass masked information.
|
|
68
69
|
|
|
69
70
|
Outputs:
|
|
70
71
|
A Tensor. The output of the attention with shape [batch_size, head_num, seq_length, head_dim]
|
|
@@ -102,17 +103,23 @@ class FlashAttention(Cell):
|
|
|
102
103
|
mp=1,
|
|
103
104
|
high_precision=False,
|
|
104
105
|
have_attention_mask_batch=True,
|
|
105
|
-
alibi=False
|
|
106
|
+
alibi=False,
|
|
107
|
+
use_mqa=False
|
|
106
108
|
):
|
|
107
109
|
super(FlashAttention, self).__init__()
|
|
108
110
|
|
|
109
111
|
scaling_constant = math.sqrt(head_dim)
|
|
110
112
|
if scaling_constant == 0:
|
|
111
113
|
raise ValueError("the scaling constant must not be 0.")
|
|
112
|
-
self.
|
|
113
|
-
|
|
114
|
-
self.is_910A = MSContext.get_instance().get_ascend_soc_version() == "Ascend910"
|
|
114
|
+
self.dropout_rate = dropout_rate
|
|
115
|
+
self.is_910A = MSContext.get_instance().get_ascend_soc_version() == "ascend910"
|
|
115
116
|
if self.is_910A:
|
|
117
|
+
self.scale_factor = Tensor([1. / math.sqrt(scaling_constant)], dtype=mstype.float16)
|
|
118
|
+
self.scale_mul = ops.Mul().shard(((dp, mp, 1, 1), (1,)))
|
|
119
|
+
self.ones = ops.Ones()
|
|
120
|
+
self.dim_mask = Tensor([1 for _ in range(head_dim)], dtype=mstype.int8)
|
|
121
|
+
self.have_attention_mask_batch = have_attention_mask_batch
|
|
122
|
+
self.alibi = alibi
|
|
116
123
|
self.flash_attention = get_flash_attention(
|
|
117
124
|
prev_block_num=prev_block_num,
|
|
118
125
|
next_block_num=next_block_num,
|
|
@@ -120,6 +127,10 @@ class FlashAttention(Cell):
|
|
|
120
127
|
high_precision=high_precision
|
|
121
128
|
)
|
|
122
129
|
self.flash_attention.add_prim_attr("primitive_target", "Ascend")
|
|
130
|
+
fa_strategies = ((dp, mp, 1, 1),
|
|
131
|
+
(dp, mp, 1, 1),
|
|
132
|
+
(dp, mp, 1, 1))
|
|
133
|
+
self.shard(fa_strategies)
|
|
123
134
|
else:
|
|
124
135
|
if alibi:
|
|
125
136
|
raise ValueError(f"When soc_version is not Ascend910A, alibi must be False")
|
|
@@ -128,25 +139,27 @@ class FlashAttention(Cell):
|
|
|
128
139
|
self.reshape = ops.Reshape()
|
|
129
140
|
self.zeros_like = ops.ZerosLike().shard(((dp, mp, 1, 1),))
|
|
130
141
|
self.zeros = ops.Zeros()
|
|
131
|
-
self.
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
142
|
+
self.attn_cast = ops.Cast()
|
|
143
|
+
if use_mqa:
|
|
144
|
+
fa_strategies = ((dp, mp, 1, 1),
|
|
145
|
+
(dp, 1, 1, 1),
|
|
146
|
+
(dp, 1, 1, 1),
|
|
147
|
+
(dp, 1, 1, 1))
|
|
148
|
+
else:
|
|
149
|
+
fa_strategies = ((dp, mp, 1, 1),
|
|
150
|
+
(dp, mp, 1, 1),
|
|
151
|
+
(dp, mp, 1, 1),
|
|
152
|
+
(dp, 1, 1, 1))
|
|
136
153
|
if dropout_rate > 1e-5:
|
|
137
154
|
fa_strategies += ((dp, mp, 1, 1),)
|
|
138
155
|
self.flash_attention = FlashAttentionScore(head_num=head_num, pre_tokens=prev_block_num,
|
|
139
156
|
next_tokens=next_block_num,
|
|
140
157
|
keep_prob=1 - dropout_rate,
|
|
141
|
-
scale_value=1.
|
|
142
|
-
inner_precise=0 if high_precision else 1
|
|
158
|
+
scale_value=1. / scaling_constant,
|
|
159
|
+
inner_precise=0 if high_precision else 1,
|
|
160
|
+
input_layout="BNSD").shard(fa_strategies)
|
|
143
161
|
|
|
144
|
-
self.ones = ops.Ones()
|
|
145
|
-
self.dim_mask = Tensor([1 for _ in range(head_dim)], dtype=mstype.int8)
|
|
146
|
-
self.scale_mul = ops.Mul().shard(((dp, mp, 1, 1), (1,)))
|
|
147
162
|
self.dropout_rate = dropout_rate
|
|
148
|
-
self.have_attention_mask_batch = have_attention_mask_batch
|
|
149
|
-
self.alibi = alibi
|
|
150
163
|
if self.dropout_rate > 1e-5:
|
|
151
164
|
self.keep_prob = Tensor(1 - self.dropout_rate, dtype=mstype.float16)
|
|
152
165
|
self.fill_v2 = ops.FillV2().shard(((dp, mp, 1, 1), ()))
|
|
@@ -162,46 +175,49 @@ class FlashAttention(Cell):
|
|
|
162
175
|
such as MatMul. Default: None.
|
|
163
176
|
:return:
|
|
164
177
|
"""
|
|
165
|
-
if
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
178
|
+
if self.is_910A:
|
|
179
|
+
if in_strategy is None:
|
|
180
|
+
# default: dp=1, mp=1, construct inputs only contain query, key, value
|
|
181
|
+
in_strategy = (
|
|
182
|
+
(1, 1, 1, 1),
|
|
183
|
+
(1, 1, 1, 1),
|
|
184
|
+
(1, 1, 1, 1),
|
|
185
|
+
)
|
|
186
|
+
self.flash_attention.shard(in_strategy)
|
|
187
|
+
dp = in_strategy[0][0]
|
|
188
|
+
mp = in_strategy[0][1]
|
|
189
|
+
self.flash_attention.add_prim_attr("dev_matrix_shape", [dp, mp, 1, 1])
|
|
190
|
+
inputs_tensor_map = [
|
|
191
|
+
[3, 2, 1, 0],
|
|
192
|
+
[3, 2, 1, 0],
|
|
193
|
+
[3, 2, 1, 0],
|
|
194
|
+
]
|
|
195
|
+
if self.have_attention_mask_batch:
|
|
196
|
+
inputs_tensor_map.append([3, 1, 0])
|
|
197
|
+
else:
|
|
198
|
+
inputs_tensor_map.append([-1, 1, 0])
|
|
185
199
|
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
200
|
+
input_empty_args_num = 2
|
|
201
|
+
# dropout_mask
|
|
202
|
+
if self.dropout_rate > 1e-5:
|
|
203
|
+
input_empty_args_num -= 1
|
|
204
|
+
inputs_tensor_map.append([3, 2, 1, 0])
|
|
191
205
|
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
206
|
+
if self.alibi:
|
|
207
|
+
input_empty_args_num -= 1
|
|
208
|
+
inputs_tensor_map.append([3, 2, 1, 0])
|
|
195
209
|
|
|
196
|
-
|
|
210
|
+
self.flash_attention.add_prim_attr("inputs_tensor_map", inputs_tensor_map)
|
|
197
211
|
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
212
|
+
self.flash_attention.add_prim_attr("outputs_tensor_map", [
|
|
213
|
+
[3, 2, 1, 0], # O
|
|
214
|
+
[3, 2, 1], # L
|
|
215
|
+
[3, 2, 1] # M
|
|
216
|
+
])
|
|
217
|
+
self.flash_attention.add_prim_attr("as_loss_divisor", 0)
|
|
218
|
+
self.flash_attention.add_prim_attr("empty_mirror_ops", input_empty_args_num)
|
|
219
|
+
else:
|
|
220
|
+
self.flash_attention.shard(in_strategy)
|
|
205
221
|
|
|
206
222
|
def construct(self, query, key, value, attn_mask=None, alibi_mask=None):
|
|
207
223
|
"""FlashAttention forward
|
|
@@ -212,24 +228,22 @@ class FlashAttention(Cell):
|
|
|
212
228
|
:param alibi_mask: [bsz, head_num, 1, seq_len], if not None
|
|
213
229
|
:return: output [bsz, head_num, seq_len, head_dim]
|
|
214
230
|
"""
|
|
215
|
-
query = self.scale_mul(query, self.scale_factor)
|
|
216
231
|
bsz, head_num, seq_len, head_dim = query.shape
|
|
217
|
-
_, k_head_num, k_seq_len, _ = key.shape
|
|
218
|
-
_, v_head_num, v_seq_len, _ = value.shape
|
|
219
|
-
if head_num != k_head_num or head_num != v_head_num:
|
|
220
|
-
raise ValueError(
|
|
221
|
-
"the head_num of query, key and value must be the same, "
|
|
222
|
-
"If different head_num are used, users need to change themselves to be same by tile.")
|
|
223
|
-
if seq_len % 16 != 0 or k_seq_len % 16 != 0 or k_seq_len != v_seq_len:
|
|
224
|
-
raise ValueError(
|
|
225
|
-
"query, key, value seq_len must be a multiple of 16, and key seq_len, value seq_len must be the same.")
|
|
226
|
-
|
|
227
|
-
if head_dim > 304:
|
|
228
|
-
raise ValueError(
|
|
229
|
-
"the head_dim must be less than 304, otherwise the ub would be OOM.")
|
|
230
|
-
|
|
231
232
|
if self.is_910A:
|
|
233
|
+
_, k_head_num, k_seq_len, _ = key.shape
|
|
234
|
+
_, v_head_num, v_seq_len, _ = value.shape
|
|
235
|
+
if head_num != k_head_num or head_num != v_head_num:
|
|
236
|
+
raise ValueError(
|
|
237
|
+
"the head_num of query, key and value must be the same, "
|
|
238
|
+
"If different head_num are used, users need to change themselves to be same by tile.")
|
|
239
|
+
if seq_len % 16 != 0 or k_seq_len % 16 != 0 or k_seq_len != v_seq_len:
|
|
240
|
+
raise ValueError(
|
|
241
|
+
"query, key, value seq_len must be a multiple of 16, "
|
|
242
|
+
"and the seq_len between key and value must be equal.")
|
|
232
243
|
# 910A -- FlashAttentionPrimtive
|
|
244
|
+
if head_dim > 304:
|
|
245
|
+
raise ValueError(
|
|
246
|
+
"the head_dim must be less than 304, otherwise the ub would be OOM.")
|
|
233
247
|
if self.dropout_rate > 1e-5:
|
|
234
248
|
drop_mask_bits = self.drop_gen_mask((bsz, head_num, seq_len, seq_len), self.keep_prob)
|
|
235
249
|
tensor_shape = Tensor((bsz, head_num, seq_len, seq_len), mstype.int32)
|
|
@@ -238,27 +252,25 @@ class FlashAttention(Cell):
|
|
|
238
252
|
drop_mask = self.do_dropout(ones, drop_mask_bits, self.keep_prob)
|
|
239
253
|
else:
|
|
240
254
|
drop_mask = None
|
|
255
|
+
query = self.scale_mul(query, self.scale_factor)
|
|
256
|
+
key = self.scale_mul(key, self.scale_factor)
|
|
257
|
+
attn_mask = self.cast(attn_mask, mstype.float16)
|
|
241
258
|
output, _, _ = self.flash_attention(query, key, value, attn_mask, drop_mask, alibi_mask)
|
|
242
259
|
else:
|
|
243
|
-
# FlashAttentionScore
|
|
244
|
-
# Useless input, just for binary calls.
|
|
260
|
+
# 910B -- FlashAttentionScore
|
|
245
261
|
if self.dropout_rate > 1e-5:
|
|
246
262
|
drop_mask_bits = self.reshape(self.drop_gen_mask((bsz, head_num, seq_len, seq_len), self.keep_prob),
|
|
247
263
|
(bsz, head_num, seq_len, seq_len // 8))
|
|
248
264
|
else:
|
|
249
265
|
drop_mask_bits = None
|
|
250
|
-
# (B,
|
|
251
|
-
|
|
252
|
-
key = self.reshape(self.transpose_4d_pre(key, (0, 2, 1, 3)), (bsz, seq_len, -1))
|
|
253
|
-
value = self.reshape(self.transpose_4d_pre(value, (0, 2, 1, 3)), (bsz, seq_len, -1))
|
|
254
|
-
attn_mask = self.attn_expand_dims(attn_mask, 1)
|
|
266
|
+
# (B, S, S) -> (B, 1, S, S)
|
|
267
|
+
attn_mask = self.cast(self.reshape(attn_mask, (bsz, 1, seq_len, seq_len)), mstype.uint8)
|
|
255
268
|
output, _, _ = self.flash_attention(query,
|
|
256
269
|
key,
|
|
257
270
|
value,
|
|
258
271
|
attn_mask,
|
|
259
272
|
drop_mask_bits,
|
|
260
273
|
None,
|
|
274
|
+
None,
|
|
261
275
|
None)
|
|
262
|
-
output = self.transpose_4d_post(self.reshape(output, (bsz, seq_len, head_num, head_dim)), (0, 2, 1, 3))
|
|
263
|
-
|
|
264
276
|
return output
|
mindspore/nn/layer/math.py
CHANGED
|
@@ -375,9 +375,6 @@ class DiGamma(Cell):
|
|
|
375
375
|
nan, real_result)
|
|
376
376
|
|
|
377
377
|
|
|
378
|
-
eps_fp32 = Tensor(np.finfo(np.float32).eps, mstype.float32)
|
|
379
|
-
|
|
380
|
-
|
|
381
378
|
def _while_helper_func(cond, body, vals):
|
|
382
379
|
while cond(vals).any():
|
|
383
380
|
vals = body(vals)
|
|
@@ -394,7 +391,7 @@ def _igamma_series(ax, x, a, enabled):
|
|
|
394
391
|
select = P.Select()
|
|
395
392
|
|
|
396
393
|
# If more data types are supported, this epsilon need to be selected.
|
|
397
|
-
epsilon =
|
|
394
|
+
epsilon = Tensor(np.finfo(np.float32).eps, mstype.float32)
|
|
398
395
|
|
|
399
396
|
def cond(vals):
|
|
400
397
|
enabled = vals[0]
|
|
@@ -443,7 +440,7 @@ def _igammac_continued_fraction(ax, x, a, enabled):
|
|
|
443
440
|
select = P.Select()
|
|
444
441
|
|
|
445
442
|
# If more data types are supported, this epsilon need to be selected.
|
|
446
|
-
epsilon =
|
|
443
|
+
epsilon = Tensor(np.finfo(np.float32).eps, mstype.float32)
|
|
447
444
|
|
|
448
445
|
def cond(vals):
|
|
449
446
|
enabled = vals[0]
|
|
@@ -620,8 +617,7 @@ class IGamma(Cell):
|
|
|
620
617
|
x = F.broadcast_to(x, para_shape)
|
|
621
618
|
a = F.broadcast_to(a, para_shape)
|
|
622
619
|
x_is_zero = self.equal(x, 0)
|
|
623
|
-
|
|
624
|
-
underflow = self.less(ax, self.neg(log_maxfloat))
|
|
620
|
+
underflow = self.less(ax, self.neg(self.log_maxfloat32))
|
|
625
621
|
ax = self.exp(ax)
|
|
626
622
|
enabled = self.logicalnot(self.logicalor(self.logicalor(x_is_zero, domain_error), underflow))
|
|
627
623
|
output = self.select(use_igammac,
|
mindspore/nn/layer/rnn_cells.py
CHANGED
|
@@ -83,7 +83,7 @@ def _check_lstmcell_init(func):
|
|
|
83
83
|
|
|
84
84
|
|
|
85
85
|
def _rnn_tanh_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
|
|
86
|
-
|
|
86
|
+
"""RNN cell function with tanh activation"""
|
|
87
87
|
if b_ih is None:
|
|
88
88
|
igates = P.MatMul(False, True)(inputs, w_ih)
|
|
89
89
|
hgates = P.MatMul(False, True)(hidden, w_hh)
|
|
@@ -94,7 +94,7 @@ def _rnn_tanh_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
|
|
|
94
94
|
|
|
95
95
|
|
|
96
96
|
def _rnn_relu_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
|
|
97
|
-
|
|
97
|
+
"""RNN cell function with relu activation"""
|
|
98
98
|
if b_ih is None:
|
|
99
99
|
igates = P.MatMul(False, True)(inputs, w_ih)
|
|
100
100
|
hgates = P.MatMul(False, True)(hidden, w_hh)
|
|
@@ -105,7 +105,7 @@ def _rnn_relu_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
|
|
|
105
105
|
|
|
106
106
|
|
|
107
107
|
def _lstm_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
|
|
108
|
-
|
|
108
|
+
"""LSTM cell function"""
|
|
109
109
|
hx, cx = hidden
|
|
110
110
|
if b_ih is None:
|
|
111
111
|
gates = P.MatMul(False, True)(inputs, w_ih) + P.MatMul(False, True)(hx, w_hh)
|
|
@@ -125,7 +125,7 @@ def _lstm_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
|
|
|
125
125
|
|
|
126
126
|
|
|
127
127
|
def _gru_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
|
|
128
|
-
|
|
128
|
+
"""GRU cell function"""
|
|
129
129
|
if b_ih is None:
|
|
130
130
|
gi = P.MatMul(False, True)(inputs, w_ih)
|
|
131
131
|
gh = P.MatMul(False, True)(hidden, w_hh)
|
|
@@ -144,7 +144,7 @@ def _gru_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
|
|
|
144
144
|
|
|
145
145
|
|
|
146
146
|
class RNNCellBase(Cell):
|
|
147
|
-
|
|
147
|
+
"""Basic class for RNN Cells"""
|
|
148
148
|
def __init__(self, input_size: int, hidden_size: int, has_bias: bool, num_chunks: int,
|
|
149
149
|
dtype=mstype.float32):
|
|
150
150
|
super().__init__()
|
|
@@ -644,6 +644,9 @@ class PipelineCell(Cell):
|
|
|
644
644
|
self.micro_inputs = nn.CellList()
|
|
645
645
|
self.micro_size = micro_size
|
|
646
646
|
self.add_list = []
|
|
647
|
+
if not isinstance(network, Cell):
|
|
648
|
+
raise TypeError("For 'PipelineCell', the argument 'network' must cell type, "
|
|
649
|
+
"but got the type : {}.".format(type(network)))
|
|
647
650
|
if not isinstance(micro_size, int):
|
|
648
651
|
raise TypeError("For 'PipelineCell', the argument 'micro_size' must be integer, "
|
|
649
652
|
"but got the type : {}.".format(type(micro_size)))
|
|
@@ -689,6 +692,9 @@ class GradAccumulationCell(Cell):
|
|
|
689
692
|
self.micro_inputs = nn.CellList()
|
|
690
693
|
self.micro_size = micro_size
|
|
691
694
|
self.add_list = []
|
|
695
|
+
if not isinstance(network, Cell):
|
|
696
|
+
raise TypeError("For 'GradAccumulationCell', the argument 'network' must cell type, "
|
|
697
|
+
"but got the type : {}.".format(type(network)))
|
|
692
698
|
if not isinstance(micro_size, int):
|
|
693
699
|
raise TypeError("For 'GradAccumulationCell', the argument 'micro_size' must be integer, "
|
|
694
700
|
"but got the type : {}.".format(type(micro_size)))
|
mindspore/numpy/utils_const.py
CHANGED
|
@@ -143,8 +143,8 @@ def _infer_out_shape(*shapes):
|
|
|
143
143
|
shape_out = list()
|
|
144
144
|
max_len = max([len(it) for it in shapes])
|
|
145
145
|
for i in range(max_len):
|
|
146
|
-
items = [
|
|
147
|
-
|
|
146
|
+
items = [
|
|
147
|
+
it[i - max_len + len(it)] if i - max_len + len(it) >= 0 else 1 for it in shapes]
|
|
148
148
|
max_size = 0 if 0 in items else max(items)
|
|
149
149
|
_check()
|
|
150
150
|
shape_out.append(max_size)
|
|
@@ -158,8 +158,8 @@ def _can_broadcast(*shapes):
|
|
|
158
158
|
"""
|
|
159
159
|
max_len = max([len(it) for it in shapes])
|
|
160
160
|
for i in range(max_len):
|
|
161
|
-
items = [
|
|
162
|
-
|
|
161
|
+
items = [
|
|
162
|
+
it[i - max_len + len(it)] if i - max_len + len(it) >= 0 else 1 for it in shapes]
|
|
163
163
|
max_size = 0 if 0 in items else max(items)
|
|
164
164
|
if any(item not in (1, max_size) for item in items):
|
|
165
165
|
return False
|
|
@@ -399,7 +399,7 @@ def _broadcast_tuples(tup1, tup2):
|
|
|
399
399
|
if not isinstance(tup1, (tuple, list)) or not isinstance(tup2, (tuple, list)):
|
|
400
400
|
raise TypeError("input shift and axis must be tuple or list or int.")
|
|
401
401
|
if len(tup1) == len(tup2) or len(tup1) == 1 or len(tup2) == 1:
|
|
402
|
-
return
|
|
402
|
+
return
|
|
403
403
|
raise ValueError("shape mismatch: objects cannot be broadcast to a single shape")
|
|
404
404
|
|
|
405
405
|
tup1 = (tup1,) if isinstance(tup1, int) else tup1
|
|
@@ -203,7 +203,7 @@ def get_bprop_index_put(self):
|
|
|
203
203
|
if is_ascend:
|
|
204
204
|
indices_ms = [convert_idx_positive(indices_ms[i], x1.shape[i]) for i in range(len(indices_ms))]
|
|
205
205
|
indices_me = stack(indices_ms)
|
|
206
|
-
indices_grad = F.transpose(indices_me, F.make_range(F.rank(indices_me)-1, -1, -1))
|
|
206
|
+
indices_grad = F.transpose(indices_me, F.make_range(F.rank(indices_me) - 1, -1, -1))
|
|
207
207
|
values_grad = gather_nd(dout, indices_grad)
|
|
208
208
|
if equal(cast(x2.shape[0], mstype.int32), Tensor(1)):
|
|
209
209
|
values_grad = values_grad.sum().reshape(1)
|
|
@@ -19,7 +19,7 @@ from mindspore.ops import functional as F
|
|
|
19
19
|
from mindspore.ops import operations as P
|
|
20
20
|
from mindspore.ops.composite import multitype_ops as C
|
|
21
21
|
from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like
|
|
22
|
-
from mindspore.ops._grad_experimental.grad_base import bprops
|
|
22
|
+
from mindspore.ops._grad_experimental.grad_base import bprops, bprop_getters
|
|
23
23
|
from mindspore.common import dtype as mstype
|
|
24
24
|
|
|
25
25
|
get_dtype = P.DType()
|
|
@@ -193,7 +193,7 @@ def bprop_tensor_move(x, out, dout):
|
|
|
193
193
|
return (dout,)
|
|
194
194
|
|
|
195
195
|
|
|
196
|
-
@
|
|
196
|
+
@bprop_getters.register("DictInplaceSetItem")
|
|
197
197
|
def get_bprop_dict_inplace_setitem(self):
|
|
198
198
|
"""Generate bprop for dict inplace pop"""
|
|
199
199
|
|
|
@@ -135,7 +135,7 @@ def get_bprop_matrix_triangular_solve(self):
|
|
|
135
135
|
|
|
136
136
|
def bprop(matrix, rhs, out, dout):
|
|
137
137
|
grad_rhs = matrix_triangular_solve_op(matrix, dout)
|
|
138
|
-
if matrix.dtype
|
|
138
|
+
if matrix.dtype in (mstype.complex64, mstype.complex128):
|
|
139
139
|
grad_rhs_temp = _adjoint(grad_rhs)
|
|
140
140
|
out_temp = _adjoint(out)
|
|
141
141
|
else:
|
|
@@ -156,14 +156,14 @@ def get_bprop_matrix_triangular_solve(self):
|
|
|
156
156
|
grad_matrix = mat_mul_op(grad_rhs, out_temp)
|
|
157
157
|
grad_matrix = neg_op(grad_matrix)
|
|
158
158
|
if lower_a:
|
|
159
|
-
if grad_matrix.dtype
|
|
159
|
+
if grad_matrix.dtype in (mstype.complex64, mstype.complex128):
|
|
160
160
|
grad_matrix_real = matrix_band_part_op(real_op(grad_matrix), -1, 0)
|
|
161
161
|
grad_matrix_imag = matrix_band_part_op(imag_op(grad_matrix), -1, 0)
|
|
162
162
|
grad_matrix = complex_op(grad_matrix_real, grad_matrix_imag)
|
|
163
163
|
else:
|
|
164
164
|
grad_matrix = matrix_band_part_op(grad_matrix, -1, 0)
|
|
165
165
|
else:
|
|
166
|
-
if grad_matrix.dtype
|
|
166
|
+
if grad_matrix.dtype in (mstype.complex64, mstype.complex128):
|
|
167
167
|
grad_matrix_real = matrix_band_part_op(real_op(grad_matrix), 0, -1)
|
|
168
168
|
grad_matrix_imag = matrix_band_part_op(imag_op(grad_matrix), 0, -1)
|
|
169
169
|
grad_matrix = complex_op(grad_matrix_real, grad_matrix_imag)
|
|
@@ -219,7 +219,7 @@ def get_bprop_matrix_solve(self):
|
|
|
219
219
|
@_primexpr
|
|
220
220
|
def _generate_perm_matrix_solve_ls(x_dim):
|
|
221
221
|
perm = tuple(range(x_dim - 2))
|
|
222
|
-
perm = perm + (x_dim-1, x_dim-2)
|
|
222
|
+
perm = perm + (x_dim - 1, x_dim - 2)
|
|
223
223
|
return perm
|
|
224
224
|
|
|
225
225
|
|
|
@@ -647,20 +647,21 @@ def _fft_rank_offset(norm_shape, rank):
|
|
|
647
647
|
@_primexpr
|
|
648
648
|
def _fft_with_size_back_norm(norm_shape, norm, inverse, rank):
|
|
649
649
|
"""generate reverse term for fft_with_size"""
|
|
650
|
+
norm_ = None
|
|
650
651
|
if inverse is False:
|
|
651
652
|
if norm == "forward":
|
|
652
|
-
norm_ = 1 / _fft_rank_offset(norm_shape, rank)
|
|
653
|
-
|
|
654
|
-
norm_ = 1 * _fft_rank_offset(norm_shape, rank)
|
|
655
|
-
|
|
656
|
-
norm_ = 1
|
|
657
|
-
|
|
653
|
+
norm_ = 1.0 / _fft_rank_offset(norm_shape, rank)
|
|
654
|
+
elif norm == "backward":
|
|
655
|
+
norm_ = 1.0 * _fft_rank_offset(norm_shape, rank)
|
|
656
|
+
elif norm == "ortho":
|
|
657
|
+
norm_ = 1.0
|
|
658
|
+
else:
|
|
658
659
|
if norm == "forward":
|
|
659
|
-
norm_ = 1 * _fft_rank_offset(norm_shape, rank)
|
|
660
|
-
|
|
661
|
-
norm_ = 1 / _fft_rank_offset(norm_shape, rank)
|
|
662
|
-
|
|
663
|
-
norm_ = 1
|
|
660
|
+
norm_ = 1.0 * _fft_rank_offset(norm_shape, rank)
|
|
661
|
+
elif norm == "backward":
|
|
662
|
+
norm_ = 1.0 / _fft_rank_offset(norm_shape, rank)
|
|
663
|
+
elif norm == "ortho":
|
|
664
|
+
norm_ = 1.0
|
|
664
665
|
return norm_
|
|
665
666
|
|
|
666
667
|
|
|
@@ -670,9 +671,9 @@ def _rfft_norm(norm_shape, norm, rank):
|
|
|
670
671
|
norm_ = 1.0
|
|
671
672
|
if norm == "forward":
|
|
672
673
|
norm_ = 1 / _fft_rank_offset(norm_shape, rank)
|
|
673
|
-
|
|
674
|
-
norm_ = 1
|
|
675
|
-
|
|
674
|
+
elif norm == "backward":
|
|
675
|
+
norm_ = 1.0
|
|
676
|
+
elif norm == "ortho":
|
|
676
677
|
norm_ = 1 / np.sqrt(_fft_rank_offset(norm_shape, rank))
|
|
677
678
|
return norm_
|
|
678
679
|
|
|
@@ -358,10 +358,10 @@ def get_bprop_ragged_tensor_to_sparse(self):
|
|
|
358
358
|
split.append(zeros_like(i))
|
|
359
359
|
all_d = (split, ragged_values_grad)
|
|
360
360
|
return all_d
|
|
361
|
-
|
|
361
|
+
split_ = ()
|
|
362
362
|
for i in enumerate(rt_nested_splits):
|
|
363
|
-
|
|
364
|
-
all_d = (
|
|
363
|
+
split_ = split_ + (zeros_like(i),)
|
|
364
|
+
all_d = (split_, ragged_values_grad)
|
|
365
365
|
return all_d
|
|
366
366
|
|
|
367
367
|
return bprop
|