@fugood/llama.node 1.0.2 → 1.0.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/package.json +14 -14
- package/src/llama.cpp/CMakeLists.txt +0 -1
- package/src/llama.cpp/common/arg.cpp +7 -0
- package/src/llama.cpp/common/common.h +1 -0
- package/src/llama.cpp/ggml/CMakeLists.txt +7 -2
- package/src/llama.cpp/ggml/include/ggml.h +91 -10
- package/src/llama.cpp/ggml/src/CMakeLists.txt +0 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +6 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +12 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +726 -155
- package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +5 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +1 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +9 -9
- package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +49 -9
- package/src/llama.cpp/include/llama.h +1 -0
- package/src/llama.cpp/src/llama-arch.cpp +90 -2
- package/src/llama.cpp/src/llama-arch.h +6 -0
- package/src/llama.cpp/src/llama-batch.cpp +27 -1
- package/src/llama.cpp/src/llama-batch.h +8 -1
- package/src/llama.cpp/src/llama-chat.cpp +15 -0
- package/src/llama.cpp/src/llama-chat.h +1 -0
- package/src/llama.cpp/src/llama-graph.cpp +64 -50
- package/src/llama.cpp/src/llama-graph.h +41 -16
- package/src/llama.cpp/src/llama-hparams.cpp +2 -1
- package/src/llama.cpp/src/llama-hparams.h +1 -0
- package/src/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +28 -18
- package/src/llama.cpp/src/llama-kv-cache-unified-iswa.h +4 -2
- package/src/llama.cpp/src/llama-kv-cache-unified.cpp +214 -65
- package/src/llama.cpp/src/llama-kv-cache-unified.h +62 -24
- package/src/llama.cpp/src/llama-kv-cells.h +62 -10
- package/src/llama.cpp/src/llama-memory-hybrid.cpp +9 -4
- package/src/llama.cpp/src/llama-memory-hybrid.h +3 -1
- package/src/llama.cpp/src/llama-memory-recurrent.cpp +15 -2
- package/src/llama.cpp/src/llama-memory.cpp +17 -0
- package/src/llama.cpp/src/llama-memory.h +3 -0
- package/src/llama.cpp/src/llama-model.cpp +1234 -248
- package/src/llama.cpp/src/llama-model.h +2 -0
- package/src/llama.cpp/src/llama-vocab.cpp +8 -1
- package/src/llama.cpp/ggml/include/ggml-kompute.h +0 -50
|
@@ -3,6 +3,7 @@
|
|
|
3
3
|
#include "ggml-cpu.h"
|
|
4
4
|
#include "ggml-impl.h"
|
|
5
5
|
#include "binary-ops.h"
|
|
6
|
+
#include "ggml.h"
|
|
6
7
|
#include "unary-ops.h"
|
|
7
8
|
#include "vec.h"
|
|
8
9
|
|
|
@@ -3613,6 +3614,292 @@ static void ggml_compute_forward_swiglu(
|
|
|
3613
3614
|
}
|
|
3614
3615
|
}
|
|
3615
3616
|
|
|
3617
|
+
// ggml_compute_forward_geglu_erf
|
|
3618
|
+
|
|
3619
|
+
static void ggml_compute_forward_geglu_erf_f32(
|
|
3620
|
+
const ggml_compute_params * params,
|
|
3621
|
+
ggml_tensor * dst) {
|
|
3622
|
+
|
|
3623
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
3624
|
+
const ggml_tensor * src1 = dst->src[1];
|
|
3625
|
+
char * src0_d = (char *) src0->data;
|
|
3626
|
+
char * src1_d = (char *) (src1 ? src1->data : src0->data);
|
|
3627
|
+
const size_t src0_o = src0->nb[1];
|
|
3628
|
+
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
|
|
3629
|
+
|
|
3630
|
+
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
|
3631
|
+
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
|
3632
|
+
|
|
3633
|
+
if (src1) {
|
|
3634
|
+
GGML_ASSERT(ggml_is_contiguous_1(src1));
|
|
3635
|
+
GGML_ASSERT(src0->type == src1->type);
|
|
3636
|
+
}
|
|
3637
|
+
|
|
3638
|
+
const int ith = params->ith;
|
|
3639
|
+
const int nth = params->nth;
|
|
3640
|
+
|
|
3641
|
+
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
|
|
3642
|
+
const int nr = ggml_nrows(src0);
|
|
3643
|
+
|
|
3644
|
+
GGML_ASSERT(dst->ne[0] == nc);
|
|
3645
|
+
GGML_ASSERT(ggml_nrows(dst) == nr);
|
|
3646
|
+
|
|
3647
|
+
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
|
|
3648
|
+
|
|
3649
|
+
// rows per thread
|
|
3650
|
+
const int dr = (nr + nth - 1)/nth;
|
|
3651
|
+
|
|
3652
|
+
// row range for this thread
|
|
3653
|
+
const int ir0 = dr*ith;
|
|
3654
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
|
3655
|
+
|
|
3656
|
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
|
3657
|
+
float * src0_p = (float *) (src0_d + i1*src0_o);
|
|
3658
|
+
float * src1_p = (float *) (src1_d + i1*src1_o);
|
|
3659
|
+
|
|
3660
|
+
if (!src1) {
|
|
3661
|
+
src0_p += swapped ? nc : 0;
|
|
3662
|
+
src1_p += swapped ? 0 : nc;
|
|
3663
|
+
}
|
|
3664
|
+
|
|
3665
|
+
ggml_vec_geglu_erf_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
|
|
3666
|
+
|
|
3667
|
+
#ifndef NDEBUG
|
|
3668
|
+
for (int k = 0; k < nc; k++) {
|
|
3669
|
+
const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
|
3670
|
+
GGML_UNUSED(x);
|
|
3671
|
+
assert(!isnan(x));
|
|
3672
|
+
assert(!isinf(x));
|
|
3673
|
+
}
|
|
3674
|
+
#endif
|
|
3675
|
+
}
|
|
3676
|
+
}
|
|
3677
|
+
|
|
3678
|
+
static void ggml_compute_forward_geglu_erf_f16(
|
|
3679
|
+
const ggml_compute_params * params,
|
|
3680
|
+
ggml_tensor * dst) {
|
|
3681
|
+
|
|
3682
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
3683
|
+
const ggml_tensor * src1 = dst->src[1];
|
|
3684
|
+
char * src0_d = (char *) src0->data;
|
|
3685
|
+
char * src1_d = (char *) (src1 ? src1->data : src0->data);
|
|
3686
|
+
const size_t src0_o = src0->nb[1];
|
|
3687
|
+
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
|
|
3688
|
+
|
|
3689
|
+
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
|
3690
|
+
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
|
3691
|
+
|
|
3692
|
+
if (src1) {
|
|
3693
|
+
GGML_ASSERT(ggml_is_contiguous_1(src1));
|
|
3694
|
+
GGML_ASSERT(src0->type == src1->type);
|
|
3695
|
+
}
|
|
3696
|
+
|
|
3697
|
+
const int ith = params->ith;
|
|
3698
|
+
const int nth = params->nth;
|
|
3699
|
+
|
|
3700
|
+
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
|
|
3701
|
+
const int nr = ggml_nrows(src0);
|
|
3702
|
+
|
|
3703
|
+
GGML_ASSERT(dst->ne[0] == nc);
|
|
3704
|
+
GGML_ASSERT(ggml_nrows(dst) == nr);
|
|
3705
|
+
|
|
3706
|
+
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
|
|
3707
|
+
|
|
3708
|
+
// rows per thread
|
|
3709
|
+
const int dr = (nr + nth - 1)/nth;
|
|
3710
|
+
|
|
3711
|
+
// row range for this thread
|
|
3712
|
+
const int ir0 = dr*ith;
|
|
3713
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
|
3714
|
+
|
|
3715
|
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
|
3716
|
+
ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
|
|
3717
|
+
ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
|
|
3718
|
+
|
|
3719
|
+
if (!src1) {
|
|
3720
|
+
src0_p += swapped ? nc : 0;
|
|
3721
|
+
src1_p += swapped ? 0 : nc;
|
|
3722
|
+
}
|
|
3723
|
+
|
|
3724
|
+
ggml_vec_geglu_erf_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
|
|
3725
|
+
|
|
3726
|
+
#ifndef NDEBUG
|
|
3727
|
+
for (int k = 0; k < nc; k++) {
|
|
3728
|
+
const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
|
3729
|
+
const float v = GGML_FP16_TO_FP32(x);
|
|
3730
|
+
GGML_UNUSED(v);
|
|
3731
|
+
assert(!isnan(v));
|
|
3732
|
+
assert(!isinf(v));
|
|
3733
|
+
}
|
|
3734
|
+
#endif
|
|
3735
|
+
}
|
|
3736
|
+
}
|
|
3737
|
+
|
|
3738
|
+
static void ggml_compute_forward_geglu_erf(
|
|
3739
|
+
const ggml_compute_params * params,
|
|
3740
|
+
ggml_tensor * dst) {
|
|
3741
|
+
|
|
3742
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
3743
|
+
|
|
3744
|
+
switch (src0->type) {
|
|
3745
|
+
case GGML_TYPE_F32:
|
|
3746
|
+
{
|
|
3747
|
+
ggml_compute_forward_geglu_erf_f32(params, dst);
|
|
3748
|
+
} break;
|
|
3749
|
+
case GGML_TYPE_F16:
|
|
3750
|
+
{
|
|
3751
|
+
ggml_compute_forward_geglu_erf_f16(params, dst);
|
|
3752
|
+
} break;
|
|
3753
|
+
default:
|
|
3754
|
+
{
|
|
3755
|
+
GGML_ABORT("fatal error");
|
|
3756
|
+
}
|
|
3757
|
+
}
|
|
3758
|
+
}
|
|
3759
|
+
|
|
3760
|
+
// ggml_compute_forward_geglu_quick
|
|
3761
|
+
|
|
3762
|
+
static void ggml_compute_forward_geglu_quick_f32(
|
|
3763
|
+
const ggml_compute_params * params,
|
|
3764
|
+
ggml_tensor * dst) {
|
|
3765
|
+
|
|
3766
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
3767
|
+
const ggml_tensor * src1 = dst->src[1];
|
|
3768
|
+
char * src0_d = (char *) src0->data;
|
|
3769
|
+
char * src1_d = (char *) (src1 ? src1->data : src0->data);
|
|
3770
|
+
const size_t src0_o = src0->nb[1];
|
|
3771
|
+
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
|
|
3772
|
+
|
|
3773
|
+
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
|
3774
|
+
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
|
3775
|
+
|
|
3776
|
+
if (src1) {
|
|
3777
|
+
GGML_ASSERT(ggml_is_contiguous_1(src1));
|
|
3778
|
+
GGML_ASSERT(src0->type == src1->type);
|
|
3779
|
+
}
|
|
3780
|
+
|
|
3781
|
+
const int ith = params->ith;
|
|
3782
|
+
const int nth = params->nth;
|
|
3783
|
+
|
|
3784
|
+
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
|
|
3785
|
+
const int nr = ggml_nrows(src0);
|
|
3786
|
+
|
|
3787
|
+
GGML_ASSERT(dst->ne[0] == nc);
|
|
3788
|
+
GGML_ASSERT(ggml_nrows(dst) == nr);
|
|
3789
|
+
|
|
3790
|
+
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
|
|
3791
|
+
|
|
3792
|
+
// rows per thread
|
|
3793
|
+
const int dr = (nr + nth - 1)/nth;
|
|
3794
|
+
|
|
3795
|
+
// row range for this thread
|
|
3796
|
+
const int ir0 = dr*ith;
|
|
3797
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
|
3798
|
+
|
|
3799
|
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
|
3800
|
+
float * src0_p = (float *) (src0_d + i1*src0_o);
|
|
3801
|
+
float * src1_p = (float *) (src1_d + i1*src1_o);
|
|
3802
|
+
|
|
3803
|
+
if (!src1) {
|
|
3804
|
+
src0_p += swapped ? nc : 0;
|
|
3805
|
+
src1_p += swapped ? 0 : nc;
|
|
3806
|
+
}
|
|
3807
|
+
|
|
3808
|
+
ggml_vec_geglu_quick_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
|
|
3809
|
+
|
|
3810
|
+
#ifndef NDEBUG
|
|
3811
|
+
for (int k = 0; k < nc; k++) {
|
|
3812
|
+
const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
|
3813
|
+
GGML_UNUSED(x);
|
|
3814
|
+
assert(!isnan(x));
|
|
3815
|
+
assert(!isinf(x));
|
|
3816
|
+
}
|
|
3817
|
+
#endif
|
|
3818
|
+
}
|
|
3819
|
+
}
|
|
3820
|
+
|
|
3821
|
+
static void ggml_compute_forward_geglu_quick_f16(
|
|
3822
|
+
const ggml_compute_params * params,
|
|
3823
|
+
ggml_tensor * dst) {
|
|
3824
|
+
|
|
3825
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
3826
|
+
const ggml_tensor * src1 = dst->src[1];
|
|
3827
|
+
char * src0_d = (char *) src0->data;
|
|
3828
|
+
char * src1_d = (char *) (src1 ? src1->data : src0->data);
|
|
3829
|
+
const size_t src0_o = src0->nb[1];
|
|
3830
|
+
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
|
|
3831
|
+
|
|
3832
|
+
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
|
3833
|
+
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
|
3834
|
+
|
|
3835
|
+
if (src1) {
|
|
3836
|
+
GGML_ASSERT(ggml_is_contiguous_1(src1));
|
|
3837
|
+
GGML_ASSERT(src0->type == src1->type);
|
|
3838
|
+
}
|
|
3839
|
+
|
|
3840
|
+
const int ith = params->ith;
|
|
3841
|
+
const int nth = params->nth;
|
|
3842
|
+
|
|
3843
|
+
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
|
|
3844
|
+
const int nr = ggml_nrows(src0);
|
|
3845
|
+
|
|
3846
|
+
GGML_ASSERT(dst->ne[0] == nc);
|
|
3847
|
+
GGML_ASSERT(ggml_nrows(dst) == nr);
|
|
3848
|
+
|
|
3849
|
+
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
|
|
3850
|
+
|
|
3851
|
+
// rows per thread
|
|
3852
|
+
const int dr = (nr + nth - 1)/nth;
|
|
3853
|
+
|
|
3854
|
+
// row range for this thread
|
|
3855
|
+
const int ir0 = dr*ith;
|
|
3856
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
|
3857
|
+
|
|
3858
|
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
|
3859
|
+
ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
|
|
3860
|
+
ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
|
|
3861
|
+
|
|
3862
|
+
if (!src1) {
|
|
3863
|
+
src0_p += swapped ? nc : 0;
|
|
3864
|
+
src1_p += swapped ? 0 : nc;
|
|
3865
|
+
}
|
|
3866
|
+
|
|
3867
|
+
ggml_vec_geglu_quick_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
|
|
3868
|
+
|
|
3869
|
+
#ifndef NDEBUG
|
|
3870
|
+
for (int k = 0; k < nc; k++) {
|
|
3871
|
+
const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
|
3872
|
+
const float v = GGML_FP16_TO_FP32(x);
|
|
3873
|
+
GGML_UNUSED(v);
|
|
3874
|
+
assert(!isnan(v));
|
|
3875
|
+
assert(!isinf(v));
|
|
3876
|
+
}
|
|
3877
|
+
#endif
|
|
3878
|
+
}
|
|
3879
|
+
}
|
|
3880
|
+
|
|
3881
|
+
static void ggml_compute_forward_geglu_quick(
|
|
3882
|
+
const ggml_compute_params * params,
|
|
3883
|
+
ggml_tensor * dst) {
|
|
3884
|
+
|
|
3885
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
3886
|
+
|
|
3887
|
+
switch (src0->type) {
|
|
3888
|
+
case GGML_TYPE_F32:
|
|
3889
|
+
{
|
|
3890
|
+
ggml_compute_forward_geglu_quick_f32(params, dst);
|
|
3891
|
+
} break;
|
|
3892
|
+
case GGML_TYPE_F16:
|
|
3893
|
+
{
|
|
3894
|
+
ggml_compute_forward_geglu_quick_f16(params, dst);
|
|
3895
|
+
} break;
|
|
3896
|
+
default:
|
|
3897
|
+
{
|
|
3898
|
+
GGML_ABORT("fatal error");
|
|
3899
|
+
}
|
|
3900
|
+
}
|
|
3901
|
+
}
|
|
3902
|
+
|
|
3616
3903
|
// ggml_compute_forward_norm
|
|
3617
3904
|
|
|
3618
3905
|
static void ggml_compute_forward_norm_f32(
|
|
@@ -5231,14 +5518,17 @@ static void ggml_compute_forward_soft_max_f32(
|
|
|
5231
5518
|
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
|
|
5232
5519
|
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
|
|
5233
5520
|
|
|
5234
|
-
// TODO: handle transposed/permuted matrices
|
|
5235
|
-
|
|
5236
5521
|
const int ith = params->ith;
|
|
5237
5522
|
const int nth = params->nth;
|
|
5238
5523
|
|
|
5239
5524
|
GGML_TENSOR_UNARY_OP_LOCALS
|
|
5240
5525
|
|
|
5241
|
-
|
|
5526
|
+
const int64_t nb11 = src1 ? src1->nb[1] : 1;
|
|
5527
|
+
const int64_t nb12 = src1 ? src1->nb[2] : 1;
|
|
5528
|
+
const int64_t nb13 = src1 ? src1->nb[3] : 1;
|
|
5529
|
+
|
|
5530
|
+
const int64_t ne12 = src1 ? src1->ne[2] : 1;
|
|
5531
|
+
const int64_t ne13 = src1 ? src1->ne[3] : 1;
|
|
5242
5532
|
|
|
5243
5533
|
// TODO: is this supposed to be ceil instead of floor?
|
|
5244
5534
|
// https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370
|
|
@@ -5248,68 +5538,66 @@ static void ggml_compute_forward_soft_max_f32(
|
|
|
5248
5538
|
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
|
5249
5539
|
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
|
5250
5540
|
|
|
5251
|
-
|
|
5252
|
-
const int nr = ggml_nrows(src0);
|
|
5253
|
-
|
|
5254
|
-
// rows per thread
|
|
5255
|
-
const int dr = (nr + nth - 1)/nth;
|
|
5256
|
-
|
|
5257
|
-
// row range for this thread
|
|
5258
|
-
const int ir0 = dr*ith;
|
|
5259
|
-
const int ir1 = MIN(ir0 + dr, nr);
|
|
5260
|
-
|
|
5261
|
-
float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
|
|
5541
|
+
float * wp = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
|
|
5262
5542
|
|
|
5263
5543
|
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
|
|
5264
5544
|
|
|
5265
|
-
for (
|
|
5266
|
-
|
|
5267
|
-
|
|
5268
|
-
|
|
5269
|
-
|
|
5270
|
-
|
|
5271
|
-
|
|
5272
|
-
|
|
5273
|
-
|
|
5274
|
-
|
|
5275
|
-
|
|
5276
|
-
|
|
5277
|
-
|
|
5278
|
-
|
|
5279
|
-
|
|
5280
|
-
|
|
5281
|
-
|
|
5282
|
-
|
|
5283
|
-
|
|
5284
|
-
|
|
5285
|
-
|
|
5286
|
-
|
|
5545
|
+
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
|
5546
|
+
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
|
5547
|
+
for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
|
|
5548
|
+
const int64_t i11 = i01;
|
|
5549
|
+
const int64_t i12 = i02%ne12;
|
|
5550
|
+
const int64_t i13 = i03%ne13;
|
|
5551
|
+
|
|
5552
|
+
// ALiBi
|
|
5553
|
+
const uint32_t h = i02; // head
|
|
5554
|
+
const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
|
|
5555
|
+
|
|
5556
|
+
float * sp = (float *)((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
|
5557
|
+
float * dp = (float *)((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
|
|
5558
|
+
|
|
5559
|
+
// broadcast the mask across rows
|
|
5560
|
+
ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL;
|
|
5561
|
+
float * mp_f32 = src1 ? (float *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL;
|
|
5562
|
+
|
|
5563
|
+
ggml_vec_cpy_f32 (ne00, wp, sp);
|
|
5564
|
+
ggml_vec_scale_f32(ne00, wp, scale);
|
|
5565
|
+
if (mp_f32) {
|
|
5566
|
+
if (use_f16) {
|
|
5567
|
+
for (int i = 0; i < ne00; ++i) {
|
|
5568
|
+
wp[i] += slope*GGML_CPU_FP16_TO_FP32(mp_f16[i]);
|
|
5569
|
+
}
|
|
5570
|
+
} else {
|
|
5571
|
+
for (int i = 0; i < ne00; ++i) {
|
|
5572
|
+
wp[i] += slope*mp_f32[i];
|
|
5573
|
+
}
|
|
5574
|
+
}
|
|
5287
5575
|
}
|
|
5288
|
-
}
|
|
5289
|
-
}
|
|
5290
5576
|
|
|
5291
5577
|
#ifndef NDEBUG
|
|
5292
|
-
|
|
5293
|
-
|
|
5294
|
-
|
|
5295
|
-
|
|
5578
|
+
for (int i = 0; i < ne00; ++i) {
|
|
5579
|
+
//printf("p[%d] = %f\n", i, p[i]);
|
|
5580
|
+
assert(!isnan(wp[i]));
|
|
5581
|
+
}
|
|
5296
5582
|
#endif
|
|
5297
5583
|
|
|
5298
|
-
|
|
5299
|
-
|
|
5584
|
+
float max = -INFINITY;
|
|
5585
|
+
ggml_vec_max_f32(ne00, &max, wp);
|
|
5300
5586
|
|
|
5301
|
-
|
|
5302
|
-
|
|
5587
|
+
ggml_float sum = ggml_vec_soft_max_f32(ne00, dp, wp, max);
|
|
5588
|
+
assert(sum > 0.0);
|
|
5303
5589
|
|
|
5304
|
-
|
|
5305
|
-
|
|
5590
|
+
sum = 1.0/sum;
|
|
5591
|
+
ggml_vec_scale_f32(ne00, dp, sum);
|
|
5306
5592
|
|
|
5307
5593
|
#ifndef NDEBUG
|
|
5308
|
-
|
|
5309
|
-
|
|
5310
|
-
|
|
5311
|
-
|
|
5594
|
+
for (int i = 0; i < ne00; ++i) {
|
|
5595
|
+
assert(!isnan(dp[i]));
|
|
5596
|
+
assert(!isinf(dp[i]));
|
|
5597
|
+
}
|
|
5312
5598
|
#endif
|
|
5599
|
+
}
|
|
5600
|
+
}
|
|
5313
5601
|
}
|
|
5314
5602
|
}
|
|
5315
5603
|
|
|
@@ -6545,6 +6833,186 @@ void ggml_compute_forward_im2col_back_f32(
|
|
|
6545
6833
|
}
|
|
6546
6834
|
}
|
|
6547
6835
|
|
|
6836
|
+
static void ggml_call_mul_mat(ggml_type type, const ggml_compute_params * params, int64_t m, int64_t n, int64_t k,
|
|
6837
|
+
void * a, void * b, float * c) {
|
|
6838
|
+
const ggml_type_traits * traits = ggml_get_type_traits(type);
|
|
6839
|
+
struct ggml_tensor src1 = {};
|
|
6840
|
+
src1.type = type;
|
|
6841
|
+
src1.ne[0] = k;
|
|
6842
|
+
src1.ne[1] = m;
|
|
6843
|
+
src1.ne[2] = 1;
|
|
6844
|
+
src1.ne[3] = 1;
|
|
6845
|
+
src1.nb[0] = traits->type_size;
|
|
6846
|
+
src1.nb[1] = k * traits->type_size;
|
|
6847
|
+
src1.nb[2] = src1.nb[1];
|
|
6848
|
+
src1.nb[3] = src1.nb[2];
|
|
6849
|
+
src1.data = a;
|
|
6850
|
+
|
|
6851
|
+
struct ggml_tensor src0 = {};
|
|
6852
|
+
src0.type = type;
|
|
6853
|
+
src0.ne[0] = k;
|
|
6854
|
+
src0.ne[1] = n;
|
|
6855
|
+
src0.ne[2] = 1;
|
|
6856
|
+
src0.ne[3] = 1;
|
|
6857
|
+
src0.nb[0] = traits->type_size;
|
|
6858
|
+
src0.nb[1] = k * traits->type_size;
|
|
6859
|
+
src0.nb[2] = src0.nb[1];
|
|
6860
|
+
src0.nb[3] = src0.nb[2];
|
|
6861
|
+
src0.data = b;
|
|
6862
|
+
|
|
6863
|
+
struct ggml_tensor dst = {};
|
|
6864
|
+
dst.ne[0] = n;
|
|
6865
|
+
dst.ne[1] = m;
|
|
6866
|
+
dst.ne[2] = 1;
|
|
6867
|
+
dst.ne[3] = 1;
|
|
6868
|
+
dst.nb[0] = sizeof(float);
|
|
6869
|
+
dst.nb[1] = n * sizeof(float);
|
|
6870
|
+
dst.nb[2] = dst.nb[1];
|
|
6871
|
+
dst.nb[3] = dst.nb[2];
|
|
6872
|
+
dst.data = c;
|
|
6873
|
+
dst.src[0] = &src0;
|
|
6874
|
+
dst.src[1] = &src1;
|
|
6875
|
+
|
|
6876
|
+
ggml_compute_forward_mul_mat(params, &dst);
|
|
6877
|
+
}
|
|
6878
|
+
|
|
6879
|
+
// ggml_compute_forward_conv_2d
|
|
6880
|
+
|
|
6881
|
+
static void ggml_compute_forward_conv_2d_impl(const ggml_compute_params * params,
|
|
6882
|
+
const ggml_tensor * kernel, // [KW, KH, IC, OC]
|
|
6883
|
+
const ggml_tensor * src, // [W, H, C, N]
|
|
6884
|
+
ggml_tensor * dst, // [OW, OH, OC, N]
|
|
6885
|
+
ggml_type kernel_type) {
|
|
6886
|
+
|
|
6887
|
+
GGML_ASSERT(ggml_is_contiguous(kernel));
|
|
6888
|
+
GGML_ASSERT(kernel_type == GGML_TYPE_F16 || kernel_type == GGML_TYPE_F32);
|
|
6889
|
+
GGML_ASSERT(kernel->type == kernel_type);
|
|
6890
|
+
|
|
6891
|
+
const ggml_type_traits * traits = ggml_get_type_traits(kernel_type);
|
|
6892
|
+
|
|
6893
|
+
const int32_t stride_x = dst->op_params[0];
|
|
6894
|
+
const int32_t stride_y = dst->op_params[1];
|
|
6895
|
+
const int32_t pad_x = dst->op_params[2];
|
|
6896
|
+
const int32_t pad_y = dst->op_params[3];
|
|
6897
|
+
const int32_t dilation_x = dst->op_params[4];
|
|
6898
|
+
const int32_t dilation_y = dst->op_params[5];
|
|
6899
|
+
|
|
6900
|
+
const int64_t c_in = src->ne[2];
|
|
6901
|
+
const int64_t c_out = kernel->ne[3];
|
|
6902
|
+
GGML_ASSERT(c_in == kernel->ne[2]);
|
|
6903
|
+
|
|
6904
|
+
const int64_t src_w = src->ne[0];
|
|
6905
|
+
const int64_t src_h = src->ne[1];
|
|
6906
|
+
const int64_t knl_w = kernel->ne[0];
|
|
6907
|
+
const int64_t knl_h = kernel->ne[1];
|
|
6908
|
+
const int64_t dst_w = dst->ne[0];
|
|
6909
|
+
const int64_t dst_h = dst->ne[1];
|
|
6910
|
+
|
|
6911
|
+
const float * src_data = (float *) src->data;
|
|
6912
|
+
void * knl_data = kernel->data;
|
|
6913
|
+
float * dst_data = (float *) dst->data;
|
|
6914
|
+
|
|
6915
|
+
const int64_t knl_n = knl_w * knl_h * c_in;
|
|
6916
|
+
const int64_t patch_total = dst->ne[3] * dst_w * dst_h;
|
|
6917
|
+
|
|
6918
|
+
const int64_t space_per_patch = knl_n * traits->type_size + c_out * sizeof(float);
|
|
6919
|
+
const int64_t batch_size = params->wsize / space_per_patch;
|
|
6920
|
+
const int64_t patches_per_batch = batch_size > 8 ? (batch_size / 8) * 8 : batch_size;
|
|
6921
|
+
const int64_t batch_n = (patch_total + patches_per_batch - 1) / patches_per_batch;
|
|
6922
|
+
|
|
6923
|
+
GGML_ASSERT(patches_per_batch > 0 && batch_size >= 1);
|
|
6924
|
+
|
|
6925
|
+
void * tmp = params->wdata;
|
|
6926
|
+
|
|
6927
|
+
for (int64_t batch_i = 0; batch_i < batch_n; ++batch_i) {
|
|
6928
|
+
|
|
6929
|
+
const int64_t patch_start_batch = batch_i * patches_per_batch;
|
|
6930
|
+
const int64_t patch_end_batch = std::min(patch_start_batch + patches_per_batch,
|
|
6931
|
+
patch_total);
|
|
6932
|
+
const int64_t patch_n = patch_end_batch - patch_start_batch;
|
|
6933
|
+
|
|
6934
|
+
const int64_t patch_per_thread = (patch_n + params->nth - 1) / params->nth;
|
|
6935
|
+
const int64_t patch_start = patch_start_batch + params->ith * patch_per_thread;
|
|
6936
|
+
const int64_t patch_end = std::min(patch_start + patch_per_thread, patch_end_batch);
|
|
6937
|
+
|
|
6938
|
+
//im2col for a patch
|
|
6939
|
+
for (int64_t p = patch_start; p < patch_end; ++p) {
|
|
6940
|
+
const int64_t batch_n = p / (dst_w * dst_h);
|
|
6941
|
+
const int64_t src_x = (p / dst_w) % dst_h;
|
|
6942
|
+
const int64_t src_y = p % dst_w;
|
|
6943
|
+
|
|
6944
|
+
const float * src_base = (const float *)((const char *)src_data + batch_n * src->nb[3]);
|
|
6945
|
+
char * dst_row = (char *) tmp + (p % patches_per_batch) * knl_n * traits->type_size;
|
|
6946
|
+
|
|
6947
|
+
for (int64_t ic = 0; ic < c_in; ++ic) {
|
|
6948
|
+
for (int64_t ky = 0; ky < knl_h; ++ky) {
|
|
6949
|
+
for (int64_t kx = 0; kx < knl_w; ++kx) {
|
|
6950
|
+
const int64_t sy = src_x * stride_y + ky * dilation_y - pad_y;
|
|
6951
|
+
const int64_t sx = src_y * stride_x + kx * dilation_x - pad_x;
|
|
6952
|
+
|
|
6953
|
+
int64_t dst_idx = ic * (knl_h * knl_w) + ky * knl_w + kx;
|
|
6954
|
+
|
|
6955
|
+
float src_val;
|
|
6956
|
+
if (sy < 0 || sy >= src_h || sx < 0 || sx >= src_w) {
|
|
6957
|
+
src_val = 0.0f;
|
|
6958
|
+
} else {
|
|
6959
|
+
const float * src_ptr = (const float *)((const char *)src_base + sx * src->nb[0] + sy * src->nb[1] + ic * src->nb[2]);
|
|
6960
|
+
src_val = *src_ptr;
|
|
6961
|
+
}
|
|
6962
|
+
|
|
6963
|
+
char * element_ptr = dst_row + dst_idx * traits->type_size;
|
|
6964
|
+
if (kernel_type == GGML_TYPE_F32) {
|
|
6965
|
+
*(float *) element_ptr = src_val;
|
|
6966
|
+
} else if (kernel_type == GGML_TYPE_F16) {
|
|
6967
|
+
*(ggml_fp16_t *) element_ptr = GGML_CPU_FP32_TO_FP16(src_val);
|
|
6968
|
+
}
|
|
6969
|
+
}
|
|
6970
|
+
}
|
|
6971
|
+
}
|
|
6972
|
+
} // patches handled by this thread
|
|
6973
|
+
|
|
6974
|
+
ggml_barrier(params->threadpool);
|
|
6975
|
+
|
|
6976
|
+
float * gemm_output = (float *) ((char *) tmp + patches_per_batch * knl_n * traits->type_size);
|
|
6977
|
+
|
|
6978
|
+
GGML_ASSERT(gemm_output + patch_n * c_out <= (float*)tmp + params->wsize);
|
|
6979
|
+
|
|
6980
|
+
// GEMM: patches[patch_n, knl_n] × kernel[knl_n, c_out] = output[patch_n, c_out]
|
|
6981
|
+
ggml_call_mul_mat(kernel_type, params, patch_n, c_out, knl_n, tmp, knl_data, gemm_output);
|
|
6982
|
+
|
|
6983
|
+
ggml_barrier(params->threadpool);
|
|
6984
|
+
|
|
6985
|
+
|
|
6986
|
+
//permute back [OC, N, OH, OW] to [N, OC, OH, OW]
|
|
6987
|
+
const int64_t permute_per_thread = (patch_n + params->nth - 1) / params->nth;
|
|
6988
|
+
const int64_t permute_start = params->ith * permute_per_thread;
|
|
6989
|
+
const int64_t permute_end = std::min(permute_start + permute_per_thread, patch_n);
|
|
6990
|
+
|
|
6991
|
+
for (int64_t i = permute_start; i < permute_end; ++i) {
|
|
6992
|
+
const int64_t p = patch_start_batch + i;
|
|
6993
|
+
const int64_t batch_n = p / (dst_w * dst_h);
|
|
6994
|
+
const int64_t dst_y = (p / dst_w) % dst_h;
|
|
6995
|
+
const int64_t dst_x = p % dst_w;
|
|
6996
|
+
|
|
6997
|
+
for (int64_t oc = 0; oc < c_out; ++oc) {
|
|
6998
|
+
const float value = gemm_output[i * c_out + oc];
|
|
6999
|
+
float * dst_ptr = (float *)((char *)dst_data + dst_x * dst->nb[0] + dst_y * dst->nb[1] + oc * dst->nb[2] + batch_n * dst->nb[3]);
|
|
7000
|
+
*dst_ptr = value;
|
|
7001
|
+
}
|
|
7002
|
+
}
|
|
7003
|
+
}
|
|
7004
|
+
}
|
|
7005
|
+
|
|
7006
|
+
void ggml_compute_forward_conv_2d(
|
|
7007
|
+
const ggml_compute_params * params,
|
|
7008
|
+
ggml_tensor * dst) {
|
|
7009
|
+
|
|
7010
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
7011
|
+
const ggml_tensor * src1 = dst->src[1];
|
|
7012
|
+
|
|
7013
|
+
ggml_compute_forward_conv_2d_impl(params, src0, src1, dst, src0->type);
|
|
7014
|
+
}
|
|
7015
|
+
|
|
6548
7016
|
// ggml_compute_forward_conv_transpose_2d
|
|
6549
7017
|
|
|
6550
7018
|
void ggml_compute_forward_conv_transpose_2d(
|
|
@@ -7095,12 +7563,13 @@ static void ggml_compute_forward_upscale_f32(
|
|
|
7095
7563
|
|
|
7096
7564
|
GGML_TENSOR_UNARY_OP_LOCALS
|
|
7097
7565
|
|
|
7098
|
-
|
|
7099
|
-
|
|
7100
|
-
|
|
7101
|
-
|
|
7566
|
+
float sf0 = (float)ne0/src0->ne[0];
|
|
7567
|
+
float sf1 = (float)ne1/src0->ne[1];
|
|
7568
|
+
float sf2 = (float)ne2/src0->ne[2];
|
|
7569
|
+
float sf3 = (float)ne3/src0->ne[3];
|
|
7102
7570
|
|
|
7103
|
-
const
|
|
7571
|
+
const int32_t mode_flags = ggml_get_op_params_i32(dst, 0);
|
|
7572
|
+
const ggml_scale_mode mode = (ggml_scale_mode) (mode_flags & 0xFF);
|
|
7104
7573
|
|
|
7105
7574
|
if (mode == GGML_SCALE_MODE_NEAREST) {
|
|
7106
7575
|
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
|
@@ -7121,8 +7590,12 @@ static void ggml_compute_forward_upscale_f32(
|
|
|
7121
7590
|
}
|
|
7122
7591
|
}
|
|
7123
7592
|
} else if (mode == GGML_SCALE_MODE_BILINEAR) {
|
|
7124
|
-
|
|
7125
|
-
|
|
7593
|
+
float pixel_offset = 0.5f;
|
|
7594
|
+
if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {
|
|
7595
|
+
pixel_offset = 0.0f;
|
|
7596
|
+
sf0 = (float)(ne0 - 1) / (src0->ne[0] - 1);
|
|
7597
|
+
sf1 = (float)(ne1 - 1) / (src0->ne[1] - 1);
|
|
7598
|
+
}
|
|
7126
7599
|
|
|
7127
7600
|
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
|
7128
7601
|
const int64_t i03 = i3 / sf3;
|
|
@@ -7580,7 +8053,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
|
7580
8053
|
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
|
7581
8054
|
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
|
7582
8055
|
|
|
7583
|
-
ggml_type
|
|
8056
|
+
ggml_type const k_vec_dot_type = ggml_get_type_traits_cpu(k->type)->vec_dot_type;
|
|
7584
8057
|
ggml_from_float_t const q_to_vec_dot = ggml_get_type_traits_cpu(k_vec_dot_type)->from_float;
|
|
7585
8058
|
ggml_vec_dot_t const kq_vec_dot = ggml_get_type_traits_cpu(k->type)->vec_dot;
|
|
7586
8059
|
ggml_to_float_t const v_to_float = ggml_get_type_traits(v->type)->to_float;
|
|
@@ -7612,7 +8085,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
|
7612
8085
|
memset(VKQ32, 0, DV*sizeof(float));
|
|
7613
8086
|
}
|
|
7614
8087
|
|
|
7615
|
-
const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
|
|
8088
|
+
const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1] + (iq2%mask->ne[2])*mask->nb[2] + (iq3%mask->ne[3])*mask->nb[3]) : NULL;
|
|
7616
8089
|
|
|
7617
8090
|
// k indices
|
|
7618
8091
|
const int ik3 = iq3 / rk3;
|
|
@@ -8150,120 +8623,210 @@ void ggml_compute_forward_ssm_conv(
|
|
|
8150
8623
|
static void ggml_compute_forward_ssm_scan_f32(
|
|
8151
8624
|
const ggml_compute_params * params,
|
|
8152
8625
|
ggml_tensor * dst) {
|
|
8153
|
-
const ggml_tensor * src0 = dst->src[0]; // s
|
|
8154
|
-
const ggml_tensor * src1 = dst->src[1]; // x
|
|
8155
|
-
const ggml_tensor * src2 = dst->src[2]; // dt
|
|
8156
|
-
const ggml_tensor * src3 = dst->src[3]; // A
|
|
8157
|
-
const ggml_tensor * src4 = dst->src[4]; // B
|
|
8158
|
-
const ggml_tensor * src5 = dst->src[5]; // C
|
|
8626
|
+
const ggml_tensor * src0 = dst->src[0]; // s {d_state, dim, n_head, n_seqs+}
|
|
8627
|
+
const ggml_tensor * src1 = dst->src[1]; // x {dim, n_head, n_seq_tokens, n_seqs}
|
|
8628
|
+
const ggml_tensor * src2 = dst->src[2]; // dt {n_head, n_seq_tokens, n_seqs}
|
|
8629
|
+
const ggml_tensor * src3 = dst->src[3]; // A {d_state, n_head} or {1, n_head}
|
|
8630
|
+
const ggml_tensor * src4 = dst->src[4]; // B {d_state, n_group, n_seq_tokens, n_seqs}
|
|
8631
|
+
const ggml_tensor * src5 = dst->src[5]; // C {d_state, n_group, n_seq_tokens, n_seqs}
|
|
8632
|
+
const ggml_tensor * src6 = dst->src[6]; // ids {n_seqs}
|
|
8159
8633
|
|
|
8160
8634
|
const int ith = params->ith;
|
|
8161
8635
|
const int nth = params->nth;
|
|
8162
8636
|
|
|
8163
|
-
const int64_t nc
|
|
8164
|
-
const int64_t nr
|
|
8165
|
-
const int64_t
|
|
8166
|
-
const int64_t
|
|
8637
|
+
const int64_t nc = src0->ne[0]; // d_state
|
|
8638
|
+
const int64_t nr = src0->ne[1]; // dim
|
|
8639
|
+
const int64_t nh = src1->ne[1]; // n_head
|
|
8640
|
+
const int64_t ng = src4->ne[1];
|
|
8641
|
+
const int64_t nt = src1->ne[2]; // number of tokens per sequence
|
|
8642
|
+
const int64_t ns = src1->ne[3]; // number of sequences in the batch
|
|
8643
|
+
|
|
8644
|
+
// can't use ggml_nbytes because src1 is not necessarily contiguous
|
|
8645
|
+
const int64_t s_off = ggml_nelements(src1) * ggml_element_size(src1);
|
|
8167
8646
|
|
|
8168
|
-
GGML_ASSERT(ggml_nelements(src1) +
|
|
8647
|
+
GGML_ASSERT(ggml_nelements(src1) + nc*nr*nh*ns == ggml_nelements(dst));
|
|
8169
8648
|
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
|
8170
8649
|
GGML_ASSERT(src1->nb[0] == sizeof(float));
|
|
8171
8650
|
GGML_ASSERT(src2->nb[0] == sizeof(float));
|
|
8172
8651
|
GGML_ASSERT(src3->nb[0] == sizeof(float));
|
|
8173
8652
|
GGML_ASSERT(src4->nb[0] == sizeof(float));
|
|
8174
8653
|
GGML_ASSERT(src5->nb[0] == sizeof(float));
|
|
8175
|
-
|
|
8176
|
-
|
|
8177
|
-
|
|
8178
|
-
GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float));
|
|
8179
|
-
// required to get correct offset for state destination (i.e. src1->nb[3])
|
|
8180
|
-
GGML_ASSERT(src1->nb[3] == src1->ne[0]*src1->ne[1]*src1->ne[2]*sizeof(float));
|
|
8654
|
+
GGML_ASSERT(src6->nb[0] == sizeof(int32_t));
|
|
8655
|
+
// allows optimizing the modulo since n_group should be a power of 2
|
|
8656
|
+
GGML_ASSERT((ng & -ng) == ng);
|
|
8181
8657
|
|
|
8182
|
-
//
|
|
8183
|
-
const int
|
|
8658
|
+
// heads per thread
|
|
8659
|
+
const int dh = (nh + nth - 1)/nth;
|
|
8184
8660
|
|
|
8185
|
-
//
|
|
8186
|
-
const int
|
|
8187
|
-
const int
|
|
8188
|
-
|
|
8661
|
+
// head range for this thread
|
|
8662
|
+
const int ih0 = dh*ith;
|
|
8663
|
+
const int ih1 = MIN(ih0 + dh, nh);
|
|
8664
|
+
|
|
8665
|
+
const int32_t * ids = (const int32_t *) src6->data;
|
|
8666
|
+
|
|
8667
|
+
for (int i3 = 0; i3 < ns; ++i3) {
|
|
8668
|
+
const float * s0 = (const float *) ((const char *) src0->data + ids[i3]*(src0->nb[3])); // {d_state, dim, nh, ns}
|
|
8669
|
+
float * s = ( float *) (( char *) dst->data + i3*(src0->nb[3]) + s_off); // {d_state, dim, nh, ns}
|
|
8670
|
+
|
|
8671
|
+
for (int i2 = 0; i2 < nt; ++i2) {
|
|
8672
|
+
const float * x = (const float *) ((const char *) src1->data + i2*(src1->nb[2]) + i3*(src1->nb[3])); // {dim, nh, nt, ns}
|
|
8673
|
+
const float * dt = (const float *) ((const char *) src2->data + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {nh, nt, ns}
|
|
8674
|
+
const float * A = (const float *) ((const char *) src3->data); // {d_state, nh} or {1, nh}
|
|
8675
|
+
const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[2]) + i3*(src4->nb[3])); // {d_state, ng, nt, ns}
|
|
8676
|
+
const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[2]) + i3*(src5->nb[3])); // {d_state, ng, nt, ns}
|
|
8677
|
+
float * y = ( float *) (( char *) dst->data + i2*(nh*nr*sizeof(float)) + i3*(nt*nh*nr*sizeof(float))); // {dim, nh, nt, ns}
|
|
8678
|
+
|
|
8679
|
+
if (src3->ne[0] == 1) {
|
|
8680
|
+
// Mamba-2 has a scalar decay factor per head; dA can be outside the state-wise loop
|
|
8681
|
+
|
|
8682
|
+
// n_head
|
|
8683
|
+
for (int h = ih0; h < ih1; ++h) {
|
|
8684
|
+
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
|
|
8685
|
+
const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
|
|
8686
|
+
const float dA = expf(dt_soft_plus * A[h]);
|
|
8687
|
+
|
|
8688
|
+
// dim
|
|
8689
|
+
for (int i1 = 0; i1 < nr; ++i1) {
|
|
8690
|
+
const int ii = i1 + h*nr;
|
|
8691
|
+
const float x_dt = x[ii] * dt_soft_plus;
|
|
8692
|
+
float sumf = 0.0f;
|
|
8693
|
+
#if defined(GGML_SIMD)
|
|
8694
|
+
#if defined(__ARM_FEATURE_SVE)
|
|
8695
|
+
const int ggml_f32_epr = svcntw();
|
|
8696
|
+
const int ggml_f32_step = 1 * ggml_f32_epr;
|
|
8697
|
+
|
|
8698
|
+
const int np = (nc & ~(ggml_f32_step - 1));
|
|
8699
|
+
|
|
8700
|
+
GGML_F32_VEC sum = GGML_F32_VEC_ZERO;
|
|
8701
|
+
|
|
8702
|
+
GGML_F32_VEC adA = GGML_F32_VEC_SET1(dA);
|
|
8703
|
+
GGML_F32_VEC axdt = GGML_F32_VEC_SET1(x_dt);
|
|
8704
|
+
|
|
8705
|
+
for (int i = 0; i < np; i += ggml_f32_step) {
|
|
8706
|
+
// TODO: maybe unroll more?
|
|
8707
|
+
for (int j = 0; j < 1; j++) {
|
|
8708
|
+
GGML_F32_VEC t0 = GGML_F32_VEC_LOAD(s0 + i + j*ggml_f32_epr + ii*nc);
|
|
8709
|
+
GGML_F32_VEC t1 = GGML_F32_VEC_LOAD(B + i + j*ggml_f32_epr + (h & (ng - 1))*nc);
|
|
8710
|
+
GGML_F32_VEC t2 = GGML_F32_VEC_LOAD(C + i + j*ggml_f32_epr + (h & (ng - 1))*nc);
|
|
8711
|
+
|
|
8712
|
+
t0 = GGML_F32_VEC_MUL(t0, adA);
|
|
8713
|
+
t1 = GGML_F32_VEC_MUL(t1, axdt);
|
|
8714
|
+
|
|
8715
|
+
t0 = GGML_F32_VEC_ADD(t0, t1);
|
|
8716
|
+
|
|
8717
|
+
sum = GGML_F32_VEC_FMA(sum, t0, t2);
|
|
8718
|
+
|
|
8719
|
+
GGML_F32_VEC_STORE(s + i + j*ggml_f32_epr + ii*nc, t0);
|
|
8720
|
+
}
|
|
8721
|
+
}
|
|
8722
|
+
|
|
8723
|
+
sumf = GGML_F32xt_REDUCE_ONE(sum);
|
|
8724
|
+
#else
|
|
8725
|
+
const int np = (nc & ~(GGML_F32_STEP - 1));
|
|
8726
|
+
|
|
8727
|
+
GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
|
|
8728
|
+
|
|
8729
|
+
GGML_F32_VEC adA = GGML_F32_VEC_SET1(dA);
|
|
8730
|
+
GGML_F32_VEC axdt = GGML_F32_VEC_SET1(x_dt);
|
|
8731
|
+
|
|
8732
|
+
GGML_F32_VEC ax[GGML_F32_ARR];
|
|
8733
|
+
GGML_F32_VEC ay[GGML_F32_ARR];
|
|
8734
|
+
GGML_F32_VEC az[GGML_F32_ARR];
|
|
8735
|
+
|
|
8736
|
+
for (int i = 0; i < np; i += GGML_F32_STEP) {
|
|
8737
|
+
for (int j = 0; j < GGML_F32_ARR; j++) {
|
|
8738
|
+
ax[j] = GGML_F32_VEC_LOAD(s0 + i + j*GGML_F32_EPR + ii*nc);
|
|
8739
|
+
ay[j] = GGML_F32_VEC_LOAD(B + i + j*GGML_F32_EPR + (h & (ng - 1))*nc);
|
|
8740
|
+
az[j] = GGML_F32_VEC_LOAD(C + i + j*GGML_F32_EPR + (h & (ng - 1))*nc);
|
|
8189
8741
|
|
|
8190
|
-
|
|
8191
|
-
|
|
8192
|
-
|
|
8193
|
-
|
|
8194
|
-
|
|
8195
|
-
|
|
8196
|
-
|
|
8197
|
-
|
|
8198
|
-
|
|
8199
|
-
|
|
8200
|
-
|
|
8201
|
-
|
|
8202
|
-
|
|
8203
|
-
|
|
8204
|
-
|
|
8205
|
-
|
|
8206
|
-
|
|
8207
|
-
|
|
8208
|
-
|
|
8209
|
-
|
|
8210
|
-
|
|
8211
|
-
|
|
8212
|
-
|
|
8213
|
-
|
|
8214
|
-
|
|
8215
|
-
|
|
8216
|
-
|
|
8217
|
-
|
|
8218
|
-
|
|
8219
|
-
svfloat32_t t1 = GGML_F32_VEC_MUL(vdt_soft_plus, vA);
|
|
8220
|
-
t1 = exp_ps_sve(svptrue_b32(), t1);
|
|
8221
|
-
svfloat32_t t2 = GGML_F32_VEC_MUL(vx_dt, vB);
|
|
8222
|
-
|
|
8223
|
-
vs0 = GGML_F32_VEC_FMA(vs0, t1, t2);
|
|
8224
|
-
r1_vector = GGML_F32_VEC_ADD(GGML_F32_VEC_MUL(vs0, vC), r1_vector);
|
|
8225
|
-
|
|
8226
|
-
GGML_F32_VEC_STORE(&s[i1*nc + k], vs0);
|
|
8742
|
+
ax[j] = GGML_F32_VEC_MUL(ax[j], adA);
|
|
8743
|
+
ay[j] = GGML_F32_VEC_MUL(ay[j], axdt);
|
|
8744
|
+
|
|
8745
|
+
ax[j] = GGML_F32_VEC_ADD(ax[j], ay[j]);
|
|
8746
|
+
|
|
8747
|
+
sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], az[j]);
|
|
8748
|
+
|
|
8749
|
+
GGML_F32_VEC_STORE(s + i + j*GGML_F32_EPR + ii*nc, ax[j]);
|
|
8750
|
+
}
|
|
8751
|
+
}
|
|
8752
|
+
|
|
8753
|
+
// reduce sum0..sum3 to sum0
|
|
8754
|
+
GGML_F32_VEC_REDUCE(sumf, sum);
|
|
8755
|
+
#endif
|
|
8756
|
+
#else
|
|
8757
|
+
const int np = 0;
|
|
8758
|
+
#endif
|
|
8759
|
+
// d_state
|
|
8760
|
+
for (int i0 = np; i0 < nc; ++i0) {
|
|
8761
|
+
const int i = i0 + ii*nc;
|
|
8762
|
+
const int ig = i0 + (h & (ng - 1))*nc;
|
|
8763
|
+
// state = prev_state * dA + dB * x
|
|
8764
|
+
const float state = (s0[i] * dA) + (B[ig] * x_dt);
|
|
8765
|
+
// y = rowwise_dotprod(state, C)
|
|
8766
|
+
sumf += state * C[ig];
|
|
8767
|
+
s[i] = state;
|
|
8768
|
+
}
|
|
8769
|
+
y[ii] = sumf;
|
|
8227
8770
|
}
|
|
8228
|
-
y[i1] = GGML_F32xt_REDUCE_ONE(r1_vector);
|
|
8229
8771
|
}
|
|
8230
|
-
}
|
|
8231
|
-
|
|
8232
|
-
|
|
8233
|
-
|
|
8234
|
-
|
|
8235
|
-
|
|
8236
|
-
|
|
8237
|
-
|
|
8238
|
-
|
|
8239
|
-
|
|
8240
|
-
|
|
8241
|
-
|
|
8242
|
-
|
|
8243
|
-
|
|
8244
|
-
|
|
8245
|
-
|
|
8246
|
-
|
|
8247
|
-
|
|
8248
|
-
|
|
8249
|
-
|
|
8250
|
-
|
|
8251
|
-
|
|
8252
|
-
|
|
8253
|
-
|
|
8254
|
-
|
|
8255
|
-
|
|
8256
|
-
|
|
8257
|
-
|
|
8258
|
-
|
|
8259
|
-
|
|
8260
|
-
|
|
8772
|
+
} else {
|
|
8773
|
+
// Mamba-1 has an element-wise decay factor for the states
|
|
8774
|
+
|
|
8775
|
+
// n_head
|
|
8776
|
+
for (int h = ih0; h < ih1; ++h) {
|
|
8777
|
+
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
|
|
8778
|
+
const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
|
|
8779
|
+
|
|
8780
|
+
// dim
|
|
8781
|
+
for (int i1 = 0; i1 < nr; ++i1) {
|
|
8782
|
+
const int ii = i1 + h*nr;
|
|
8783
|
+
const float x_dt = x[ii] * dt_soft_plus;
|
|
8784
|
+
#if defined(__ARM_FEATURE_SVE)
|
|
8785
|
+
svfloat32_t vx_dt = GGML_F32_VEC_SET1(x_dt);
|
|
8786
|
+
svfloat32_t vdt_soft_plus = GGML_F32_VEC_SET1(dt_soft_plus);
|
|
8787
|
+
svfloat32_t r1_vector = GGML_F32_VEC_ZERO;
|
|
8788
|
+
|
|
8789
|
+
// d_state
|
|
8790
|
+
// TODO: what happens when (d_state % svcntw()) != 0?
|
|
8791
|
+
for (int64_t k = 0; k < nc; k += svcntw()) {
|
|
8792
|
+
svfloat32_t vA = GGML_F32_VEC_LOAD(&A[h*nc + k]);
|
|
8793
|
+
svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k + (h & (ng - 1))*nc]);
|
|
8794
|
+
svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k + (h & (ng - 1))*nc]);
|
|
8795
|
+
svfloat32_t vs0 = GGML_F32_VEC_LOAD(&s0[ii*nc + k]);
|
|
8796
|
+
|
|
8797
|
+
svfloat32_t t1 = GGML_F32_VEC_MUL(vdt_soft_plus, vA);
|
|
8798
|
+
t1 = exp_ps_sve(svptrue_b32(), t1);
|
|
8799
|
+
svfloat32_t t2 = GGML_F32_VEC_MUL(vx_dt, vB);
|
|
8800
|
+
|
|
8801
|
+
vs0 = GGML_F32_VEC_FMA(t2, vs0, t1);
|
|
8802
|
+
r1_vector = GGML_F32_VEC_ADD(GGML_F32_VEC_MUL(vs0, vC), r1_vector);
|
|
8803
|
+
|
|
8804
|
+
GGML_F32_VEC_STORE(&s[ii*nc + k], vs0);
|
|
8805
|
+
}
|
|
8806
|
+
y[ii] = GGML_F32xt_REDUCE_ONE(r1_vector);
|
|
8807
|
+
#else
|
|
8808
|
+
float sumf = 0.0f;
|
|
8809
|
+
// NOTE: can't really use GGML_SIMD here because d_state is usually 16
|
|
8810
|
+
// and also because expf is used within the loop.
|
|
8811
|
+
// d_state
|
|
8812
|
+
for (int i0 = 0; i0 < nc; ++i0) {
|
|
8813
|
+
const int i = i0 + ii*nc;
|
|
8814
|
+
const int ig = i0 + (h & (ng - 1))*nc;
|
|
8815
|
+
// state = prev_state * dA + dB * x
|
|
8816
|
+
const float state = (s0[i] * expf(dt_soft_plus * A[i0 + h*nc])) + (B[ig] * x_dt);
|
|
8817
|
+
// y = rowwise_dotprod(state, C)
|
|
8818
|
+
sumf += state * C[ig];
|
|
8819
|
+
s[i] = state;
|
|
8820
|
+
}
|
|
8821
|
+
y[ii] = sumf;
|
|
8822
|
+
#endif
|
|
8261
8823
|
}
|
|
8262
|
-
y[i1] = sumf;
|
|
8263
8824
|
}
|
|
8264
8825
|
}
|
|
8826
|
+
// use the output as the source when it's not the first token-wise iteration
|
|
8827
|
+
s0 = s;
|
|
8265
8828
|
}
|
|
8266
|
-
|
|
8829
|
+
}
|
|
8267
8830
|
}
|
|
8268
8831
|
|
|
8269
8832
|
void ggml_compute_forward_ssm_scan(
|
|
@@ -8502,6 +9065,14 @@ void ggml_compute_forward_glu(
|
|
|
8502
9065
|
{
|
|
8503
9066
|
ggml_compute_forward_swiglu(params, dst);
|
|
8504
9067
|
} break;
|
|
9068
|
+
case GGML_GLU_OP_GEGLU_ERF:
|
|
9069
|
+
{
|
|
9070
|
+
ggml_compute_forward_geglu_erf(params, dst);
|
|
9071
|
+
} break;
|
|
9072
|
+
case GGML_GLU_OP_GEGLU_QUICK:
|
|
9073
|
+
{
|
|
9074
|
+
ggml_compute_forward_geglu_quick(params, dst);
|
|
9075
|
+
} break;
|
|
8505
9076
|
default:
|
|
8506
9077
|
{
|
|
8507
9078
|
GGML_ABORT("fatal error");
|