@fugood/llama.node 1.0.2 → 1.0.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/package.json +14 -14
- package/src/llama.cpp/CMakeLists.txt +0 -1
- package/src/llama.cpp/common/CMakeLists.txt +4 -5
- package/src/llama.cpp/common/arg.cpp +44 -0
- package/src/llama.cpp/common/common.cpp +22 -6
- package/src/llama.cpp/common/common.h +15 -1
- package/src/llama.cpp/ggml/CMakeLists.txt +10 -2
- package/src/llama.cpp/ggml/include/ggml-webgpu.h +19 -0
- package/src/llama.cpp/ggml/include/ggml.h +104 -10
- package/src/llama.cpp/ggml/src/CMakeLists.txt +1 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +6 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +12 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +343 -1094
- package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +749 -163
- package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +5 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +1 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +12 -9
- package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +88 -9
- package/src/llama.cpp/include/llama.h +13 -47
- package/src/llama.cpp/src/llama-arch.cpp +298 -3
- package/src/llama.cpp/src/llama-arch.h +22 -1
- package/src/llama.cpp/src/llama-batch.cpp +103 -71
- package/src/llama.cpp/src/llama-batch.h +31 -18
- package/src/llama.cpp/src/llama-chat.cpp +59 -1
- package/src/llama.cpp/src/llama-chat.h +3 -0
- package/src/llama.cpp/src/llama-context.cpp +134 -95
- package/src/llama.cpp/src/llama-context.h +13 -16
- package/src/llama.cpp/src/llama-cparams.h +3 -2
- package/src/llama.cpp/src/llama-graph.cpp +279 -180
- package/src/llama.cpp/src/llama-graph.h +183 -122
- package/src/llama.cpp/src/llama-hparams.cpp +47 -1
- package/src/llama.cpp/src/llama-hparams.h +12 -1
- package/src/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +38 -22
- package/src/llama.cpp/src/llama-kv-cache-unified-iswa.h +7 -2
- package/src/llama.cpp/src/llama-kv-cache-unified.cpp +849 -304
- package/src/llama.cpp/src/llama-kv-cache-unified.h +143 -47
- package/src/llama.cpp/src/llama-kv-cells.h +62 -10
- package/src/llama.cpp/src/llama-memory-hybrid.cpp +10 -4
- package/src/llama.cpp/src/llama-memory-hybrid.h +3 -1
- package/src/llama.cpp/src/llama-memory-recurrent.cpp +21 -11
- package/src/llama.cpp/src/llama-memory.cpp +17 -0
- package/src/llama.cpp/src/llama-memory.h +3 -0
- package/src/llama.cpp/src/llama-model.cpp +3373 -743
- package/src/llama.cpp/src/llama-model.h +20 -4
- package/src/llama.cpp/src/llama-quant.cpp +2 -2
- package/src/llama.cpp/src/llama-vocab.cpp +376 -10
- package/src/llama.cpp/src/llama-vocab.h +43 -0
- package/src/llama.cpp/src/unicode.cpp +207 -0
- package/src/llama.cpp/src/unicode.h +2 -0
- package/src/llama.cpp/ggml/include/ggml-kompute.h +0 -50
|
@@ -3,6 +3,7 @@
|
|
|
3
3
|
#include "ggml-cpu.h"
|
|
4
4
|
#include "ggml-impl.h"
|
|
5
5
|
#include "binary-ops.h"
|
|
6
|
+
#include "ggml.h"
|
|
6
7
|
#include "unary-ops.h"
|
|
7
8
|
#include "vec.h"
|
|
8
9
|
|
|
@@ -3613,6 +3614,292 @@ static void ggml_compute_forward_swiglu(
|
|
|
3613
3614
|
}
|
|
3614
3615
|
}
|
|
3615
3616
|
|
|
3617
|
+
// ggml_compute_forward_geglu_erf
|
|
3618
|
+
|
|
3619
|
+
static void ggml_compute_forward_geglu_erf_f32(
|
|
3620
|
+
const ggml_compute_params * params,
|
|
3621
|
+
ggml_tensor * dst) {
|
|
3622
|
+
|
|
3623
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
3624
|
+
const ggml_tensor * src1 = dst->src[1];
|
|
3625
|
+
char * src0_d = (char *) src0->data;
|
|
3626
|
+
char * src1_d = (char *) (src1 ? src1->data : src0->data);
|
|
3627
|
+
const size_t src0_o = src0->nb[1];
|
|
3628
|
+
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
|
|
3629
|
+
|
|
3630
|
+
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
|
3631
|
+
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
|
3632
|
+
|
|
3633
|
+
if (src1) {
|
|
3634
|
+
GGML_ASSERT(ggml_is_contiguous_1(src1));
|
|
3635
|
+
GGML_ASSERT(src0->type == src1->type);
|
|
3636
|
+
}
|
|
3637
|
+
|
|
3638
|
+
const int ith = params->ith;
|
|
3639
|
+
const int nth = params->nth;
|
|
3640
|
+
|
|
3641
|
+
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
|
|
3642
|
+
const int nr = ggml_nrows(src0);
|
|
3643
|
+
|
|
3644
|
+
GGML_ASSERT(dst->ne[0] == nc);
|
|
3645
|
+
GGML_ASSERT(ggml_nrows(dst) == nr);
|
|
3646
|
+
|
|
3647
|
+
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
|
|
3648
|
+
|
|
3649
|
+
// rows per thread
|
|
3650
|
+
const int dr = (nr + nth - 1)/nth;
|
|
3651
|
+
|
|
3652
|
+
// row range for this thread
|
|
3653
|
+
const int ir0 = dr*ith;
|
|
3654
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
|
3655
|
+
|
|
3656
|
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
|
3657
|
+
float * src0_p = (float *) (src0_d + i1*src0_o);
|
|
3658
|
+
float * src1_p = (float *) (src1_d + i1*src1_o);
|
|
3659
|
+
|
|
3660
|
+
if (!src1) {
|
|
3661
|
+
src0_p += swapped ? nc : 0;
|
|
3662
|
+
src1_p += swapped ? 0 : nc;
|
|
3663
|
+
}
|
|
3664
|
+
|
|
3665
|
+
ggml_vec_geglu_erf_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
|
|
3666
|
+
|
|
3667
|
+
#ifndef NDEBUG
|
|
3668
|
+
for (int k = 0; k < nc; k++) {
|
|
3669
|
+
const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
|
3670
|
+
GGML_UNUSED(x);
|
|
3671
|
+
assert(!isnan(x));
|
|
3672
|
+
assert(!isinf(x));
|
|
3673
|
+
}
|
|
3674
|
+
#endif
|
|
3675
|
+
}
|
|
3676
|
+
}
|
|
3677
|
+
|
|
3678
|
+
static void ggml_compute_forward_geglu_erf_f16(
|
|
3679
|
+
const ggml_compute_params * params,
|
|
3680
|
+
ggml_tensor * dst) {
|
|
3681
|
+
|
|
3682
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
3683
|
+
const ggml_tensor * src1 = dst->src[1];
|
|
3684
|
+
char * src0_d = (char *) src0->data;
|
|
3685
|
+
char * src1_d = (char *) (src1 ? src1->data : src0->data);
|
|
3686
|
+
const size_t src0_o = src0->nb[1];
|
|
3687
|
+
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
|
|
3688
|
+
|
|
3689
|
+
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
|
3690
|
+
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
|
3691
|
+
|
|
3692
|
+
if (src1) {
|
|
3693
|
+
GGML_ASSERT(ggml_is_contiguous_1(src1));
|
|
3694
|
+
GGML_ASSERT(src0->type == src1->type);
|
|
3695
|
+
}
|
|
3696
|
+
|
|
3697
|
+
const int ith = params->ith;
|
|
3698
|
+
const int nth = params->nth;
|
|
3699
|
+
|
|
3700
|
+
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
|
|
3701
|
+
const int nr = ggml_nrows(src0);
|
|
3702
|
+
|
|
3703
|
+
GGML_ASSERT(dst->ne[0] == nc);
|
|
3704
|
+
GGML_ASSERT(ggml_nrows(dst) == nr);
|
|
3705
|
+
|
|
3706
|
+
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
|
|
3707
|
+
|
|
3708
|
+
// rows per thread
|
|
3709
|
+
const int dr = (nr + nth - 1)/nth;
|
|
3710
|
+
|
|
3711
|
+
// row range for this thread
|
|
3712
|
+
const int ir0 = dr*ith;
|
|
3713
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
|
3714
|
+
|
|
3715
|
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
|
3716
|
+
ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
|
|
3717
|
+
ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
|
|
3718
|
+
|
|
3719
|
+
if (!src1) {
|
|
3720
|
+
src0_p += swapped ? nc : 0;
|
|
3721
|
+
src1_p += swapped ? 0 : nc;
|
|
3722
|
+
}
|
|
3723
|
+
|
|
3724
|
+
ggml_vec_geglu_erf_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
|
|
3725
|
+
|
|
3726
|
+
#ifndef NDEBUG
|
|
3727
|
+
for (int k = 0; k < nc; k++) {
|
|
3728
|
+
const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
|
3729
|
+
const float v = GGML_FP16_TO_FP32(x);
|
|
3730
|
+
GGML_UNUSED(v);
|
|
3731
|
+
assert(!isnan(v));
|
|
3732
|
+
assert(!isinf(v));
|
|
3733
|
+
}
|
|
3734
|
+
#endif
|
|
3735
|
+
}
|
|
3736
|
+
}
|
|
3737
|
+
|
|
3738
|
+
static void ggml_compute_forward_geglu_erf(
|
|
3739
|
+
const ggml_compute_params * params,
|
|
3740
|
+
ggml_tensor * dst) {
|
|
3741
|
+
|
|
3742
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
3743
|
+
|
|
3744
|
+
switch (src0->type) {
|
|
3745
|
+
case GGML_TYPE_F32:
|
|
3746
|
+
{
|
|
3747
|
+
ggml_compute_forward_geglu_erf_f32(params, dst);
|
|
3748
|
+
} break;
|
|
3749
|
+
case GGML_TYPE_F16:
|
|
3750
|
+
{
|
|
3751
|
+
ggml_compute_forward_geglu_erf_f16(params, dst);
|
|
3752
|
+
} break;
|
|
3753
|
+
default:
|
|
3754
|
+
{
|
|
3755
|
+
GGML_ABORT("fatal error");
|
|
3756
|
+
}
|
|
3757
|
+
}
|
|
3758
|
+
}
|
|
3759
|
+
|
|
3760
|
+
// ggml_compute_forward_geglu_quick
|
|
3761
|
+
|
|
3762
|
+
static void ggml_compute_forward_geglu_quick_f32(
|
|
3763
|
+
const ggml_compute_params * params,
|
|
3764
|
+
ggml_tensor * dst) {
|
|
3765
|
+
|
|
3766
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
3767
|
+
const ggml_tensor * src1 = dst->src[1];
|
|
3768
|
+
char * src0_d = (char *) src0->data;
|
|
3769
|
+
char * src1_d = (char *) (src1 ? src1->data : src0->data);
|
|
3770
|
+
const size_t src0_o = src0->nb[1];
|
|
3771
|
+
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
|
|
3772
|
+
|
|
3773
|
+
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
|
3774
|
+
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
|
3775
|
+
|
|
3776
|
+
if (src1) {
|
|
3777
|
+
GGML_ASSERT(ggml_is_contiguous_1(src1));
|
|
3778
|
+
GGML_ASSERT(src0->type == src1->type);
|
|
3779
|
+
}
|
|
3780
|
+
|
|
3781
|
+
const int ith = params->ith;
|
|
3782
|
+
const int nth = params->nth;
|
|
3783
|
+
|
|
3784
|
+
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
|
|
3785
|
+
const int nr = ggml_nrows(src0);
|
|
3786
|
+
|
|
3787
|
+
GGML_ASSERT(dst->ne[0] == nc);
|
|
3788
|
+
GGML_ASSERT(ggml_nrows(dst) == nr);
|
|
3789
|
+
|
|
3790
|
+
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
|
|
3791
|
+
|
|
3792
|
+
// rows per thread
|
|
3793
|
+
const int dr = (nr + nth - 1)/nth;
|
|
3794
|
+
|
|
3795
|
+
// row range for this thread
|
|
3796
|
+
const int ir0 = dr*ith;
|
|
3797
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
|
3798
|
+
|
|
3799
|
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
|
3800
|
+
float * src0_p = (float *) (src0_d + i1*src0_o);
|
|
3801
|
+
float * src1_p = (float *) (src1_d + i1*src1_o);
|
|
3802
|
+
|
|
3803
|
+
if (!src1) {
|
|
3804
|
+
src0_p += swapped ? nc : 0;
|
|
3805
|
+
src1_p += swapped ? 0 : nc;
|
|
3806
|
+
}
|
|
3807
|
+
|
|
3808
|
+
ggml_vec_geglu_quick_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
|
|
3809
|
+
|
|
3810
|
+
#ifndef NDEBUG
|
|
3811
|
+
for (int k = 0; k < nc; k++) {
|
|
3812
|
+
const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
|
3813
|
+
GGML_UNUSED(x);
|
|
3814
|
+
assert(!isnan(x));
|
|
3815
|
+
assert(!isinf(x));
|
|
3816
|
+
}
|
|
3817
|
+
#endif
|
|
3818
|
+
}
|
|
3819
|
+
}
|
|
3820
|
+
|
|
3821
|
+
static void ggml_compute_forward_geglu_quick_f16(
|
|
3822
|
+
const ggml_compute_params * params,
|
|
3823
|
+
ggml_tensor * dst) {
|
|
3824
|
+
|
|
3825
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
3826
|
+
const ggml_tensor * src1 = dst->src[1];
|
|
3827
|
+
char * src0_d = (char *) src0->data;
|
|
3828
|
+
char * src1_d = (char *) (src1 ? src1->data : src0->data);
|
|
3829
|
+
const size_t src0_o = src0->nb[1];
|
|
3830
|
+
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
|
|
3831
|
+
|
|
3832
|
+
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
|
3833
|
+
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
|
3834
|
+
|
|
3835
|
+
if (src1) {
|
|
3836
|
+
GGML_ASSERT(ggml_is_contiguous_1(src1));
|
|
3837
|
+
GGML_ASSERT(src0->type == src1->type);
|
|
3838
|
+
}
|
|
3839
|
+
|
|
3840
|
+
const int ith = params->ith;
|
|
3841
|
+
const int nth = params->nth;
|
|
3842
|
+
|
|
3843
|
+
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
|
|
3844
|
+
const int nr = ggml_nrows(src0);
|
|
3845
|
+
|
|
3846
|
+
GGML_ASSERT(dst->ne[0] == nc);
|
|
3847
|
+
GGML_ASSERT(ggml_nrows(dst) == nr);
|
|
3848
|
+
|
|
3849
|
+
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
|
|
3850
|
+
|
|
3851
|
+
// rows per thread
|
|
3852
|
+
const int dr = (nr + nth - 1)/nth;
|
|
3853
|
+
|
|
3854
|
+
// row range for this thread
|
|
3855
|
+
const int ir0 = dr*ith;
|
|
3856
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
|
3857
|
+
|
|
3858
|
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
|
3859
|
+
ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
|
|
3860
|
+
ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
|
|
3861
|
+
|
|
3862
|
+
if (!src1) {
|
|
3863
|
+
src0_p += swapped ? nc : 0;
|
|
3864
|
+
src1_p += swapped ? 0 : nc;
|
|
3865
|
+
}
|
|
3866
|
+
|
|
3867
|
+
ggml_vec_geglu_quick_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
|
|
3868
|
+
|
|
3869
|
+
#ifndef NDEBUG
|
|
3870
|
+
for (int k = 0; k < nc; k++) {
|
|
3871
|
+
const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
|
3872
|
+
const float v = GGML_FP16_TO_FP32(x);
|
|
3873
|
+
GGML_UNUSED(v);
|
|
3874
|
+
assert(!isnan(v));
|
|
3875
|
+
assert(!isinf(v));
|
|
3876
|
+
}
|
|
3877
|
+
#endif
|
|
3878
|
+
}
|
|
3879
|
+
}
|
|
3880
|
+
|
|
3881
|
+
static void ggml_compute_forward_geglu_quick(
|
|
3882
|
+
const ggml_compute_params * params,
|
|
3883
|
+
ggml_tensor * dst) {
|
|
3884
|
+
|
|
3885
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
3886
|
+
|
|
3887
|
+
switch (src0->type) {
|
|
3888
|
+
case GGML_TYPE_F32:
|
|
3889
|
+
{
|
|
3890
|
+
ggml_compute_forward_geglu_quick_f32(params, dst);
|
|
3891
|
+
} break;
|
|
3892
|
+
case GGML_TYPE_F16:
|
|
3893
|
+
{
|
|
3894
|
+
ggml_compute_forward_geglu_quick_f16(params, dst);
|
|
3895
|
+
} break;
|
|
3896
|
+
default:
|
|
3897
|
+
{
|
|
3898
|
+
GGML_ABORT("fatal error");
|
|
3899
|
+
}
|
|
3900
|
+
}
|
|
3901
|
+
}
|
|
3902
|
+
|
|
3616
3903
|
// ggml_compute_forward_norm
|
|
3617
3904
|
|
|
3618
3905
|
static void ggml_compute_forward_norm_f32(
|
|
@@ -3728,6 +4015,9 @@ static void ggml_compute_forward_rms_norm_f32(
|
|
|
3728
4015
|
|
|
3729
4016
|
const float scale = 1.0f/sqrtf(mean + eps);
|
|
3730
4017
|
|
|
4018
|
+
// if you hit this, likely you got an inf somewhere earlier
|
|
4019
|
+
assert(scale > 0.0f);
|
|
4020
|
+
|
|
3731
4021
|
ggml_vec_scale_f32(ne00, y, scale);
|
|
3732
4022
|
}
|
|
3733
4023
|
}
|
|
@@ -4356,9 +4646,11 @@ static void ggml_compute_forward_scale_f32(
|
|
|
4356
4646
|
GGML_ASSERT(ggml_is_contiguous(dst));
|
|
4357
4647
|
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
|
4358
4648
|
|
|
4359
|
-
// scale factor
|
|
4360
|
-
float
|
|
4361
|
-
|
|
4649
|
+
float s; // scale factor
|
|
4650
|
+
float b; // bias
|
|
4651
|
+
|
|
4652
|
+
memcpy(&s, (float *) dst->op_params + 0, sizeof(float));
|
|
4653
|
+
memcpy(&b, (float *) dst->op_params + 1, sizeof(float));
|
|
4362
4654
|
|
|
4363
4655
|
const int ith = params->ith;
|
|
4364
4656
|
const int nth = params->nth;
|
|
@@ -4377,12 +4669,22 @@ static void ggml_compute_forward_scale_f32(
|
|
|
4377
4669
|
|
|
4378
4670
|
const size_t nb1 = dst->nb[1];
|
|
4379
4671
|
|
|
4380
|
-
|
|
4381
|
-
|
|
4382
|
-
|
|
4383
|
-
|
|
4672
|
+
if (b == 0.0f) {
|
|
4673
|
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
|
4674
|
+
if (dst->data != src0->data) {
|
|
4675
|
+
// src0 is same shape as dst => same indices
|
|
4676
|
+
// TODO: add x parameter to ggml_vec_scale_f32 and remove this memcpy
|
|
4677
|
+
memcpy((char *)dst->data + i1*nb1, (char *)src0->data + i1*nb01, nc * sizeof(float));
|
|
4678
|
+
}
|
|
4679
|
+
ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*nb1), s);
|
|
4680
|
+
}
|
|
4681
|
+
} else {
|
|
4682
|
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
|
4683
|
+
ggml_vec_mad1_f32(nc,
|
|
4684
|
+
(float *) ((char *) dst->data + i1*nb1),
|
|
4685
|
+
(float *) ((char *) src0->data + i1*nb1),
|
|
4686
|
+
s, b);
|
|
4384
4687
|
}
|
|
4385
|
-
ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*nb1), v);
|
|
4386
4688
|
}
|
|
4387
4689
|
}
|
|
4388
4690
|
|
|
@@ -5231,14 +5533,17 @@ static void ggml_compute_forward_soft_max_f32(
|
|
|
5231
5533
|
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
|
|
5232
5534
|
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
|
|
5233
5535
|
|
|
5234
|
-
// TODO: handle transposed/permuted matrices
|
|
5235
|
-
|
|
5236
5536
|
const int ith = params->ith;
|
|
5237
5537
|
const int nth = params->nth;
|
|
5238
5538
|
|
|
5239
5539
|
GGML_TENSOR_UNARY_OP_LOCALS
|
|
5240
5540
|
|
|
5241
|
-
|
|
5541
|
+
const int64_t nb11 = src1 ? src1->nb[1] : 1;
|
|
5542
|
+
const int64_t nb12 = src1 ? src1->nb[2] : 1;
|
|
5543
|
+
const int64_t nb13 = src1 ? src1->nb[3] : 1;
|
|
5544
|
+
|
|
5545
|
+
const int64_t ne12 = src1 ? src1->ne[2] : 1;
|
|
5546
|
+
const int64_t ne13 = src1 ? src1->ne[3] : 1;
|
|
5242
5547
|
|
|
5243
5548
|
// TODO: is this supposed to be ceil instead of floor?
|
|
5244
5549
|
// https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370
|
|
@@ -5248,68 +5553,66 @@ static void ggml_compute_forward_soft_max_f32(
|
|
|
5248
5553
|
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
|
5249
5554
|
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
|
5250
5555
|
|
|
5251
|
-
|
|
5252
|
-
const int nr = ggml_nrows(src0);
|
|
5253
|
-
|
|
5254
|
-
// rows per thread
|
|
5255
|
-
const int dr = (nr + nth - 1)/nth;
|
|
5256
|
-
|
|
5257
|
-
// row range for this thread
|
|
5258
|
-
const int ir0 = dr*ith;
|
|
5259
|
-
const int ir1 = MIN(ir0 + dr, nr);
|
|
5260
|
-
|
|
5261
|
-
float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
|
|
5556
|
+
float * wp = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
|
|
5262
5557
|
|
|
5263
5558
|
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
|
|
5264
5559
|
|
|
5265
|
-
for (
|
|
5266
|
-
|
|
5267
|
-
|
|
5268
|
-
|
|
5269
|
-
|
|
5270
|
-
|
|
5271
|
-
|
|
5272
|
-
|
|
5273
|
-
|
|
5274
|
-
|
|
5275
|
-
|
|
5276
|
-
|
|
5277
|
-
|
|
5278
|
-
|
|
5279
|
-
|
|
5280
|
-
|
|
5281
|
-
|
|
5282
|
-
|
|
5283
|
-
|
|
5284
|
-
|
|
5285
|
-
|
|
5286
|
-
|
|
5560
|
+
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
|
5561
|
+
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
|
5562
|
+
for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
|
|
5563
|
+
const int64_t i11 = i01;
|
|
5564
|
+
const int64_t i12 = i02%ne12;
|
|
5565
|
+
const int64_t i13 = i03%ne13;
|
|
5566
|
+
|
|
5567
|
+
// ALiBi
|
|
5568
|
+
const uint32_t h = i02; // head
|
|
5569
|
+
const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
|
|
5570
|
+
|
|
5571
|
+
float * sp = (float *)((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
|
5572
|
+
float * dp = (float *)((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
|
|
5573
|
+
|
|
5574
|
+
// broadcast the mask across rows
|
|
5575
|
+
ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL;
|
|
5576
|
+
float * mp_f32 = src1 ? (float *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL;
|
|
5577
|
+
|
|
5578
|
+
ggml_vec_cpy_f32 (ne00, wp, sp);
|
|
5579
|
+
ggml_vec_scale_f32(ne00, wp, scale);
|
|
5580
|
+
if (mp_f32) {
|
|
5581
|
+
if (use_f16) {
|
|
5582
|
+
for (int i = 0; i < ne00; ++i) {
|
|
5583
|
+
wp[i] += slope*GGML_CPU_FP16_TO_FP32(mp_f16[i]);
|
|
5584
|
+
}
|
|
5585
|
+
} else {
|
|
5586
|
+
for (int i = 0; i < ne00; ++i) {
|
|
5587
|
+
wp[i] += slope*mp_f32[i];
|
|
5588
|
+
}
|
|
5589
|
+
}
|
|
5287
5590
|
}
|
|
5288
|
-
}
|
|
5289
|
-
}
|
|
5290
5591
|
|
|
5291
5592
|
#ifndef NDEBUG
|
|
5292
|
-
|
|
5293
|
-
|
|
5294
|
-
|
|
5295
|
-
|
|
5593
|
+
for (int i = 0; i < ne00; ++i) {
|
|
5594
|
+
//printf("p[%d] = %f\n", i, p[i]);
|
|
5595
|
+
assert(!isnan(wp[i]));
|
|
5596
|
+
}
|
|
5296
5597
|
#endif
|
|
5297
5598
|
|
|
5298
|
-
|
|
5299
|
-
|
|
5599
|
+
float max = -INFINITY;
|
|
5600
|
+
ggml_vec_max_f32(ne00, &max, wp);
|
|
5300
5601
|
|
|
5301
|
-
|
|
5302
|
-
|
|
5602
|
+
ggml_float sum = ggml_vec_soft_max_f32(ne00, dp, wp, max);
|
|
5603
|
+
assert(sum > 0.0);
|
|
5303
5604
|
|
|
5304
|
-
|
|
5305
|
-
|
|
5605
|
+
sum = 1.0/sum;
|
|
5606
|
+
ggml_vec_scale_f32(ne00, dp, sum);
|
|
5306
5607
|
|
|
5307
5608
|
#ifndef NDEBUG
|
|
5308
|
-
|
|
5309
|
-
|
|
5310
|
-
|
|
5311
|
-
|
|
5609
|
+
for (int i = 0; i < ne00; ++i) {
|
|
5610
|
+
assert(!isnan(dp[i]));
|
|
5611
|
+
assert(!isinf(dp[i]));
|
|
5612
|
+
}
|
|
5312
5613
|
#endif
|
|
5614
|
+
}
|
|
5615
|
+
}
|
|
5313
5616
|
}
|
|
5314
5617
|
}
|
|
5315
5618
|
|
|
@@ -6545,6 +6848,186 @@ void ggml_compute_forward_im2col_back_f32(
|
|
|
6545
6848
|
}
|
|
6546
6849
|
}
|
|
6547
6850
|
|
|
6851
|
+
static void ggml_call_mul_mat(ggml_type type, const ggml_compute_params * params, int64_t m, int64_t n, int64_t k,
|
|
6852
|
+
void * a, void * b, float * c) {
|
|
6853
|
+
const ggml_type_traits * traits = ggml_get_type_traits(type);
|
|
6854
|
+
struct ggml_tensor src1 = {};
|
|
6855
|
+
src1.type = type;
|
|
6856
|
+
src1.ne[0] = k;
|
|
6857
|
+
src1.ne[1] = m;
|
|
6858
|
+
src1.ne[2] = 1;
|
|
6859
|
+
src1.ne[3] = 1;
|
|
6860
|
+
src1.nb[0] = traits->type_size;
|
|
6861
|
+
src1.nb[1] = k * traits->type_size;
|
|
6862
|
+
src1.nb[2] = src1.nb[1];
|
|
6863
|
+
src1.nb[3] = src1.nb[2];
|
|
6864
|
+
src1.data = a;
|
|
6865
|
+
|
|
6866
|
+
struct ggml_tensor src0 = {};
|
|
6867
|
+
src0.type = type;
|
|
6868
|
+
src0.ne[0] = k;
|
|
6869
|
+
src0.ne[1] = n;
|
|
6870
|
+
src0.ne[2] = 1;
|
|
6871
|
+
src0.ne[3] = 1;
|
|
6872
|
+
src0.nb[0] = traits->type_size;
|
|
6873
|
+
src0.nb[1] = k * traits->type_size;
|
|
6874
|
+
src0.nb[2] = src0.nb[1];
|
|
6875
|
+
src0.nb[3] = src0.nb[2];
|
|
6876
|
+
src0.data = b;
|
|
6877
|
+
|
|
6878
|
+
struct ggml_tensor dst = {};
|
|
6879
|
+
dst.ne[0] = n;
|
|
6880
|
+
dst.ne[1] = m;
|
|
6881
|
+
dst.ne[2] = 1;
|
|
6882
|
+
dst.ne[3] = 1;
|
|
6883
|
+
dst.nb[0] = sizeof(float);
|
|
6884
|
+
dst.nb[1] = n * sizeof(float);
|
|
6885
|
+
dst.nb[2] = dst.nb[1];
|
|
6886
|
+
dst.nb[3] = dst.nb[2];
|
|
6887
|
+
dst.data = c;
|
|
6888
|
+
dst.src[0] = &src0;
|
|
6889
|
+
dst.src[1] = &src1;
|
|
6890
|
+
|
|
6891
|
+
ggml_compute_forward_mul_mat(params, &dst);
|
|
6892
|
+
}
|
|
6893
|
+
|
|
6894
|
+
// ggml_compute_forward_conv_2d
|
|
6895
|
+
|
|
6896
|
+
static void ggml_compute_forward_conv_2d_impl(const ggml_compute_params * params,
|
|
6897
|
+
const ggml_tensor * kernel, // [KW, KH, IC, OC]
|
|
6898
|
+
const ggml_tensor * src, // [W, H, C, N]
|
|
6899
|
+
ggml_tensor * dst, // [OW, OH, OC, N]
|
|
6900
|
+
ggml_type kernel_type) {
|
|
6901
|
+
|
|
6902
|
+
GGML_ASSERT(ggml_is_contiguous(kernel));
|
|
6903
|
+
GGML_ASSERT(kernel_type == GGML_TYPE_F16 || kernel_type == GGML_TYPE_F32);
|
|
6904
|
+
GGML_ASSERT(kernel->type == kernel_type);
|
|
6905
|
+
|
|
6906
|
+
const ggml_type_traits * traits = ggml_get_type_traits(kernel_type);
|
|
6907
|
+
|
|
6908
|
+
const int32_t stride_x = dst->op_params[0];
|
|
6909
|
+
const int32_t stride_y = dst->op_params[1];
|
|
6910
|
+
const int32_t pad_x = dst->op_params[2];
|
|
6911
|
+
const int32_t pad_y = dst->op_params[3];
|
|
6912
|
+
const int32_t dilation_x = dst->op_params[4];
|
|
6913
|
+
const int32_t dilation_y = dst->op_params[5];
|
|
6914
|
+
|
|
6915
|
+
const int64_t c_in = src->ne[2];
|
|
6916
|
+
const int64_t c_out = kernel->ne[3];
|
|
6917
|
+
GGML_ASSERT(c_in == kernel->ne[2]);
|
|
6918
|
+
|
|
6919
|
+
const int64_t src_w = src->ne[0];
|
|
6920
|
+
const int64_t src_h = src->ne[1];
|
|
6921
|
+
const int64_t knl_w = kernel->ne[0];
|
|
6922
|
+
const int64_t knl_h = kernel->ne[1];
|
|
6923
|
+
const int64_t dst_w = dst->ne[0];
|
|
6924
|
+
const int64_t dst_h = dst->ne[1];
|
|
6925
|
+
|
|
6926
|
+
const float * src_data = (float *) src->data;
|
|
6927
|
+
void * knl_data = kernel->data;
|
|
6928
|
+
float * dst_data = (float *) dst->data;
|
|
6929
|
+
|
|
6930
|
+
const int64_t knl_n = knl_w * knl_h * c_in;
|
|
6931
|
+
const int64_t patch_total = dst->ne[3] * dst_w * dst_h;
|
|
6932
|
+
|
|
6933
|
+
const int64_t space_per_patch = knl_n * traits->type_size + c_out * sizeof(float);
|
|
6934
|
+
const int64_t batch_size = params->wsize / space_per_patch;
|
|
6935
|
+
const int64_t patches_per_batch = batch_size > 8 ? (batch_size / 8) * 8 : batch_size;
|
|
6936
|
+
const int64_t batch_n = (patch_total + patches_per_batch - 1) / patches_per_batch;
|
|
6937
|
+
|
|
6938
|
+
GGML_ASSERT(patches_per_batch > 0 && batch_size >= 1);
|
|
6939
|
+
|
|
6940
|
+
void * tmp = params->wdata;
|
|
6941
|
+
|
|
6942
|
+
for (int64_t batch_i = 0; batch_i < batch_n; ++batch_i) {
|
|
6943
|
+
|
|
6944
|
+
const int64_t patch_start_batch = batch_i * patches_per_batch;
|
|
6945
|
+
const int64_t patch_end_batch = std::min(patch_start_batch + patches_per_batch,
|
|
6946
|
+
patch_total);
|
|
6947
|
+
const int64_t patch_n = patch_end_batch - patch_start_batch;
|
|
6948
|
+
|
|
6949
|
+
const int64_t patch_per_thread = (patch_n + params->nth - 1) / params->nth;
|
|
6950
|
+
const int64_t patch_start = patch_start_batch + params->ith * patch_per_thread;
|
|
6951
|
+
const int64_t patch_end = std::min(patch_start + patch_per_thread, patch_end_batch);
|
|
6952
|
+
|
|
6953
|
+
//im2col for a patch
|
|
6954
|
+
for (int64_t p = patch_start; p < patch_end; ++p) {
|
|
6955
|
+
const int64_t batch_n = p / (dst_w * dst_h);
|
|
6956
|
+
const int64_t src_x = (p / dst_w) % dst_h;
|
|
6957
|
+
const int64_t src_y = p % dst_w;
|
|
6958
|
+
|
|
6959
|
+
const float * src_base = (const float *)((const char *)src_data + batch_n * src->nb[3]);
|
|
6960
|
+
char * dst_row = (char *) tmp + (p % patches_per_batch) * knl_n * traits->type_size;
|
|
6961
|
+
|
|
6962
|
+
for (int64_t ic = 0; ic < c_in; ++ic) {
|
|
6963
|
+
for (int64_t ky = 0; ky < knl_h; ++ky) {
|
|
6964
|
+
for (int64_t kx = 0; kx < knl_w; ++kx) {
|
|
6965
|
+
const int64_t sy = src_x * stride_y + ky * dilation_y - pad_y;
|
|
6966
|
+
const int64_t sx = src_y * stride_x + kx * dilation_x - pad_x;
|
|
6967
|
+
|
|
6968
|
+
int64_t dst_idx = ic * (knl_h * knl_w) + ky * knl_w + kx;
|
|
6969
|
+
|
|
6970
|
+
float src_val;
|
|
6971
|
+
if (sy < 0 || sy >= src_h || sx < 0 || sx >= src_w) {
|
|
6972
|
+
src_val = 0.0f;
|
|
6973
|
+
} else {
|
|
6974
|
+
const float * src_ptr = (const float *)((const char *)src_base + sx * src->nb[0] + sy * src->nb[1] + ic * src->nb[2]);
|
|
6975
|
+
src_val = *src_ptr;
|
|
6976
|
+
}
|
|
6977
|
+
|
|
6978
|
+
char * element_ptr = dst_row + dst_idx * traits->type_size;
|
|
6979
|
+
if (kernel_type == GGML_TYPE_F32) {
|
|
6980
|
+
*(float *) element_ptr = src_val;
|
|
6981
|
+
} else if (kernel_type == GGML_TYPE_F16) {
|
|
6982
|
+
*(ggml_fp16_t *) element_ptr = GGML_CPU_FP32_TO_FP16(src_val);
|
|
6983
|
+
}
|
|
6984
|
+
}
|
|
6985
|
+
}
|
|
6986
|
+
}
|
|
6987
|
+
} // patches handled by this thread
|
|
6988
|
+
|
|
6989
|
+
ggml_barrier(params->threadpool);
|
|
6990
|
+
|
|
6991
|
+
float * gemm_output = (float *) ((char *) tmp + patches_per_batch * knl_n * traits->type_size);
|
|
6992
|
+
|
|
6993
|
+
GGML_ASSERT(gemm_output + patch_n * c_out <= (float*)tmp + params->wsize);
|
|
6994
|
+
|
|
6995
|
+
// GEMM: patches[patch_n, knl_n] × kernel[knl_n, c_out] = output[patch_n, c_out]
|
|
6996
|
+
ggml_call_mul_mat(kernel_type, params, patch_n, c_out, knl_n, tmp, knl_data, gemm_output);
|
|
6997
|
+
|
|
6998
|
+
ggml_barrier(params->threadpool);
|
|
6999
|
+
|
|
7000
|
+
|
|
7001
|
+
//permute back [OC, N, OH, OW] to [N, OC, OH, OW]
|
|
7002
|
+
const int64_t permute_per_thread = (patch_n + params->nth - 1) / params->nth;
|
|
7003
|
+
const int64_t permute_start = params->ith * permute_per_thread;
|
|
7004
|
+
const int64_t permute_end = std::min(permute_start + permute_per_thread, patch_n);
|
|
7005
|
+
|
|
7006
|
+
for (int64_t i = permute_start; i < permute_end; ++i) {
|
|
7007
|
+
const int64_t p = patch_start_batch + i;
|
|
7008
|
+
const int64_t batch_n = p / (dst_w * dst_h);
|
|
7009
|
+
const int64_t dst_y = (p / dst_w) % dst_h;
|
|
7010
|
+
const int64_t dst_x = p % dst_w;
|
|
7011
|
+
|
|
7012
|
+
for (int64_t oc = 0; oc < c_out; ++oc) {
|
|
7013
|
+
const float value = gemm_output[i * c_out + oc];
|
|
7014
|
+
float * dst_ptr = (float *)((char *)dst_data + dst_x * dst->nb[0] + dst_y * dst->nb[1] + oc * dst->nb[2] + batch_n * dst->nb[3]);
|
|
7015
|
+
*dst_ptr = value;
|
|
7016
|
+
}
|
|
7017
|
+
}
|
|
7018
|
+
}
|
|
7019
|
+
}
|
|
7020
|
+
|
|
7021
|
+
void ggml_compute_forward_conv_2d(
|
|
7022
|
+
const ggml_compute_params * params,
|
|
7023
|
+
ggml_tensor * dst) {
|
|
7024
|
+
|
|
7025
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
7026
|
+
const ggml_tensor * src1 = dst->src[1];
|
|
7027
|
+
|
|
7028
|
+
ggml_compute_forward_conv_2d_impl(params, src0, src1, dst, src0->type);
|
|
7029
|
+
}
|
|
7030
|
+
|
|
6548
7031
|
// ggml_compute_forward_conv_transpose_2d
|
|
6549
7032
|
|
|
6550
7033
|
void ggml_compute_forward_conv_transpose_2d(
|
|
@@ -7095,12 +7578,13 @@ static void ggml_compute_forward_upscale_f32(
|
|
|
7095
7578
|
|
|
7096
7579
|
GGML_TENSOR_UNARY_OP_LOCALS
|
|
7097
7580
|
|
|
7098
|
-
|
|
7099
|
-
|
|
7100
|
-
|
|
7101
|
-
|
|
7581
|
+
float sf0 = (float)ne0/src0->ne[0];
|
|
7582
|
+
float sf1 = (float)ne1/src0->ne[1];
|
|
7583
|
+
float sf2 = (float)ne2/src0->ne[2];
|
|
7584
|
+
float sf3 = (float)ne3/src0->ne[3];
|
|
7102
7585
|
|
|
7103
|
-
const
|
|
7586
|
+
const int32_t mode_flags = ggml_get_op_params_i32(dst, 0);
|
|
7587
|
+
const ggml_scale_mode mode = (ggml_scale_mode) (mode_flags & 0xFF);
|
|
7104
7588
|
|
|
7105
7589
|
if (mode == GGML_SCALE_MODE_NEAREST) {
|
|
7106
7590
|
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
|
@@ -7121,8 +7605,12 @@ static void ggml_compute_forward_upscale_f32(
|
|
|
7121
7605
|
}
|
|
7122
7606
|
}
|
|
7123
7607
|
} else if (mode == GGML_SCALE_MODE_BILINEAR) {
|
|
7124
|
-
|
|
7125
|
-
|
|
7608
|
+
float pixel_offset = 0.5f;
|
|
7609
|
+
if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {
|
|
7610
|
+
pixel_offset = 0.0f;
|
|
7611
|
+
sf0 = (float)(ne0 - 1) / (src0->ne[0] - 1);
|
|
7612
|
+
sf1 = (float)(ne1 - 1) / (src0->ne[1] - 1);
|
|
7613
|
+
}
|
|
7126
7614
|
|
|
7127
7615
|
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
|
7128
7616
|
const int64_t i03 = i3 / sf3;
|
|
@@ -7580,7 +8068,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
|
7580
8068
|
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
|
7581
8069
|
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
|
7582
8070
|
|
|
7583
|
-
ggml_type
|
|
8071
|
+
ggml_type const k_vec_dot_type = ggml_get_type_traits_cpu(k->type)->vec_dot_type;
|
|
7584
8072
|
ggml_from_float_t const q_to_vec_dot = ggml_get_type_traits_cpu(k_vec_dot_type)->from_float;
|
|
7585
8073
|
ggml_vec_dot_t const kq_vec_dot = ggml_get_type_traits_cpu(k->type)->vec_dot;
|
|
7586
8074
|
ggml_to_float_t const v_to_float = ggml_get_type_traits(v->type)->to_float;
|
|
@@ -7612,7 +8100,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
|
7612
8100
|
memset(VKQ32, 0, DV*sizeof(float));
|
|
7613
8101
|
}
|
|
7614
8102
|
|
|
7615
|
-
const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
|
|
8103
|
+
const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1] + (iq2%mask->ne[2])*mask->nb[2] + (iq3%mask->ne[3])*mask->nb[3]) : NULL;
|
|
7616
8104
|
|
|
7617
8105
|
// k indices
|
|
7618
8106
|
const int ik3 = iq3 / rk3;
|
|
@@ -8150,120 +8638,210 @@ void ggml_compute_forward_ssm_conv(
|
|
|
8150
8638
|
static void ggml_compute_forward_ssm_scan_f32(
|
|
8151
8639
|
const ggml_compute_params * params,
|
|
8152
8640
|
ggml_tensor * dst) {
|
|
8153
|
-
const ggml_tensor * src0 = dst->src[0]; // s
|
|
8154
|
-
const ggml_tensor * src1 = dst->src[1]; // x
|
|
8155
|
-
const ggml_tensor * src2 = dst->src[2]; // dt
|
|
8156
|
-
const ggml_tensor * src3 = dst->src[3]; // A
|
|
8157
|
-
const ggml_tensor * src4 = dst->src[4]; // B
|
|
8158
|
-
const ggml_tensor * src5 = dst->src[5]; // C
|
|
8641
|
+
const ggml_tensor * src0 = dst->src[0]; // s {d_state, dim, n_head, n_seqs+}
|
|
8642
|
+
const ggml_tensor * src1 = dst->src[1]; // x {dim, n_head, n_seq_tokens, n_seqs}
|
|
8643
|
+
const ggml_tensor * src2 = dst->src[2]; // dt {n_head, n_seq_tokens, n_seqs}
|
|
8644
|
+
const ggml_tensor * src3 = dst->src[3]; // A {d_state, n_head} or {1, n_head}
|
|
8645
|
+
const ggml_tensor * src4 = dst->src[4]; // B {d_state, n_group, n_seq_tokens, n_seqs}
|
|
8646
|
+
const ggml_tensor * src5 = dst->src[5]; // C {d_state, n_group, n_seq_tokens, n_seqs}
|
|
8647
|
+
const ggml_tensor * src6 = dst->src[6]; // ids {n_seqs}
|
|
8159
8648
|
|
|
8160
8649
|
const int ith = params->ith;
|
|
8161
8650
|
const int nth = params->nth;
|
|
8162
8651
|
|
|
8163
|
-
const int64_t nc
|
|
8164
|
-
const int64_t nr
|
|
8165
|
-
const int64_t
|
|
8166
|
-
const int64_t
|
|
8652
|
+
const int64_t nc = src0->ne[0]; // d_state
|
|
8653
|
+
const int64_t nr = src0->ne[1]; // dim
|
|
8654
|
+
const int64_t nh = src1->ne[1]; // n_head
|
|
8655
|
+
const int64_t ng = src4->ne[1];
|
|
8656
|
+
const int64_t nt = src1->ne[2]; // number of tokens per sequence
|
|
8657
|
+
const int64_t ns = src1->ne[3]; // number of sequences in the batch
|
|
8658
|
+
|
|
8659
|
+
// can't use ggml_nbytes because src1 is not necessarily contiguous
|
|
8660
|
+
const int64_t s_off = ggml_nelements(src1) * ggml_element_size(src1);
|
|
8167
8661
|
|
|
8168
|
-
GGML_ASSERT(ggml_nelements(src1) +
|
|
8662
|
+
GGML_ASSERT(ggml_nelements(src1) + nc*nr*nh*ns == ggml_nelements(dst));
|
|
8169
8663
|
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
|
8170
8664
|
GGML_ASSERT(src1->nb[0] == sizeof(float));
|
|
8171
8665
|
GGML_ASSERT(src2->nb[0] == sizeof(float));
|
|
8172
8666
|
GGML_ASSERT(src3->nb[0] == sizeof(float));
|
|
8173
8667
|
GGML_ASSERT(src4->nb[0] == sizeof(float));
|
|
8174
8668
|
GGML_ASSERT(src5->nb[0] == sizeof(float));
|
|
8175
|
-
|
|
8176
|
-
|
|
8177
|
-
|
|
8178
|
-
GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float));
|
|
8179
|
-
// required to get correct offset for state destination (i.e. src1->nb[3])
|
|
8180
|
-
GGML_ASSERT(src1->nb[3] == src1->ne[0]*src1->ne[1]*src1->ne[2]*sizeof(float));
|
|
8669
|
+
GGML_ASSERT(src6->nb[0] == sizeof(int32_t));
|
|
8670
|
+
// allows optimizing the modulo since n_group should be a power of 2
|
|
8671
|
+
GGML_ASSERT((ng & -ng) == ng);
|
|
8181
8672
|
|
|
8182
|
-
//
|
|
8183
|
-
const int
|
|
8673
|
+
// heads per thread
|
|
8674
|
+
const int dh = (nh + nth - 1)/nth;
|
|
8184
8675
|
|
|
8185
|
-
//
|
|
8186
|
-
const int
|
|
8187
|
-
const int
|
|
8188
|
-
|
|
8676
|
+
// head range for this thread
|
|
8677
|
+
const int ih0 = dh*ith;
|
|
8678
|
+
const int ih1 = MIN(ih0 + dh, nh);
|
|
8679
|
+
|
|
8680
|
+
const int32_t * ids = (const int32_t *) src6->data;
|
|
8681
|
+
|
|
8682
|
+
for (int i3 = 0; i3 < ns; ++i3) {
|
|
8683
|
+
const float * s0 = (const float *) ((const char *) src0->data + ids[i3]*(src0->nb[3])); // {d_state, dim, nh, ns}
|
|
8684
|
+
float * s = ( float *) (( char *) dst->data + i3*(src0->nb[3]) + s_off); // {d_state, dim, nh, ns}
|
|
8685
|
+
|
|
8686
|
+
for (int i2 = 0; i2 < nt; ++i2) {
|
|
8687
|
+
const float * x = (const float *) ((const char *) src1->data + i2*(src1->nb[2]) + i3*(src1->nb[3])); // {dim, nh, nt, ns}
|
|
8688
|
+
const float * dt = (const float *) ((const char *) src2->data + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {nh, nt, ns}
|
|
8689
|
+
const float * A = (const float *) ((const char *) src3->data); // {d_state, nh} or {1, nh}
|
|
8690
|
+
const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[2]) + i3*(src4->nb[3])); // {d_state, ng, nt, ns}
|
|
8691
|
+
const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[2]) + i3*(src5->nb[3])); // {d_state, ng, nt, ns}
|
|
8692
|
+
float * y = ( float *) (( char *) dst->data + i2*(nh*nr*sizeof(float)) + i3*(nt*nh*nr*sizeof(float))); // {dim, nh, nt, ns}
|
|
8693
|
+
|
|
8694
|
+
if (src3->ne[0] == 1) {
|
|
8695
|
+
// Mamba-2 has a scalar decay factor per head; dA can be outside the state-wise loop
|
|
8696
|
+
|
|
8697
|
+
// n_head
|
|
8698
|
+
for (int h = ih0; h < ih1; ++h) {
|
|
8699
|
+
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
|
|
8700
|
+
const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
|
|
8701
|
+
const float dA = expf(dt_soft_plus * A[h]);
|
|
8702
|
+
|
|
8703
|
+
// dim
|
|
8704
|
+
for (int i1 = 0; i1 < nr; ++i1) {
|
|
8705
|
+
const int ii = i1 + h*nr;
|
|
8706
|
+
const float x_dt = x[ii] * dt_soft_plus;
|
|
8707
|
+
float sumf = 0.0f;
|
|
8708
|
+
#if defined(GGML_SIMD)
|
|
8709
|
+
#if defined(__ARM_FEATURE_SVE)
|
|
8710
|
+
const int ggml_f32_epr = svcntw();
|
|
8711
|
+
const int ggml_f32_step = 1 * ggml_f32_epr;
|
|
8712
|
+
|
|
8713
|
+
const int np = (nc & ~(ggml_f32_step - 1));
|
|
8714
|
+
|
|
8715
|
+
GGML_F32_VEC sum = GGML_F32_VEC_ZERO;
|
|
8716
|
+
|
|
8717
|
+
GGML_F32_VEC adA = GGML_F32_VEC_SET1(dA);
|
|
8718
|
+
GGML_F32_VEC axdt = GGML_F32_VEC_SET1(x_dt);
|
|
8719
|
+
|
|
8720
|
+
for (int i = 0; i < np; i += ggml_f32_step) {
|
|
8721
|
+
// TODO: maybe unroll more?
|
|
8722
|
+
for (int j = 0; j < 1; j++) {
|
|
8723
|
+
GGML_F32_VEC t0 = GGML_F32_VEC_LOAD(s0 + i + j*ggml_f32_epr + ii*nc);
|
|
8724
|
+
GGML_F32_VEC t1 = GGML_F32_VEC_LOAD(B + i + j*ggml_f32_epr + (h & (ng - 1))*nc);
|
|
8725
|
+
GGML_F32_VEC t2 = GGML_F32_VEC_LOAD(C + i + j*ggml_f32_epr + (h & (ng - 1))*nc);
|
|
8726
|
+
|
|
8727
|
+
t0 = GGML_F32_VEC_MUL(t0, adA);
|
|
8728
|
+
t1 = GGML_F32_VEC_MUL(t1, axdt);
|
|
8729
|
+
|
|
8730
|
+
t0 = GGML_F32_VEC_ADD(t0, t1);
|
|
8731
|
+
|
|
8732
|
+
sum = GGML_F32_VEC_FMA(sum, t0, t2);
|
|
8733
|
+
|
|
8734
|
+
GGML_F32_VEC_STORE(s + i + j*ggml_f32_epr + ii*nc, t0);
|
|
8735
|
+
}
|
|
8736
|
+
}
|
|
8737
|
+
|
|
8738
|
+
sumf = GGML_F32xt_REDUCE_ONE(sum);
|
|
8739
|
+
#else
|
|
8740
|
+
const int np = (nc & ~(GGML_F32_STEP - 1));
|
|
8741
|
+
|
|
8742
|
+
GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
|
|
8743
|
+
|
|
8744
|
+
GGML_F32_VEC adA = GGML_F32_VEC_SET1(dA);
|
|
8745
|
+
GGML_F32_VEC axdt = GGML_F32_VEC_SET1(x_dt);
|
|
8746
|
+
|
|
8747
|
+
GGML_F32_VEC ax[GGML_F32_ARR];
|
|
8748
|
+
GGML_F32_VEC ay[GGML_F32_ARR];
|
|
8749
|
+
GGML_F32_VEC az[GGML_F32_ARR];
|
|
8189
8750
|
|
|
8190
|
-
|
|
8191
|
-
|
|
8192
|
-
|
|
8193
|
-
|
|
8194
|
-
|
|
8195
|
-
|
|
8196
|
-
|
|
8197
|
-
|
|
8198
|
-
|
|
8199
|
-
|
|
8200
|
-
|
|
8201
|
-
|
|
8202
|
-
|
|
8203
|
-
|
|
8204
|
-
|
|
8205
|
-
|
|
8206
|
-
|
|
8207
|
-
|
|
8208
|
-
|
|
8209
|
-
|
|
8210
|
-
|
|
8211
|
-
|
|
8212
|
-
|
|
8213
|
-
|
|
8214
|
-
|
|
8215
|
-
|
|
8216
|
-
|
|
8217
|
-
|
|
8218
|
-
|
|
8219
|
-
|
|
8220
|
-
|
|
8221
|
-
|
|
8222
|
-
|
|
8223
|
-
|
|
8224
|
-
r1_vector = GGML_F32_VEC_ADD(GGML_F32_VEC_MUL(vs0, vC), r1_vector);
|
|
8225
|
-
|
|
8226
|
-
GGML_F32_VEC_STORE(&s[i1*nc + k], vs0);
|
|
8751
|
+
for (int i = 0; i < np; i += GGML_F32_STEP) {
|
|
8752
|
+
for (int j = 0; j < GGML_F32_ARR; j++) {
|
|
8753
|
+
ax[j] = GGML_F32_VEC_LOAD(s0 + i + j*GGML_F32_EPR + ii*nc);
|
|
8754
|
+
ay[j] = GGML_F32_VEC_LOAD(B + i + j*GGML_F32_EPR + (h & (ng - 1))*nc);
|
|
8755
|
+
az[j] = GGML_F32_VEC_LOAD(C + i + j*GGML_F32_EPR + (h & (ng - 1))*nc);
|
|
8756
|
+
|
|
8757
|
+
ax[j] = GGML_F32_VEC_MUL(ax[j], adA);
|
|
8758
|
+
ay[j] = GGML_F32_VEC_MUL(ay[j], axdt);
|
|
8759
|
+
|
|
8760
|
+
ax[j] = GGML_F32_VEC_ADD(ax[j], ay[j]);
|
|
8761
|
+
|
|
8762
|
+
sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], az[j]);
|
|
8763
|
+
|
|
8764
|
+
GGML_F32_VEC_STORE(s + i + j*GGML_F32_EPR + ii*nc, ax[j]);
|
|
8765
|
+
}
|
|
8766
|
+
}
|
|
8767
|
+
|
|
8768
|
+
// reduce sum0..sum3 to sum0
|
|
8769
|
+
GGML_F32_VEC_REDUCE(sumf, sum);
|
|
8770
|
+
#endif
|
|
8771
|
+
#else
|
|
8772
|
+
const int np = 0;
|
|
8773
|
+
#endif
|
|
8774
|
+
// d_state
|
|
8775
|
+
for (int i0 = np; i0 < nc; ++i0) {
|
|
8776
|
+
const int i = i0 + ii*nc;
|
|
8777
|
+
const int ig = i0 + (h & (ng - 1))*nc;
|
|
8778
|
+
// state = prev_state * dA + dB * x
|
|
8779
|
+
const float state = (s0[i] * dA) + (B[ig] * x_dt);
|
|
8780
|
+
// y = rowwise_dotprod(state, C)
|
|
8781
|
+
sumf += state * C[ig];
|
|
8782
|
+
s[i] = state;
|
|
8783
|
+
}
|
|
8784
|
+
y[ii] = sumf;
|
|
8227
8785
|
}
|
|
8228
|
-
y[i1] = GGML_F32xt_REDUCE_ONE(r1_vector);
|
|
8229
8786
|
}
|
|
8230
|
-
}
|
|
8231
|
-
|
|
8232
|
-
|
|
8233
|
-
|
|
8234
|
-
|
|
8235
|
-
|
|
8236
|
-
|
|
8237
|
-
|
|
8238
|
-
|
|
8239
|
-
|
|
8240
|
-
|
|
8241
|
-
|
|
8242
|
-
|
|
8243
|
-
|
|
8244
|
-
|
|
8245
|
-
|
|
8246
|
-
|
|
8247
|
-
|
|
8248
|
-
|
|
8249
|
-
|
|
8250
|
-
|
|
8251
|
-
|
|
8252
|
-
|
|
8253
|
-
|
|
8254
|
-
|
|
8255
|
-
|
|
8256
|
-
|
|
8257
|
-
|
|
8258
|
-
|
|
8259
|
-
|
|
8260
|
-
|
|
8787
|
+
} else {
|
|
8788
|
+
// Mamba-1 has an element-wise decay factor for the states
|
|
8789
|
+
|
|
8790
|
+
// n_head
|
|
8791
|
+
for (int h = ih0; h < ih1; ++h) {
|
|
8792
|
+
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
|
|
8793
|
+
const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
|
|
8794
|
+
|
|
8795
|
+
// dim
|
|
8796
|
+
for (int i1 = 0; i1 < nr; ++i1) {
|
|
8797
|
+
const int ii = i1 + h*nr;
|
|
8798
|
+
const float x_dt = x[ii] * dt_soft_plus;
|
|
8799
|
+
#if defined(__ARM_FEATURE_SVE)
|
|
8800
|
+
svfloat32_t vx_dt = GGML_F32_VEC_SET1(x_dt);
|
|
8801
|
+
svfloat32_t vdt_soft_plus = GGML_F32_VEC_SET1(dt_soft_plus);
|
|
8802
|
+
svfloat32_t r1_vector = GGML_F32_VEC_ZERO;
|
|
8803
|
+
|
|
8804
|
+
// d_state
|
|
8805
|
+
// TODO: what happens when (d_state % svcntw()) != 0?
|
|
8806
|
+
for (int64_t k = 0; k < nc; k += svcntw()) {
|
|
8807
|
+
svfloat32_t vA = GGML_F32_VEC_LOAD(&A[h*nc + k]);
|
|
8808
|
+
svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k + (h & (ng - 1))*nc]);
|
|
8809
|
+
svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k + (h & (ng - 1))*nc]);
|
|
8810
|
+
svfloat32_t vs0 = GGML_F32_VEC_LOAD(&s0[ii*nc + k]);
|
|
8811
|
+
|
|
8812
|
+
svfloat32_t t1 = GGML_F32_VEC_MUL(vdt_soft_plus, vA);
|
|
8813
|
+
t1 = exp_ps_sve(svptrue_b32(), t1);
|
|
8814
|
+
svfloat32_t t2 = GGML_F32_VEC_MUL(vx_dt, vB);
|
|
8815
|
+
|
|
8816
|
+
vs0 = GGML_F32_VEC_FMA(t2, vs0, t1);
|
|
8817
|
+
r1_vector = GGML_F32_VEC_ADD(GGML_F32_VEC_MUL(vs0, vC), r1_vector);
|
|
8818
|
+
|
|
8819
|
+
GGML_F32_VEC_STORE(&s[ii*nc + k], vs0);
|
|
8820
|
+
}
|
|
8821
|
+
y[ii] = GGML_F32xt_REDUCE_ONE(r1_vector);
|
|
8822
|
+
#else
|
|
8823
|
+
float sumf = 0.0f;
|
|
8824
|
+
// NOTE: can't really use GGML_SIMD here because d_state is usually 16
|
|
8825
|
+
// and also because expf is used within the loop.
|
|
8826
|
+
// d_state
|
|
8827
|
+
for (int i0 = 0; i0 < nc; ++i0) {
|
|
8828
|
+
const int i = i0 + ii*nc;
|
|
8829
|
+
const int ig = i0 + (h & (ng - 1))*nc;
|
|
8830
|
+
// state = prev_state * dA + dB * x
|
|
8831
|
+
const float state = (s0[i] * expf(dt_soft_plus * A[i0 + h*nc])) + (B[ig] * x_dt);
|
|
8832
|
+
// y = rowwise_dotprod(state, C)
|
|
8833
|
+
sumf += state * C[ig];
|
|
8834
|
+
s[i] = state;
|
|
8835
|
+
}
|
|
8836
|
+
y[ii] = sumf;
|
|
8837
|
+
#endif
|
|
8261
8838
|
}
|
|
8262
|
-
y[i1] = sumf;
|
|
8263
8839
|
}
|
|
8264
8840
|
}
|
|
8841
|
+
// use the output as the source when it's not the first token-wise iteration
|
|
8842
|
+
s0 = s;
|
|
8265
8843
|
}
|
|
8266
|
-
|
|
8844
|
+
}
|
|
8267
8845
|
}
|
|
8268
8846
|
|
|
8269
8847
|
void ggml_compute_forward_ssm_scan(
|
|
@@ -8502,6 +9080,14 @@ void ggml_compute_forward_glu(
|
|
|
8502
9080
|
{
|
|
8503
9081
|
ggml_compute_forward_swiglu(params, dst);
|
|
8504
9082
|
} break;
|
|
9083
|
+
case GGML_GLU_OP_GEGLU_ERF:
|
|
9084
|
+
{
|
|
9085
|
+
ggml_compute_forward_geglu_erf(params, dst);
|
|
9086
|
+
} break;
|
|
9087
|
+
case GGML_GLU_OP_GEGLU_QUICK:
|
|
9088
|
+
{
|
|
9089
|
+
ggml_compute_forward_geglu_quick(params, dst);
|
|
9090
|
+
} break;
|
|
8505
9091
|
default:
|
|
8506
9092
|
{
|
|
8507
9093
|
GGML_ABORT("fatal error");
|